raxit-core 0.1.1

Core security scanning engine for AI agent applications
Documentation
// LangGraph framework extractor

use super::ExtractedAssets;
use crate::error::Result;
use crate::schema::{Agent, SourceLocation, Tool};
use std::path::Path;

/// Extract assets from a LangGraph file
pub fn extract(path: &Path) -> Result<ExtractedAssets> {
    // Read source code for pattern matching
    let _source_code = std::fs::read_to_string(path)?;

    // Parse the file using tree-sitter
    let nodes = crate::ast::parse_python_file(path)?;

    let mut assets = ExtractedAssets::default();
    let mut stategraph_nodes = std::collections::HashMap::new(); // node_name -> line_num

    // First pass: Extract StateGraph nodes
    for node in &nodes {
        // Look for workflow.add_node("name", function)
        if node.kind == "call" && node.text.contains("add_node") {
            if let Some((node_name, function_name)) = extract_add_node_call(&node.text) {
                stategraph_nodes.insert(node_name.clone(), node.start_line);

                // Create an agent for each StateGraph node
                let agent = Agent {
                    id: format!("node_{}", node.start_line),
                    name: format!("StateGraph node: {node_name}"),
                    location: SourceLocation {
                        file: path.to_string_lossy().to_string(),
                        line: node.start_line,
                        end_line: Some(node.end_line),
                        function: Some(function_name.clone()),
                    },
                    model_id: None,
                    tool_ids: Vec::new(),
                    memory_id: None,
                    system_prompt: None,
                    result_type: None,
                    deps_type: Some("AgentState".to_string()),
                };
                assets.agents.push(agent);
            }
        }
    }

    // Second pass: Extract tools (functions used in nodes)
    for node in &nodes {
        if node.kind == "function_definition" {
            // Check if this function is used as a node handler
            if let Some(func_name) = extract_function_name(&node.text) {
                // Look for state parameter type hints
                if node.text.contains("AgentState") || node.text.contains("state:") {
                    let tool = Tool {
                        id: format!("tool_{}", node.start_line),
                        name: func_name,
                        location: SourceLocation {
                            file: path.to_string_lossy().to_string(),
                            line: node.start_line,
                            end_line: Some(node.end_line),
                            function: None,
                        },
                        description: extract_docstring(&node.text),
                        parameters: None,
                        requires_context: true,
                        tool_type: "state_handler".to_string(),
                        data_flows: Vec::new(),
                    };
                    assets.tools.push(tool);
                }
            }
        }
    }

    // Third pass: Extract edges and conditional edges
    // This would be added to a workflow/graph structure in future versions
    for node in &nodes {
        if node.kind == "call"
            && (node.text.contains("add_edge") || node.text.contains("add_conditional_edges"))
        {
            // TODO: Extract edge information for workflow graph
            // For now, we capture this as part of the StateGraph agent
        }
    }

    Ok(assets)
}

fn extract_add_node_call(text: &str) -> Option<(String, String)> {
    // Pattern: workflow.add_node("name", function_name)
    // Extract both the node name and function name

    if let Some(start) = text.find("add_node(") {
        let rest = &text[start + 9..]; // Skip "add_node("

        // Extract node name (first argument, usually a string)
        let node_name = if let Some(quote_start) = rest.find('"').or_else(|| rest.find('\'')) {
            let quote_char = rest.chars().nth(quote_start).unwrap();
            let after_quote = &rest[quote_start + 1..];
            after_quote
                .find(quote_char)
                .map(|end| after_quote[..end].to_string())
        } else {
            None
        };

        // Extract function name (second argument)
        let function_name = if let Some(comma_pos) = rest.find(',') {
            let after_comma = &rest[comma_pos + 1..].trim();
            after_comma
                .find(')')
                .or_else(|| after_comma.find(','))
                .map(|end| after_comma[..end].trim().to_string())
        } else {
            None
        };

        if let (Some(nn), Some(fn_name)) = (node_name, function_name) {
            return Some((nn, fn_name));
        }
    }

    None
}

fn extract_function_name(text: &str) -> Option<String> {
    // Extract function name from definition
    if let Some(idx) = text.find("def ") {
        let rest = &text[idx + 4..];
        if let Some(end_idx) = rest.find('(') {
            return Some(rest[..end_idx].trim().to_string());
        }
    }
    None
}

fn extract_docstring(text: &str) -> Option<String> {
    // Extract docstring from function body
    // Look for triple-quoted strings
    if let Some(start) = text.find("\"\"\"").or_else(|| text.find("'''")) {
        let quote_type = if text[start..].starts_with("\"\"\"") {
            "\"\"\""
        } else {
            "'''"
        };
        let after_start = &text[start + 3..];

        if let Some(end) = after_start.find(quote_type) {
            let docstring = after_start[..end].trim().to_string();
            return Some(docstring);
        }
    }
    None
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_extract_add_node_call() {
        let text = r#"workflow.add_node("input", input_node)"#;
        let result = extract_add_node_call(text);
        assert_eq!(
            result,
            Some(("input".to_string(), "input_node".to_string()))
        );
    }

    #[test]
    fn test_extract_function_name() {
        let text = "def input_node(state: AgentState) -> AgentState:";
        let result = extract_function_name(text);
        assert_eq!(result, Some("input_node".to_string()));
    }

    #[test]
    fn test_extract_docstring() {
        let text = r#"
def my_function():
    """This is a docstring."""
    pass
"#;
        let result = extract_docstring(text);
        assert_eq!(result, Some("This is a docstring.".to_string()));
    }
}