raxit-core 0.1.2

Core security scanning engine for AI agent applications
Documentation
////! AutoGen framework extractor
//!
//! Extracts agents and tools from Microsoft's AutoGen framework.
//!
//! AutoGen uses class-based agents (AssistantAgent, UserProxyAgent, ConversableAgent)
//! and @register_function decorators for tools.

use super::ExtractedAssets;
use crate::error::Result;
use crate::schema::{Agent, SourceLocation, Tool};
use std::path::Path;

/// Extract assets from an AutoGen file
pub fn extract(path: &Path) -> Result<ExtractedAssets> {
    // Read source code for pattern matching
    let _source_code = std::fs::read_to_string(path)?;

    // Parse the file using tree-sitter
    let nodes = crate::ast::parse_python_file(path)?;

    let mut assets = ExtractedAssets::default();

    // First pass: Extract agents (AssistantAgent, UserProxyAgent, ConversableAgent)
    for node in &nodes {
        if node.kind == "assignment" && contains_agent_pattern(&node.text) {
            if let Some(agent) = extract_agent_from_assignment(node, path) {
                assets.agents.push(agent);
            }
        }
    }

    // Second pass: Extract tools from @register_function decorators
    for node in &nodes {
        if node.kind == "decorated_definition" && node.text.contains("@register_function") {
            if let Some(tool) = extract_tool_from_decorated_function(node, path) {
                assets.tools.push(tool);
            }
        }
    }

    Ok(assets)
}

fn extract_agent_from_assignment(node: &crate::ast::AstNode, path: &Path) -> Option<Agent> {
    // Extract variable name (agent variable)
    let var_name = extract_variable_name(&node.text)?;

    // Determine agent type
    let agent_type = if node.text.contains("AssistantAgent") {
        "AssistantAgent"
    } else if node.text.contains("UserProxyAgent") {
        "UserProxyAgent"
    } else if node.text.contains("ConversableAgent") {
        "ConversableAgent"
    } else {
        return None;
    };

    let id = format!("autogen_agent_{}", node.start_line);

    // Extract name parameter (should match var_name or be explicitly set)
    let name = extract_parameter_value(&node.text, "name").unwrap_or_else(|| var_name.clone());

    // Extract system_message as system prompt
    let system_prompt = extract_parameter_value(&node.text, "system_message");

    // Extract model from llm_config
    let model_id = extract_model_from_llm_config(&node.text);

    // Extract tools list
    let tool_ids = extract_tools_list(&node.text);

    Some(Agent {
        id,
        name,
        location: SourceLocation {
            file: path.to_string_lossy().to_string(),
            line: node.start_line,
            end_line: Some(node.end_line),
            function: None,
        },
        model_id,
        tool_ids,
        memory_id: None,
        system_prompt,
        result_type: Some(format!("AutoGen::{agent_type}")),
        deps_type: None,
    })
}

fn extract_tool_from_decorated_function(node: &crate::ast::AstNode, path: &Path) -> Option<Tool> {
    let id = format!("autogen_tool_{}", node.start_line);

    // Extract function name
    let name = extract_function_name(&node.text)?;

    // Extract docstring as description
    let description = extract_docstring(&node.text);

    // Extract parameters from function signature
    let parameters = extract_function_parameters(&node.text);

    Some(Tool {
        id,
        name,
        location: SourceLocation {
            file: path.to_string_lossy().to_string(),
            line: node.start_line,
            end_line: Some(node.end_line),
            function: None,
        },
        description,
        parameters,
        requires_context: false,
        tool_type: "autogen_function".to_string(),
        data_flows: Vec::new(),
    })
}

fn contains_agent_pattern(text: &str) -> bool {
    text.contains("AssistantAgent(")
        || text.contains("UserProxyAgent(")
        || text.contains("ConversableAgent(")
}

fn extract_variable_name(text: &str) -> Option<String> {
    // Extract variable name from assignment: var_name = Class(...)
    if let Some(eq_idx) = text.find(" = ") {
        let before = &text[..eq_idx].trim();
        // Take the last word (variable name)
        if let Some(var_name) = before.split_whitespace().last() {
            return Some(var_name.to_string());
        }
    }
    None
}

fn extract_parameter_value(text: &str, param_name: &str) -> Option<String> {
    // Look for param_name="value" or param_name='value'
    let pattern = format!("{param_name}=");
    if let Some(start) = text.find(&pattern) {
        let rest = &text[start + pattern.len()..];
        let rest = rest.trim_start();

        // Check for string literal
        if rest.starts_with('"') || rest.starts_with('\'') {
            let quote_char = rest.chars().next().unwrap();
            let content = &rest[1..];

            // Handle multi-line strings
            if content.starts_with(quote_char) && content.chars().nth(1) == Some(quote_char) {
                // Triple-quoted string
                let triple_quote = format!("{quote_char}{quote_char}{quote_char}");
                let after_open = &content[2..];
                if let Some(end) = after_open.find(&triple_quote) {
                    return Some(after_open[..end].trim().to_string());
                }
            } else {
                // Regular string
                if let Some(end) = content.find(quote_char) {
                    return Some(content[..end].to_string());
                }
            }
        }
    }
    None
}

fn extract_model_from_llm_config(text: &str) -> Option<String> {
    // Look for llm_config={ ... "model": "gpt-4" ... }
    if let Some(llm_config_start) = text.find("llm_config=") {
        let rest = &text[llm_config_start..];

        // Look for "model": "value" or 'model': 'value'
        if let Some(model_start) = rest.find("\"model\"").or_else(|| rest.find("'model'")) {
            let after_model = &rest[model_start..];

            // Find the value after the colon
            if let Some(colon_idx) = after_model.find(':') {
                let after_colon = &after_model[colon_idx + 1..].trim_start();

                // Extract the quoted value
                if after_colon.starts_with('"') || after_colon.starts_with('\'') {
                    let quote_char = after_colon.chars().next().unwrap();
                    let content = &after_colon[1..];
                    if let Some(end) = content.find(quote_char) {
                        return Some(content[..end].to_string());
                    }
                }
            }
        }
    }
    None
}

fn extract_tools_list(text: &str) -> Vec<String> {
    // Look for tools=[tool1, tool2, ...]
    if let Some(start) = text.find("tools=") {
        let rest = &text[start + 6..];
        let rest = rest.trim_start();

        if let Some(content) = rest.strip_prefix('[') {
            if let Some(end) = content.find(']') {
                let tools_str = &content[..end];
                return tools_str
                    .split(',')
                    .map(|s| s.trim().to_string())
                    .filter(|s| !s.is_empty())
                    .collect();
            }
        }
    }
    Vec::new()
}

fn extract_function_name(text: &str) -> Option<String> {
    // Extract function name from definition: def function_name(...):
    if let Some(idx) = text.find("def ") {
        let rest = &text[idx + 4..];
        if let Some(end_idx) = rest.find('(') {
            return Some(rest[..end_idx].trim().to_string());
        }
    }
    None
}

fn extract_docstring(text: &str) -> Option<String> {
    // Extract docstring from function body
    // Look for triple-quoted strings
    for quote_type in &["\"\"\"", "'''"] {
        if let Some(start) = text.find(quote_type) {
            let after_start = &text[start + 3..];

            if let Some(end) = after_start.find(quote_type) {
                let docstring = after_start[..end].trim().to_string();
                if !docstring.is_empty() {
                    return Some(docstring);
                }
            }
        }
    }
    None
}

fn extract_function_parameters(text: &str) -> Option<std::collections::HashMap<String, String>> {
    // Extract parameters from function signature: def func(param1: type1, param2: type2):
    use std::collections::HashMap;

    if let Some(def_idx) = text.find("def ") {
        let rest = &text[def_idx..];
        if let Some(open_paren) = rest.find('(') {
            let after_paren = &rest[open_paren + 1..];
            if let Some(close_paren) = after_paren.find(')') {
                let params_str = after_paren[..close_paren].trim();
                if !params_str.is_empty() {
                    // Parse parameters into a HashMap
                    let mut params_map = HashMap::new();
                    for param in params_str.split(',') {
                        let param = param.trim();
                        if let Some(colon_idx) = param.find(':') {
                            let name = param[..colon_idx].trim().to_string();
                            let type_part = param[colon_idx + 1..].trim();
                            // Remove default values if present
                            let type_str = if let Some(eq_idx) = type_part.find('=') {
                                type_part[..eq_idx].trim()
                            } else {
                                type_part
                            };
                            params_map.insert(name, type_str.to_string());
                        }
                    }
                    if !params_map.is_empty() {
                        return Some(params_map);
                    }
                }
            }
        }
    }
    None
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_contains_agent_pattern() {
        assert!(contains_agent_pattern(
            "agent = AssistantAgent(name='test')"
        ));
        assert!(contains_agent_pattern(
            "proxy = UserProxyAgent(name='user')"
        ));
        assert!(contains_agent_pattern(
            "conv = ConversableAgent(name='conv')"
        ));
        assert!(!contains_agent_pattern("other = SomeClass()"));
    }

    #[test]
    fn test_extract_variable_name() {
        let text = "coding_assistant = AssistantAgent(name='test')";
        let var = extract_variable_name(text);
        assert_eq!(var, Some("coding_assistant".to_string()));
    }

    #[test]
    fn test_extract_parameter_value() {
        let text = r#"Agent(name="coding_assistant", system_message="You are helpful")"#;

        let name = extract_parameter_value(text, "name");
        assert_eq!(name, Some("coding_assistant".to_string()));

        let msg = extract_parameter_value(text, "system_message");
        assert_eq!(msg, Some("You are helpful".to_string()));
    }

    #[test]
    fn test_extract_model_from_llm_config() {
        let text =
            r#"AssistantAgent(name="test", llm_config={"model": "gpt-4", "temperature": 0.7})"#;
        let model = extract_model_from_llm_config(text);
        assert_eq!(model, Some("gpt-4".to_string()));
    }

    #[test]
    fn test_extract_tools_list() {
        let text = "Agent(name='test', tools=[execute_code, run_tests], verbose=True)";
        let tools = extract_tools_list(text);
        assert_eq!(tools, vec!["execute_code", "run_tests"]);
    }

    #[test]
    fn test_extract_function_name() {
        let text = "def execute_code(code: str) -> str:";
        let name = extract_function_name(text);
        assert_eq!(name, Some("execute_code".to_string()));
    }

    #[test]
    fn test_extract_docstring() {
        let text = r#"
def my_function():
    """This is a docstring."""
    pass
"#;
        let docstring = extract_docstring(text);
        assert_eq!(docstring, Some("This is a docstring.".to_string()));
    }

    #[test]
    fn test_extract_function_parameters() {
        let text = "def execute_code(code: str, timeout: int = 30) -> str:";
        let params = extract_function_parameters(text);

        assert!(params.is_some());
        let params = params.unwrap();
        assert_eq!(params.get("code"), Some(&"str".to_string()));
        assert_eq!(params.get("timeout"), Some(&"int".to_string()));
    }
}