composio_sdk/wizard/
generator.rs1use super::skills::{Impact, Rule, SkillsExtractor, SkillsError};
8
9#[derive(Debug, Clone)]
11pub struct WizardInstructionGenerator {
12 skills: SkillsExtractor,
13}
14
15impl WizardInstructionGenerator {
16 pub fn new(skills: SkillsExtractor) -> Self {
32 Self { skills }
33 }
34
35 pub fn generate_composio_instructions(&self, toolkit: Option<&str>) -> Result<String, SkillsError> {
64 let mut output = String::new();
65
66 output.push_str("# Composio Wizard Instructions\n\n");
68
69 if let Some(tk) = toolkit {
70 output.push_str(&format!("**Context:** Using toolkit `{}`\n\n", tk));
71 }
72
73 output.push_str(&self.generate_overview_section()?);
75
76 output.push_str(&self.generate_critical_rules_section()?);
78
79 output.push_str(&self.generate_session_management_section()?);
81
82 output.push_str(&self.generate_authentication_section()?);
84
85 if let Some(tk) = toolkit {
87 output.push_str(&self.generate_toolkit_specific_section(tk)?);
88 }
89
90 Ok(output)
91 }
92
93 fn generate_overview_section(&self) -> Result<String, SkillsError> {
95 let mut section = String::new();
96
97 section.push_str("## Overview\n\n");
98
99 self.skills.verify_path()?;
101
102 match self.skills.get_consolidated_content() {
104 Ok(content) => {
105 let lines: Vec<&str> = content.lines().collect();
107 let mut overview = String::new();
108 let mut char_count = 0;
109
110 for line in lines.iter().take(50) {
111 if line.starts_with('#') && !overview.is_empty() {
112 break; }
114 if !line.trim().is_empty() {
115 overview.push_str(line);
116 overview.push('\n');
117 char_count += line.len();
118
119 if char_count > 500 {
120 break;
121 }
122 }
123 }
124
125 if !overview.is_empty() {
126 section.push_str(&overview);
127 section.push_str("\n\n");
128 } else {
129 section.push_str("Composio provides a comprehensive platform for connecting AI agents to external services.\n\n");
130 }
131 }
132 Err(_) => {
133 section.push_str("Composio provides a comprehensive platform for connecting AI agents to external services.\n");
135 section.push_str("Use sessions for user-scoped tool execution, meta tools for discovery, and proper authentication patterns.\n\n");
136 }
137 }
138
139 Ok(section)
140 }
141
142 fn generate_critical_rules_section(&self) -> Result<String, SkillsError> {
144 let mut section = String::new();
145
146 section.push_str("## Critical Rules\n\n");
147 section.push_str("These rules are **CRITICAL** and must be followed to ensure correct behavior:\n\n");
148
149 let rules = self.skills.get_tool_router_rules()?;
151
152 let critical_rules: Vec<&Rule> = rules
154 .iter()
155 .filter(|r| r.impact == Impact::Critical)
156 .collect();
157
158 if critical_rules.is_empty() {
159 section.push_str("*No critical rules found. Ensure Skills repository is properly configured.*\n\n");
160 } else {
161 for (i, rule) in critical_rules.iter().enumerate() {
162 section.push_str(&format!("### {}.{} {}\n\n", i + 1, " ", rule.title));
163 section.push_str(&self.format_rule(rule));
164 section.push('\n');
165 }
166 }
167
168 Ok(section)
169 }
170
171 fn generate_session_management_section(&self) -> Result<String, SkillsError> {
173 let mut section = String::new();
174
175 section.push_str("## Session Management Patterns\n\n");
176
177 let session_rules = self.skills.get_rules_by_tag("sessions")?;
179
180 if session_rules.is_empty() {
181 section.push_str("**Best Practices:**\n\n");
183 section.push_str("- Always create sessions with a valid user_id\n");
184 section.push_str("- Never use \"default\" as a user_id in production\n");
185 section.push_str("- Sessions are immutable - create new sessions when config changes\n");
186 section.push_str("- Use session.tools() to get meta tools for the agent\n\n");
187 } else {
188 for rule in session_rules.iter() {
189 section.push_str(&format!("### {}\n\n", rule.title));
190 section.push_str(&self.format_rule(rule));
191 section.push('\n');
192 }
193 }
194
195 Ok(section)
196 }
197
198 fn generate_authentication_section(&self) -> Result<String, SkillsError> {
200 let mut section = String::new();
201
202 section.push_str("## Authentication Patterns\n\n");
203
204 let auth_rules = self.skills.get_rules_by_tag("authentication")?;
206
207 if auth_rules.is_empty() {
208 section.push_str("**Best Practices:**\n\n");
210 section.push_str("- Use in-chat authentication (manage_connections=true) for dynamic auth\n");
211 section.push_str("- Use manual authentication (session.authorize()) for pre-onboarding\n");
212 section.push_str("- Check connection status before executing tools\n");
213 section.push_str("- Handle OAuth redirects with callback URLs\n\n");
214 } else {
215 for rule in auth_rules.iter() {
216 section.push_str(&format!("### {}\n\n", rule.title));
217 section.push_str(&self.format_rule(rule));
218 section.push('\n');
219 }
220 }
221
222 Ok(section)
223 }
224
225 fn generate_toolkit_specific_section(&self, toolkit: &str) -> Result<String, SkillsError> {
227 let mut section = String::new();
228
229 section.push_str(&format!("## Toolkit-Specific Guidance: {}\n\n", toolkit));
230
231 let toolkit_rules = self.skills.get_rules_by_tag(toolkit)?;
233
234 if toolkit_rules.is_empty() {
235 section.push_str(&format!(
236 "*No specific rules found for toolkit `{}`. Use general best practices.*\n\n",
237 toolkit
238 ));
239 } else {
240 for rule in toolkit_rules.iter() {
241 section.push_str(&format!("### {}\n\n", rule.title));
242 section.push_str(&self.format_rule(rule));
243 section.push('\n');
244 }
245 }
246
247 Ok(section)
248 }
249
250 fn format_rule(&self, rule: &Rule) -> String {
265 let mut output = String::new();
266
267 if !rule.description.is_empty() {
269 output.push_str(&format!("**Description:** {}\n\n", rule.description));
270 }
271
272 output.push_str(&format!("**Impact:** {}\n\n", rule.impact.as_str()));
274
275 if !rule.tags.is_empty() {
277 output.push_str(&format!("**Tags:** {}\n\n", rule.tags.join(", ")));
278 }
279
280 if !rule.correct_examples.is_empty() {
282 output.push_str("✅ **Correct Examples:**\n\n");
283 for example in &rule.correct_examples {
284 output.push_str("```\n");
285 output.push_str(example);
286 output.push_str("\n```\n\n");
287 }
288 }
289
290 if !rule.incorrect_examples.is_empty() {
292 output.push_str("❌ **Incorrect Examples:**\n\n");
293 for example in &rule.incorrect_examples {
294 output.push_str("```\n");
295 output.push_str(example);
296 output.push_str("\n```\n\n");
297 }
298 }
299
300 output
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 fn create_test_generator() -> WizardInstructionGenerator {
309 let skills_path = concat!(env!("CARGO_MANIFEST_DIR"), "/skills");
310 let skills = SkillsExtractor::new(skills_path);
311 WizardInstructionGenerator::new(skills)
312 }
313
314 #[test]
315 fn test_generator_creation() {
316 let generator = create_test_generator();
317 assert!(std::mem::size_of_val(&generator) > 0);
318 }
319
320 #[test]
321 fn test_format_rule() {
322 let generator = create_test_generator();
323
324 let rule = Rule {
325 title: "Test Rule".to_string(),
326 impact: Impact::Critical,
327 description: "A test rule for formatting".to_string(),
328 tags: vec!["test".to_string(), "example".to_string()],
329 content: "Test content".to_string(),
330 correct_examples: vec!["let x = 1;".to_string()],
331 incorrect_examples: vec!["let x = \"default\";".to_string()],
332 };
333
334 let formatted = generator.format_rule(&rule);
335
336 assert!(formatted.contains("**Description:**"));
337 assert!(formatted.contains("**Impact:** CRITICAL"));
338 assert!(formatted.contains("**Tags:** test, example"));
339 assert!(formatted.contains("✅ **Correct Examples:**"));
340 assert!(formatted.contains("❌ **Incorrect Examples:**"));
341 assert!(formatted.contains("let x = 1;"));
342 assert!(formatted.contains("let x = \"default\";"));
343 }
344
345 #[test]
346 fn test_format_rule_minimal() {
347 let generator = create_test_generator();
348
349 let rule = Rule {
350 title: "Minimal Rule".to_string(),
351 impact: Impact::Low,
352 description: String::new(),
353 tags: Vec::new(),
354 content: String::new(),
355 correct_examples: Vec::new(),
356 incorrect_examples: Vec::new(),
357 };
358
359 let formatted = generator.format_rule(&rule);
360
361 assert!(formatted.contains("**Impact:** LOW"));
362 assert!(!formatted.contains("**Description:**"));
363 assert!(!formatted.contains("**Tags:**"));
364 assert!(!formatted.contains("✅"));
365 assert!(!formatted.contains("❌"));
366 }
367
368 #[test]
369 #[ignore] fn test_generate_composio_instructions() {
371 let generator = create_test_generator();
372
373 let instructions = generator.generate_composio_instructions(None);
374
375 if let Ok(content) = instructions {
376 assert!(content.contains("# Composio Wizard Instructions"));
377 assert!(content.contains("## Overview"));
378 assert!(content.contains("## Critical Rules"));
379 assert!(content.contains("## Session Management Patterns"));
380 assert!(content.contains("## Authentication Patterns"));
381 }
382 }
383
384 #[test]
385 #[ignore] fn test_generate_with_toolkit() {
387 let generator = create_test_generator();
388
389 let instructions = generator.generate_composio_instructions(Some("github"));
390
391 if let Ok(content) = instructions {
392 assert!(content.contains("**Context:** Using toolkit `github`"));
393 assert!(content.contains("## Toolkit-Specific Guidance: github"));
394 }
395 }
396
397 #[test]
398 fn test_generate_overview_section_fallback() {
399 let generator = create_test_generator();
400
401 let section = generator.generate_overview_section();
403
404 assert!(section.is_ok());
405 let content = section.unwrap();
406 assert!(content.contains("## Overview"));
407 assert!(content.contains("Composio"));
408 }
409
410 #[test]
411 fn test_generate_critical_rules_section() {
412 let generator = create_test_generator();
413
414 let section = generator.generate_critical_rules_section();
415
416 assert!(section.is_ok());
417 let content = section.unwrap();
418 assert!(content.contains("## Critical Rules"));
419 }
420
421 #[test]
422 fn test_generate_session_management_section() {
423 let generator = create_test_generator();
424
425 let section = generator.generate_session_management_section();
426
427 assert!(section.is_ok());
428 let content = section.unwrap();
429 assert!(content.contains("## Session Management Patterns"));
430 }
431
432 #[test]
433 fn test_generate_authentication_section() {
434 let generator = create_test_generator();
435
436 let section = generator.generate_authentication_section();
437
438 assert!(section.is_ok());
439 let content = section.unwrap();
440 assert!(content.contains("## Authentication Patterns"));
441 }
442
443 #[test]
444 fn test_generate_toolkit_specific_section() {
445 let generator = create_test_generator();
446
447 let section = generator.generate_toolkit_specific_section("github");
448
449 assert!(section.is_ok());
450 let content = section.unwrap();
451 assert!(content.contains("## Toolkit-Specific Guidance: github"));
452 }
453}