Skip to main content

agentshield/adapter/
langchain.rs

1use std::path::{Path, PathBuf};
2
3use crate::analysis::cross_file::apply_cross_file_sanitization;
4use crate::config::ScanPathFilter;
5use crate::error::Result;
6use crate::ir::taint_builder::build_data_surface;
7use crate::ir::*;
8use crate::parser;
9
10/// LangChain framework adapter.
11///
12/// Detects LangChain projects by looking for:
13/// - `pyproject.toml` with `langchain` dependency
14/// - `requirements.txt` containing `langchain` or `langgraph`
15/// - `langgraph.json` configuration file
16/// - Python files importing `from langchain` / `from langchain_core` / `from langgraph`
17pub struct LangChainAdapter;
18
19impl super::Adapter for LangChainAdapter {
20    fn framework(&self) -> Framework {
21        Framework::LangChain
22    }
23
24    fn detect(&self, root: &Path) -> bool {
25        // Check pyproject.toml for langchain dependency
26        let pyproject = root.join("pyproject.toml");
27        if pyproject.exists() {
28            if let Ok(content) = std::fs::read_to_string(&pyproject) {
29                if content.contains("langchain") || content.contains("langgraph") {
30                    return true;
31                }
32            }
33        }
34
35        // Check requirements.txt for langchain/langgraph
36        let requirements = root.join("requirements.txt");
37        if requirements.exists() {
38            if let Ok(content) = std::fs::read_to_string(&requirements) {
39                if content.lines().any(|l| {
40                    let trimmed = l.trim();
41                    trimmed.starts_with("langchain") || trimmed.starts_with("langgraph")
42                }) {
43                    return true;
44                }
45            }
46        }
47
48        // Check for langgraph.json configuration file
49        if root.join("langgraph.json").exists() {
50            return true;
51        }
52
53        // Check Python files for langchain imports (top-level only)
54        if let Ok(entries) = std::fs::read_dir(root) {
55            for entry in entries.flatten() {
56                let path = entry.path();
57                if path.extension().is_some_and(|e| e == "py") {
58                    if let Ok(content) = std::fs::read_to_string(&path) {
59                        if content.contains("from langchain")
60                            || content.contains("import langchain")
61                            || content.contains("from langgraph")
62                            || content.contains("import langgraph")
63                        {
64                            return true;
65                        }
66                    }
67                }
68            }
69        }
70
71        // Also check src/ directory (common LangChain layout)
72        let src_dir = root.join("src");
73        if src_dir.is_dir() {
74            if let Ok(entries) = std::fs::read_dir(&src_dir) {
75                for entry in entries.flatten() {
76                    let path = entry.path();
77                    if path.extension().is_some_and(|e| e == "py") {
78                        if let Ok(content) = std::fs::read_to_string(&path) {
79                            if content.contains("from langchain")
80                                || content.contains("import langchain")
81                                || content.contains("from langgraph")
82                                || content.contains("import langgraph")
83                            {
84                                return true;
85                            }
86                        }
87                    }
88                }
89            }
90        }
91
92        false
93    }
94
95    fn load(&self, root: &Path, ignore_tests: bool) -> Result<Vec<ScanTarget>> {
96        let filter = ScanPathFilter::for_ignore_tests(ignore_tests);
97        self.load_with_filter(root, &filter)
98    }
99
100    fn load_with_filter(&self, root: &Path, filter: &ScanPathFilter) -> Result<Vec<ScanTarget>> {
101        let name = root
102            .file_name()
103            .map(|n| n.to_string_lossy().to_string())
104            .unwrap_or_else(|| "langchain-project".into());
105
106        let mut source_files = Vec::new();
107        let mut execution = execution_surface::ExecutionSurface::default();
108
109        // Phase 0: Collect source files (reuses MCP adapter's walker)
110        super::mcp::collect_source_files_with_filter(root, filter, &mut source_files)?;
111
112        // Filter to Python-only (LangChain is a Python framework)
113        source_files.retain(|sf| matches!(sf.language, Language::Python));
114
115        // Phase 1: Parse each Python file
116        let mut parsed_files: Vec<(PathBuf, parser::ParsedFile)> = Vec::new();
117        for sf in &source_files {
118            if let Some(parser) = parser::parser_for_language(sf.language) {
119                if let Ok(parsed) = parser.parse_file(&sf.path, &sf.content) {
120                    parsed_files.push((sf.path.clone(), parsed));
121                }
122            }
123        }
124
125        // Phase 2: Cross-file sanitizer-aware analysis
126        apply_cross_file_sanitization(&mut parsed_files);
127
128        // Phase 3: Merge into execution surface
129        for (_, parsed) in parsed_files {
130            execution.commands.extend(parsed.commands);
131            execution.file_operations.extend(parsed.file_operations);
132            execution
133                .network_operations
134                .extend(parsed.network_operations);
135            execution.env_accesses.extend(parsed.env_accesses);
136            execution.dynamic_exec.extend(parsed.dynamic_exec);
137        }
138
139        // Parse dependencies from pyproject.toml / requirements.txt
140        let dependencies = super::mcp::parse_dependencies(root, filter);
141
142        // Parse provenance from pyproject.toml
143        let provenance = super::mcp::parse_provenance(root, filter);
144
145        let tools = vec![];
146        let data = build_data_surface(&tools, &execution);
147
148        Ok(vec![ScanTarget {
149            name,
150            framework: Framework::LangChain,
151            root_path: root.to_path_buf(),
152            tools,
153            execution,
154            data,
155            dependencies,
156            provenance,
157            source_files,
158        }])
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::adapter::Adapter;
166
167    #[test]
168    fn test_detect_langchain_via_pyproject() {
169        let dir =
170            PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/langchain_project");
171        let adapter = LangChainAdapter;
172        assert!(adapter.detect(&dir));
173    }
174
175    #[test]
176    fn test_detect_langchain_via_langgraph_json() {
177        let dir =
178            PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/langchain_project");
179        let adapter = LangChainAdapter;
180        // The fixture has pyproject.toml, but langgraph.json also triggers detection
181        assert!(adapter.detect(&dir));
182    }
183
184    #[test]
185    fn test_detect_non_langchain_project() {
186        let dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
187            .join("tests/fixtures/mcp_servers/safe_calculator");
188        let adapter = LangChainAdapter;
189        assert!(!adapter.detect(&dir));
190    }
191
192    #[test]
193    fn test_load_langchain_finds_cmd_injection() {
194        let dir =
195            PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/langchain_project");
196        let adapter = LangChainAdapter;
197        let targets = adapter.load(&dir, false).unwrap();
198        assert_eq!(targets.len(), 1);
199
200        let target = &targets[0];
201        assert_eq!(target.framework, Framework::LangChain);
202        assert_eq!(target.name, "langchain_project");
203
204        // Should find command injection in shell_tool.py
205        assert!(
206            !target.execution.commands.is_empty(),
207            "expected command execution findings from shell_tool.py"
208        );
209        // Should find tainted command args
210        assert!(
211            target
212                .execution
213                .commands
214                .iter()
215                .any(|c| c.command_arg.is_tainted()),
216            "expected tainted command source from subprocess.run with user input"
217        );
218    }
219
220    #[test]
221    fn test_load_langchain_finds_ssrf() {
222        let dir =
223            PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/langchain_project");
224        let adapter = LangChainAdapter;
225        let targets = adapter.load(&dir, false).unwrap();
226        let target = &targets[0];
227
228        // Should find network operations in fetch_tool.py
229        assert!(
230            !target.execution.network_operations.is_empty(),
231            "expected network operation findings from fetch_tool.py"
232        );
233    }
234
235    #[test]
236    fn test_load_langchain_only_python_files() {
237        let dir =
238            PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/langchain_project");
239        let adapter = LangChainAdapter;
240        let targets = adapter.load(&dir, false).unwrap();
241        let target = &targets[0];
242
243        // All source files should be Python
244        for sf in &target.source_files {
245            assert_eq!(
246                sf.language,
247                Language::Python,
248                "non-Python file found: {:?}",
249                sf.path
250            );
251        }
252    }
253}