Skip to main content

pmcp_code_mode/
graphql.rs

1//! GraphQL-specific validation for Code Mode.
2
3use crate::types::{
4    CodeType, Complexity, SecurityAnalysis, SecurityIssue, SecurityIssueType, ValidationError,
5};
6use graphql_parser::query::{Definition, Document, OperationDefinition, Selection, SelectionSet};
7use std::collections::HashSet;
8
9/// Information extracted from a parsed GraphQL query.
10#[derive(Debug, Clone, Default)]
11pub struct GraphQLQueryInfo {
12    /// The operation type (query, mutation, subscription)
13    pub operation_type: GraphQLOperationType,
14
15    /// Name of the operation (if named)
16    pub operation_name: Option<String>,
17
18    /// Root fields being queried
19    pub root_fields: Vec<String>,
20
21    /// All types accessed in the query
22    pub types_accessed: HashSet<String>,
23
24    /// All fields accessed in the query
25    pub fields_accessed: HashSet<String>,
26
27    /// Whether the query has variables
28    pub has_variables: bool,
29
30    /// Variable names
31    pub variable_names: Vec<String>,
32
33    /// Maximum depth of the query
34    pub max_depth: usize,
35
36    /// Whether query has fragments
37    pub has_fragments: bool,
38
39    /// Fragment names used
40    pub fragment_names: Vec<String>,
41
42    /// Whether the query contains introspection fields (__schema, __type)
43    pub has_introspection: bool,
44}
45
46/// GraphQL operation type.
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
48pub enum GraphQLOperationType {
49    #[default]
50    Query,
51    Mutation,
52    Subscription,
53}
54
55impl GraphQLOperationType {
56    /// Whether this operation is read-only.
57    pub fn is_read_only(&self) -> bool {
58        matches!(self, GraphQLOperationType::Query)
59    }
60}
61
62/// GraphQL query validator.
63pub struct GraphQLValidator {
64    /// Known sensitive field patterns (e.g., "password", "ssn", "creditCard")
65    sensitive_fields: Vec<String>,
66
67    /// Maximum allowed query depth
68    max_depth: usize,
69
70    /// Maximum allowed complexity
71    max_complexity: usize,
72}
73
74impl Default for GraphQLValidator {
75    fn default() -> Self {
76        Self {
77            sensitive_fields: vec![
78                "password".into(),
79                "ssn".into(),
80                "socialSecurityNumber".into(),
81                "creditCard".into(),
82                "creditCardNumber".into(),
83                "apiKey".into(),
84                "secret".into(),
85                "token".into(),
86            ],
87            max_depth: 10,
88            max_complexity: 100,
89        }
90    }
91}
92
93impl GraphQLValidator {
94    /// Create a new validator with custom settings.
95    pub fn new(sensitive_fields: Vec<String>, max_depth: usize, max_complexity: usize) -> Self {
96        Self {
97            sensitive_fields,
98            max_depth,
99            max_complexity,
100        }
101    }
102
103    /// Parse and validate a GraphQL query.
104    pub fn validate(&self, query: &str) -> Result<GraphQLQueryInfo, ValidationError> {
105        // Parse the query
106        let document = graphql_parser::parse_query::<&str>(query).map_err(|e| {
107            ValidationError::ParseError {
108                message: e.to_string(),
109                line: 0,
110                column: 0,
111            }
112        })?;
113
114        // Extract query information
115        let info = self.extract_query_info(&document)?;
116
117        // Validate depth
118        if info.max_depth > self.max_depth {
119            return Err(ValidationError::SecurityError {
120                message: format!(
121                    "Query depth {} exceeds maximum allowed depth {}",
122                    info.max_depth, self.max_depth
123                ),
124                issue: SecurityIssueType::DeepNesting,
125            });
126        }
127
128        Ok(info)
129    }
130
131    /// Extract information from a parsed document.
132    fn extract_query_info<'a>(
133        &self,
134        document: &'a Document<'a, &'a str>,
135    ) -> Result<GraphQLQueryInfo, ValidationError> {
136        let mut info = GraphQLQueryInfo::default();
137        let mut found_operation = false;
138
139        for definition in &document.definitions {
140            match definition {
141                Definition::Operation(op) => {
142                    if found_operation {
143                        return Err(ValidationError::ParseError {
144                            message: "Multiple operations not supported".into(),
145                            line: 0,
146                            column: 0,
147                        });
148                    }
149                    found_operation = true;
150                    self.extract_operation_info(op, &mut info)?;
151                }
152                Definition::Fragment(frag) => {
153                    info.has_fragments = true;
154                    info.fragment_names.push(frag.name.to_string());
155                }
156            }
157        }
158
159        if !found_operation {
160            return Err(ValidationError::ParseError {
161                message: "No operation found in query".into(),
162                line: 0,
163                column: 0,
164            });
165        }
166
167        Ok(info)
168    }
169
170    /// Extract information from an operation.
171    fn extract_operation_info<'a>(
172        &self,
173        op: &'a OperationDefinition<'a, &'a str>,
174        info: &mut GraphQLQueryInfo,
175    ) -> Result<(), ValidationError> {
176        match op {
177            OperationDefinition::Query(q) => {
178                info.operation_type = GraphQLOperationType::Query;
179                info.operation_name = q.name.map(|s| s.to_string());
180                info.has_variables = !q.variable_definitions.is_empty();
181                info.variable_names = q
182                    .variable_definitions
183                    .iter()
184                    .map(|v| v.name.to_string())
185                    .collect();
186                self.extract_selection_set(&q.selection_set, info, 1)?;
187            }
188            OperationDefinition::Mutation(m) => {
189                info.operation_type = GraphQLOperationType::Mutation;
190                info.operation_name = m.name.map(|s| s.to_string());
191                info.has_variables = !m.variable_definitions.is_empty();
192                info.variable_names = m
193                    .variable_definitions
194                    .iter()
195                    .map(|v| v.name.to_string())
196                    .collect();
197                self.extract_selection_set(&m.selection_set, info, 1)?;
198            }
199            OperationDefinition::Subscription(s) => {
200                info.operation_type = GraphQLOperationType::Subscription;
201                info.operation_name = s.name.map(|s| s.to_string());
202                info.has_variables = !s.variable_definitions.is_empty();
203                info.variable_names = s
204                    .variable_definitions
205                    .iter()
206                    .map(|v| v.name.to_string())
207                    .collect();
208                self.extract_selection_set(&s.selection_set, info, 1)?;
209            }
210            OperationDefinition::SelectionSet(ss) => {
211                // Anonymous query
212                info.operation_type = GraphQLOperationType::Query;
213                self.extract_selection_set(ss, info, 1)?;
214            }
215        }
216        Ok(())
217    }
218
219    /// Extract information from a selection set.
220    fn extract_selection_set<'a>(
221        &self,
222        selection_set: &'a SelectionSet<'a, &'a str>,
223        info: &mut GraphQLQueryInfo,
224        depth: usize,
225    ) -> Result<(), ValidationError> {
226        info.max_depth = info.max_depth.max(depth);
227
228        for selection in &selection_set.items {
229            match selection {
230                Selection::Field(field) => {
231                    let field_name = field.name.to_string();
232
233                    // Check for introspection fields
234                    if field_name.starts_with("__") {
235                        info.has_introspection = true;
236                    }
237
238                    // Track root fields
239                    if depth == 1 {
240                        info.root_fields.push(field_name.clone());
241                    }
242
243                    // Track all fields
244                    info.fields_accessed.insert(field_name.clone());
245
246                    // Check for type name hints in field name (e.g., "users" -> "User")
247                    // This is a heuristic - real implementation would use schema
248                    if depth == 1 {
249                        let type_name = field_name_to_type(&field_name);
250                        info.types_accessed.insert(type_name);
251                    }
252
253                    // Recurse into nested selections
254                    self.extract_selection_set(&field.selection_set, info, depth + 1)?;
255                }
256                Selection::FragmentSpread(spread) => {
257                    info.fragment_names.push(spread.fragment_name.to_string());
258                }
259                Selection::InlineFragment(inline) => {
260                    if let Some(type_cond) = &inline.type_condition {
261                        info.types_accessed.insert(type_cond.to_string());
262                    }
263                    self.extract_selection_set(&inline.selection_set, info, depth + 1)?;
264                }
265            }
266        }
267
268        Ok(())
269    }
270
271    /// Perform security analysis on query info.
272    pub fn analyze_security(&self, info: &GraphQLQueryInfo) -> SecurityAnalysis {
273        let mut analysis = SecurityAnalysis {
274            is_read_only: info.operation_type.is_read_only(),
275            tables_accessed: info.types_accessed.clone(),
276            fields_accessed: info.fields_accessed.clone(),
277            has_aggregation: false,
278            has_subqueries: info.max_depth > 3,
279            estimated_complexity: self.estimate_complexity(info),
280            potential_issues: Vec::new(),
281            estimated_rows: None,
282        };
283
284        // Check for sensitive fields
285        for field in &info.fields_accessed {
286            let field_lower = field.to_lowercase();
287            if self
288                .sensitive_fields
289                .iter()
290                .any(|s| field_lower.contains(&s.to_lowercase()))
291            {
292                analysis.potential_issues.push(SecurityIssue::new(
293                    SecurityIssueType::SensitiveFields,
294                    format!("Query accesses potentially sensitive field: {}", field),
295                ));
296            }
297        }
298
299        // Check for deep nesting
300        if info.max_depth > 5 {
301            analysis.potential_issues.push(SecurityIssue::new(
302                SecurityIssueType::DeepNesting,
303                format!("Query has deep nesting (depth: {})", info.max_depth),
304            ));
305        }
306
307        // Check for high complexity
308        if matches!(analysis.estimated_complexity, Complexity::High) {
309            analysis.potential_issues.push(SecurityIssue::new(
310                SecurityIssueType::HighComplexity,
311                "Query has high complexity",
312            ));
313        }
314
315        analysis
316    }
317
318    /// Estimate query complexity.
319    fn estimate_complexity(&self, info: &GraphQLQueryInfo) -> Complexity {
320        let field_count = info.fields_accessed.len();
321        let type_count = info.types_accessed.len();
322        let depth = info.max_depth;
323
324        // Simple heuristic based on fields, types, and depth
325        let complexity_score = field_count + (type_count * 2) + (depth * depth);
326
327        if complexity_score > self.max_complexity {
328            Complexity::High
329        } else if complexity_score > self.max_complexity / 2 {
330            Complexity::Medium
331        } else {
332            Complexity::Low
333        }
334    }
335
336    /// Convert query info to CodeType.
337    pub fn to_code_type(&self, info: &GraphQLQueryInfo) -> CodeType {
338        match info.operation_type {
339            GraphQLOperationType::Query => CodeType::GraphQLQuery,
340            GraphQLOperationType::Mutation => CodeType::GraphQLMutation,
341            GraphQLOperationType::Subscription => CodeType::GraphQLQuery, // Treat as query for now
342        }
343    }
344}
345
346/// Convert a field name to a probable type name.
347///
348/// e.g., "users" -> "User", "orderItems" -> "OrderItem"
349pub(crate) fn field_name_to_type(field_name: &str) -> String {
350    // Remove trailing 's' for plurals and capitalize
351    let singular = if field_name.ends_with('s') && field_name.len() > 1 {
352        &field_name[..field_name.len() - 1]
353    } else {
354        field_name
355    };
356
357    // Capitalize first letter
358    let mut chars: Vec<char> = singular.chars().collect();
359    if let Some(first) = chars.first_mut() {
360        *first = first.to_uppercase().next().unwrap_or(*first);
361    }
362    chars.into_iter().collect()
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn test_simple_query_parsing() {
371        let validator = GraphQLValidator::default();
372        let query = "query { users { id name email } }";
373
374        let info = validator.validate(query).unwrap();
375
376        assert_eq!(info.operation_type, GraphQLOperationType::Query);
377        assert!(info.root_fields.contains(&"users".to_string()));
378        assert!(info.fields_accessed.contains("id"));
379        assert!(info.fields_accessed.contains("name"));
380        assert!(info.fields_accessed.contains("email"));
381    }
382
383    #[test]
384    fn test_mutation_detection() {
385        let validator = GraphQLValidator::default();
386        let query = "mutation { createUser(name: \"test\") { id } }";
387
388        let info = validator.validate(query).unwrap();
389
390        assert_eq!(info.operation_type, GraphQLOperationType::Mutation);
391        assert!(!info.operation_type.is_read_only());
392    }
393
394    #[test]
395    fn test_nested_query() {
396        let validator = GraphQLValidator::default();
397        let query = r#"
398            query {
399                users {
400                    id
401                    orders {
402                        id
403                        items {
404                            product {
405                                name
406                            }
407                        }
408                    }
409                }
410            }
411        "#;
412
413        let info = validator.validate(query).unwrap();
414
415        assert!(info.max_depth >= 4);
416    }
417
418    #[test]
419    fn test_sensitive_field_detection() {
420        let validator = GraphQLValidator::default();
421        let query = "query { users { id name password } }";
422
423        let info = validator.validate(query).unwrap();
424        let analysis = validator.analyze_security(&info);
425
426        assert!(analysis
427            .potential_issues
428            .iter()
429            .any(|i| matches!(i.issue_type, SecurityIssueType::SensitiveFields)));
430    }
431
432    #[test]
433    fn test_variables() {
434        let validator = GraphQLValidator::default();
435        let query = "query GetUser($id: ID!) { user(id: $id) { name } }";
436
437        let info = validator.validate(query).unwrap();
438
439        assert!(info.has_variables);
440        assert!(info.variable_names.contains(&"id".to_string()));
441        assert_eq!(info.operation_name, Some("GetUser".to_string()));
442    }
443
444    #[test]
445    fn test_field_name_to_type() {
446        assert_eq!(field_name_to_type("users"), "User");
447        assert_eq!(field_name_to_type("orderItems"), "OrderItem");
448        assert_eq!(field_name_to_type("user"), "User");
449    }
450}