Skip to main content

pmcp_code_mode/
javascript.rs

1//! JavaScript-specific validation for Code Mode (OpenAPI servers).
2//!
3//! This module validates JavaScript code generated by LLMs for REST API interactions.
4//! It uses SWC (Speedy Web Compiler) for production-grade parsing and enforces a safe
5//! subset of JavaScript that prevents malicious operations while enabling powerful
6//! API orchestration.
7//!
8//! ## Safe Subset
9//!
10//! Allowed:
11//! - async/await for API calls
12//! - api.get(), api.post(), api.put(), api.delete(), api.patch() calls
13//! - const/let variable declarations
14//! - Arrow functions for callbacks (including block bodies with nested callbacks)
15//! - Array methods: map, filter, reduce, find, some, every, slice
16//! - Object destructuring and spread
17//! - Template literals (string interpolation)
18//! - Bounded for...of loops (with .slice() limits)
19//! - if/else conditionals
20//! - try/catch for error handling
21//! - Logical operators (&&, ||)
22//!
23//! Blocked:
24//! - import/export statements
25//! - eval(), Function(), new Function()
26//! - while/do-while loops (unbounded)
27//! - Regular function declarations (only arrow functions)
28//! - new keyword (except specific built-ins)
29//! - this keyword
30//! - class declarations
31//! - Generators/iterators
32//! - with statement
33//! - delete operator
34//! - Prototype manipulation
35
36use crate::types::{
37    CodeLocation, CodeType, Complexity, SecurityAnalysis, SecurityIssue, SecurityIssueType,
38    ValidationError,
39};
40use std::collections::HashSet;
41use swc_common::{sync::Lrc, SourceMap, Span};
42use swc_ecma_ast::*;
43use swc_ecma_parser::{lexer::Lexer, Parser, StringInput, Syntax};
44use swc_ecma_visit::{Visit, VisitWith};
45
46/// HTTP methods that can be called via the api object.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum HttpMethod {
49    Get,
50    Post,
51    Put,
52    Delete,
53    Patch,
54    Head,
55    Options,
56}
57
58impl HttpMethod {
59    /// Whether this method is read-only (safe).
60    pub fn is_read_only(&self) -> bool {
61        matches!(
62            self,
63            HttpMethod::Get | HttpMethod::Head | HttpMethod::Options
64        )
65    }
66
67    /// Parse from string.
68    // Why: existing public API surface — renaming to satisfy `should_implement_trait`
69    // would be a breaking change. Returning Option (vs Result for FromStr) is intentional
70    // since unknown methods map to None, not a typed error.
71    #[allow(clippy::should_implement_trait)]
72    pub fn from_str(s: &str) -> Option<Self> {
73        match s.to_lowercase().as_str() {
74            "get" => Some(HttpMethod::Get),
75            "post" => Some(HttpMethod::Post),
76            "put" => Some(HttpMethod::Put),
77            "delete" => Some(HttpMethod::Delete),
78            "patch" => Some(HttpMethod::Patch),
79            "head" => Some(HttpMethod::Head),
80            "options" => Some(HttpMethod::Options),
81            _ => None,
82        }
83    }
84}
85
86/// An API call extracted from the JavaScript code.
87#[derive(Debug, Clone)]
88pub struct ApiCall {
89    /// The HTTP method
90    pub method: HttpMethod,
91    /// The path template (may contain interpolations)
92    pub path: String,
93    /// Whether the path is dynamic (contains template expressions)
94    pub is_dynamic_path: bool,
95    /// Line number in the source
96    pub line: u32,
97    /// Column number in the source
98    pub column: u32,
99}
100
101/// Declared output type from @returns annotation.
102#[derive(Debug, Clone, Default)]
103pub struct OutputDeclaration {
104    /// Whether a @returns annotation was found
105    pub has_declaration: bool,
106
107    /// The raw type string from the annotation (e.g., "{ users: Array<{ id: string, name: string }> }")
108    pub type_string: Option<String>,
109
110    /// Fields mentioned in the output type (extracted for field blocklist checking)
111    pub declared_fields: HashSet<String>,
112
113    /// Whether the declaration uses spread operators (potential field leakage)
114    pub has_spread_risk: bool,
115}
116
117/// Information extracted from parsed JavaScript code.
118#[derive(Debug, Clone, Default)]
119pub struct JavaScriptCodeInfo {
120    /// All API calls in the code
121    pub api_calls: Vec<ApiCall>,
122
123    /// Whether the code is read-only (only GET/HEAD/OPTIONS calls)
124    pub is_read_only: bool,
125
126    /// All endpoints accessed
127    pub endpoints_accessed: HashSet<String>,
128
129    /// All HTTP methods used
130    pub methods_used: HashSet<String>,
131
132    /// Whether the code uses async/await
133    pub uses_async: bool,
134
135    /// Variable names declared
136    pub variable_names: Vec<String>,
137
138    /// Maximum nesting depth
139    pub max_depth: usize,
140
141    /// Number of for...of loops
142    pub loop_count: usize,
143
144    /// Whether all loops are bounded (use .slice())
145    pub all_loops_bounded: bool,
146
147    /// Policy violations found during parsing
148    pub violations: Vec<SafetyViolation>,
149
150    /// Total number of statements
151    pub statement_count: usize,
152
153    /// Output declaration from @returns annotation
154    pub output_declaration: OutputDeclaration,
155
156    /// Whether the script contains spread operators that could leak fields
157    pub has_output_spread_risk: bool,
158}
159
160/// A safety violation found during JavaScript validation.
161#[derive(Debug, Clone)]
162pub struct SafetyViolation {
163    /// Type of violation
164    pub violation_type: SafetyViolationType,
165    /// Human-readable message
166    pub message: String,
167    /// Location in source
168    pub location: Option<CodeLocation>,
169}
170
171/// Types of safety violations in JavaScript code.
172#[derive(Debug, Clone, Copy, PartialEq, Eq)]
173pub enum SafetyViolationType {
174    /// import/export statement
175    ImportExport,
176    /// eval() or Function() call
177    DynamicCodeExecution,
178    /// while/do-while loop (unbounded)
179    UnboundedLoop,
180    /// Regular function declaration
181    FunctionDeclaration,
182    /// try/catch statement
183    TryCatch,
184    /// new keyword (except allowed)
185    NewKeyword,
186    /// this keyword
187    ThisKeyword,
188    /// class declaration
189    ClassDeclaration,
190    /// Generator function
191    Generator,
192    /// with statement
193    WithStatement,
194    /// delete operator
195    DeleteOperator,
196    /// Prototype manipulation
197    PrototypeManipulation,
198    /// Unbounded for loop (no .slice())
199    UnboundedForLoop,
200    /// Unknown API call (not api.method())
201    UnknownApiCall,
202}
203
204/// JavaScript code validator for OpenAPI Code Mode.
205pub struct JavaScriptValidator {
206    /// Sensitive path patterns (e.g., "/admin", "/internal")
207    sensitive_paths: Vec<String>,
208
209    /// Maximum allowed nesting depth
210    max_depth: usize,
211
212    /// Maximum allowed API calls
213    max_api_calls: usize,
214
215    /// Maximum allowed loop count
216    max_loops: usize,
217
218    /// Maximum allowed statements
219    max_statements: usize,
220
221    /// Allowed SDK operation names (camelCase). When non-empty, SDK mode is active:
222    /// known operations are accepted, unknown ones generate UnknownApiCall violations.
223    sdk_operations: HashSet<String>,
224}
225
226impl Default for JavaScriptValidator {
227    fn default() -> Self {
228        Self {
229            sensitive_paths: vec![
230                "/admin".into(),
231                "/internal".into(),
232                "/debug".into(),
233                "/metrics".into(),
234                "/health".into(),
235            ],
236            max_depth: 10,
237            max_api_calls: 50,
238            max_loops: 10,
239            max_statements: 100,
240            sdk_operations: HashSet::new(),
241        }
242    }
243}
244
245/// Check if a word is a TypeScript/JSDoc type keyword (not a field name).
246fn is_type_keyword(word: &str) -> bool {
247    matches!(
248        word,
249        "string"
250            | "number"
251            | "boolean"
252            | "null"
253            | "undefined"
254            | "void"
255            | "any"
256            | "never"
257            | "object"
258            | "Array"
259            | "Promise"
260            | "Record"
261            | "Map"
262            | "Set"
263            | "Date"
264            | "type"
265            | "interface"
266    )
267}
268
269impl JavaScriptValidator {
270    /// Create a new validator with custom settings.
271    pub fn new(
272        sensitive_paths: Vec<String>,
273        max_depth: usize,
274        max_api_calls: usize,
275        max_loops: usize,
276        max_statements: usize,
277    ) -> Self {
278        Self {
279            sensitive_paths,
280            max_depth,
281            max_api_calls,
282            max_loops,
283            max_statements,
284            sdk_operations: HashSet::new(),
285        }
286    }
287
288    /// Set the allowed SDK operation names. When non-empty, the validator operates in SDK mode:
289    /// known SDK operations are accepted, unknown ones generate `UnknownApiCall` violations.
290    pub fn with_sdk_operations(mut self, operations: HashSet<String>) -> Self {
291        self.sdk_operations = operations;
292        self
293    }
294
295    /// Parse @returns annotation from code comments.
296    ///
297    /// Supports formats:
298    /// - `// @returns { type }` (single-line comment)
299    /// - `/// @returns { type }` (triple-slash comment)
300    /// - `/** @returns { type } */` (JSDoc comment)
301    fn parse_returns_annotation(code: &str) -> OutputDeclaration {
302        let mut declaration = OutputDeclaration::default();
303
304        // Try to find @returns annotation in comments
305        for line in code.lines() {
306            let trimmed = line.trim();
307
308            // Check for triple-slash comment: /// @returns { ... }
309            // Must check before double-slash since /// starts with //
310            if let Some(rest) = trimmed.strip_prefix("///") {
311                if let Some(returns_content) = Self::extract_returns_content(rest) {
312                    declaration.has_declaration = true;
313                    declaration.type_string = Some(returns_content.clone());
314                    declaration.declared_fields = Self::extract_fields_from_type(&returns_content);
315                    declaration.has_spread_risk = returns_content.contains("...");
316                    return declaration;
317                }
318            }
319            // Check for double-slash comment: // @returns { ... }
320            else if let Some(rest) = trimmed.strip_prefix("//") {
321                if let Some(returns_content) = Self::extract_returns_content(rest) {
322                    declaration.has_declaration = true;
323                    declaration.type_string = Some(returns_content.clone());
324                    declaration.declared_fields = Self::extract_fields_from_type(&returns_content);
325                    declaration.has_spread_risk = returns_content.contains("...");
326                    return declaration;
327                }
328            }
329
330            // Check for JSDoc comment start: /** @returns { ... } */ or /** ... @returns ... */
331            if trimmed.starts_with("/**") || trimmed.starts_with("*") {
332                let content = trimmed
333                    .trim_start_matches("/**")
334                    .trim_start_matches('*')
335                    .trim_end_matches("*/")
336                    .trim();
337
338                if let Some(returns_content) = Self::extract_returns_content(content) {
339                    declaration.has_declaration = true;
340                    declaration.type_string = Some(returns_content.clone());
341                    declaration.declared_fields = Self::extract_fields_from_type(&returns_content);
342                    declaration.has_spread_risk = returns_content.contains("...");
343                    return declaration;
344                }
345            }
346        }
347
348        declaration
349    }
350
351    /// Extract the content after @returns.
352    fn extract_returns_content(text: &str) -> Option<String> {
353        let text = text.trim();
354
355        // Look for @returns or @return
356        let returns_pos = text.find("@returns").or_else(|| text.find("@return"))?;
357
358        // Extract everything after the tag
359        let after_tag = &text[returns_pos..];
360        let content_start = after_tag.find(['{', '('])?;
361
362        // Find the matching closing bracket
363        let chars: Vec<char> = after_tag[content_start..].chars().collect();
364        let open_char = chars[0];
365        let close_char = if open_char == '{' { '}' } else { ')' };
366
367        let mut depth = 0;
368        let mut end_pos = 0;
369
370        for (i, c) in chars.iter().enumerate() {
371            if *c == open_char {
372                depth += 1;
373            } else if *c == close_char {
374                depth -= 1;
375                if depth == 0 {
376                    end_pos = i + 1;
377                    break;
378                }
379            }
380        }
381
382        if end_pos > 0 {
383            Some(after_tag[content_start..content_start + end_pos].to_string())
384        } else {
385            // If no closing bracket found, take the rest of the line
386            Some(after_tag[content_start..].trim().to_string())
387        }
388    }
389
390    /// Extract field names from a type declaration string.
391    ///
392    /// This is a simple parser that extracts identifiers that appear
393    /// before colons, which are typically field names in object types.
394    fn extract_fields_from_type(type_string: &str) -> HashSet<String> {
395        let mut fields = HashSet::new();
396
397        // Simple regex-free parser: find "fieldName:" patterns
398        let chars: Vec<char> = type_string.chars().collect();
399        let mut current_word = String::new();
400        let mut in_word = false;
401
402        for c in chars.iter() {
403            if c.is_alphanumeric() || *c == '_' {
404                current_word.push(*c);
405                in_word = true;
406            } else {
407                if in_word && *c == ':' {
408                    // This is a field name
409                    if !current_word.is_empty()
410                        && !is_type_keyword(&current_word)
411                        && !current_word.chars().next().unwrap().is_ascii_uppercase()
412                    {
413                        fields.insert(current_word.clone());
414                    }
415                }
416                current_word.clear();
417                in_word = false;
418            }
419        }
420
421        fields
422    }
423
424    /// Check if declared output fields contain any blocked fields.
425    pub fn check_output_against_blocklist(
426        declaration: &OutputDeclaration,
427        blocked_fields: &HashSet<String>,
428    ) -> Vec<String> {
429        let mut violations = Vec::new();
430
431        for field in &declaration.declared_fields {
432            // Check exact match
433            if blocked_fields.contains(field) {
434                violations.push(format!("Output declares blocked field: {}", field));
435                continue;
436            }
437
438            // Check wildcard patterns like *.fieldName
439            for blocked in blocked_fields {
440                if let Some(pattern) = blocked.strip_prefix("*.") {
441                    if field == pattern {
442                        violations.push(format!(
443                            "Output declares blocked field pattern: {}",
444                            blocked
445                        ));
446                    }
447                }
448            }
449        }
450
451        violations
452    }
453
454    /// Parse and validate JavaScript code.
455    pub fn validate(&self, code: &str) -> Result<JavaScriptCodeInfo, ValidationError> {
456        // Parse the code using SWC
457        let cm: Lrc<SourceMap> = Default::default();
458
459        // Create a source file from the code - BytesStr accepts String
460        let fm = cm.new_source_file(
461            swc_common::FileName::Custom("code.js".into()).into(),
462            code.to_string(),
463        );
464
465        let lexer = Lexer::new(
466            Syntax::Es(Default::default()),
467            EsVersion::Es2022,
468            StringInput::from(&*fm),
469            None,
470        );
471
472        let mut parser = Parser::new_from(lexer);
473
474        let module = parser
475            .parse_module()
476            .map_err(|e| ValidationError::ParseError {
477                message: format!("JavaScript parse error: {:?}", e.into_kind()),
478                line: 0,
479                column: 0,
480            })?;
481
482        // Visit the AST to extract information and check safety
483        let mut visitor = SafetyVisitor::new(&cm).with_sdk_operations(self.sdk_operations.clone());
484        module.visit_with(&mut visitor);
485
486        let mut info = visitor.into_info();
487
488        // Parse @returns annotation from code comments
489        info.output_declaration = Self::parse_returns_annotation(code);
490
491        // Validate constraints
492        if info.api_calls.len() > self.max_api_calls {
493            return Err(ValidationError::SecurityError {
494                message: format!(
495                    "Too many API calls: {} (max: {})",
496                    info.api_calls.len(),
497                    self.max_api_calls
498                ),
499                issue: SecurityIssueType::HighComplexity,
500            });
501        }
502
503        if info.max_depth > self.max_depth {
504            return Err(ValidationError::SecurityError {
505                message: format!(
506                    "Code nesting depth {} exceeds maximum {}",
507                    info.max_depth, self.max_depth
508                ),
509                issue: SecurityIssueType::DeepNesting,
510            });
511        }
512
513        if info.loop_count > self.max_loops {
514            return Err(ValidationError::SecurityError {
515                message: format!(
516                    "Too many loops: {} (max: {})",
517                    info.loop_count, self.max_loops
518                ),
519                issue: SecurityIssueType::HighComplexity,
520            });
521        }
522
523        if info.statement_count > self.max_statements {
524            return Err(ValidationError::SecurityError {
525                message: format!(
526                    "Too many statements: {} (max: {})",
527                    info.statement_count, self.max_statements
528                ),
529                issue: SecurityIssueType::HighComplexity,
530            });
531        }
532
533        // Check for safety violations
534        if !info.violations.is_empty() {
535            let first = &info.violations[0];
536            return Err(ValidationError::SecurityError {
537                message: first.message.clone(),
538                issue: violation_to_security_issue(first.violation_type),
539            });
540        }
541
542        // Determine if read-only based on API calls
543        info.is_read_only = info.api_calls.iter().all(|c| c.method.is_read_only());
544
545        Ok(info)
546    }
547
548    /// Perform security analysis on code info.
549    pub fn analyze_security(&self, info: &JavaScriptCodeInfo) -> SecurityAnalysis {
550        let mut analysis = SecurityAnalysis {
551            is_read_only: info.is_read_only,
552            tables_accessed: info.endpoints_accessed.clone(),
553            fields_accessed: HashSet::new(),
554            has_aggregation: false,
555            has_subqueries: info.max_depth > 3,
556            estimated_complexity: self.estimate_complexity(info),
557            potential_issues: Vec::new(),
558            estimated_rows: None,
559        };
560
561        // Check for sensitive endpoints
562        for endpoint in &info.endpoints_accessed {
563            let endpoint_lower = endpoint.to_lowercase();
564            if self
565                .sensitive_paths
566                .iter()
567                .any(|s| endpoint_lower.contains(&s.to_lowercase()))
568            {
569                analysis.potential_issues.push(SecurityIssue::new(
570                    SecurityIssueType::SensitiveFields,
571                    format!("Code accesses potentially sensitive endpoint: {}", endpoint),
572                ));
573            }
574        }
575
576        // Check for dynamic paths (potential injection)
577        for call in &info.api_calls {
578            if call.is_dynamic_path {
579                analysis.potential_issues.push(
580                    SecurityIssue::new(
581                        SecurityIssueType::DynamicTableName,
582                        format!(
583                            "API call at line {} uses dynamic path interpolation",
584                            call.line
585                        ),
586                    )
587                    .with_location(CodeLocation {
588                        line: call.line,
589                        column: call.column,
590                    }),
591                );
592            }
593        }
594
595        // Check for deep nesting
596        if info.max_depth > 5 {
597            analysis.potential_issues.push(SecurityIssue::new(
598                SecurityIssueType::DeepNesting,
599                format!("Code has deep nesting (depth: {})", info.max_depth),
600            ));
601        }
602
603        // Check for unbounded loops
604        if !info.all_loops_bounded && info.loop_count > 0 {
605            analysis.potential_issues.push(SecurityIssue::new(
606                SecurityIssueType::UnboundedQuery,
607                "Code contains for...of loops without .slice() bounds",
608            ));
609        }
610
611        // Check for high complexity
612        if matches!(analysis.estimated_complexity, Complexity::High) {
613            analysis.potential_issues.push(SecurityIssue::new(
614                SecurityIssueType::HighComplexity,
615                "Code has high complexity",
616            ));
617        }
618
619        analysis
620    }
621
622    /// Estimate code complexity.
623    fn estimate_complexity(&self, info: &JavaScriptCodeInfo) -> Complexity {
624        let api_count = info.api_calls.len();
625        let loop_count = info.loop_count;
626        let depth = info.max_depth;
627        let statement_count = info.statement_count;
628
629        // Simple heuristic
630        let complexity_score = api_count * 3 + loop_count * 5 + depth * 2 + statement_count;
631
632        if complexity_score > 100 {
633            Complexity::High
634        } else if complexity_score > 50 {
635            Complexity::Medium
636        } else {
637            Complexity::Low
638        }
639    }
640
641    /// Convert code info to CodeType.
642    pub fn to_code_type(&self, info: &JavaScriptCodeInfo) -> CodeType {
643        if info.is_read_only {
644            CodeType::RestGet
645        } else {
646            CodeType::RestMutation
647        }
648    }
649}
650
651/// AST visitor for extracting info and checking safety.
652struct SafetyVisitor {
653    source_map: Lrc<SourceMap>,
654    api_calls: Vec<ApiCall>,
655    violations: Vec<SafetyViolation>,
656    variable_names: Vec<String>,
657    endpoints_accessed: HashSet<String>,
658    methods_used: HashSet<String>,
659    uses_async: bool,
660    current_depth: usize,
661    max_depth: usize,
662    loop_count: usize,
663    bounded_loops: usize,
664    statement_count: usize,
665    /// Whether the code uses spread operators in return values (potential field leakage)
666    has_spread_in_return: bool,
667    /// Whether we're currently inside a return statement
668    in_return_context: bool,
669    /// Allowed SDK operation names (camelCase). When non-empty, SDK mode validation is active.
670    sdk_operations: HashSet<String>,
671}
672
673impl SafetyVisitor {
674    fn new(source_map: &Lrc<SourceMap>) -> Self {
675        Self {
676            source_map: source_map.clone(),
677            api_calls: Vec::new(),
678            violations: Vec::new(),
679            variable_names: Vec::new(),
680            endpoints_accessed: HashSet::new(),
681            methods_used: HashSet::new(),
682            uses_async: false,
683            current_depth: 0,
684            max_depth: 0,
685            loop_count: 0,
686            bounded_loops: 0,
687            statement_count: 0,
688            has_spread_in_return: false,
689            in_return_context: false,
690            sdk_operations: HashSet::new(),
691        }
692    }
693
694    fn with_sdk_operations(mut self, operations: HashSet<String>) -> Self {
695        self.sdk_operations = operations;
696        self
697    }
698
699    fn into_info(self) -> JavaScriptCodeInfo {
700        JavaScriptCodeInfo {
701            api_calls: self.api_calls,
702            is_read_only: false, // Set later based on API calls
703            endpoints_accessed: self.endpoints_accessed,
704            methods_used: self.methods_used,
705            uses_async: self.uses_async,
706            variable_names: self.variable_names,
707            max_depth: self.max_depth,
708            loop_count: self.loop_count,
709            all_loops_bounded: self.loop_count == 0 || self.bounded_loops == self.loop_count,
710            violations: self.violations,
711            statement_count: self.statement_count,
712            output_declaration: OutputDeclaration::default(), // Set later by validator
713            has_output_spread_risk: self.has_spread_in_return,
714        }
715    }
716
717    fn span_to_location(&self, span: Span) -> CodeLocation {
718        let loc = self.source_map.lookup_char_pos(span.lo);
719        CodeLocation {
720            line: loc.line as u32,
721            column: loc.col_display as u32,
722        }
723    }
724
725    fn add_violation(&mut self, violation_type: SafetyViolationType, message: &str, span: Span) {
726        self.violations.push(SafetyViolation {
727            violation_type,
728            message: message.into(),
729            location: Some(self.span_to_location(span)),
730        });
731    }
732
733    fn check_api_call(&mut self, call: &CallExpr) {
734        // Check for api.method() pattern
735        if let Callee::Expr(expr) = &call.callee {
736            if let Expr::Member(member) = &**expr {
737                if let Expr::Ident(obj) = &*member.obj {
738                    if obj.sym.as_ref() == "api" {
739                        if let MemberProp::Ident(method_ident) = &member.prop {
740                            let method_name = method_ident.sym.as_ref();
741
742                            if !self.sdk_operations.is_empty() {
743                                // SDK mode: validate against allowed operation names
744                                if self.sdk_operations.contains(method_name) {
745                                    self.methods_used.insert(method_name.to_string());
746                                    self.endpoints_accessed
747                                        .insert(format!("sdk:{}", method_name));
748                                    // No path extraction needed for SDK calls
749                                } else {
750                                    self.add_violation(
751                                        SafetyViolationType::UnknownApiCall,
752                                        &format!(
753                                            "Unknown SDK operation: api.{}(). Check the code mode schema resource for available operations.",
754                                            method_name
755                                        ),
756                                        call.span,
757                                    );
758                                }
759                                return;
760                            }
761
762                            // HTTP mode: validate against known HTTP methods
763                            if let Some(method) = HttpMethod::from_str(method_name) {
764                                self.methods_used.insert(method_name.to_uppercase());
765
766                                // Extract path from first argument
767                                let (path, is_dynamic) = if let Some(arg) = call.args.first() {
768                                    self.extract_path(&arg.expr)
769                                } else {
770                                    ("unknown".into(), false)
771                                };
772
773                                self.endpoints_accessed.insert(path.clone());
774
775                                let loc = self.span_to_location(call.span);
776                                self.api_calls.push(ApiCall {
777                                    method,
778                                    path,
779                                    is_dynamic_path: is_dynamic,
780                                    line: loc.line,
781                                    column: loc.column,
782                                });
783                            } else {
784                                self.add_violation(
785                                    SafetyViolationType::UnknownApiCall,
786                                    &format!("Unknown api method: api.{}()", method_name),
787                                    call.span,
788                                );
789                            }
790                        }
791                    }
792                }
793            }
794        }
795    }
796
797    fn extract_path(&self, expr: &Expr) -> (String, bool) {
798        match expr {
799            Expr::Lit(Lit::Str(s)) => {
800                // Convert Wtf8Atom to String using to_string_lossy (handles WTF-8 encoding)
801                (s.value.to_string_lossy().into_owned(), false)
802            },
803            Expr::Tpl(tpl) => {
804                // Template literal - extract static parts
805                let mut path = String::new();
806                for quasi in &tpl.quasis {
807                    // quasi.raw is an Atom (UTF-8), not Wtf8Atom
808                    path.push_str(quasi.raw.as_ref());
809                    if !tpl.exprs.is_empty() {
810                        path.push_str("{...}");
811                    }
812                }
813                (path, !tpl.exprs.is_empty())
814            },
815            _ => ("dynamic".into(), true),
816        }
817    }
818
819    fn check_for_bounded(&mut self, for_of: &ForOfStmt) -> bool {
820        // Check if the iterable uses .slice()
821        if let Expr::Call(call) = &*for_of.right {
822            if let Callee::Expr(callee) = &call.callee {
823                if let Expr::Member(member) = &**callee {
824                    if let MemberProp::Ident(ident) = &member.prop {
825                        if ident.sym.as_ref() == "slice" {
826                            return true;
827                        }
828                    }
829                }
830            }
831        }
832        false
833    }
834}
835
836impl Visit for SafetyVisitor {
837    // Track depth
838    fn visit_block_stmt(&mut self, n: &BlockStmt) {
839        self.current_depth += 1;
840        self.max_depth = self.max_depth.max(self.current_depth);
841        n.visit_children_with(self);
842        self.current_depth -= 1;
843    }
844
845    // Count statements
846    fn visit_stmt(&mut self, n: &Stmt) {
847        self.statement_count += 1;
848        n.visit_children_with(self);
849    }
850
851    // Check for import/export
852    fn visit_import_decl(&mut self, n: &ImportDecl) {
853        self.add_violation(
854            SafetyViolationType::ImportExport,
855            "import statements are not allowed",
856            n.span,
857        );
858    }
859
860    fn visit_export_decl(&mut self, n: &ExportDecl) {
861        self.add_violation(
862            SafetyViolationType::ImportExport,
863            "export statements are not allowed",
864            n.span,
865        );
866    }
867
868    fn visit_export_default_decl(&mut self, n: &ExportDefaultDecl) {
869        self.add_violation(
870            SafetyViolationType::ImportExport,
871            "export default is not allowed",
872            n.span,
873        );
874    }
875
876    fn visit_export_default_expr(&mut self, n: &ExportDefaultExpr) {
877        self.add_violation(
878            SafetyViolationType::ImportExport,
879            "export default is not allowed",
880            n.span,
881        );
882    }
883
884    // Check for eval/Function
885    fn visit_call_expr(&mut self, n: &CallExpr) {
886        // Check for eval() or Function()
887        if let Callee::Expr(callee) = &n.callee {
888            if let Expr::Ident(ident) = &**callee {
889                let name = ident.sym.as_ref();
890                if name == "eval" || name == "Function" {
891                    self.add_violation(
892                        SafetyViolationType::DynamicCodeExecution,
893                        &format!("{}() is not allowed", name),
894                        n.span,
895                    );
896                }
897            }
898        }
899
900        // Check for api.method() calls
901        self.check_api_call(n);
902
903        n.visit_children_with(self);
904    }
905
906    // Check for while loops
907    fn visit_while_stmt(&mut self, n: &WhileStmt) {
908        self.add_violation(
909            SafetyViolationType::UnboundedLoop,
910            "while loops are not allowed (use bounded for...of with .slice())",
911            n.span,
912        );
913        n.visit_children_with(self);
914    }
915
916    fn visit_do_while_stmt(&mut self, n: &DoWhileStmt) {
917        self.add_violation(
918            SafetyViolationType::UnboundedLoop,
919            "do-while loops are not allowed (use bounded for...of with .slice())",
920            n.span,
921        );
922        n.visit_children_with(self);
923    }
924
925    // Check for...of loops for bounds
926    fn visit_for_of_stmt(&mut self, n: &ForOfStmt) {
927        self.loop_count += 1;
928        if self.check_for_bounded(n) {
929            self.bounded_loops += 1;
930        }
931        n.visit_children_with(self);
932    }
933
934    // Check for regular for loops (allow but track)
935    fn visit_for_stmt(&mut self, n: &ForStmt) {
936        self.loop_count += 1;
937        // Regular for loops are bounded by definition
938        self.bounded_loops += 1;
939        n.visit_children_with(self);
940    }
941
942    // Check for function declarations (only arrow functions allowed)
943    fn visit_fn_decl(&mut self, n: &FnDecl) {
944        self.add_violation(
945            SafetyViolationType::FunctionDeclaration,
946            "function declarations are not allowed (use arrow functions)",
947            n.function.span,
948        );
949        n.visit_children_with(self);
950    }
951
952    // Allow try/catch - it's just control flow and doesn't pose a security risk.
953    // This enables scripts to gracefully handle API errors.
954    fn visit_try_stmt(&mut self, n: &TryStmt) {
955        // Visit children to validate the contents of try/catch blocks
956        n.visit_children_with(self);
957    }
958
959    // Check for new keyword
960    fn visit_new_expr(&mut self, n: &NewExpr) {
961        // Allow specific constructors like Date, URL, URLSearchParams
962        let allowed = if let Expr::Ident(ident) = &*n.callee {
963            matches!(
964                ident.sym.as_ref(),
965                "Date" | "URL" | "URLSearchParams" | "Map" | "Set" | "Array"
966            )
967        } else {
968            false
969        };
970
971        if !allowed {
972            self.add_violation(
973                SafetyViolationType::NewKeyword,
974                "new keyword is only allowed for Date, URL, URLSearchParams, Map, Set, Array",
975                n.span,
976            );
977        }
978        n.visit_children_with(self);
979    }
980
981    // Check for this keyword
982    fn visit_this_expr(&mut self, n: &ThisExpr) {
983        self.add_violation(
984            SafetyViolationType::ThisKeyword,
985            "'this' keyword is not allowed",
986            n.span,
987        );
988    }
989
990    // Check for class declarations
991    fn visit_class_decl(&mut self, n: &ClassDecl) {
992        self.add_violation(
993            SafetyViolationType::ClassDeclaration,
994            "class declarations are not allowed",
995            n.class.span,
996        );
997        n.visit_children_with(self);
998    }
999
1000    // Check for with statement
1001    fn visit_with_stmt(&mut self, n: &WithStmt) {
1002        self.add_violation(
1003            SafetyViolationType::WithStatement,
1004            "'with' statement is not allowed",
1005            n.span,
1006        );
1007        n.visit_children_with(self);
1008    }
1009
1010    // Track async
1011    fn visit_await_expr(&mut self, n: &AwaitExpr) {
1012        self.uses_async = true;
1013        n.visit_children_with(self);
1014    }
1015
1016    // Track variable declarations
1017    fn visit_var_decl(&mut self, n: &VarDecl) {
1018        for decl in &n.decls {
1019            if let Pat::Ident(ident) = &decl.name {
1020                self.variable_names.push(ident.id.sym.to_string());
1021            }
1022        }
1023        n.visit_children_with(self);
1024    }
1025
1026    // Check for generators
1027    fn visit_function(&mut self, n: &Function) {
1028        if n.is_generator {
1029            self.add_violation(
1030                SafetyViolationType::Generator,
1031                "generator functions are not allowed",
1032                n.span,
1033            );
1034        }
1035        n.visit_children_with(self);
1036    }
1037
1038    // Check for delete operator
1039    fn visit_unary_expr(&mut self, n: &UnaryExpr) {
1040        if n.op == UnaryOp::Delete {
1041            self.add_violation(
1042                SafetyViolationType::DeleteOperator,
1043                "'delete' operator is not allowed",
1044                n.span,
1045            );
1046        }
1047        n.visit_children_with(self);
1048    }
1049
1050    // Check for prototype manipulation
1051    fn visit_member_expr(&mut self, n: &MemberExpr) {
1052        if let MemberProp::Ident(ident) = &n.prop {
1053            let name = ident.sym.as_ref();
1054            if name == "__proto__" || name == "prototype" {
1055                self.add_violation(
1056                    SafetyViolationType::PrototypeManipulation,
1057                    "prototype manipulation is not allowed",
1058                    n.span,
1059                );
1060            }
1061        }
1062        n.visit_children_with(self);
1063    }
1064
1065    // Track return statements to detect spread operators in output
1066    fn visit_return_stmt(&mut self, n: &ReturnStmt) {
1067        self.in_return_context = true;
1068        n.visit_children_with(self);
1069        self.in_return_context = false;
1070    }
1071
1072    // Detect spread operators in return values (potential field leakage)
1073    fn visit_spread_element(&mut self, n: &SpreadElement) {
1074        if self.in_return_context {
1075            self.has_spread_in_return = true;
1076        }
1077        n.visit_children_with(self);
1078    }
1079}
1080
1081/// Convert a safety violation type to a security issue type.
1082fn violation_to_security_issue(violation: SafetyViolationType) -> SecurityIssueType {
1083    match violation {
1084        SafetyViolationType::DynamicCodeExecution => SecurityIssueType::PotentialInjection,
1085        SafetyViolationType::PrototypeManipulation => SecurityIssueType::PotentialInjection,
1086        SafetyViolationType::UnboundedLoop | SafetyViolationType::UnboundedForLoop => {
1087            SecurityIssueType::UnboundedQuery
1088        },
1089        _ => SecurityIssueType::HighComplexity,
1090    }
1091}
1092
1093#[cfg(test)]
1094mod tests {
1095    use super::*;
1096
1097    #[test]
1098    fn test_simple_api_call() {
1099        let validator = JavaScriptValidator::default();
1100        let code = r#"
1101            const response = await api.get("/users");
1102            return response.data;
1103        "#;
1104
1105        let info = validator.validate(code).unwrap();
1106        assert!(info.is_read_only);
1107        assert_eq!(info.api_calls.len(), 1);
1108        assert_eq!(info.api_calls[0].method, HttpMethod::Get);
1109        assert!(info.endpoints_accessed.contains("/users"));
1110    }
1111
1112    #[test]
1113    fn test_multiple_api_calls() {
1114        let validator = JavaScriptValidator::default();
1115        let code = r#"
1116            const user = await api.get("/users/123");
1117            const orders = await api.get(`/users/${user.id}/orders`);
1118            return { user, orders };
1119        "#;
1120
1121        let info = validator.validate(code).unwrap();
1122        assert!(info.is_read_only);
1123        assert_eq!(info.api_calls.len(), 2);
1124        assert!(info.api_calls[1].is_dynamic_path);
1125    }
1126
1127    #[test]
1128    fn test_mutation_detection() {
1129        let validator = JavaScriptValidator::default();
1130        let code = r#"
1131            const result = await api.post("/users", { name: "test" });
1132            return result;
1133        "#;
1134
1135        let info = validator.validate(code).unwrap();
1136        assert!(!info.is_read_only);
1137        assert_eq!(info.api_calls[0].method, HttpMethod::Post);
1138    }
1139
1140    #[test]
1141    fn test_reject_eval() {
1142        let validator = JavaScriptValidator::default();
1143        let code = r#"
1144            const result = eval("api.get('/users')");
1145        "#;
1146
1147        let result = validator.validate(code);
1148        assert!(result.is_err());
1149    }
1150
1151    #[test]
1152    fn test_reject_while_loop() {
1153        let validator = JavaScriptValidator::default();
1154        let code = r#"
1155            let i = 0;
1156            while (i < 10) {
1157                await api.get("/data");
1158                i++;
1159            }
1160        "#;
1161
1162        let result = validator.validate(code);
1163        assert!(result.is_err());
1164    }
1165
1166    #[test]
1167    fn test_allow_bounded_for_of() {
1168        let validator = JavaScriptValidator::default();
1169        let code = r#"
1170            const results = [];
1171            for (const id of userIds.slice(0, 10)) {
1172                const user = await api.get(`/users/${id}`);
1173                results.push(user);
1174            }
1175            return results;
1176        "#;
1177
1178        let info = validator.validate(code).unwrap();
1179        assert!(info.all_loops_bounded);
1180        assert_eq!(info.loop_count, 1);
1181    }
1182
1183    #[test]
1184    fn test_reject_import() {
1185        let validator = JavaScriptValidator::default();
1186        let code = r#"
1187            import axios from 'axios';
1188            const result = await api.get("/users");
1189        "#;
1190
1191        let result = validator.validate(code);
1192        assert!(result.is_err());
1193    }
1194
1195    #[test]
1196    fn test_allow_arrow_functions() {
1197        let validator = JavaScriptValidator::default();
1198        let code = r#"
1199            const users = await api.get("/users");
1200            const names = users.data.map(u => u.name);
1201            return names;
1202        "#;
1203
1204        let info = validator.validate(code).unwrap();
1205        assert!(info.violations.is_empty());
1206    }
1207
1208    #[test]
1209    fn test_reject_function_declaration() {
1210        let validator = JavaScriptValidator::default();
1211        let code = r#"
1212            function fetchUser(id) {
1213                return api.get(`/users/${id}`);
1214            }
1215        "#;
1216
1217        let result = validator.validate(code);
1218        assert!(result.is_err());
1219    }
1220
1221    #[test]
1222    fn test_security_analysis_sensitive_endpoint() {
1223        let validator = JavaScriptValidator::default();
1224        let code = r#"
1225            const config = await api.get("/admin/config");
1226            return config;
1227        "#;
1228
1229        let info = validator.validate(code).unwrap();
1230        let analysis = validator.analyze_security(&info);
1231
1232        assert!(analysis
1233            .potential_issues
1234            .iter()
1235            .any(|i| matches!(i.issue_type, SecurityIssueType::SensitiveFields)));
1236    }
1237
1238    #[test]
1239    fn test_parse_returns_annotation_triple_slash() {
1240        let validator = JavaScriptValidator::default();
1241        let code = r#"
1242            /// @returns { users: Array<{ id: string, name: string }> }
1243            const users = await api.get("/users");
1244            return { users: users.map(u => ({ id: u.id, name: u.name })) };
1245        "#;
1246
1247        let info = validator.validate(code).unwrap();
1248        assert!(info.output_declaration.has_declaration);
1249        assert!(info.output_declaration.declared_fields.contains("id"));
1250        assert!(info.output_declaration.declared_fields.contains("name"));
1251        assert!(info.output_declaration.declared_fields.contains("users"));
1252    }
1253
1254    #[test]
1255    fn test_parse_returns_annotation_double_slash() {
1256        let validator = JavaScriptValidator::default();
1257        let code = r#"
1258            // @returns { products: Array<{ id: string, name: string, price: number }> }
1259            const products = await api.get("/products");
1260            return { products: products.map(p => ({ id: p.id, name: p.name, price: p.price })) };
1261        "#;
1262
1263        let info = validator.validate(code).unwrap();
1264        assert!(info.output_declaration.has_declaration);
1265        assert!(info.output_declaration.declared_fields.contains("id"));
1266        assert!(info.output_declaration.declared_fields.contains("name"));
1267        assert!(info.output_declaration.declared_fields.contains("price"));
1268        assert!(info.output_declaration.declared_fields.contains("products"));
1269    }
1270
1271    #[test]
1272    fn test_parse_returns_annotation_jsdoc() {
1273        let validator = JavaScriptValidator::default();
1274        let code = r#"
1275            /** @returns { user: { id: string, email: string } } */
1276            const user = await api.get("/users/123");
1277            return { user: { id: user.id, email: user.email } };
1278        "#;
1279
1280        let info = validator.validate(code).unwrap();
1281        assert!(info.output_declaration.has_declaration);
1282        assert!(info.output_declaration.declared_fields.contains("id"));
1283        assert!(info.output_declaration.declared_fields.contains("email"));
1284        assert!(info.output_declaration.declared_fields.contains("user"));
1285    }
1286
1287    #[test]
1288    fn test_no_returns_annotation() {
1289        let validator = JavaScriptValidator::default();
1290        let code = r#"
1291            const users = await api.get("/users");
1292            return users;
1293        "#;
1294
1295        let info = validator.validate(code).unwrap();
1296        assert!(!info.output_declaration.has_declaration);
1297        assert!(info.output_declaration.declared_fields.is_empty());
1298    }
1299
1300    #[test]
1301    fn test_spread_operator_detection() {
1302        let validator = JavaScriptValidator::default();
1303        let code = r#"
1304            const user = await api.get("/users/123");
1305            return { ...user, computed: "value" };
1306        "#;
1307
1308        let info = validator.validate(code).unwrap();
1309        assert!(info.has_output_spread_risk);
1310    }
1311
1312    #[test]
1313    fn test_no_spread_operator_in_return() {
1314        let validator = JavaScriptValidator::default();
1315        let code = r#"
1316            const user = await api.get("/users/123");
1317            return { id: user.id, name: user.name };
1318        "#;
1319
1320        let info = validator.validate(code).unwrap();
1321        assert!(!info.has_output_spread_risk);
1322    }
1323
1324    #[test]
1325    fn test_check_output_against_blocklist() {
1326        let declaration = OutputDeclaration {
1327            has_declaration: true,
1328            type_string: Some("{ id: string, ssn: string }".to_string()),
1329            declared_fields: ["id", "ssn"].iter().map(|s| s.to_string()).collect(),
1330            has_spread_risk: false,
1331        };
1332
1333        let blocked_fields: HashSet<String> =
1334            ["ssn", "password"].iter().map(|s| s.to_string()).collect();
1335
1336        let violations =
1337            JavaScriptValidator::check_output_against_blocklist(&declaration, &blocked_fields);
1338        assert_eq!(violations.len(), 1);
1339        assert!(violations[0].contains("ssn"));
1340    }
1341
1342    #[test]
1343    fn test_check_output_against_wildcard_blocklist() {
1344        let declaration = OutputDeclaration {
1345            has_declaration: true,
1346            type_string: Some("{ user: { id: string, salary: number } }".to_string()),
1347            declared_fields: ["user", "id", "salary"]
1348                .iter()
1349                .map(|s| s.to_string())
1350                .collect(),
1351            has_spread_risk: false,
1352        };
1353
1354        let blocked_fields: HashSet<String> = ["*.salary"].iter().map(|s| s.to_string()).collect();
1355
1356        let violations =
1357            JavaScriptValidator::check_output_against_blocklist(&declaration, &blocked_fields);
1358        assert_eq!(violations.len(), 1);
1359        assert!(violations[0].contains("salary"));
1360    }
1361}