Skip to main content

aster_cli/recipes/
extract_from_cli.rs

1use std::path::PathBuf;
2
3use anyhow::{anyhow, Result};
4use aster::recipe::SubRecipe;
5
6use crate::recipes::print_recipe::print_recipe_info;
7use crate::recipes::recipe::load_recipe;
8use crate::recipes::search_recipe::load_recipe_file;
9use crate::{
10    cli::{InputConfig, RecipeInfo},
11    session::SessionSettings,
12};
13
14pub fn extract_recipe_info_from_cli(
15    recipe_name: String,
16    params: Vec<(String, String)>,
17    additional_sub_recipes: Vec<String>,
18    quiet: bool,
19) -> Result<(InputConfig, RecipeInfo)> {
20    let recipe = load_recipe(&recipe_name, params.clone()).unwrap_or_else(|err| {
21        eprintln!("{}: {}", console::style("Error").red().bold(), err);
22        std::process::exit(1);
23    });
24    if !quiet {
25        print_recipe_info(&recipe, params);
26    }
27    let mut all_sub_recipes = recipe.sub_recipes.clone().unwrap_or_default();
28    if !additional_sub_recipes.is_empty() {
29        for sub_recipe_name in additional_sub_recipes {
30            match load_recipe_file(&sub_recipe_name) {
31                Ok(recipe_file) => {
32                    let name = extract_recipe_name(&sub_recipe_name);
33                    let recipe_file_path = recipe_file.file_path;
34                    let additional_sub_recipe = SubRecipe {
35                        path: recipe_file_path.to_string_lossy().to_string(),
36                        name,
37                        values: None,
38                        sequential_when_repeated: true,
39                        description: None,
40                    };
41                    all_sub_recipes.push(additional_sub_recipe);
42                }
43                Err(e) => {
44                    return Err(anyhow!(
45                        "Could not retrieve sub-recipe '{}': {}",
46                        sub_recipe_name,
47                        e
48                    ));
49                }
50            }
51        }
52    }
53    let input_config = InputConfig {
54        contents: recipe.prompt.filter(|s| !s.trim().is_empty()),
55        extensions_override: recipe.extensions,
56        additional_system_prompt: recipe.instructions,
57    };
58
59    let recipe_info = RecipeInfo {
60        session_settings: recipe.settings.map(|s| SessionSettings {
61            aster_provider: s.aster_provider,
62            aster_model: s.aster_model,
63            temperature: s.temperature,
64        }),
65        sub_recipes: Some(all_sub_recipes),
66        final_output_response: recipe.response,
67        retry_config: recipe.retry,
68    };
69
70    Ok((input_config, recipe_info))
71}
72
73fn extract_recipe_name(recipe_identifier: &str) -> String {
74    // If it's a path (contains / or \), extract the file stem
75    if recipe_identifier.contains('/') || recipe_identifier.contains('\\') {
76        PathBuf::from(recipe_identifier)
77            .file_stem()
78            .and_then(|s| s.to_str())
79            .unwrap_or("unknown")
80            .to_string()
81    } else {
82        // If it's just a name (like "weekly-updates"), use it directly
83        recipe_identifier.to_string()
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use std::path::PathBuf;
90
91    use tempfile::TempDir;
92
93    use super::*;
94
95    #[test]
96    fn test_extract_recipe_info_from_cli_basic() {
97        let (_temp_dir, recipe_path) = create_recipe();
98        let params = vec![("name".to_string(), "my_value".to_string())];
99        let recipe_name = recipe_path.to_str().unwrap().to_string();
100
101        let (input_config, recipe_info) =
102            extract_recipe_info_from_cli(recipe_name, params, Vec::new(), false).unwrap();
103        let settings = recipe_info.session_settings;
104        let sub_recipes = recipe_info.sub_recipes;
105        let response = recipe_info.final_output_response;
106
107        assert_eq!(input_config.contents, Some("test_prompt".to_string()));
108        assert_eq!(
109            input_config.additional_system_prompt,
110            Some("test_instructions my_value".to_string())
111        );
112        assert!(input_config.extensions_override.is_none());
113
114        assert!(settings.is_some());
115        let settings = settings.unwrap();
116        assert_eq!(settings.aster_provider, Some("test_provider".to_string()));
117        assert_eq!(settings.aster_model, Some("test_model".to_string()));
118        assert_eq!(settings.temperature, Some(0.7));
119
120        assert!(sub_recipes.is_some());
121        let sub_recipes = sub_recipes.unwrap();
122        assert!(sub_recipes.len() == 1);
123        let full_sub_recipe_path = recipe_path
124            .parent()
125            .unwrap()
126            .join("existing_sub_recipe.yaml")
127            .to_string_lossy()
128            .to_string();
129        assert_eq!(sub_recipes[0].path, full_sub_recipe_path);
130        assert_eq!(sub_recipes[0].name, "existing_sub_recipe".to_string());
131        assert!(sub_recipes[0].values.is_none());
132        assert!(response.is_some());
133        let response = response.unwrap();
134        assert_eq!(
135            response.json_schema,
136            Some(serde_json::json!({
137                "type": "object",
138                "properties": {
139                    "result": {"type": "string"}
140                }
141            }))
142        );
143    }
144
145    #[test]
146    fn test_extract_recipe_info_from_cli_with_additional_sub_recipes() {
147        let (temp_dir, recipe_path) = create_recipe();
148
149        std::fs::create_dir_all(temp_dir.path().join("path/to")).unwrap();
150        std::fs::create_dir_all(temp_dir.path().join("another")).unwrap();
151
152        let sub_recipe1_path = temp_dir.path().join("path/to/sub_recipe1.yaml");
153        let sub_recipe2_path = temp_dir.path().join("another/sub_recipe2.yaml");
154
155        std::fs::write(&sub_recipe1_path, "title: Sub Recipe 1").unwrap();
156        std::fs::write(&sub_recipe2_path, "title: Sub Recipe 2").unwrap();
157
158        let params = vec![("name".to_string(), "my_value".to_string())];
159        let recipe_name = recipe_path.to_str().unwrap().to_string();
160        let additional_sub_recipes = vec![
161            sub_recipe1_path.to_string_lossy().to_string(),
162            sub_recipe2_path.to_string_lossy().to_string(),
163        ];
164
165        let (input_config, recipe_info) =
166            extract_recipe_info_from_cli(recipe_name, params, additional_sub_recipes, false)
167                .unwrap();
168        let settings = recipe_info.session_settings;
169        let sub_recipes = recipe_info.sub_recipes;
170        let response = recipe_info.final_output_response;
171
172        assert_eq!(input_config.contents, Some("test_prompt".to_string()));
173        assert_eq!(
174            input_config.additional_system_prompt,
175            Some("test_instructions my_value".to_string())
176        );
177        assert!(input_config.extensions_override.is_none());
178
179        assert!(settings.is_some());
180        let settings = settings.unwrap();
181        assert_eq!(settings.aster_provider, Some("test_provider".to_string()));
182        assert_eq!(settings.aster_model, Some("test_model".to_string()));
183        assert_eq!(settings.temperature, Some(0.7));
184
185        assert!(sub_recipes.is_some());
186        let sub_recipes = sub_recipes.unwrap();
187        assert!(sub_recipes.len() == 3);
188        let full_sub_recipe_path = recipe_path
189            .parent()
190            .unwrap()
191            .join("existing_sub_recipe.yaml")
192            .to_string_lossy()
193            .to_string();
194        assert_eq!(sub_recipes[0].path, full_sub_recipe_path);
195        assert_eq!(sub_recipes[0].name, "existing_sub_recipe".to_string());
196        assert!(sub_recipes[0].values.is_none());
197        assert_eq!(
198            sub_recipes[1].path,
199            sub_recipe1_path
200                .canonicalize()
201                .unwrap()
202                .to_string_lossy()
203                .to_string()
204        );
205        assert_eq!(sub_recipes[1].name, "sub_recipe1".to_string());
206        assert!(sub_recipes[1].values.is_none());
207        assert_eq!(
208            sub_recipes[2].path,
209            sub_recipe2_path
210                .canonicalize()
211                .unwrap()
212                .to_string_lossy()
213                .to_string()
214        );
215        assert_eq!(sub_recipes[2].name, "sub_recipe2".to_string());
216        assert!(sub_recipes[2].values.is_none());
217        assert!(response.is_some());
218        let response = response.unwrap();
219        assert_eq!(
220            response.json_schema,
221            Some(serde_json::json!({
222                "type": "object",
223                "properties": {
224                    "result": {"type": "string"}
225                }
226            }))
227        );
228    }
229
230    fn create_recipe() -> (TempDir, PathBuf) {
231        let test_recipe_content = r#"
232title: test_recipe
233description: A test recipe
234instructions: test_instructions {{name}}
235prompt: test_prompt
236parameters:
237- key: name
238  description: name
239  input_type: string
240  requirement: required
241settings:
242  aster_provider: test_provider
243  aster_model: test_model
244  temperature: 0.7
245sub_recipes:
246- path: existing_sub_recipe.yaml
247  name: existing_sub_recipe
248response:
249  json_schema:
250    type: object
251    properties:
252      result:
253        type: string
254"#;
255        let sub_recipe_content = r#"
256title: existing_sub_recipe
257description: An existing sub recipe
258instructions: sub recipe instructions
259prompt: sub recipe prompt
260"#;
261        let temp_dir = tempfile::tempdir().unwrap();
262        let recipe_path: std::path::PathBuf = temp_dir.path().join("test_recipe.yaml");
263        let sub_recipe_path: std::path::PathBuf = temp_dir.path().join("existing_sub_recipe.yaml");
264
265        std::fs::write(&recipe_path, test_recipe_content).unwrap();
266        std::fs::write(&sub_recipe_path, sub_recipe_content).unwrap();
267        let canonical_recipe_path = recipe_path.canonicalize().unwrap();
268        (temp_dir, canonical_recipe_path)
269    }
270}