Skip to main content

agentshield/adapter/
mcp.rs

1use std::path::{Path, PathBuf};
2
3use crate::analysis::cross_file::apply_cross_file_sanitization;
4use crate::error::Result;
5use crate::ir::execution_surface::ExecutionSurface;
6use crate::ir::taint_builder::build_data_surface;
7use crate::ir::*;
8use crate::parser;
9
10/// MCP Server adapter.
11///
12/// Detects MCP servers by looking for:
13/// - package.json with `@modelcontextprotocol/sdk` dependency
14/// - Python files importing `mcp` or `mcp.server`
15/// - mcp.json / mcp-config.json manifest
16pub struct McpAdapter;
17
18impl super::Adapter for McpAdapter {
19    fn framework(&self) -> Framework {
20        Framework::Mcp
21    }
22
23    fn detect(&self, root: &Path) -> bool {
24        // Check package.json for MCP SDK
25        let pkg_json = root.join("package.json");
26        if pkg_json.exists() {
27            if let Ok(content) = std::fs::read_to_string(&pkg_json) {
28                if content.contains("@modelcontextprotocol/sdk") || content.contains("mcp-server") {
29                    return true;
30                }
31            }
32        }
33
34        // Check pyproject.toml for mcp dependency
35        let pyproject = root.join("pyproject.toml");
36        if pyproject.exists() {
37            if let Ok(content) = std::fs::read_to_string(&pyproject) {
38                if content.contains("mcp") {
39                    return true;
40                }
41            }
42        }
43
44        // Check for Python files importing mcp
45        if let Ok(entries) = std::fs::read_dir(root) {
46            for entry in entries.flatten() {
47                let path = entry.path();
48                if path.extension().is_some_and(|e| e == "py") {
49                    if let Ok(content) = std::fs::read_to_string(&path) {
50                        if content.contains("from mcp")
51                            || content.contains("import mcp")
52                            || content.contains("@server.tool")
53                        {
54                            return true;
55                        }
56                    }
57                }
58            }
59        }
60
61        // Check requirements.txt
62        let requirements = root.join("requirements.txt");
63        if requirements.exists() {
64            if let Ok(content) = std::fs::read_to_string(&requirements) {
65                if content.lines().any(|l| l.trim().starts_with("mcp")) {
66                    return true;
67                }
68            }
69        }
70
71        false
72    }
73
74    fn load(&self, root: &Path, ignore_tests: bool) -> Result<Vec<ScanTarget>> {
75        let name = root
76            .file_name()
77            .map(|n| n.to_string_lossy().to_string())
78            .unwrap_or_else(|| "mcp-server".into());
79
80        let mut source_files = Vec::new();
81        let mut execution = ExecutionSurface::default();
82        let mut tools = Vec::new();
83
84        // Collect source files
85        collect_source_files(root, ignore_tests, &mut source_files)?;
86
87        // Phase 1: Parse each source file, collecting results for cross-file analysis.
88        let mut parsed_files: Vec<(PathBuf, parser::ParsedFile)> = Vec::new();
89        for sf in &source_files {
90            if let Some(parser) = parser::parser_for_language(sf.language) {
91                if let Ok(parsed) = parser.parse_file(&sf.path, &sf.content) {
92                    parsed_files.push((sf.path.clone(), parsed));
93                }
94            }
95        }
96
97        // Phase 2: Cross-file sanitizer-aware analysis — downgrade operations
98        // in functions that are only called with sanitized arguments.
99        apply_cross_file_sanitization(&mut parsed_files);
100
101        // Phase 3: Merge parsed results into execution surface.
102        for (_, parsed) in parsed_files {
103            execution.commands.extend(parsed.commands);
104            execution.file_operations.extend(parsed.file_operations);
105            execution
106                .network_operations
107                .extend(parsed.network_operations);
108            execution.env_accesses.extend(parsed.env_accesses);
109            execution.dynamic_exec.extend(parsed.dynamic_exec);
110        }
111
112        // Parse tool definitions from JSON if available
113        let tools_json = root.join("tools.json");
114        if tools_json.exists() {
115            if let Ok(content) = std::fs::read_to_string(&tools_json) {
116                if let Ok(value) = serde_json::from_str::<serde_json::Value>(&content) {
117                    tools = parser::json_schema::parse_tools_from_json(&value);
118                }
119            }
120        }
121
122        // Parse dependencies
123        let dependencies = parse_dependencies(root);
124
125        // Parse provenance from package.json or pyproject.toml
126        let provenance = parse_provenance(root);
127
128        let data = build_data_surface(&tools, &execution);
129
130        Ok(vec![ScanTarget {
131            name,
132            framework: Framework::Mcp,
133            root_path: root.to_path_buf(),
134            tools,
135            execution,
136            data,
137            dependencies,
138            provenance,
139            source_files,
140        }])
141    }
142}
143
144/// Check if a file path belongs to a test file or test directory.
145///
146/// Matches common conventions across Python, TypeScript, and JavaScript:
147/// - Directories: `test/`, `tests/`, `__tests__/`, `__pycache__/`
148/// - Suffixes: `.test.{ts,js,tsx,jsx,py}`, `.spec.{ts,js,tsx,jsx}`
149/// - Prefixes: `test_*.py` (pytest convention)
150/// - Config files: `conftest.py`, `jest.config.*`, `vitest.config.*`, `pytest.ini`, `setup.cfg`
151pub fn is_test_file(path: &Path) -> bool {
152    // Check if any path component is a test directory
153    for component in path.components() {
154        if let std::path::Component::Normal(name) = component {
155            let name = name.to_string_lossy();
156            if matches!(
157                name.as_ref(),
158                "test" | "tests" | "__tests__" | "__pycache__"
159            ) {
160                return true;
161            }
162        }
163    }
164
165    let file_name = match path.file_name() {
166        Some(n) => n.to_string_lossy(),
167        None => return false,
168    };
169    let file_name = file_name.as_ref();
170
171    // Test config files
172    if matches!(file_name, "conftest.py" | "pytest.ini" | "setup.cfg")
173        || file_name.starts_with("jest.config.")
174        || file_name.starts_with("vitest.config.")
175    {
176        return true;
177    }
178
179    // pytest prefix convention: test_*.py
180    if file_name.starts_with("test_") && file_name.ends_with(".py") {
181        return true;
182    }
183
184    // Suffix conventions: *.test.{ts,js,tsx,jsx,py}, *.spec.{ts,js,tsx,jsx}
185    for suffix in [
186        ".test.ts",
187        ".test.js",
188        ".test.tsx",
189        ".test.jsx",
190        ".test.py",
191        ".spec.ts",
192        ".spec.js",
193        ".spec.tsx",
194        ".spec.jsx",
195    ] {
196        if file_name.ends_with(suffix) {
197            return true;
198        }
199    }
200
201    false
202}
203
204pub(super) fn collect_source_files(
205    root: &Path,
206    ignore_tests: bool,
207    files: &mut Vec<SourceFile>,
208) -> Result<()> {
209    let walker = ignore::WalkBuilder::new(root)
210        .hidden(true)
211        .git_ignore(true)
212        .max_depth(Some(5))
213        .build();
214
215    for entry in walker.flatten() {
216        let path = entry.path();
217        if !path.is_file() {
218            continue;
219        }
220
221        if ignore_tests && is_test_file(path) {
222            continue;
223        }
224
225        let ext = path
226            .extension()
227            .map(|e| e.to_string_lossy().to_string())
228            .unwrap_or_default();
229        let lang = Language::from_extension(&ext);
230
231        if matches!(lang, Language::Unknown) {
232            continue;
233        }
234
235        // Skip files larger than 1MB
236        let metadata = std::fs::metadata(path)?;
237        if metadata.len() > 1_048_576 {
238            continue;
239        }
240
241        if let Ok(content) = std::fs::read_to_string(path) {
242            let hash = format!(
243                "{:x}",
244                sha2::Digest::finalize(sha2::Sha256::new().chain_update(content.as_bytes()))
245            );
246            files.push(SourceFile {
247                path: path.to_path_buf(),
248                language: lang,
249                size_bytes: metadata.len(),
250                content_hash: hash,
251                content,
252            });
253        }
254    }
255
256    Ok(())
257}
258
259pub(super) fn parse_dependencies(root: &Path) -> dependency_surface::DependencySurface {
260    use crate::ir::dependency_surface::*;
261    let mut surface = DependencySurface::default();
262
263    // Parse requirements.txt as a dependency manifest (NOT a lockfile)
264    let req_file = root.join("requirements.txt");
265    if req_file.exists() {
266        if let Ok(content) = std::fs::read_to_string(&req_file) {
267            for (idx, line) in content.lines().enumerate() {
268                let line = line.trim();
269                if line.is_empty() || line.starts_with('#') || line.starts_with('-') {
270                    continue;
271                }
272                let (name, version) = if let Some(pos) = line.find("==") {
273                    (
274                        line[..pos].trim().to_string(),
275                        Some(line[pos + 2..].trim().to_string()),
276                    )
277                } else if let Some(pos) = line.find(">=") {
278                    (
279                        line[..pos].trim().to_string(),
280                        Some(line[pos..].trim().to_string()),
281                    )
282                } else {
283                    (line.to_string(), None)
284                };
285
286                surface.dependencies.push(Dependency {
287                    name,
288                    version_constraint: version,
289                    locked_version: None,
290                    locked_hash: None,
291                    registry: "pypi".into(),
292                    is_dev: false,
293                    location: Some(SourceLocation {
294                        file: req_file.clone(),
295                        line: idx + 1,
296                        column: 0,
297                        end_line: None,
298                        end_column: None,
299                    }),
300                });
301            }
302        }
303    }
304
305    // Check for actual Python lockfiles
306    for (filename, format) in [
307        ("Pipfile.lock", LockfileFormat::PipenvLock),
308        ("poetry.lock", LockfileFormat::PoetryLock),
309        ("uv.lock", LockfileFormat::UvLock),
310    ] {
311        let lock_path = root.join(filename);
312        if lock_path.exists() {
313            surface.lockfile = Some(LockfileInfo {
314                path: lock_path,
315                format,
316                all_pinned: true,
317                all_hashed: false,
318            });
319            break;
320        }
321    }
322
323    // Parse package.json dependencies
324    let pkg_json = root.join("package.json");
325    if pkg_json.exists() {
326        if let Ok(content) = std::fs::read_to_string(&pkg_json) {
327            if let Ok(value) = serde_json::from_str::<serde_json::Value>(&content) {
328                for (key, is_dev) in [("dependencies", false), ("devDependencies", true)] {
329                    if let Some(deps) = value.get(key).and_then(|v| v.as_object()) {
330                        for (name, version) in deps {
331                            let line = find_json_key_line(&content, name);
332                            surface.dependencies.push(Dependency {
333                                name: name.clone(),
334                                version_constraint: version.as_str().map(|s| s.to_string()),
335                                locked_version: None,
336                                locked_hash: None,
337                                registry: "npm".into(),
338                                is_dev,
339                                location: Some(SourceLocation {
340                                    file: pkg_json.clone(),
341                                    line,
342                                    column: 0,
343                                    end_line: None,
344                                    end_column: None,
345                                }),
346                            });
347                        }
348                    }
349                }
350            }
351        }
352
353        // Check for lockfile
354        let lock = root.join("package-lock.json");
355        if lock.exists() {
356            surface.lockfile = Some(LockfileInfo {
357                path: lock,
358                format: dependency_surface::LockfileFormat::NpmLock,
359                all_pinned: true,
360                all_hashed: false,
361            });
362        }
363    }
364
365    surface
366}
367
368/// Find the 1-based line number where a JSON key (e.g. `"package-name"`) appears.
369/// Falls back to line 1 if the key is not found.
370fn find_json_key_line(content: &str, key: &str) -> usize {
371    let needle = format!("\"{}\"", key);
372    for (idx, line) in content.lines().enumerate() {
373        if line.contains(&needle) {
374            return idx + 1;
375        }
376    }
377    1
378}
379
380pub(super) fn parse_provenance(root: &Path) -> provenance_surface::ProvenanceSurface {
381    let mut prov = provenance_surface::ProvenanceSurface::default();
382
383    // From package.json
384    let pkg_json = root.join("package.json");
385    if pkg_json.exists() {
386        if let Ok(content) = std::fs::read_to_string(&pkg_json) {
387            if let Ok(value) = serde_json::from_str::<serde_json::Value>(&content) {
388                prov.author = value
389                    .get("author")
390                    .and_then(|v| v.as_str())
391                    .map(|s| s.to_string());
392                prov.repository = value
393                    .get("repository")
394                    .and_then(|v| v.get("url").or(Some(v)))
395                    .and_then(|v| v.as_str())
396                    .map(|s| s.to_string());
397                prov.license = value
398                    .get("license")
399                    .and_then(|v| v.as_str())
400                    .map(|s| s.to_string());
401            }
402        }
403    }
404
405    // From pyproject.toml
406    let pyproject = root.join("pyproject.toml");
407    if pyproject.exists() {
408        if let Ok(content) = std::fs::read_to_string(&pyproject) {
409            if let Ok(value) = content.parse::<toml::Value>() {
410                if let Some(project) = value.get("project") {
411                    prov.license = project
412                        .get("license")
413                        .and_then(|v| v.get("text").or(Some(v)))
414                        .and_then(|v| v.as_str())
415                        .map(|s| s.to_string());
416                    if let Some(authors) = project.get("authors").and_then(|v| v.as_array()) {
417                        if let Some(first) = authors.first() {
418                            prov.author = first
419                                .get("name")
420                                .and_then(|v| v.as_str())
421                                .map(|s| s.to_string());
422                        }
423                    }
424                }
425                if let Some(urls) = value.get("project").and_then(|p| p.get("urls")) {
426                    prov.repository = urls
427                        .get("Repository")
428                        .or(urls.get("repository"))
429                        .and_then(|v| v.as_str())
430                        .map(|s| s.to_string());
431                }
432            }
433        }
434    }
435
436    prov
437}
438
439use sha2::Digest;