raxit-core 0.1.2

Core security scanning engine for AI agent applications
Documentation
//! PydanticAI framework extractor

use super::ExtractedAssets;
use crate::error::Result;
use crate::schema::{Agent, SourceLocation, Tool};
use std::path::Path;

/// Extract assets from a PydanticAI file
pub fn extract(path: &Path) -> Result<ExtractedAssets> {
    // Read source code for pattern matching
    let source_code = std::fs::read_to_string(path)?;

    // Parse the file using tree-sitter
    let nodes = crate::ast::parse_python_file(path)?;

    let mut assets = ExtractedAssets::default();
    let mut agent_vars = std::collections::HashMap::new(); // agent_var -> agent_id
    let mut tool_to_agent = std::collections::HashMap::new(); // tool_id -> agent_id

    // Pre-scan: Find agent variable names from source code
    // Pattern: <varname> = Agent(...)
    let mut source_agent_vars = std::collections::HashMap::new(); // line -> var_name
    for (line_num, line) in source_code.lines().enumerate() {
        if let Some(var_name) = extract_agent_var_name(line) {
            source_agent_vars.insert(line_num + 1, var_name);
        }
    }

    // First pass: Extract agents and link to variable names
    for node in &nodes {
        if node.kind == "call" && node.text.contains("Agent(") {
            // Extract agent definition
            if let Some(agent) = extract_agent_from_node(node, path) {
                // Look up variable name from source code line mapping
                if let Some(var_name) = source_agent_vars.get(&(node.start_line as usize)) {
                    agent_vars.insert(var_name.clone(), agent.id.clone());
                }
                assets.agents.push(agent);
            }
        }
    }

    // Second pass: Extract tools and link to agents
    for node in &nodes {
        // Look for @agent.tool decorators
        if node.kind == "decorated_definition" && node.text.contains("@agent.tool") {
            if let Some(tool) = extract_tool_from_node(node, path) {
                // Extract which agent this tool belongs to from decorator
                if let Some(agent_var) = extract_agent_from_decorator(&node.text) {
                    if let Some(agent_id) = agent_vars.get(&agent_var) {
                        tool_to_agent.insert(tool.id.clone(), agent_id.clone());
                    }
                }
                assets.tools.push(tool);
            }
        }
    }

    // Third pass: Populate tool_ids in agents
    for agent in &mut assets.agents {
        let mut tool_ids = Vec::new();
        for (tool_id, agent_id) in &tool_to_agent {
            if agent_id == &agent.id {
                tool_ids.push(tool_id.clone());
            }
        }
        agent.tool_ids = tool_ids;
    }

    Ok(assets)
}

fn extract_agent_from_node(node: &crate::ast::AstNode, path: &Path) -> Option<Agent> {
    let id = format!("agent_{}", node.start_line);
    let name = extract_agent_name(&node.text).unwrap_or_else(|| "unnamed_agent".to_string());

    // Extract model from Agent() call
    let model_id = extract_model_from_agent(&node.text);

    // Extract system_prompt
    let system_prompt = extract_system_prompt(&node.text);

    // Extract deps_type
    let deps_type = extract_deps_type(&node.text);

    Some(Agent {
        id,
        name,
        location: SourceLocation {
            file: path.to_string_lossy().to_string(),
            line: node.start_line,
            end_line: Some(node.end_line),
            function: None,
        },
        model_id,
        tool_ids: Vec::new(),
        memory_id: None,
        system_prompt,
        result_type: None,
        deps_type,
    })
}

fn extract_tool_from_node(node: &crate::ast::AstNode, path: &Path) -> Option<Tool> {
    let id = format!("tool_{}", node.start_line);
    let name = extract_function_name(&node.text).unwrap_or_else(|| "unnamed_tool".to_string());

    // Extract description from docstring
    let description = extract_docstring(&node.text);

    Some(Tool {
        id,
        name,
        location: SourceLocation {
            file: path.to_string_lossy().to_string(),
            line: node.start_line,
            end_line: Some(node.end_line),
            function: None,
        },
        description,
        parameters: None,
        requires_context: node.text.contains("RunContext"),
        tool_type: if node.text.contains("@agent.tool_plain") {
            "plain".to_string()
        } else {
            "context".to_string()
        },
        data_flows: Vec::new(),
    })
}

fn extract_agent_name(text: &str) -> Option<String> {
    // Simple pattern matching - will be improved
    if let Some(idx) = text.find("=") {
        let var_part = &text[..idx].trim();
        return Some(var_part.to_string());
    }
    None
}

fn extract_function_name(text: &str) -> Option<String> {
    // Extract function name from definition
    if let Some(idx) = text.find("def ") {
        let rest = &text[idx + 4..];
        if let Some(end_idx) = rest.find("(") {
            return Some(rest[..end_idx].trim().to_string());
        }
    }
    None
}

fn extract_model_from_agent(text: &str) -> Option<String> {
    // Look for Agent('model-name', ...) or Agent(model_var, ...)
    if let Some(start) = text.find("Agent(") {
        let rest = &text[start + 6..];
        if let Some(end) = rest.find(',') {
            let model_part = rest[..end].trim();
            // Remove quotes if present
            let model = model_part.trim_matches('\'').trim_matches('"');
            if !model.is_empty() {
                return Some(model.to_string());
            }
        }
    }
    None
}

fn extract_system_prompt(text: &str) -> Option<String> {
    // Look for system_prompt='...' or system_prompt="..."
    if let Some(start) = text.find("system_prompt=") {
        let rest = &text[start + 14..];
        // Find the quote character
        let quote = if rest.starts_with('\'') {
            '\''
        } else if rest.starts_with('"') {
            '"'
        } else {
            return None;
        };
        let content = &rest[1..];
        if let Some(end) = content.find(quote) {
            return Some(content[..end].to_string());
        }
    }
    None
}

fn extract_deps_type(text: &str) -> Option<String> {
    // Look for deps_type=SomeClass
    if let Some(start) = text.find("deps_type=") {
        let rest = &text[start + 10..];
        // Extract until comma or closing paren
        let end_chars = [',', ')'];
        let end = rest.find(&end_chars[..]).unwrap_or(rest.len());
        let deps = rest[..end].trim();
        if !deps.is_empty() {
            return Some(deps.to_string());
        }
    }
    None
}

fn extract_docstring(text: &str) -> Option<String> {
    // Look for triple-quoted docstring
    for quote in &["\"\"\"", "'''"] {
        if let Some(start) = text.find(quote) {
            let content = &text[start + 3..];
            if let Some(end) = content.find(quote) {
                let docstring = content[..end].trim();
                if !docstring.is_empty() {
                    return Some(docstring.to_string());
                }
            }
        }
    }
    None
}

/// Extract agent variable name from assignment like "agent = Agent(...)"
fn extract_agent_var_name(text: &str) -> Option<String> {
    // Look for pattern: <varname> = Agent(
    for line in text.lines() {
        if let Some(eq_idx) = line.find(" = Agent(") {
            // Get everything before " = Agent("
            let before = &line[..eq_idx].trim();
            // Take the last word (variable name)
            if let Some(var_name) = before.split_whitespace().last() {
                return Some(var_name.to_string());
            }
        }
    }
    None
}

/// Extract agent variable name from decorator like "@agent.tool"
fn extract_agent_from_decorator(text: &str) -> Option<String> {
    // Look for pattern: @<varname>.tool or @<varname>.tool_plain
    for line in text.lines() {
        let trimmed = line.trim();
        if trimmed.starts_with('@')
            && (trimmed.contains(".tool") || trimmed.contains(".tool_plain"))
        {
            // Extract the part between @ and .tool
            if let Some(start) = trimmed.find('@') {
                if let Some(dot_idx) = trimmed.find('.') {
                    let var_name = &trimmed[start + 1..dot_idx].trim();
                    if !var_name.is_empty() {
                        return Some(var_name.to_string());
                    }
                }
            }
        }
    }
    None
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_extract_function_name() {
        let text = "def my_tool(ctx: RunContext):";
        let name = extract_function_name(text);
        assert_eq!(name, Some("my_tool".to_string()));
    }

    #[test]
    fn test_extract_model_from_agent() {
        let text = "agent = Agent('openai:gpt-4', deps_type=AgentDeps)";
        let model = extract_model_from_agent(text);
        assert_eq!(model, Some("openai:gpt-4".to_string()));
    }

    #[test]
    fn test_extract_system_prompt() {
        let text = "Agent('gpt-4', system_prompt='You are helpful')";
        let prompt = extract_system_prompt(text);
        assert_eq!(prompt, Some("You are helpful".to_string()));
    }

    #[test]
    fn test_extract_deps_type() {
        let text = "Agent('gpt-4', deps_type=AgentDependencies)";
        let deps = extract_deps_type(text);
        assert_eq!(deps, Some("AgentDependencies".to_string()));
    }

    #[test]
    fn test_extract_docstring() {
        let text = r#"def my_func():
    """This is a docstring."""
    pass"#;
        let docstring = extract_docstring(text);
        assert_eq!(docstring, Some("This is a docstring.".to_string()));
    }

    #[test]
    fn test_extract_agent_var_name() {
        let text = "agent = Agent('openai:gpt-4')";
        let var_name = extract_agent_var_name(text);
        assert_eq!(var_name, Some("agent".to_string()));

        let text2 = "my_agent = Agent('openai:gpt-4', system_prompt='hello')";
        let var_name2 = extract_agent_var_name(text2);
        assert_eq!(var_name2, Some("my_agent".to_string()));
    }

    #[test]
    fn test_extract_agent_from_decorator() {
        let text = "@agent.tool\nasync def search_web(ctx):";
        let agent_var = extract_agent_from_decorator(text);
        assert_eq!(agent_var, Some("agent".to_string()));

        let text2 = "@my_agent.tool_plain\ndef format_text(text: str):";
        let agent_var2 = extract_agent_from_decorator(text2);
        assert_eq!(agent_var2, Some("my_agent".to_string()));
    }
}