aster_cli/recipes/
extract_from_cli.rs1use std::path::PathBuf;
2
3use anyhow::{anyhow, Result};
4use aster::recipe::SubRecipe;
5
6use crate::recipes::print_recipe::print_recipe_info;
7use crate::recipes::recipe::load_recipe;
8use crate::recipes::search_recipe::load_recipe_file;
9use crate::{
10 cli::{InputConfig, RecipeInfo},
11 session::SessionSettings,
12};
13
14pub fn extract_recipe_info_from_cli(
15 recipe_name: String,
16 params: Vec<(String, String)>,
17 additional_sub_recipes: Vec<String>,
18 quiet: bool,
19) -> Result<(InputConfig, RecipeInfo)> {
20 let recipe = load_recipe(&recipe_name, params.clone()).unwrap_or_else(|err| {
21 eprintln!("{}: {}", console::style("Error").red().bold(), err);
22 std::process::exit(1);
23 });
24 if !quiet {
25 print_recipe_info(&recipe, params);
26 }
27 let mut all_sub_recipes = recipe.sub_recipes.clone().unwrap_or_default();
28 if !additional_sub_recipes.is_empty() {
29 for sub_recipe_name in additional_sub_recipes {
30 match load_recipe_file(&sub_recipe_name) {
31 Ok(recipe_file) => {
32 let name = extract_recipe_name(&sub_recipe_name);
33 let recipe_file_path = recipe_file.file_path;
34 let additional_sub_recipe = SubRecipe {
35 path: recipe_file_path.to_string_lossy().to_string(),
36 name,
37 values: None,
38 sequential_when_repeated: true,
39 description: None,
40 };
41 all_sub_recipes.push(additional_sub_recipe);
42 }
43 Err(e) => {
44 return Err(anyhow!(
45 "Could not retrieve sub-recipe '{}': {}",
46 sub_recipe_name,
47 e
48 ));
49 }
50 }
51 }
52 }
53 let input_config = InputConfig {
54 contents: recipe.prompt.filter(|s| !s.trim().is_empty()),
55 extensions_override: recipe.extensions,
56 additional_system_prompt: recipe.instructions,
57 };
58
59 let recipe_info = RecipeInfo {
60 session_settings: recipe.settings.map(|s| SessionSettings {
61 aster_provider: s.aster_provider,
62 aster_model: s.aster_model,
63 temperature: s.temperature,
64 }),
65 sub_recipes: Some(all_sub_recipes),
66 final_output_response: recipe.response,
67 retry_config: recipe.retry,
68 };
69
70 Ok((input_config, recipe_info))
71}
72
73fn extract_recipe_name(recipe_identifier: &str) -> String {
74 if recipe_identifier.contains('/') || recipe_identifier.contains('\\') {
76 PathBuf::from(recipe_identifier)
77 .file_stem()
78 .and_then(|s| s.to_str())
79 .unwrap_or("unknown")
80 .to_string()
81 } else {
82 recipe_identifier.to_string()
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use std::path::PathBuf;
90
91 use tempfile::TempDir;
92
93 use super::*;
94
95 #[test]
96 fn test_extract_recipe_info_from_cli_basic() {
97 let (_temp_dir, recipe_path) = create_recipe();
98 let params = vec![("name".to_string(), "my_value".to_string())];
99 let recipe_name = recipe_path.to_str().unwrap().to_string();
100
101 let (input_config, recipe_info) =
102 extract_recipe_info_from_cli(recipe_name, params, Vec::new(), false).unwrap();
103 let settings = recipe_info.session_settings;
104 let sub_recipes = recipe_info.sub_recipes;
105 let response = recipe_info.final_output_response;
106
107 assert_eq!(input_config.contents, Some("test_prompt".to_string()));
108 assert_eq!(
109 input_config.additional_system_prompt,
110 Some("test_instructions my_value".to_string())
111 );
112 assert!(input_config.extensions_override.is_none());
113
114 assert!(settings.is_some());
115 let settings = settings.unwrap();
116 assert_eq!(settings.aster_provider, Some("test_provider".to_string()));
117 assert_eq!(settings.aster_model, Some("test_model".to_string()));
118 assert_eq!(settings.temperature, Some(0.7));
119
120 assert!(sub_recipes.is_some());
121 let sub_recipes = sub_recipes.unwrap();
122 assert!(sub_recipes.len() == 1);
123 let full_sub_recipe_path = recipe_path
124 .parent()
125 .unwrap()
126 .join("existing_sub_recipe.yaml")
127 .to_string_lossy()
128 .to_string();
129 assert_eq!(sub_recipes[0].path, full_sub_recipe_path);
130 assert_eq!(sub_recipes[0].name, "existing_sub_recipe".to_string());
131 assert!(sub_recipes[0].values.is_none());
132 assert!(response.is_some());
133 let response = response.unwrap();
134 assert_eq!(
135 response.json_schema,
136 Some(serde_json::json!({
137 "type": "object",
138 "properties": {
139 "result": {"type": "string"}
140 }
141 }))
142 );
143 }
144
145 #[test]
146 fn test_extract_recipe_info_from_cli_with_additional_sub_recipes() {
147 let (temp_dir, recipe_path) = create_recipe();
148
149 std::fs::create_dir_all(temp_dir.path().join("path/to")).unwrap();
150 std::fs::create_dir_all(temp_dir.path().join("another")).unwrap();
151
152 let sub_recipe1_path = temp_dir.path().join("path/to/sub_recipe1.yaml");
153 let sub_recipe2_path = temp_dir.path().join("another/sub_recipe2.yaml");
154
155 std::fs::write(&sub_recipe1_path, "title: Sub Recipe 1").unwrap();
156 std::fs::write(&sub_recipe2_path, "title: Sub Recipe 2").unwrap();
157
158 let params = vec![("name".to_string(), "my_value".to_string())];
159 let recipe_name = recipe_path.to_str().unwrap().to_string();
160 let additional_sub_recipes = vec![
161 sub_recipe1_path.to_string_lossy().to_string(),
162 sub_recipe2_path.to_string_lossy().to_string(),
163 ];
164
165 let (input_config, recipe_info) =
166 extract_recipe_info_from_cli(recipe_name, params, additional_sub_recipes, false)
167 .unwrap();
168 let settings = recipe_info.session_settings;
169 let sub_recipes = recipe_info.sub_recipes;
170 let response = recipe_info.final_output_response;
171
172 assert_eq!(input_config.contents, Some("test_prompt".to_string()));
173 assert_eq!(
174 input_config.additional_system_prompt,
175 Some("test_instructions my_value".to_string())
176 );
177 assert!(input_config.extensions_override.is_none());
178
179 assert!(settings.is_some());
180 let settings = settings.unwrap();
181 assert_eq!(settings.aster_provider, Some("test_provider".to_string()));
182 assert_eq!(settings.aster_model, Some("test_model".to_string()));
183 assert_eq!(settings.temperature, Some(0.7));
184
185 assert!(sub_recipes.is_some());
186 let sub_recipes = sub_recipes.unwrap();
187 assert!(sub_recipes.len() == 3);
188 let full_sub_recipe_path = recipe_path
189 .parent()
190 .unwrap()
191 .join("existing_sub_recipe.yaml")
192 .to_string_lossy()
193 .to_string();
194 assert_eq!(sub_recipes[0].path, full_sub_recipe_path);
195 assert_eq!(sub_recipes[0].name, "existing_sub_recipe".to_string());
196 assert!(sub_recipes[0].values.is_none());
197 assert_eq!(
198 sub_recipes[1].path,
199 sub_recipe1_path
200 .canonicalize()
201 .unwrap()
202 .to_string_lossy()
203 .to_string()
204 );
205 assert_eq!(sub_recipes[1].name, "sub_recipe1".to_string());
206 assert!(sub_recipes[1].values.is_none());
207 assert_eq!(
208 sub_recipes[2].path,
209 sub_recipe2_path
210 .canonicalize()
211 .unwrap()
212 .to_string_lossy()
213 .to_string()
214 );
215 assert_eq!(sub_recipes[2].name, "sub_recipe2".to_string());
216 assert!(sub_recipes[2].values.is_none());
217 assert!(response.is_some());
218 let response = response.unwrap();
219 assert_eq!(
220 response.json_schema,
221 Some(serde_json::json!({
222 "type": "object",
223 "properties": {
224 "result": {"type": "string"}
225 }
226 }))
227 );
228 }
229
230 fn create_recipe() -> (TempDir, PathBuf) {
231 let test_recipe_content = r#"
232title: test_recipe
233description: A test recipe
234instructions: test_instructions {{name}}
235prompt: test_prompt
236parameters:
237- key: name
238 description: name
239 input_type: string
240 requirement: required
241settings:
242 aster_provider: test_provider
243 aster_model: test_model
244 temperature: 0.7
245sub_recipes:
246- path: existing_sub_recipe.yaml
247 name: existing_sub_recipe
248response:
249 json_schema:
250 type: object
251 properties:
252 result:
253 type: string
254"#;
255 let sub_recipe_content = r#"
256title: existing_sub_recipe
257description: An existing sub recipe
258instructions: sub recipe instructions
259prompt: sub recipe prompt
260"#;
261 let temp_dir = tempfile::tempdir().unwrap();
262 let recipe_path: std::path::PathBuf = temp_dir.path().join("test_recipe.yaml");
263 let sub_recipe_path: std::path::PathBuf = temp_dir.path().join("existing_sub_recipe.yaml");
264
265 std::fs::write(&recipe_path, test_recipe_content).unwrap();
266 std::fs::write(&sub_recipe_path, sub_recipe_content).unwrap();
267 let canonical_recipe_path = recipe_path.canonicalize().unwrap();
268 (temp_dir, canonical_recipe_path)
269 }
270}