raxit-core 0.1.1

Core security scanning engine for AI agent applications
Documentation
////! Swarm framework extractor
//!
//! Extracts agents and handoff functions from OpenAI's Swarm framework.
//!
//! Swarm is a lightweight multi-agent framework that uses simple Agent() instantiations
//! and handoff functions (functions that return other agents).

use super::ExtractedAssets;
use crate::error::Result;
use crate::schema::{Agent, SourceLocation, Tool};
use std::collections::{HashMap, HashSet};
use std::path::Path;

/// Extract assets from a Swarm file
pub fn extract(path: &Path) -> Result<ExtractedAssets> {
    // Read source code for pattern matching
    let _source_code = std::fs::read_to_string(path)?;

    // Parse the file using tree-sitter
    let nodes = crate::ast::parse_python_file(path)?;

    let mut assets = ExtractedAssets::default();
    let mut agent_vars = HashMap::new(); // var_name -> agent_id
    let mut handoff_functions = HashSet::new(); // Function names that return agents

    // First pass: Extract Agent() instantiations
    for node in &nodes {
        if node.kind == "assignment" && node.text.contains("Agent(") {
            if let Some((var_name, agent)) = extract_agent_from_assignment(node, path) {
                agent_vars.insert(var_name, agent.id.clone());
                assets.agents.push(agent);
            }
        }
    }

    // Second pass: Identify handoff functions (functions that return agents)
    for node in &nodes {
        if node.kind == "function_definition" {
            if let Some(func_name) = extract_function_name(&node.text) {
                // Check if function returns an agent variable
                if returns_agent_variable(&node.text, &agent_vars) {
                    handoff_functions.insert(func_name.clone());

                    // Create a tool entry for the handoff function
                    let tool = Tool {
                        id: format!("swarm_handoff_{}", node.start_line),
                        name: func_name,
                        location: SourceLocation {
                            file: path.to_string_lossy().to_string(),
                            line: node.start_line,
                            end_line: Some(node.end_line),
                            function: None,
                        },
                        description: extract_docstring(&node.text)
                            .or_else(|| Some("Agent handoff function".to_string())),
                        parameters: None,
                        requires_context: false,
                        tool_type: "handoff".to_string(),
                        data_flows: Vec::new(),
                    };
                    assets.tools.push(tool);
                }
            }
        }
    }

    // Third pass: Extract regular tool functions (not handoffs, not class methods)
    for node in &nodes {
        if node.kind == "function_definition" {
            if let Some(func_name) = extract_function_name(&node.text) {
                // Skip if it's a handoff function
                if !handoff_functions.contains(&func_name) {
                    // Check if function has docstring and parameters (likely a tool)
                    if has_docstring(&node.text) || has_parameters(&node.text) {
                        let tool = Tool {
                            id: format!("swarm_tool_{}", node.start_line),
                            name: func_name,
                            location: SourceLocation {
                                file: path.to_string_lossy().to_string(),
                                line: node.start_line,
                                end_line: Some(node.end_line),
                                function: None,
                            },
                            description: extract_docstring(&node.text),
                            parameters: extract_function_parameters(&node.text),
                            requires_context: false,
                            tool_type: "function".to_string(),
                            data_flows: Vec::new(),
                        };
                        assets.tools.push(tool);
                    }
                }
            }
        }
    }

    Ok(assets)
}

fn extract_agent_from_assignment(
    node: &crate::ast::AstNode,
    path: &Path,
) -> Option<(String, Agent)> {
    // Extract variable name
    let var_name = extract_variable_name(&node.text)?;

    let id = format!("swarm_agent_{}", node.start_line);

    // Extract name parameter
    let name = extract_parameter_value(&node.text, "name").unwrap_or_else(|| var_name.clone());

    // Extract instructions as system prompt
    let system_prompt = extract_parameter_value(&node.text, "instructions");

    // Extract functions list as tool_ids
    let tool_ids = extract_functions_list(&node.text);

    let agent = Agent {
        id,
        name,
        location: SourceLocation {
            file: path.to_string_lossy().to_string(),
            line: node.start_line,
            end_line: Some(node.end_line),
            function: None,
        },
        model_id: extract_model(&node.text),
        tool_ids,
        memory_id: None,
        system_prompt,
        result_type: Some("Swarm::Agent".to_string()),
        deps_type: None,
    };

    Some((var_name, agent))
}

fn extract_variable_name(text: &str) -> Option<String> {
    // Extract variable name from assignment: var_name = Agent(...)
    if let Some(eq_idx) = text.find(" = ") {
        let before = &text[..eq_idx].trim();
        // Take the last word (variable name)
        if let Some(var_name) = before.split_whitespace().last() {
            return Some(var_name.to_string());
        }
    }
    None
}

fn extract_parameter_value(text: &str, param_name: &str) -> Option<String> {
    // Look for param_name="value" or param_name='value'
    let pattern = format!("{param_name}=");
    if let Some(start) = text.find(&pattern) {
        let rest = &text[start + pattern.len()..];
        let rest = rest.trim_start();

        // Check for string literal
        if rest.starts_with('"') || rest.starts_with('\'') {
            let quote_char = rest.chars().next().unwrap();
            let content = &rest[1..];

            // Handle multi-line strings
            if content.starts_with(quote_char) && content.chars().nth(1) == Some(quote_char) {
                // Triple-quoted string
                let triple_quote = format!("{quote_char}{quote_char}{quote_char}");
                let after_open = &content[2..];
                if let Some(end) = after_open.find(&triple_quote) {
                    return Some(after_open[..end].trim().to_string());
                }
            } else {
                // Regular string
                if let Some(end) = content.find(quote_char) {
                    return Some(content[..end].to_string());
                }
            }
        }
    }
    None
}

fn extract_functions_list(text: &str) -> Vec<String> {
    // Look for functions=[func1, func2, ...]
    if let Some(start) = text.find("functions=") {
        let rest = &text[start + 10..];
        let rest = rest.trim_start();

        if let Some(content) = rest.strip_prefix('[') {
            if let Some(end) = content.find(']') {
                let funcs_str = &content[..end];
                return funcs_str
                    .split(',')
                    .map(|s| s.trim().to_string())
                    .filter(|s| !s.is_empty())
                    .collect();
            }
        }
    }
    Vec::new()
}

fn extract_model(text: &str) -> Option<String> {
    // Swarm typically uses default OpenAI models, but could be specified
    // Look for model= parameter
    extract_parameter_value(text, "model")
}

fn extract_function_name(text: &str) -> Option<String> {
    // Extract function name from definition: def function_name(...):
    if let Some(idx) = text.find("def ") {
        let rest = &text[idx + 4..];
        if let Some(end_idx) = rest.find('(') {
            return Some(rest[..end_idx].trim().to_string());
        }
    }
    None
}

fn returns_agent_variable(text: &str, agent_vars: &HashMap<String, String>) -> bool {
    // Check if function has a return statement that returns an agent variable
    for line in text.lines() {
        let trimmed = line.trim();
        if let Some(stripped) = trimmed.strip_prefix("return ") {
            let return_value = stripped.trim();
            // Check if the return value is an agent variable
            if agent_vars.contains_key(return_value) {
                return true;
            }
        }
    }
    false
}

fn extract_docstring(text: &str) -> Option<String> {
    // Extract docstring from function body
    // Look for triple-quoted strings
    for quote_type in &["\"\"\"", "'''"] {
        if let Some(start) = text.find(quote_type) {
            let after_start = &text[start + 3..];

            if let Some(end) = after_start.find(quote_type) {
                let docstring = after_start[..end].trim().to_string();
                if !docstring.is_empty() {
                    return Some(docstring);
                }
            }
        }
    }
    None
}

fn has_docstring(text: &str) -> bool {
    extract_docstring(text).is_some()
}

fn has_parameters(text: &str) -> bool {
    // Check if function has parameters (excluding 'self')
    if let Some(def_idx) = text.find("def ") {
        let rest = &text[def_idx..];
        if let Some(open_paren) = rest.find('(') {
            let after_paren = &rest[open_paren + 1..];
            if let Some(close_paren) = after_paren.find(')') {
                let params = after_paren[..close_paren].trim();
                // Has parameters if non-empty and not just 'self'
                return !params.is_empty() && params != "self";
            }
        }
    }
    false
}

fn extract_function_parameters(text: &str) -> Option<HashMap<String, String>> {
    // Extract parameters from function signature: def func(param1: type1, param2: type2):
    use std::collections::HashMap;

    if let Some(def_idx) = text.find("def ") {
        let rest = &text[def_idx..];
        if let Some(open_paren) = rest.find('(') {
            let after_paren = &rest[open_paren + 1..];
            if let Some(close_paren) = after_paren.find(')') {
                let params_str = after_paren[..close_paren].trim();
                if !params_str.is_empty() {
                    // Parse parameters into a HashMap
                    let mut params_map = HashMap::new();
                    for param in params_str.split(',') {
                        let param = param.trim();
                        if param == "self" {
                            continue;
                        }
                        if let Some(colon_idx) = param.find(':') {
                            let name = param[..colon_idx].trim().to_string();
                            let type_part = param[colon_idx + 1..].trim();
                            // Remove default values if present
                            let type_str = if let Some(eq_idx) = type_part.find('=') {
                                type_part[..eq_idx].trim()
                            } else {
                                type_part
                            };
                            params_map.insert(name, type_str.to_string());
                        }
                    }
                    if !params_map.is_empty() {
                        return Some(params_map);
                    }
                }
            }
        }
    }
    None
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_extract_variable_name() {
        let text = "sales_agent = Agent(name='Sales')";
        let var = extract_variable_name(text);
        assert_eq!(var, Some("sales_agent".to_string()));
    }

    #[test]
    fn test_extract_parameter_value() {
        let text = r#"Agent(name="Sales Agent", instructions="Help customers")"#;

        let name = extract_parameter_value(text, "name");
        assert_eq!(name, Some("Sales Agent".to_string()));

        let instructions = extract_parameter_value(text, "instructions");
        assert_eq!(instructions, Some("Help customers".to_string()));
    }

    #[test]
    fn test_extract_functions_list() {
        let text =
            "Agent(name='Sales', functions=[create_order, check_inventory], tool_choice='auto')";
        let funcs = extract_functions_list(text);
        assert_eq!(funcs, vec!["create_order", "check_inventory"]);
    }

    #[test]
    fn test_extract_function_name() {
        let text = "def transfer_to_sales():";
        let name = extract_function_name(text);
        assert_eq!(name, Some("transfer_to_sales".to_string()));
    }

    #[test]
    fn test_returns_agent_variable() {
        let mut agent_vars = HashMap::new();
        agent_vars.insert("sales_agent".to_string(), "agent_1".to_string());

        let text = "def transfer_to_sales():\n    return sales_agent";
        assert!(returns_agent_variable(text, &agent_vars));

        let text2 = "def other_func():\n    return 'hello'";
        assert!(!returns_agent_variable(text2, &agent_vars));
    }

    #[test]
    fn test_has_parameters() {
        assert!(has_parameters("def create_order(product_id: str):"));
        assert!(!has_parameters("def transfer_to_sales():"));
    }

    #[test]
    fn test_extract_function_parameters() {
        let text = "def create_order(product_id: str, quantity: int = 1) -> str:";
        let params = extract_function_parameters(text);

        assert!(params.is_some());
        let params = params.unwrap();
        assert_eq!(params.get("product_id"), Some(&"str".to_string()));
        assert_eq!(params.get("quantity"), Some(&"int".to_string()));
    }
}