Skip to main content

composio_sdk/wizard/
generator.rs

1//! Wizard instruction generator for AI agents
2//!
3//! This module generates comprehensive wizard instructions for AI agents
4//! using Composio Skills content. Instructions include best practices,
5//! critical rules, and context-aware guidance.
6
7use super::skills::{Impact, Rule, SkillsExtractor, SkillsError};
8
9/// Generator for wizard instructions
10#[derive(Debug, Clone)]
11pub struct WizardInstructionGenerator {
12    skills: SkillsExtractor,
13}
14
15impl WizardInstructionGenerator {
16    /// Create a new wizard instruction generator
17    ///
18    /// # Arguments
19    ///
20    /// * `skills` - SkillsExtractor instance for accessing Skills content
21    ///
22    /// # Example
23    ///
24    /// ```no_run
25    /// use composio_sdk::wizard::{SkillsExtractor, WizardInstructionGenerator};
26    ///
27    /// let skills = SkillsExtractor::new("vendor/skills/skills/composio");
28    /// let generator = WizardInstructionGenerator::new(skills);
29    /// ```
30    pub fn new(skills: SkillsExtractor) -> Self {
31        Self { skills }
32    }
33
34    /// Generate comprehensive Composio wizard instructions
35    ///
36    /// Generates a complete set of instructions including:
37    /// - Overview from AGENTS.md
38    /// - Critical Tool Router rules
39    /// - Session management patterns
40    /// - Authentication patterns
41    /// - Correct and incorrect examples
42    ///
43    /// # Arguments
44    ///
45    /// * `toolkit` - Optional toolkit name for context-aware instructions
46    ///
47    /// # Returns
48    ///
49    /// A formatted markdown string with wizard instructions
50    ///
51    /// # Example
52    ///
53    /// ```no_run
54    /// use composio_sdk::wizard::{SkillsExtractor, WizardInstructionGenerator};
55    ///
56    /// let skills = SkillsExtractor::new("vendor/skills/skills/composio");
57    /// let generator = WizardInstructionGenerator::new(skills);
58    /// let instructions = generator.generate_composio_instructions(Some("github")).unwrap();
59    /// println!("{}", instructions);
60    /// ```
61    pub fn generate_composio_instructions(&self, toolkit: Option<&str>) -> Result<String, SkillsError> {
62        let mut output = String::new();
63
64        // Add header
65        output.push_str("# Composio Wizard Instructions\n\n");
66
67        if let Some(tk) = toolkit {
68            output.push_str(&format!("**Context:** Using toolkit `{}`\n\n", tk));
69        }
70
71        // Add overview section from AGENTS.md
72        output.push_str(&self.generate_overview_section()?);
73
74        // Add critical Tool Router rules
75        output.push_str(&self.generate_critical_rules_section()?);
76
77        // Add session management patterns
78        output.push_str(&self.generate_session_management_section()?);
79
80        // Add authentication patterns
81        output.push_str(&self.generate_authentication_section()?);
82
83        // Add toolkit-specific guidance if provided
84        if let Some(tk) = toolkit {
85            output.push_str(&self.generate_toolkit_specific_section(tk)?);
86        }
87
88        Ok(output)
89    }
90
91    /// Generate overview section from AGENTS.md
92    fn generate_overview_section(&self) -> Result<String, SkillsError> {
93        let mut section = String::new();
94
95        section.push_str("## Overview\n\n");
96
97        // Verify path before attempting to read
98        self.skills.verify_path()?;
99
100        // Get consolidated content from AGENTS.md
101        match self.skills.get_consolidated_content() {
102            Ok(content) => {
103                // Extract first few paragraphs as overview (limit to ~500 chars)
104                let lines: Vec<&str> = content.lines().collect();
105                let mut overview = String::new();
106                let mut char_count = 0;
107
108                for line in lines.iter().take(50) {
109                    if line.starts_with('#') && !overview.is_empty() {
110                        break; // Stop at first heading after content
111                    }
112                    if !line.trim().is_empty() {
113                        overview.push_str(line);
114                        overview.push('\n');
115                        char_count += line.len();
116
117                        if char_count > 500 {
118                            break;
119                        }
120                    }
121                }
122
123                if !overview.is_empty() {
124                    section.push_str(&overview);
125                    section.push_str("\n\n");
126                } else {
127                    section.push_str("Composio provides a comprehensive platform for connecting AI agents to external services.\n\n");
128                }
129            }
130            Err(_) => {
131                // Fallback if AGENTS.md is not available
132                section.push_str("Composio provides a comprehensive platform for connecting AI agents to external services.\n");
133                section.push_str("Use sessions for user-scoped tool execution, meta tools for discovery, and proper authentication patterns.\n\n");
134            }
135        }
136
137        Ok(section)
138    }
139
140    /// Generate critical rules section
141    fn generate_critical_rules_section(&self) -> Result<String, SkillsError> {
142        let mut section = String::new();
143
144        section.push_str("## Critical Rules\n\n");
145        section.push_str("These rules are **CRITICAL** and must be followed to ensure correct behavior:\n\n");
146
147        // Get all Tool Router rules
148        let rules = self.skills.get_tool_router_rules()?;
149
150        // Filter for critical impact
151        let critical_rules: Vec<&Rule> = rules
152            .iter()
153            .filter(|r| r.impact == Impact::Critical)
154            .collect();
155
156        if critical_rules.is_empty() {
157            section.push_str("*No critical rules found. Ensure Skills repository is properly configured.*\n\n");
158        } else {
159            for (i, rule) in critical_rules.iter().enumerate() {
160                section.push_str(&format!("### {}.{} {}\n\n", i + 1, " ", rule.title));
161                section.push_str(&self.format_rule(rule));
162                section.push('\n');
163            }
164        }
165
166        Ok(section)
167    }
168
169    /// Generate session management patterns section
170    fn generate_session_management_section(&self) -> Result<String, SkillsError> {
171        let mut section = String::new();
172
173        section.push_str("## Session Management Patterns\n\n");
174
175        // Get rules tagged with "sessions"
176        let session_rules = self.skills.get_rules_by_tag("sessions")?;
177
178        if session_rules.is_empty() {
179            // Fallback content
180            section.push_str("**Best Practices:**\n\n");
181            section.push_str("- Always create sessions with a valid user_id\n");
182            section.push_str("- Never use \"default\" as a user_id in production\n");
183            section.push_str("- Sessions are immutable - create new sessions when config changes\n");
184            section.push_str("- Use session.tools() to get meta tools for the agent\n\n");
185        } else {
186            for rule in session_rules.iter() {
187                section.push_str(&format!("### {}\n\n", rule.title));
188                section.push_str(&self.format_rule(rule));
189                section.push('\n');
190            }
191        }
192
193        Ok(section)
194    }
195
196    /// Generate authentication patterns section
197    fn generate_authentication_section(&self) -> Result<String, SkillsError> {
198        let mut section = String::new();
199
200        section.push_str("## Authentication Patterns\n\n");
201
202        // Get rules tagged with "authentication"
203        let auth_rules = self.skills.get_rules_by_tag("authentication")?;
204
205        if auth_rules.is_empty() {
206            // Fallback content
207            section.push_str("**Best Practices:**\n\n");
208            section.push_str("- Use in-chat authentication (manage_connections=true) for dynamic auth\n");
209            section.push_str("- Use manual authentication (session.authorize()) for pre-onboarding\n");
210            section.push_str("- Check connection status before executing tools\n");
211            section.push_str("- Handle OAuth redirects with callback URLs\n\n");
212        } else {
213            for rule in auth_rules.iter() {
214                section.push_str(&format!("### {}\n\n", rule.title));
215                section.push_str(&self.format_rule(rule));
216                section.push('\n');
217            }
218        }
219
220        Ok(section)
221    }
222
223    /// Generate toolkit-specific guidance section
224    fn generate_toolkit_specific_section(&self, toolkit: &str) -> Result<String, SkillsError> {
225        let mut section = String::new();
226
227        section.push_str(&format!("## Toolkit-Specific Guidance: {}\n\n", toolkit));
228
229        // Get rules tagged with the toolkit name
230        let toolkit_rules = self.skills.get_rules_by_tag(toolkit)?;
231
232        if toolkit_rules.is_empty() {
233            section.push_str(&format!(
234                "*No specific rules found for toolkit `{}`. Use general best practices.*\n\n",
235                toolkit
236            ));
237        } else {
238            for rule in toolkit_rules.iter() {
239                section.push_str(&format!("### {}\n\n", rule.title));
240                section.push_str(&self.format_rule(rule));
241                section.push('\n');
242            }
243        }
244
245        Ok(section)
246    }
247
248    /// Format a rule with description and examples
249    ///
250    /// Formats a rule with:
251    /// - Description
252    /// - Correct examples (✅)
253    /// - Incorrect examples (❌)
254    ///
255    /// # Arguments
256    ///
257    /// * `rule` - The rule to format
258    ///
259    /// # Returns
260    ///
261    /// A formatted markdown string
262    fn format_rule(&self, rule: &Rule) -> String {
263        let mut output = String::new();
264
265        // Add description
266        if !rule.description.is_empty() {
267            output.push_str(&format!("**Description:** {}\n\n", rule.description));
268        }
269
270        // Add impact level
271        output.push_str(&format!("**Impact:** {}\n\n", rule.impact.as_str()));
272
273        // Add tags if present
274        if !rule.tags.is_empty() {
275            output.push_str(&format!("**Tags:** {}\n\n", rule.tags.join(", ")));
276        }
277
278        // Add correct examples
279        if !rule.correct_examples.is_empty() {
280            output.push_str("✅ **Correct Examples:**\n\n");
281            for example in &rule.correct_examples {
282                output.push_str("```\n");
283                output.push_str(example);
284                output.push_str("\n```\n\n");
285            }
286        }
287
288        // Add incorrect examples
289        if !rule.incorrect_examples.is_empty() {
290            output.push_str("❌ **Incorrect Examples:**\n\n");
291            for example in &rule.incorrect_examples {
292                output.push_str("```\n");
293                output.push_str(example);
294                output.push_str("\n```\n\n");
295            }
296        }
297
298        output
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    fn create_test_generator() -> WizardInstructionGenerator {
307        let skills = SkillsExtractor::new("vendor/skills/skills/composio");
308        WizardInstructionGenerator::new(skills)
309    }
310
311    #[test]
312    fn test_generator_creation() {
313        let generator = create_test_generator();
314        assert!(std::mem::size_of_val(&generator) > 0);
315    }
316
317    #[test]
318    fn test_format_rule() {
319        let generator = create_test_generator();
320
321        let rule = Rule {
322            title: "Test Rule".to_string(),
323            impact: Impact::Critical,
324            description: "A test rule for formatting".to_string(),
325            tags: vec!["test".to_string(), "example".to_string()],
326            content: "Test content".to_string(),
327            correct_examples: vec!["let x = 1;".to_string()],
328            incorrect_examples: vec!["let x = \"default\";".to_string()],
329        };
330
331        let formatted = generator.format_rule(&rule);
332
333        assert!(formatted.contains("**Description:**"));
334        assert!(formatted.contains("**Impact:** CRITICAL"));
335        assert!(formatted.contains("**Tags:** test, example"));
336        assert!(formatted.contains("✅ **Correct Examples:**"));
337        assert!(formatted.contains("❌ **Incorrect Examples:**"));
338        assert!(formatted.contains("let x = 1;"));
339        assert!(formatted.contains("let x = \"default\";"));
340    }
341
342    #[test]
343    fn test_format_rule_minimal() {
344        let generator = create_test_generator();
345
346        let rule = Rule {
347            title: "Minimal Rule".to_string(),
348            impact: Impact::Low,
349            description: String::new(),
350            tags: Vec::new(),
351            content: String::new(),
352            correct_examples: Vec::new(),
353            incorrect_examples: Vec::new(),
354        };
355
356        let formatted = generator.format_rule(&rule);
357
358        assert!(formatted.contains("**Impact:** LOW"));
359        assert!(!formatted.contains("**Description:**"));
360        assert!(!formatted.contains("**Tags:**"));
361        assert!(!formatted.contains("✅"));
362        assert!(!formatted.contains("❌"));
363    }
364
365    #[test]
366    #[ignore] // Requires Skills repository to be present
367    fn test_generate_composio_instructions() {
368        let generator = create_test_generator();
369
370        let instructions = generator.generate_composio_instructions(None);
371
372        if let Ok(content) = instructions {
373            assert!(content.contains("# Composio Wizard Instructions"));
374            assert!(content.contains("## Overview"));
375            assert!(content.contains("## Critical Rules"));
376            assert!(content.contains("## Session Management Patterns"));
377            assert!(content.contains("## Authentication Patterns"));
378        }
379    }
380
381    #[test]
382    #[ignore] // Requires Skills repository to be present
383    fn test_generate_with_toolkit() {
384        let generator = create_test_generator();
385
386        let instructions = generator.generate_composio_instructions(Some("github"));
387
388        if let Ok(content) = instructions {
389            assert!(content.contains("**Context:** Using toolkit `github`"));
390            assert!(content.contains("## Toolkit-Specific Guidance: github"));
391        }
392    }
393
394    #[test]
395    fn test_generate_overview_section_fallback() {
396        let generator = create_test_generator();
397
398        // This should use fallback content if AGENTS.md is not available
399        let section = generator.generate_overview_section();
400
401        assert!(section.is_ok());
402        let content = section.unwrap();
403        assert!(content.contains("## Overview"));
404        assert!(content.contains("Composio"));
405    }
406
407    #[test]
408    fn test_generate_critical_rules_section() {
409        let generator = create_test_generator();
410
411        let section = generator.generate_critical_rules_section();
412
413        assert!(section.is_ok());
414        let content = section.unwrap();
415        assert!(content.contains("## Critical Rules"));
416    }
417
418    #[test]
419    fn test_generate_session_management_section() {
420        let generator = create_test_generator();
421
422        let section = generator.generate_session_management_section();
423
424        assert!(section.is_ok());
425        let content = section.unwrap();
426        assert!(content.contains("## Session Management Patterns"));
427    }
428
429    #[test]
430    fn test_generate_authentication_section() {
431        let generator = create_test_generator();
432
433        let section = generator.generate_authentication_section();
434
435        assert!(section.is_ok());
436        let content = section.unwrap();
437        assert!(content.contains("## Authentication Patterns"));
438    }
439
440    #[test]
441    fn test_generate_toolkit_specific_section() {
442        let generator = create_test_generator();
443
444        let section = generator.generate_toolkit_specific_section("github");
445
446        assert!(section.is_ok());
447        let content = section.unwrap();
448        assert!(content.contains("## Toolkit-Specific Guidance: github"));
449    }
450}