Skip to main content

aster/recipe/
mod.rs

1use anyhow::Result;
2use serde_json::Value;
3use std::collections::HashMap;
4use std::fmt;
5use std::path::Path;
6
7use crate::agents::extension::ExtensionConfig;
8use crate::agents::types::RetryConfig;
9use crate::recipe::read_recipe_file_content::read_recipe_file;
10use crate::recipe::yaml_format_utils::reformat_fields_with_multiline_values;
11use crate::utils::contains_unicode_tags;
12use serde::de::Deserializer;
13use serde::{Deserialize, Serialize};
14use utoipa::ToSchema;
15
16pub mod build_recipe;
17pub mod local_recipes;
18pub mod read_recipe_file_content;
19mod recipe_extension_adapter;
20pub mod template_recipe;
21pub mod validate_recipe;
22pub mod yaml_format_utils;
23
24pub const BUILT_IN_RECIPE_DIR_PARAM: &str = "recipe_dir";
25pub const RECIPE_FILE_EXTENSIONS: &[&str] = &["yaml", "json"];
26
27fn default_version() -> String {
28    "1.0.0".to_string()
29}
30
31#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)]
32pub struct Recipe {
33    // Required fields
34    #[serde(default = "default_version")]
35    pub version: String, // version of the file format, sem ver
36
37    pub title: String, // short title of the recipe
38
39    pub description: String, // a longer description of the recipe
40
41    // Optional fields
42    // Note: at least one of instructions or prompt need to be set
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub instructions: Option<String>, // the instructions for the model
45
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub prompt: Option<String>, // the prompt to start the session with
48
49    #[serde(
50        skip_serializing_if = "Option::is_none",
51        default,
52        deserialize_with = "recipe_extension_adapter::deserialize_recipe_extensions"
53    )]
54    pub extensions: Option<Vec<ExtensionConfig>>, // a list of extensions to enable
55
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub settings: Option<Settings>, // settings for the recipe
58
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub activities: Option<Vec<String>>, // the activity pills that show up when loading the
61
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub author: Option<Author>, // any additional author information
64
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub parameters: Option<Vec<RecipeParameter>>, // any additional parameters for the recipe
67
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub response: Option<Response>, // response configuration including JSON schema
70
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub sub_recipes: Option<Vec<SubRecipe>>, // sub-recipes for the recipe
73
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub retry: Option<RetryConfig>,
76}
77
78#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)]
79pub struct Author {
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub contact: Option<String>, // creator/contact information of the recipe
82
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub metadata: Option<String>, // any additional metadata for the author
85}
86
87#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)]
88pub struct Settings {
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub aster_provider: Option<String>,
91
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub aster_model: Option<String>,
94
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub temperature: Option<f32>,
97}
98
99#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)]
100pub struct Response {
101    #[serde(skip_serializing_if = "Option::is_none")]
102    pub json_schema: Option<serde_json::Value>,
103}
104
105#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)]
106pub struct SubRecipe {
107    pub name: String,
108    pub path: String,
109    #[serde(default, deserialize_with = "deserialize_value_map_as_string")]
110    pub values: Option<HashMap<String, String>>,
111    #[serde(default)]
112    pub sequential_when_repeated: bool,
113    #[serde(skip_serializing_if = "Option::is_none")]
114    pub description: Option<String>,
115}
116
117fn deserialize_value_map_as_string<'de, D>(
118    deserializer: D,
119) -> Result<Option<HashMap<String, String>>, D::Error>
120where
121    D: Deserializer<'de>,
122{
123    // First, try to deserialize a map of values
124    let opt_raw: Option<HashMap<String, Value>> = Option::deserialize(deserializer)?;
125
126    match opt_raw {
127        Some(raw_map) => {
128            let mut result = HashMap::new();
129            for (k, v) in raw_map {
130                let s = match v {
131                    Value::String(s) => s,
132                    _ => serde_json::to_string(&v).map_err(serde::de::Error::custom)?,
133                };
134                result.insert(k, s);
135            }
136            Ok(Some(result))
137        }
138        None => Ok(None),
139    }
140}
141
142#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)]
143#[serde(rename_all = "snake_case")]
144pub enum RecipeParameterRequirement {
145    Required,
146    Optional,
147    UserPrompt,
148}
149
150impl fmt::Display for RecipeParameterRequirement {
151    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152        write!(
153            f,
154            "{}",
155            serde_json::to_string(self).unwrap().trim_matches('"')
156        )
157    }
158}
159
160#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)]
161#[serde(rename_all = "snake_case")]
162pub enum RecipeParameterInputType {
163    String,
164    Number,
165    Boolean,
166    Date,
167    /// File parameter that imports content from a file path.
168    /// Cannot have default values to prevent importing sensitive user files.
169    File,
170    Select,
171}
172
173impl fmt::Display for RecipeParameterInputType {
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        write!(
176            f,
177            "{}",
178            serde_json::to_string(self).unwrap().trim_matches('"')
179        )
180    }
181}
182
183#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)]
184pub struct RecipeParameter {
185    pub key: String,
186    pub input_type: RecipeParameterInputType,
187    pub requirement: RecipeParameterRequirement,
188    pub description: String,
189    #[serde(skip_serializing_if = "Option::is_none")]
190    pub default: Option<String>,
191    #[serde(skip_serializing_if = "Option::is_none")]
192    pub options: Option<Vec<String>>,
193}
194
195/// Builder for creating Recipe instances
196pub struct RecipeBuilder {
197    // Required fields with default values
198    version: String,
199    title: Option<String>,
200    description: Option<String>,
201    instructions: Option<String>,
202
203    // Optional fields
204    prompt: Option<String>,
205    extensions: Option<Vec<ExtensionConfig>>,
206    settings: Option<Settings>,
207    activities: Option<Vec<String>>,
208    author: Option<Author>,
209    parameters: Option<Vec<RecipeParameter>>,
210    response: Option<Response>,
211    sub_recipes: Option<Vec<SubRecipe>>,
212    retry: Option<RetryConfig>,
213}
214
215impl Recipe {
216    /// Returns true if harmful content is detected in instructions, prompt, or activities fields
217    pub fn check_for_security_warnings(&self) -> bool {
218        if [self.instructions.as_deref(), self.prompt.as_deref()]
219            .iter()
220            .flatten()
221            .any(|&field| contains_unicode_tags(field))
222        {
223            return true;
224        }
225
226        if let Some(activities) = &self.activities {
227            return activities
228                .iter()
229                .any(|activity| contains_unicode_tags(activity));
230        }
231
232        false
233    }
234
235    pub fn to_yaml(&self) -> Result<String> {
236        let recipe_yaml = serde_yaml::to_string(self)
237            .map_err(|err| anyhow::anyhow!("Failed to serialize recipe: {}", err))?;
238        let formatted_recipe_yaml =
239            reformat_fields_with_multiline_values(&recipe_yaml, &["prompt", "instructions"]);
240        Ok(formatted_recipe_yaml)
241    }
242
243    pub fn builder() -> RecipeBuilder {
244        RecipeBuilder {
245            version: default_version(),
246            title: None,
247            description: None,
248            instructions: None,
249            prompt: None,
250            extensions: None,
251            settings: None,
252            activities: None,
253            author: None,
254            parameters: None,
255            response: None,
256            sub_recipes: None,
257            retry: None,
258        }
259    }
260
261    pub fn from_file_path(file_path: &Path) -> Result<Self> {
262        let file = read_recipe_file(file_path)?;
263        Self::from_content(&file.content)
264    }
265
266    pub fn from_content(content: &str) -> Result<Self> {
267        let recipe: Recipe = match serde_yaml::from_str::<serde_yaml::Value>(content) {
268            Ok(yaml_value) => {
269                if let Some(nested_recipe) = yaml_value.get("recipe") {
270                    serde_yaml::from_value(nested_recipe.clone())
271                        .map_err(|e| anyhow::anyhow!("Failed to parse nested recipe: {}", e))?
272                } else {
273                    serde_yaml::from_str(content)
274                        .map_err(|e| anyhow::anyhow!("Failed to parse recipe: {}", e))?
275                }
276            }
277            Err(_) => serde_yaml::from_str(content)
278                .map_err(|e| anyhow::anyhow!("Failed to parse recipe: {}", e))?,
279        };
280
281        if let Some(ref retry_config) = recipe.retry {
282            if let Err(validation_error) = retry_config.validate() {
283                return Err(anyhow::anyhow!(
284                    "Invalid retry configuration: {}",
285                    validation_error
286                ));
287            }
288        }
289
290        Ok(recipe)
291    }
292}
293
294impl RecipeBuilder {
295    pub fn version(mut self, version: impl Into<String>) -> Self {
296        self.version = version.into();
297        self
298    }
299
300    pub fn title(mut self, title: impl Into<String>) -> Self {
301        self.title = Some(title.into());
302        self
303    }
304
305    pub fn description(mut self, description: impl Into<String>) -> Self {
306        self.description = Some(description.into());
307        self
308    }
309
310    pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
311        self.instructions = Some(instructions.into());
312        self
313    }
314
315    pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
316        self.prompt = Some(prompt.into());
317        self
318    }
319
320    pub fn extensions(mut self, extensions: Vec<ExtensionConfig>) -> Self {
321        self.extensions = Some(extensions);
322        self
323    }
324
325    pub fn settings(mut self, settings: Settings) -> Self {
326        self.settings = Some(settings);
327        self
328    }
329
330    pub fn activities(mut self, activities: Vec<String>) -> Self {
331        self.activities = Some(activities);
332        self
333    }
334
335    pub fn author(mut self, author: Author) -> Self {
336        self.author = Some(author);
337        self
338    }
339
340    pub fn parameters(mut self, parameters: Vec<RecipeParameter>) -> Self {
341        self.parameters = Some(parameters);
342        self
343    }
344
345    pub fn response(mut self, response: Response) -> Self {
346        self.response = Some(response);
347        self
348    }
349
350    pub fn sub_recipes(mut self, sub_recipes: Vec<SubRecipe>) -> Self {
351        self.sub_recipes = Some(sub_recipes);
352        self
353    }
354
355    pub fn retry(mut self, retry: RetryConfig) -> Self {
356        self.retry = Some(retry);
357        self
358    }
359
360    pub fn build(self) -> Result<Recipe, &'static str> {
361        let title = self.title.ok_or("Title is required")?;
362        let description = self.description.ok_or("Description is required")?;
363
364        if self.instructions.is_none() && self.prompt.is_none() {
365            return Err("At least one of 'prompt' or 'instructions' is required");
366        }
367
368        Ok(Recipe {
369            version: self.version,
370            title,
371            description,
372            instructions: self.instructions,
373            prompt: self.prompt,
374            extensions: self.extensions,
375            settings: self.settings,
376            activities: self.activities,
377            author: self.author,
378            parameters: self.parameters,
379            response: self.response,
380            sub_recipes: self.sub_recipes,
381            retry: self.retry,
382        })
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389
390    #[test]
391    fn test_from_content_with_json() {
392        let content = r#"{
393            "version": "1.0.0",
394            "title": "Test Recipe",
395            "description": "A test recipe",
396            "prompt": "Test prompt",
397            "instructions": "Test instructions",
398            "extensions": [
399                {
400                    "type": "stdio",
401                    "name": "test_extension",
402                    "cmd": "test_cmd",
403                    "args": ["arg1", "arg2"],
404                    "timeout": 300,
405                    "description": "Test extension"
406                }
407            ],
408            "parameters": [
409                {
410                    "key": "test_param",
411                    "input_type": "string",
412                    "requirement": "required",
413                    "description": "A test parameter"
414                }
415            ],
416            "response": {
417                "json_schema": {
418                    "type": "object",
419                    "properties": {
420                        "name": {
421                            "type": "string"
422                        },
423                        "age": {
424                            "type": "number"
425                        }
426                    },
427                    "required": ["name"]
428                }
429            },
430            "sub_recipes": [
431                {
432                    "name": "test_sub_recipe",
433                    "path": "test_sub_recipe.yaml",
434                    "values": {
435                        "sub_recipe_param": "sub_recipe_value"
436                    }
437                }
438            ]
439        }"#;
440
441        let recipe = Recipe::from_content(content).unwrap();
442        assert_eq!(recipe.version, "1.0.0");
443        assert_eq!(recipe.title, "Test Recipe");
444        assert_eq!(recipe.description, "A test recipe");
445        assert_eq!(recipe.instructions, Some("Test instructions".to_string()));
446        assert_eq!(recipe.prompt, Some("Test prompt".to_string()));
447
448        assert!(recipe.extensions.is_some());
449        let extensions = recipe.extensions.unwrap();
450        assert_eq!(extensions.len(), 1);
451
452        assert!(recipe.parameters.is_some());
453        let parameters = recipe.parameters.unwrap();
454        assert_eq!(parameters.len(), 1);
455        assert_eq!(parameters[0].key, "test_param");
456        assert!(matches!(
457            parameters[0].input_type,
458            RecipeParameterInputType::String
459        ));
460        assert!(matches!(
461            parameters[0].requirement,
462            RecipeParameterRequirement::Required
463        ));
464
465        assert!(recipe.response.is_some());
466        let response = recipe.response.unwrap();
467        assert!(response.json_schema.is_some());
468        let json_schema = response.json_schema.unwrap();
469        assert_eq!(json_schema["type"], "object");
470        assert!(json_schema["properties"].is_object());
471        assert_eq!(json_schema["properties"]["name"]["type"], "string");
472        assert_eq!(json_schema["properties"]["age"]["type"], "number");
473        assert_eq!(json_schema["required"], serde_json::json!(["name"]));
474
475        assert!(recipe.sub_recipes.is_some());
476        let sub_recipes = recipe.sub_recipes.unwrap();
477        assert_eq!(sub_recipes.len(), 1);
478        assert_eq!(sub_recipes[0].name, "test_sub_recipe");
479        assert_eq!(sub_recipes[0].path, "test_sub_recipe.yaml");
480        assert_eq!(
481            sub_recipes[0].values,
482            Some(HashMap::from([(
483                "sub_recipe_param".to_string(),
484                "sub_recipe_value".to_string()
485            )]))
486        );
487    }
488
489    #[test]
490    fn test_from_content_with_yaml() {
491        let content = r#"version: 1.0.0
492title: Test Recipe
493description: A test recipe
494prompt: Test prompt
495instructions: Test instructions
496extensions:
497  - type: stdio
498    name: test_extension
499    cmd: test_cmd
500    args: [arg1, arg2]
501    timeout: 300
502    description: Test extension
503parameters:
504  - key: test_param
505    input_type: string
506    requirement: required
507    description: A test parameter
508response:
509  json_schema:
510    type: object
511    properties:
512      name:
513        type: string
514      age:
515        type: number
516    required:
517      - name
518sub_recipes:
519  - name: test_sub_recipe
520    path: test_sub_recipe.yaml
521    values:
522      sub_recipe_param: sub_recipe_value"#;
523
524        let recipe = Recipe::from_content(content).unwrap();
525        assert_eq!(recipe.version, "1.0.0");
526        assert_eq!(recipe.title, "Test Recipe");
527        assert_eq!(recipe.description, "A test recipe");
528        assert_eq!(recipe.instructions, Some("Test instructions".to_string()));
529        assert_eq!(recipe.prompt, Some("Test prompt".to_string()));
530
531        assert!(recipe.extensions.is_some());
532        let extensions = recipe.extensions.unwrap();
533        assert_eq!(extensions.len(), 1);
534
535        assert!(recipe.parameters.is_some());
536        let parameters = recipe.parameters.unwrap();
537        assert_eq!(parameters.len(), 1);
538        assert_eq!(parameters[0].key, "test_param");
539        assert!(matches!(
540            parameters[0].input_type,
541            RecipeParameterInputType::String
542        ));
543        assert!(matches!(
544            parameters[0].requirement,
545            RecipeParameterRequirement::Required
546        ));
547
548        assert!(recipe.response.is_some());
549        let response = recipe.response.unwrap();
550        assert!(response.json_schema.is_some());
551        let json_schema = response.json_schema.unwrap();
552        assert_eq!(json_schema["type"], "object");
553        assert!(json_schema["properties"].is_object());
554        assert_eq!(json_schema["properties"]["name"]["type"], "string");
555        assert_eq!(json_schema["properties"]["age"]["type"], "number");
556        assert_eq!(json_schema["required"], serde_json::json!(["name"]));
557
558        assert!(recipe.sub_recipes.is_some());
559        let sub_recipes = recipe.sub_recipes.unwrap();
560        assert_eq!(sub_recipes.len(), 1);
561        assert_eq!(sub_recipes[0].name, "test_sub_recipe");
562        assert_eq!(sub_recipes[0].path, "test_sub_recipe.yaml");
563        assert_eq!(
564            sub_recipes[0].values,
565            Some(HashMap::from([(
566                "sub_recipe_param".to_string(),
567                "sub_recipe_value".to_string()
568            )]))
569        );
570    }
571
572    #[test]
573    fn test_from_content_invalid_json() {
574        let content = "{ invalid json }";
575
576        let result = Recipe::from_content(content);
577        assert!(result.is_err());
578    }
579
580    #[test]
581    fn test_from_content_missing_required_fields() {
582        let content = r#"{
583            "version": "1.0.0",
584            "description": "A test recipe"
585        }"#;
586
587        let result = Recipe::from_content(content);
588        assert!(result.is_err());
589    }
590
591    #[test]
592    fn test_from_content_with_author() {
593        let content = r#"{
594            "version": "1.0.0",
595            "title": "Test Recipe",
596            "description": "A test recipe",
597            "instructions": "Test instructions",
598            "author": {
599                "contact": "test@example.com"
600            }
601        }"#;
602
603        let recipe = Recipe::from_content(content).unwrap();
604
605        assert!(recipe.author.is_some());
606        let author = recipe.author.unwrap();
607        assert_eq!(author.contact, Some("test@example.com".to_string()));
608    }
609
610    #[test]
611    fn test_inline_python_extension() {
612        let content = r#"{
613            "version": "1.0.0",
614            "title": "Test Recipe",
615            "description": "A test recipe",
616            "instructions": "Test instructions",
617            "extensions": [
618                {
619                    "type": "inline_python",
620                    "name": "test_python",
621                    "code": "print('hello world')",
622                    "timeout": 300,
623                    "description": "Test python extension",
624                    "dependencies": ["numpy", "matplotlib"]
625                }
626            ]
627        }"#;
628
629        let recipe = Recipe::from_content(content).unwrap();
630
631        assert!(recipe.extensions.is_some());
632        let extensions = recipe.extensions.unwrap();
633        assert_eq!(extensions.len(), 1);
634
635        match &extensions[0] {
636            ExtensionConfig::InlinePython {
637                name,
638                code,
639                description,
640                timeout,
641                dependencies,
642                ..
643            } => {
644                assert_eq!(name, "test_python");
645                assert_eq!(code, "print('hello world')");
646                assert_eq!(description, "Test python extension");
647                assert_eq!(timeout, &Some(300));
648                assert!(dependencies.is_some());
649                let deps = dependencies.as_ref().unwrap();
650                assert!(deps.contains(&"numpy".to_string()));
651                assert!(deps.contains(&"matplotlib".to_string()));
652            }
653            _ => panic!("Expected InlinePython extension"),
654        }
655    }
656
657    #[test]
658    fn test_from_content_with_activities() {
659        let content = r#"{
660            "version": "1.0.0",
661            "title": "Test Recipe",
662            "description": "A test recipe",
663            "instructions": "Test instructions",
664            "activities": ["activity1", "activity2"]
665        }"#;
666
667        let recipe = Recipe::from_content(content).unwrap();
668
669        assert!(recipe.activities.is_some());
670        let activities = recipe.activities.unwrap();
671        assert_eq!(activities, vec!["activity1", "activity2"]);
672    }
673
674    #[test]
675    fn test_from_content_with_nested_recipe_yaml() {
676        let content = r#"name: test_recipe
677recipe:
678  title: Nested Recipe Test
679  description: A test recipe with nested structure
680  instructions: Test instructions for nested recipe
681  activities:
682    - Test activity 1
683    - Test activity 2
684  prompt: Test prompt
685  extensions: []
686isGlobal: true"#;
687
688        let recipe = Recipe::from_content(content).unwrap();
689        assert_eq!(recipe.title, "Nested Recipe Test");
690        assert_eq!(recipe.description, "A test recipe with nested structure");
691        assert_eq!(
692            recipe.instructions,
693            Some("Test instructions for nested recipe".to_string())
694        );
695        assert_eq!(recipe.prompt, Some("Test prompt".to_string()));
696        assert!(recipe.activities.is_some());
697        let activities = recipe.activities.unwrap();
698        assert_eq!(activities, vec!["Test activity 1", "Test activity 2"]);
699        assert!(recipe.extensions.is_some());
700        let extensions = recipe.extensions.unwrap();
701        assert_eq!(extensions.len(), 0);
702    }
703
704    #[test]
705    fn test_check_for_security_warnings() {
706        let mut recipe = Recipe {
707            version: "1.0.0".to_string(),
708            title: "Test".to_string(),
709            description: "Test".to_string(),
710            instructions: Some("clean instructions".to_string()),
711            prompt: Some("clean prompt".to_string()),
712            extensions: None,
713            settings: None,
714            activities: Some(vec!["clean activity 1".to_string()]),
715            author: None,
716            parameters: None,
717            response: None,
718            sub_recipes: None,
719            retry: None,
720        };
721
722        assert!(!recipe.check_for_security_warnings());
723
724        // Malicious activities
725        recipe.activities = Some(vec![
726            "clean activity".to_string(),
727            format!("malicious{}activity", '\u{E0041}'),
728        ]);
729        assert!(recipe.check_for_security_warnings());
730
731        // Malicious instructions
732        recipe.instructions = Some(format!("instructions{}", '\u{E0041}'));
733        assert!(recipe.check_for_security_warnings());
734
735        // Malicious prompt
736        recipe.prompt = Some(format!("prompt{}", '\u{E0042}'));
737        assert!(recipe.check_for_security_warnings());
738    }
739
740    #[test]
741    fn test_from_content_with_null_description() {
742        let content = r#"{
743            "version": "1.0.0",
744            "title": "Test Recipe",
745            "description": "A test recipe",
746            "instructions": "Test instructions",
747            "extensions": [
748                {
749                    "type": "stdio",
750                    "name": "test_extension",
751                    "cmd": "test_cmd",
752                    "args": [],
753                    "timeout": 300,
754                    "description": null
755                }
756            ]
757        }"#;
758
759        let recipe = Recipe::from_content(content).unwrap();
760
761        assert!(recipe.extensions.is_some());
762        let extensions = recipe.extensions.unwrap();
763        assert_eq!(extensions.len(), 1);
764
765        if let ExtensionConfig::Stdio {
766            name, description, ..
767        } = &extensions[0]
768        {
769            assert_eq!(name, "test_extension");
770            assert_eq!(description, "");
771        } else {
772            panic!("Expected Stdio extension");
773        }
774    }
775}