Skip to main content

agentshield/adapter/
langchain.rs

1use std::path::{Path, PathBuf};
2
3use crate::analysis::cross_file::apply_cross_file_sanitization;
4use crate::error::Result;
5use crate::ir::taint_builder::build_data_surface;
6use crate::ir::*;
7use crate::parser;
8
9/// LangChain framework adapter.
10///
11/// Detects LangChain projects by looking for:
12/// - `pyproject.toml` with `langchain` dependency
13/// - `requirements.txt` containing `langchain` or `langgraph`
14/// - `langgraph.json` configuration file
15/// - Python files importing `from langchain` / `from langchain_core` / `from langgraph`
16pub struct LangChainAdapter;
17
18impl super::Adapter for LangChainAdapter {
19    fn framework(&self) -> Framework {
20        Framework::LangChain
21    }
22
23    fn detect(&self, root: &Path) -> bool {
24        // Check pyproject.toml for langchain dependency
25        let pyproject = root.join("pyproject.toml");
26        if pyproject.exists() {
27            if let Ok(content) = std::fs::read_to_string(&pyproject) {
28                if content.contains("langchain") || content.contains("langgraph") {
29                    return true;
30                }
31            }
32        }
33
34        // Check requirements.txt for langchain/langgraph
35        let requirements = root.join("requirements.txt");
36        if requirements.exists() {
37            if let Ok(content) = std::fs::read_to_string(&requirements) {
38                if content.lines().any(|l| {
39                    let trimmed = l.trim();
40                    trimmed.starts_with("langchain") || trimmed.starts_with("langgraph")
41                }) {
42                    return true;
43                }
44            }
45        }
46
47        // Check for langgraph.json configuration file
48        if root.join("langgraph.json").exists() {
49            return true;
50        }
51
52        // Check Python files for langchain imports (top-level only)
53        if let Ok(entries) = std::fs::read_dir(root) {
54            for entry in entries.flatten() {
55                let path = entry.path();
56                if path.extension().is_some_and(|e| e == "py") {
57                    if let Ok(content) = std::fs::read_to_string(&path) {
58                        if content.contains("from langchain")
59                            || content.contains("import langchain")
60                            || content.contains("from langgraph")
61                            || content.contains("import langgraph")
62                        {
63                            return true;
64                        }
65                    }
66                }
67            }
68        }
69
70        // Also check src/ directory (common LangChain layout)
71        let src_dir = root.join("src");
72        if src_dir.is_dir() {
73            if let Ok(entries) = std::fs::read_dir(&src_dir) {
74                for entry in entries.flatten() {
75                    let path = entry.path();
76                    if path.extension().is_some_and(|e| e == "py") {
77                        if let Ok(content) = std::fs::read_to_string(&path) {
78                            if content.contains("from langchain")
79                                || content.contains("import langchain")
80                                || content.contains("from langgraph")
81                                || content.contains("import langgraph")
82                            {
83                                return true;
84                            }
85                        }
86                    }
87                }
88            }
89        }
90
91        false
92    }
93
94    fn load(&self, root: &Path, ignore_tests: bool) -> Result<Vec<ScanTarget>> {
95        let name = root
96            .file_name()
97            .map(|n| n.to_string_lossy().to_string())
98            .unwrap_or_else(|| "langchain-project".into());
99
100        let mut source_files = Vec::new();
101        let mut execution = execution_surface::ExecutionSurface::default();
102
103        // Phase 0: Collect source files (reuses MCP adapter's walker)
104        super::mcp::collect_source_files(root, ignore_tests, &mut source_files)?;
105
106        // Filter to Python-only (LangChain is a Python framework)
107        source_files.retain(|sf| matches!(sf.language, Language::Python));
108
109        // Phase 1: Parse each Python file
110        let mut parsed_files: Vec<(PathBuf, parser::ParsedFile)> = Vec::new();
111        for sf in &source_files {
112            if let Some(parser) = parser::parser_for_language(sf.language) {
113                if let Ok(parsed) = parser.parse_file(&sf.path, &sf.content) {
114                    parsed_files.push((sf.path.clone(), parsed));
115                }
116            }
117        }
118
119        // Phase 2: Cross-file sanitizer-aware analysis
120        apply_cross_file_sanitization(&mut parsed_files);
121
122        // Phase 3: Merge into execution surface
123        for (_, parsed) in parsed_files {
124            execution.commands.extend(parsed.commands);
125            execution.file_operations.extend(parsed.file_operations);
126            execution
127                .network_operations
128                .extend(parsed.network_operations);
129            execution.env_accesses.extend(parsed.env_accesses);
130            execution.dynamic_exec.extend(parsed.dynamic_exec);
131        }
132
133        // Parse dependencies from pyproject.toml / requirements.txt
134        let dependencies = super::mcp::parse_dependencies(root);
135
136        // Parse provenance from pyproject.toml
137        let provenance = super::mcp::parse_provenance(root);
138
139        let tools = vec![];
140        let data = build_data_surface(&tools, &execution);
141
142        Ok(vec![ScanTarget {
143            name,
144            framework: Framework::LangChain,
145            root_path: root.to_path_buf(),
146            tools,
147            execution,
148            data,
149            dependencies,
150            provenance,
151            source_files,
152        }])
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::adapter::Adapter;
160
161    #[test]
162    fn test_detect_langchain_via_pyproject() {
163        let dir =
164            PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/langchain_project");
165        let adapter = LangChainAdapter;
166        assert!(adapter.detect(&dir));
167    }
168
169    #[test]
170    fn test_detect_langchain_via_langgraph_json() {
171        let dir =
172            PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/langchain_project");
173        let adapter = LangChainAdapter;
174        // The fixture has pyproject.toml, but langgraph.json also triggers detection
175        assert!(adapter.detect(&dir));
176    }
177
178    #[test]
179    fn test_detect_non_langchain_project() {
180        let dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
181            .join("tests/fixtures/mcp_servers/safe_calculator");
182        let adapter = LangChainAdapter;
183        assert!(!adapter.detect(&dir));
184    }
185
186    #[test]
187    fn test_load_langchain_finds_cmd_injection() {
188        let dir =
189            PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/langchain_project");
190        let adapter = LangChainAdapter;
191        let targets = adapter.load(&dir, false).unwrap();
192        assert_eq!(targets.len(), 1);
193
194        let target = &targets[0];
195        assert_eq!(target.framework, Framework::LangChain);
196        assert_eq!(target.name, "langchain_project");
197
198        // Should find command injection in shell_tool.py
199        assert!(
200            !target.execution.commands.is_empty(),
201            "expected command execution findings from shell_tool.py"
202        );
203        // Should find tainted command args
204        assert!(
205            target
206                .execution
207                .commands
208                .iter()
209                .any(|c| c.command_arg.is_tainted()),
210            "expected tainted command source from subprocess.run with user input"
211        );
212    }
213
214    #[test]
215    fn test_load_langchain_finds_ssrf() {
216        let dir =
217            PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/langchain_project");
218        let adapter = LangChainAdapter;
219        let targets = adapter.load(&dir, false).unwrap();
220        let target = &targets[0];
221
222        // Should find network operations in fetch_tool.py
223        assert!(
224            !target.execution.network_operations.is_empty(),
225            "expected network operation findings from fetch_tool.py"
226        );
227    }
228
229    #[test]
230    fn test_load_langchain_only_python_files() {
231        let dir =
232            PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/langchain_project");
233        let adapter = LangChainAdapter;
234        let targets = adapter.load(&dir, false).unwrap();
235        let target = &targets[0];
236
237        // All source files should be Python
238        for sf in &target.source_files {
239            assert_eq!(
240                sf.language,
241                Language::Python,
242                "non-Python file found: {:?}",
243                sf.path
244            );
245        }
246    }
247}