Skip to main content

pmcp_code_mode/
graphql.rs

1//! GraphQL-specific validation for Code Mode.
2
3use crate::types::{
4    CodeType, Complexity, SecurityAnalysis, SecurityIssue, SecurityIssueType, ValidationError,
5};
6use graphql_parser::query::{Definition, Document, OperationDefinition, Selection, SelectionSet};
7use std::collections::HashSet;
8
9/// Information extracted from a parsed GraphQL query.
10#[derive(Debug, Clone, Default)]
11pub struct GraphQLQueryInfo {
12    /// The operation type (query, mutation, subscription)
13    pub operation_type: GraphQLOperationType,
14
15    /// Name of the operation (if named)
16    pub operation_name: Option<String>,
17
18    /// Root fields being queried
19    pub root_fields: Vec<String>,
20
21    /// All types accessed in the query
22    pub types_accessed: HashSet<String>,
23
24    /// All fields accessed in the query
25    pub fields_accessed: HashSet<String>,
26
27    /// Whether the query has variables
28    pub has_variables: bool,
29
30    /// Variable names
31    pub variable_names: Vec<String>,
32
33    /// Maximum depth of the query
34    pub max_depth: usize,
35
36    /// Whether query has fragments
37    pub has_fragments: bool,
38
39    /// Fragment names used
40    pub fragment_names: Vec<String>,
41
42    /// Whether the query contains introspection fields (__schema, __type)
43    pub has_introspection: bool,
44}
45
46/// GraphQL operation type.
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
48pub enum GraphQLOperationType {
49    #[default]
50    Query,
51    Mutation,
52    Subscription,
53}
54
55impl GraphQLOperationType {
56    /// Whether this operation is read-only.
57    pub fn is_read_only(&self) -> bool {
58        matches!(self, GraphQLOperationType::Query)
59    }
60}
61
62/// GraphQL query validator.
63pub struct GraphQLValidator {
64    /// Known sensitive field patterns (e.g., "password", "ssn", "creditCard")
65    sensitive_fields: Vec<String>,
66
67    /// Maximum allowed query depth
68    max_depth: usize,
69
70    /// Maximum allowed complexity
71    max_complexity: usize,
72}
73
74impl Default for GraphQLValidator {
75    fn default() -> Self {
76        Self {
77            sensitive_fields: vec![
78                "password".into(),
79                "ssn".into(),
80                "socialSecurityNumber".into(),
81                "creditCard".into(),
82                "creditCardNumber".into(),
83                "apiKey".into(),
84                "secret".into(),
85                "token".into(),
86            ],
87            max_depth: 10,
88            max_complexity: 100,
89        }
90    }
91}
92
93impl GraphQLValidator {
94    /// Create a new validator with custom settings.
95    pub fn new(sensitive_fields: Vec<String>, max_depth: usize, max_complexity: usize) -> Self {
96        Self {
97            sensitive_fields: sensitive_fields
98                .into_iter()
99                .map(|s| s.to_lowercase())
100                .collect(),
101            max_depth,
102            max_complexity,
103        }
104    }
105
106    /// Parse and validate a GraphQL query.
107    pub fn validate(&self, query: &str) -> Result<GraphQLQueryInfo, ValidationError> {
108        // Parse the query
109        let document = graphql_parser::parse_query::<&str>(query).map_err(|e| {
110            ValidationError::ParseError {
111                message: e.to_string(),
112                line: 0,
113                column: 0,
114            }
115        })?;
116
117        // Extract query information
118        let info = self.extract_query_info(&document)?;
119
120        // Validate depth
121        if info.max_depth > self.max_depth {
122            return Err(ValidationError::SecurityError {
123                message: format!(
124                    "Query depth {} exceeds maximum allowed depth {}",
125                    info.max_depth, self.max_depth
126                ),
127                issue: SecurityIssueType::DeepNesting,
128            });
129        }
130
131        Ok(info)
132    }
133
134    /// Extract information from a parsed document.
135    fn extract_query_info<'a>(
136        &self,
137        document: &'a Document<'a, &'a str>,
138    ) -> Result<GraphQLQueryInfo, ValidationError> {
139        let mut info = GraphQLQueryInfo::default();
140        let mut found_operation = false;
141
142        for definition in &document.definitions {
143            match definition {
144                Definition::Operation(op) => {
145                    if found_operation {
146                        return Err(ValidationError::ParseError {
147                            message: "Multiple operations not supported".into(),
148                            line: 0,
149                            column: 0,
150                        });
151                    }
152                    found_operation = true;
153                    self.extract_operation_info(op, &mut info)?;
154                },
155                Definition::Fragment(frag) => {
156                    info.has_fragments = true;
157                    info.fragment_names.push(frag.name.to_string());
158                },
159            }
160        }
161
162        if !found_operation {
163            return Err(ValidationError::ParseError {
164                message: "No operation found in query".into(),
165                line: 0,
166                column: 0,
167            });
168        }
169
170        Ok(info)
171    }
172
173    /// Extract information from an operation.
174    fn extract_operation_info<'a>(
175        &self,
176        op: &'a OperationDefinition<'a, &'a str>,
177        info: &mut GraphQLQueryInfo,
178    ) -> Result<(), ValidationError> {
179        match op {
180            OperationDefinition::Query(q) => {
181                info.operation_type = GraphQLOperationType::Query;
182                info.operation_name = q.name.map(|s| s.to_string());
183                info.has_variables = !q.variable_definitions.is_empty();
184                info.variable_names = q
185                    .variable_definitions
186                    .iter()
187                    .map(|v| v.name.to_string())
188                    .collect();
189                self.extract_selection_set(&q.selection_set, info, 1)?;
190            },
191            OperationDefinition::Mutation(m) => {
192                info.operation_type = GraphQLOperationType::Mutation;
193                info.operation_name = m.name.map(|s| s.to_string());
194                info.has_variables = !m.variable_definitions.is_empty();
195                info.variable_names = m
196                    .variable_definitions
197                    .iter()
198                    .map(|v| v.name.to_string())
199                    .collect();
200                self.extract_selection_set(&m.selection_set, info, 1)?;
201            },
202            OperationDefinition::Subscription(s) => {
203                info.operation_type = GraphQLOperationType::Subscription;
204                info.operation_name = s.name.map(|s| s.to_string());
205                info.has_variables = !s.variable_definitions.is_empty();
206                info.variable_names = s
207                    .variable_definitions
208                    .iter()
209                    .map(|v| v.name.to_string())
210                    .collect();
211                self.extract_selection_set(&s.selection_set, info, 1)?;
212            },
213            OperationDefinition::SelectionSet(ss) => {
214                // Anonymous query
215                info.operation_type = GraphQLOperationType::Query;
216                self.extract_selection_set(ss, info, 1)?;
217            },
218        }
219        Ok(())
220    }
221
222    /// Extract information from a selection set.
223    fn extract_selection_set<'a>(
224        &self,
225        selection_set: &'a SelectionSet<'a, &'a str>,
226        info: &mut GraphQLQueryInfo,
227        depth: usize,
228    ) -> Result<(), ValidationError> {
229        info.max_depth = info.max_depth.max(depth);
230
231        for selection in &selection_set.items {
232            match selection {
233                Selection::Field(field) => {
234                    let field_name = field.name.to_string();
235
236                    // Check for introspection fields
237                    if field_name.starts_with("__") {
238                        info.has_introspection = true;
239                    }
240
241                    // Track root fields
242                    if depth == 1 {
243                        info.root_fields.push(field_name.clone());
244                    }
245
246                    // Track all fields
247                    info.fields_accessed.insert(field_name.clone());
248
249                    // Check for type name hints in field name (e.g., "users" -> "User")
250                    // This is a heuristic - real implementation would use schema
251                    if depth == 1 {
252                        let type_name = field_name_to_type(&field_name);
253                        info.types_accessed.insert(type_name);
254                    }
255
256                    // Recurse into nested selections
257                    self.extract_selection_set(&field.selection_set, info, depth + 1)?;
258                },
259                Selection::FragmentSpread(spread) => {
260                    info.fragment_names.push(spread.fragment_name.to_string());
261                },
262                Selection::InlineFragment(inline) => {
263                    if let Some(type_cond) = &inline.type_condition {
264                        info.types_accessed.insert(type_cond.to_string());
265                    }
266                    self.extract_selection_set(&inline.selection_set, info, depth + 1)?;
267                },
268            }
269        }
270
271        Ok(())
272    }
273
274    /// Perform security analysis on query info.
275    pub fn analyze_security(&self, info: &GraphQLQueryInfo) -> SecurityAnalysis {
276        let mut analysis = SecurityAnalysis {
277            is_read_only: info.operation_type.is_read_only(),
278            tables_accessed: info.types_accessed.clone(),
279            fields_accessed: info.fields_accessed.clone(),
280            has_aggregation: false,
281            has_subqueries: info.max_depth > 3,
282            estimated_complexity: self.estimate_complexity(info),
283            potential_issues: Vec::new(),
284            estimated_rows: None,
285        };
286
287        // Check for sensitive fields
288        for field in &info.fields_accessed {
289            let field_lower = field.to_lowercase();
290            if self
291                .sensitive_fields
292                .iter()
293                .any(|s| field_lower.contains(s))
294            {
295                analysis.potential_issues.push(SecurityIssue::new(
296                    SecurityIssueType::SensitiveFields,
297                    format!("Query accesses potentially sensitive field: {}", field),
298                ));
299            }
300        }
301
302        // Check for deep nesting
303        if info.max_depth > 5 {
304            analysis.potential_issues.push(SecurityIssue::new(
305                SecurityIssueType::DeepNesting,
306                format!("Query has deep nesting (depth: {})", info.max_depth),
307            ));
308        }
309
310        // Check for high complexity
311        if matches!(analysis.estimated_complexity, Complexity::High) {
312            analysis.potential_issues.push(SecurityIssue::new(
313                SecurityIssueType::HighComplexity,
314                "Query has high complexity",
315            ));
316        }
317
318        analysis
319    }
320
321    /// Estimate query complexity.
322    fn estimate_complexity(&self, info: &GraphQLQueryInfo) -> Complexity {
323        let field_count = info.fields_accessed.len();
324        let type_count = info.types_accessed.len();
325        let depth = info.max_depth;
326
327        // Simple heuristic based on fields, types, and depth
328        let complexity_score = field_count + (type_count * 2) + (depth * depth);
329
330        if complexity_score > self.max_complexity {
331            Complexity::High
332        } else if complexity_score > self.max_complexity / 2 {
333            Complexity::Medium
334        } else {
335            Complexity::Low
336        }
337    }
338
339    /// Convert query info to CodeType.
340    pub fn to_code_type(&self, info: &GraphQLQueryInfo) -> CodeType {
341        match info.operation_type {
342            GraphQLOperationType::Query => CodeType::GraphQLQuery,
343            GraphQLOperationType::Mutation => CodeType::GraphQLMutation,
344            GraphQLOperationType::Subscription => CodeType::GraphQLQuery, // Treat as query for now
345        }
346    }
347}
348
349/// Convert a field name to a probable type name.
350///
351/// e.g., "users" -> "User", "orderItems" -> "OrderItem"
352pub(crate) fn field_name_to_type(field_name: &str) -> String {
353    // Remove trailing 's' for plurals and capitalize
354    let singular = if field_name.ends_with('s') && field_name.len() > 1 {
355        &field_name[..field_name.len() - 1]
356    } else {
357        field_name
358    };
359
360    // Capitalize first letter
361    let mut c = singular.chars();
362    match c.next() {
363        None => String::new(),
364        Some(f) => f.to_uppercase().to_string() + c.as_str(),
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn test_simple_query_parsing() {
374        let validator = GraphQLValidator::default();
375        let query = "query { users { id name email } }";
376
377        let info = validator.validate(query).unwrap();
378
379        assert_eq!(info.operation_type, GraphQLOperationType::Query);
380        assert!(info.root_fields.contains(&"users".to_string()));
381        assert!(info.fields_accessed.contains("id"));
382        assert!(info.fields_accessed.contains("name"));
383        assert!(info.fields_accessed.contains("email"));
384    }
385
386    #[test]
387    fn test_mutation_detection() {
388        let validator = GraphQLValidator::default();
389        let query = "mutation { createUser(name: \"test\") { id } }";
390
391        let info = validator.validate(query).unwrap();
392
393        assert_eq!(info.operation_type, GraphQLOperationType::Mutation);
394        assert!(!info.operation_type.is_read_only());
395    }
396
397    #[test]
398    fn test_nested_query() {
399        let validator = GraphQLValidator::default();
400        let query = r#"
401            query {
402                users {
403                    id
404                    orders {
405                        id
406                        items {
407                            product {
408                                name
409                            }
410                        }
411                    }
412                }
413            }
414        "#;
415
416        let info = validator.validate(query).unwrap();
417
418        assert!(info.max_depth >= 4);
419    }
420
421    #[test]
422    fn test_sensitive_field_detection() {
423        let validator = GraphQLValidator::default();
424        let query = "query { users { id name password } }";
425
426        let info = validator.validate(query).unwrap();
427        let analysis = validator.analyze_security(&info);
428
429        assert!(analysis
430            .potential_issues
431            .iter()
432            .any(|i| matches!(i.issue_type, SecurityIssueType::SensitiveFields)));
433    }
434
435    #[test]
436    fn test_variables() {
437        let validator = GraphQLValidator::default();
438        let query = "query GetUser($id: ID!) { user(id: $id) { name } }";
439
440        let info = validator.validate(query).unwrap();
441
442        assert!(info.has_variables);
443        assert!(info.variable_names.contains(&"id".to_string()));
444        assert_eq!(info.operation_name, Some("GetUser".to_string()));
445    }
446
447    #[test]
448    fn test_field_name_to_type() {
449        assert_eq!(field_name_to_type("users"), "User");
450        assert_eq!(field_name_to_type("orderItems"), "OrderItem");
451        assert_eq!(field_name_to_type("user"), "User");
452    }
453}