Skip to main content

composio_sdk/wizard/
generator.rs

1//! Wizard instruction generator for AI agents
2//!
3//! This module generates comprehensive wizard instructions for AI agents
4//! using Composio Skills content. Instructions include best practices,
5//! critical rules, and context-aware guidance.
6
7use super::skills::{Impact, Rule, SkillsExtractor, SkillsError};
8
9/// Generator for wizard instructions
10#[derive(Debug, Clone)]
11pub struct WizardInstructionGenerator {
12    skills: SkillsExtractor,
13}
14
15impl WizardInstructionGenerator {
16    /// Create a new wizard instruction generator
17    ///
18    /// # Arguments
19    ///
20    /// * `skills` - SkillsExtractor instance for accessing Skills content
21    ///
22    /// # Example
23    ///
24    /// ```no_run
25    /// use composio_sdk::wizard::{SkillsExtractor, WizardInstructionGenerator};
26    ///
27    /// let skills_path = concat!(env!("CARGO_MANIFEST_DIR"), "/skills");
28    /// let skills = SkillsExtractor::new(skills_path);
29    /// let generator = WizardInstructionGenerator::new(skills);
30    /// ```
31    pub fn new(skills: SkillsExtractor) -> Self {
32        Self { skills }
33    }
34
35    /// Generate comprehensive Composio wizard instructions
36    ///
37    /// Generates a complete set of instructions including:
38    /// - Overview from AGENTS.md
39    /// - Critical Tool Router rules
40    /// - Session management patterns
41    /// - Authentication patterns
42    /// - Correct and incorrect examples
43    ///
44    /// # Arguments
45    ///
46    /// * `toolkit` - Optional toolkit name for context-aware instructions
47    ///
48    /// # Returns
49    ///
50    /// A formatted markdown string with wizard instructions
51    ///
52    /// # Example
53    ///
54    /// ```no_run
55    /// use composio_sdk::wizard::{SkillsExtractor, WizardInstructionGenerator};
56    ///
57    /// let skills_path = concat!(env!("CARGO_MANIFEST_DIR"), "/skills");
58    /// let skills = SkillsExtractor::new(skills_path);
59    /// let generator = WizardInstructionGenerator::new(skills);
60    /// let instructions = generator.generate_composio_instructions(Some("github")).unwrap();
61    /// println!("{}", instructions);
62    /// ```
63    pub fn generate_composio_instructions(&self, toolkit: Option<&str>) -> Result<String, SkillsError> {
64        let mut output = String::new();
65
66        // Add header
67        output.push_str("# Composio Wizard Instructions\n\n");
68
69        if let Some(tk) = toolkit {
70            output.push_str(&format!("**Context:** Using toolkit `{}`\n\n", tk));
71        }
72
73        // Add overview section from AGENTS.md
74        output.push_str(&self.generate_overview_section()?);
75
76        // Add critical Tool Router rules
77        output.push_str(&self.generate_critical_rules_section()?);
78
79        // Add session management patterns
80        output.push_str(&self.generate_session_management_section()?);
81
82        // Add authentication patterns
83        output.push_str(&self.generate_authentication_section()?);
84
85        // Add toolkit-specific guidance if provided
86        if let Some(tk) = toolkit {
87            output.push_str(&self.generate_toolkit_specific_section(tk)?);
88        }
89
90        Ok(output)
91    }
92
93    /// Generate overview section from AGENTS.md
94    fn generate_overview_section(&self) -> Result<String, SkillsError> {
95        let mut section = String::new();
96
97        section.push_str("## Overview\n\n");
98
99        // Verify path before attempting to read
100        self.skills.verify_path()?;
101
102        // Get consolidated content from AGENTS.md
103        match self.skills.get_consolidated_content() {
104            Ok(content) => {
105                // Extract first few paragraphs as overview (limit to ~500 chars)
106                let lines: Vec<&str> = content.lines().collect();
107                let mut overview = String::new();
108                let mut char_count = 0;
109
110                for line in lines.iter().take(50) {
111                    if line.starts_with('#') && !overview.is_empty() {
112                        break; // Stop at first heading after content
113                    }
114                    if !line.trim().is_empty() {
115                        overview.push_str(line);
116                        overview.push('\n');
117                        char_count += line.len();
118
119                        if char_count > 500 {
120                            break;
121                        }
122                    }
123                }
124
125                if !overview.is_empty() {
126                    section.push_str(&overview);
127                    section.push_str("\n\n");
128                } else {
129                    section.push_str("Composio provides a comprehensive platform for connecting AI agents to external services.\n\n");
130                }
131            }
132            Err(_) => {
133                // Fallback if AGENTS.md is not available
134                section.push_str("Composio provides a comprehensive platform for connecting AI agents to external services.\n");
135                section.push_str("Use sessions for user-scoped tool execution, meta tools for discovery, and proper authentication patterns.\n\n");
136            }
137        }
138
139        Ok(section)
140    }
141
142    /// Generate critical rules section
143    fn generate_critical_rules_section(&self) -> Result<String, SkillsError> {
144        let mut section = String::new();
145
146        section.push_str("## Critical Rules\n\n");
147        section.push_str("These rules are **CRITICAL** and must be followed to ensure correct behavior:\n\n");
148
149        // Get all Tool Router rules
150        let rules = self.skills.get_tool_router_rules()?;
151
152        // Filter for critical impact
153        let critical_rules: Vec<&Rule> = rules
154            .iter()
155            .filter(|r| r.impact == Impact::Critical)
156            .collect();
157
158        if critical_rules.is_empty() {
159            section.push_str("*No critical rules found. Ensure Skills repository is properly configured.*\n\n");
160        } else {
161            for (i, rule) in critical_rules.iter().enumerate() {
162                section.push_str(&format!("### {}.{} {}\n\n", i + 1, " ", rule.title));
163                section.push_str(&self.format_rule(rule));
164                section.push('\n');
165            }
166        }
167
168        Ok(section)
169    }
170
171    /// Generate session management patterns section
172    fn generate_session_management_section(&self) -> Result<String, SkillsError> {
173        let mut section = String::new();
174
175        section.push_str("## Session Management Patterns\n\n");
176
177        // Get rules tagged with "sessions"
178        let session_rules = self.skills.get_rules_by_tag("sessions")?;
179
180        if session_rules.is_empty() {
181            // Fallback content
182            section.push_str("**Best Practices:**\n\n");
183            section.push_str("- Always create sessions with a valid user_id\n");
184            section.push_str("- Never use \"default\" as a user_id in production\n");
185            section.push_str("- Sessions are immutable - create new sessions when config changes\n");
186            section.push_str("- Use session.tools() to get meta tools for the agent\n\n");
187        } else {
188            for rule in session_rules.iter() {
189                section.push_str(&format!("### {}\n\n", rule.title));
190                section.push_str(&self.format_rule(rule));
191                section.push('\n');
192            }
193        }
194
195        Ok(section)
196    }
197
198    /// Generate authentication patterns section
199    fn generate_authentication_section(&self) -> Result<String, SkillsError> {
200        let mut section = String::new();
201
202        section.push_str("## Authentication Patterns\n\n");
203
204        // Get rules tagged with "authentication"
205        let auth_rules = self.skills.get_rules_by_tag("authentication")?;
206
207        if auth_rules.is_empty() {
208            // Fallback content
209            section.push_str("**Best Practices:**\n\n");
210            section.push_str("- Use in-chat authentication (manage_connections=true) for dynamic auth\n");
211            section.push_str("- Use manual authentication (session.authorize()) for pre-onboarding\n");
212            section.push_str("- Check connection status before executing tools\n");
213            section.push_str("- Handle OAuth redirects with callback URLs\n\n");
214        } else {
215            for rule in auth_rules.iter() {
216                section.push_str(&format!("### {}\n\n", rule.title));
217                section.push_str(&self.format_rule(rule));
218                section.push('\n');
219            }
220        }
221
222        Ok(section)
223    }
224
225    /// Generate toolkit-specific guidance section
226    fn generate_toolkit_specific_section(&self, toolkit: &str) -> Result<String, SkillsError> {
227        let mut section = String::new();
228
229        section.push_str(&format!("## Toolkit-Specific Guidance: {}\n\n", toolkit));
230
231        // Get rules tagged with the toolkit name
232        let toolkit_rules = self.skills.get_rules_by_tag(toolkit)?;
233
234        if toolkit_rules.is_empty() {
235            section.push_str(&format!(
236                "*No specific rules found for toolkit `{}`. Use general best practices.*\n\n",
237                toolkit
238            ));
239        } else {
240            for rule in toolkit_rules.iter() {
241                section.push_str(&format!("### {}\n\n", rule.title));
242                section.push_str(&self.format_rule(rule));
243                section.push('\n');
244            }
245        }
246
247        Ok(section)
248    }
249
250    /// Format a rule with description and examples
251    ///
252    /// Formats a rule with:
253    /// - Description
254    /// - Correct examples (✅)
255    /// - Incorrect examples (❌)
256    ///
257    /// # Arguments
258    ///
259    /// * `rule` - The rule to format
260    ///
261    /// # Returns
262    ///
263    /// A formatted markdown string
264    fn format_rule(&self, rule: &Rule) -> String {
265        let mut output = String::new();
266
267        // Add description
268        if !rule.description.is_empty() {
269            output.push_str(&format!("**Description:** {}\n\n", rule.description));
270        }
271
272        // Add impact level
273        output.push_str(&format!("**Impact:** {}\n\n", rule.impact.as_str()));
274
275        // Add tags if present
276        if !rule.tags.is_empty() {
277            output.push_str(&format!("**Tags:** {}\n\n", rule.tags.join(", ")));
278        }
279
280        // Add correct examples
281        if !rule.correct_examples.is_empty() {
282            output.push_str("✅ **Correct Examples:**\n\n");
283            for example in &rule.correct_examples {
284                output.push_str("```\n");
285                output.push_str(example);
286                output.push_str("\n```\n\n");
287            }
288        }
289
290        // Add incorrect examples
291        if !rule.incorrect_examples.is_empty() {
292            output.push_str("❌ **Incorrect Examples:**\n\n");
293            for example in &rule.incorrect_examples {
294                output.push_str("```\n");
295                output.push_str(example);
296                output.push_str("\n```\n\n");
297            }
298        }
299
300        output
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    fn create_test_generator() -> WizardInstructionGenerator {
309        let skills_path = concat!(env!("CARGO_MANIFEST_DIR"), "/skills");
310        let skills = SkillsExtractor::new(skills_path);
311        WizardInstructionGenerator::new(skills)
312    }
313
314    #[test]
315    fn test_generator_creation() {
316        let generator = create_test_generator();
317        assert!(std::mem::size_of_val(&generator) > 0);
318    }
319
320    #[test]
321    fn test_format_rule() {
322        let generator = create_test_generator();
323
324        let rule = Rule {
325            title: "Test Rule".to_string(),
326            impact: Impact::Critical,
327            description: "A test rule for formatting".to_string(),
328            tags: vec!["test".to_string(), "example".to_string()],
329            content: "Test content".to_string(),
330            correct_examples: vec!["let x = 1;".to_string()],
331            incorrect_examples: vec!["let x = \"default\";".to_string()],
332        };
333
334        let formatted = generator.format_rule(&rule);
335
336        assert!(formatted.contains("**Description:**"));
337        assert!(formatted.contains("**Impact:** CRITICAL"));
338        assert!(formatted.contains("**Tags:** test, example"));
339        assert!(formatted.contains("✅ **Correct Examples:**"));
340        assert!(formatted.contains("❌ **Incorrect Examples:**"));
341        assert!(formatted.contains("let x = 1;"));
342        assert!(formatted.contains("let x = \"default\";"));
343    }
344
345    #[test]
346    fn test_format_rule_minimal() {
347        let generator = create_test_generator();
348
349        let rule = Rule {
350            title: "Minimal Rule".to_string(),
351            impact: Impact::Low,
352            description: String::new(),
353            tags: Vec::new(),
354            content: String::new(),
355            correct_examples: Vec::new(),
356            incorrect_examples: Vec::new(),
357        };
358
359        let formatted = generator.format_rule(&rule);
360
361        assert!(formatted.contains("**Impact:** LOW"));
362        assert!(!formatted.contains("**Description:**"));
363        assert!(!formatted.contains("**Tags:**"));
364        assert!(!formatted.contains("✅"));
365        assert!(!formatted.contains("❌"));
366    }
367
368    #[test]
369    #[ignore] // Requires Skills repository to be present
370    fn test_generate_composio_instructions() {
371        let generator = create_test_generator();
372
373        let instructions = generator.generate_composio_instructions(None);
374
375        if let Ok(content) = instructions {
376            assert!(content.contains("# Composio Wizard Instructions"));
377            assert!(content.contains("## Overview"));
378            assert!(content.contains("## Critical Rules"));
379            assert!(content.contains("## Session Management Patterns"));
380            assert!(content.contains("## Authentication Patterns"));
381        }
382    }
383
384    #[test]
385    #[ignore] // Requires Skills repository to be present
386    fn test_generate_with_toolkit() {
387        let generator = create_test_generator();
388
389        let instructions = generator.generate_composio_instructions(Some("github"));
390
391        if let Ok(content) = instructions {
392            assert!(content.contains("**Context:** Using toolkit `github`"));
393            assert!(content.contains("## Toolkit-Specific Guidance: github"));
394        }
395    }
396
397    #[test]
398    fn test_generate_overview_section_fallback() {
399        let generator = create_test_generator();
400
401        // This should use fallback content if AGENTS.md is not available
402        let section = generator.generate_overview_section();
403
404        assert!(section.is_ok());
405        let content = section.unwrap();
406        assert!(content.contains("## Overview"));
407        assert!(content.contains("Composio"));
408    }
409
410    #[test]
411    fn test_generate_critical_rules_section() {
412        let generator = create_test_generator();
413
414        let section = generator.generate_critical_rules_section();
415
416        assert!(section.is_ok());
417        let content = section.unwrap();
418        assert!(content.contains("## Critical Rules"));
419    }
420
421    #[test]
422    fn test_generate_session_management_section() {
423        let generator = create_test_generator();
424
425        let section = generator.generate_session_management_section();
426
427        assert!(section.is_ok());
428        let content = section.unwrap();
429        assert!(content.contains("## Session Management Patterns"));
430    }
431
432    #[test]
433    fn test_generate_authentication_section() {
434        let generator = create_test_generator();
435
436        let section = generator.generate_authentication_section();
437
438        assert!(section.is_ok());
439        let content = section.unwrap();
440        assert!(content.contains("## Authentication Patterns"));
441    }
442
443    #[test]
444    fn test_generate_toolkit_specific_section() {
445        let generator = create_test_generator();
446
447        let section = generator.generate_toolkit_specific_section("github");
448
449        assert!(section.is_ok());
450        let content = section.unwrap();
451        assert!(content.contains("## Toolkit-Specific Guidance: github"));
452    }
453}