1use crate::types::{
4 CodeType, Complexity, SecurityAnalysis, SecurityIssue, SecurityIssueType, ValidationError,
5};
6use graphql_parser::query::{Definition, Document, OperationDefinition, Selection, SelectionSet};
7use std::collections::HashSet;
8
9#[derive(Debug, Clone, Default)]
11pub struct GraphQLQueryInfo {
12 pub operation_type: GraphQLOperationType,
14
15 pub operation_name: Option<String>,
17
18 pub root_fields: Vec<String>,
20
21 pub types_accessed: HashSet<String>,
23
24 pub fields_accessed: HashSet<String>,
26
27 pub has_variables: bool,
29
30 pub variable_names: Vec<String>,
32
33 pub max_depth: usize,
35
36 pub has_fragments: bool,
38
39 pub fragment_names: Vec<String>,
41
42 pub has_introspection: bool,
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
48pub enum GraphQLOperationType {
49 #[default]
50 Query,
51 Mutation,
52 Subscription,
53}
54
55impl GraphQLOperationType {
56 pub fn is_read_only(&self) -> bool {
58 matches!(self, GraphQLOperationType::Query)
59 }
60}
61
62pub struct GraphQLValidator {
64 sensitive_fields: Vec<String>,
66
67 max_depth: usize,
69
70 max_complexity: usize,
72}
73
74impl Default for GraphQLValidator {
75 fn default() -> Self {
76 Self {
77 sensitive_fields: vec![
78 "password".into(),
79 "ssn".into(),
80 "socialSecurityNumber".into(),
81 "creditCard".into(),
82 "creditCardNumber".into(),
83 "apiKey".into(),
84 "secret".into(),
85 "token".into(),
86 ],
87 max_depth: 10,
88 max_complexity: 100,
89 }
90 }
91}
92
93impl GraphQLValidator {
94 pub fn new(sensitive_fields: Vec<String>, max_depth: usize, max_complexity: usize) -> Self {
96 Self {
97 sensitive_fields: sensitive_fields
98 .into_iter()
99 .map(|s| s.to_lowercase())
100 .collect(),
101 max_depth,
102 max_complexity,
103 }
104 }
105
106 pub fn validate(&self, query: &str) -> Result<GraphQLQueryInfo, ValidationError> {
108 let document = graphql_parser::parse_query::<&str>(query).map_err(|e| {
110 ValidationError::ParseError {
111 message: e.to_string(),
112 line: 0,
113 column: 0,
114 }
115 })?;
116
117 let info = self.extract_query_info(&document)?;
119
120 if info.max_depth > self.max_depth {
122 return Err(ValidationError::SecurityError {
123 message: format!(
124 "Query depth {} exceeds maximum allowed depth {}",
125 info.max_depth, self.max_depth
126 ),
127 issue: SecurityIssueType::DeepNesting,
128 });
129 }
130
131 Ok(info)
132 }
133
134 fn extract_query_info<'a>(
136 &self,
137 document: &'a Document<'a, &'a str>,
138 ) -> Result<GraphQLQueryInfo, ValidationError> {
139 let mut info = GraphQLQueryInfo::default();
140 let mut found_operation = false;
141
142 for definition in &document.definitions {
143 match definition {
144 Definition::Operation(op) => {
145 if found_operation {
146 return Err(ValidationError::ParseError {
147 message: "Multiple operations not supported".into(),
148 line: 0,
149 column: 0,
150 });
151 }
152 found_operation = true;
153 self.extract_operation_info(op, &mut info)?;
154 },
155 Definition::Fragment(frag) => {
156 info.has_fragments = true;
157 info.fragment_names.push(frag.name.to_string());
158 },
159 }
160 }
161
162 if !found_operation {
163 return Err(ValidationError::ParseError {
164 message: "No operation found in query".into(),
165 line: 0,
166 column: 0,
167 });
168 }
169
170 Ok(info)
171 }
172
173 fn extract_operation_info<'a>(
175 &self,
176 op: &'a OperationDefinition<'a, &'a str>,
177 info: &mut GraphQLQueryInfo,
178 ) -> Result<(), ValidationError> {
179 match op {
180 OperationDefinition::Query(q) => {
181 info.operation_type = GraphQLOperationType::Query;
182 info.operation_name = q.name.map(|s| s.to_string());
183 info.has_variables = !q.variable_definitions.is_empty();
184 info.variable_names = q
185 .variable_definitions
186 .iter()
187 .map(|v| v.name.to_string())
188 .collect();
189 self.extract_selection_set(&q.selection_set, info, 1)?;
190 },
191 OperationDefinition::Mutation(m) => {
192 info.operation_type = GraphQLOperationType::Mutation;
193 info.operation_name = m.name.map(|s| s.to_string());
194 info.has_variables = !m.variable_definitions.is_empty();
195 info.variable_names = m
196 .variable_definitions
197 .iter()
198 .map(|v| v.name.to_string())
199 .collect();
200 self.extract_selection_set(&m.selection_set, info, 1)?;
201 },
202 OperationDefinition::Subscription(s) => {
203 info.operation_type = GraphQLOperationType::Subscription;
204 info.operation_name = s.name.map(|s| s.to_string());
205 info.has_variables = !s.variable_definitions.is_empty();
206 info.variable_names = s
207 .variable_definitions
208 .iter()
209 .map(|v| v.name.to_string())
210 .collect();
211 self.extract_selection_set(&s.selection_set, info, 1)?;
212 },
213 OperationDefinition::SelectionSet(ss) => {
214 info.operation_type = GraphQLOperationType::Query;
216 self.extract_selection_set(ss, info, 1)?;
217 },
218 }
219 Ok(())
220 }
221
222 fn extract_selection_set<'a>(
224 &self,
225 selection_set: &'a SelectionSet<'a, &'a str>,
226 info: &mut GraphQLQueryInfo,
227 depth: usize,
228 ) -> Result<(), ValidationError> {
229 info.max_depth = info.max_depth.max(depth);
230
231 for selection in &selection_set.items {
232 match selection {
233 Selection::Field(field) => {
234 let field_name = field.name.to_string();
235
236 if field_name.starts_with("__") {
238 info.has_introspection = true;
239 }
240
241 if depth == 1 {
243 info.root_fields.push(field_name.clone());
244 }
245
246 info.fields_accessed.insert(field_name.clone());
248
249 if depth == 1 {
252 let type_name = field_name_to_type(&field_name);
253 info.types_accessed.insert(type_name);
254 }
255
256 self.extract_selection_set(&field.selection_set, info, depth + 1)?;
258 },
259 Selection::FragmentSpread(spread) => {
260 info.fragment_names.push(spread.fragment_name.to_string());
261 },
262 Selection::InlineFragment(inline) => {
263 if let Some(type_cond) = &inline.type_condition {
264 info.types_accessed.insert(type_cond.to_string());
265 }
266 self.extract_selection_set(&inline.selection_set, info, depth + 1)?;
267 },
268 }
269 }
270
271 Ok(())
272 }
273
274 pub fn analyze_security(&self, info: &GraphQLQueryInfo) -> SecurityAnalysis {
276 let mut analysis = SecurityAnalysis {
277 is_read_only: info.operation_type.is_read_only(),
278 tables_accessed: info.types_accessed.clone(),
279 fields_accessed: info.fields_accessed.clone(),
280 has_aggregation: false,
281 has_subqueries: info.max_depth > 3,
282 estimated_complexity: self.estimate_complexity(info),
283 potential_issues: Vec::new(),
284 estimated_rows: None,
285 };
286
287 for field in &info.fields_accessed {
289 let field_lower = field.to_lowercase();
290 if self
291 .sensitive_fields
292 .iter()
293 .any(|s| field_lower.contains(s))
294 {
295 analysis.potential_issues.push(SecurityIssue::new(
296 SecurityIssueType::SensitiveFields,
297 format!("Query accesses potentially sensitive field: {}", field),
298 ));
299 }
300 }
301
302 if info.max_depth > 5 {
304 analysis.potential_issues.push(SecurityIssue::new(
305 SecurityIssueType::DeepNesting,
306 format!("Query has deep nesting (depth: {})", info.max_depth),
307 ));
308 }
309
310 if matches!(analysis.estimated_complexity, Complexity::High) {
312 analysis.potential_issues.push(SecurityIssue::new(
313 SecurityIssueType::HighComplexity,
314 "Query has high complexity",
315 ));
316 }
317
318 analysis
319 }
320
321 fn estimate_complexity(&self, info: &GraphQLQueryInfo) -> Complexity {
323 let field_count = info.fields_accessed.len();
324 let type_count = info.types_accessed.len();
325 let depth = info.max_depth;
326
327 let complexity_score = field_count + (type_count * 2) + (depth * depth);
329
330 if complexity_score > self.max_complexity {
331 Complexity::High
332 } else if complexity_score > self.max_complexity / 2 {
333 Complexity::Medium
334 } else {
335 Complexity::Low
336 }
337 }
338
339 pub fn to_code_type(&self, info: &GraphQLQueryInfo) -> CodeType {
341 match info.operation_type {
342 GraphQLOperationType::Query => CodeType::GraphQLQuery,
343 GraphQLOperationType::Mutation => CodeType::GraphQLMutation,
344 GraphQLOperationType::Subscription => CodeType::GraphQLQuery, }
346 }
347}
348
349pub(crate) fn field_name_to_type(field_name: &str) -> String {
353 let singular = if field_name.ends_with('s') && field_name.len() > 1 {
355 &field_name[..field_name.len() - 1]
356 } else {
357 field_name
358 };
359
360 let mut c = singular.chars();
362 match c.next() {
363 None => String::new(),
364 Some(f) => f.to_uppercase().to_string() + c.as_str(),
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn test_simple_query_parsing() {
374 let validator = GraphQLValidator::default();
375 let query = "query { users { id name email } }";
376
377 let info = validator.validate(query).unwrap();
378
379 assert_eq!(info.operation_type, GraphQLOperationType::Query);
380 assert!(info.root_fields.contains(&"users".to_string()));
381 assert!(info.fields_accessed.contains("id"));
382 assert!(info.fields_accessed.contains("name"));
383 assert!(info.fields_accessed.contains("email"));
384 }
385
386 #[test]
387 fn test_mutation_detection() {
388 let validator = GraphQLValidator::default();
389 let query = "mutation { createUser(name: \"test\") { id } }";
390
391 let info = validator.validate(query).unwrap();
392
393 assert_eq!(info.operation_type, GraphQLOperationType::Mutation);
394 assert!(!info.operation_type.is_read_only());
395 }
396
397 #[test]
398 fn test_nested_query() {
399 let validator = GraphQLValidator::default();
400 let query = r#"
401 query {
402 users {
403 id
404 orders {
405 id
406 items {
407 product {
408 name
409 }
410 }
411 }
412 }
413 }
414 "#;
415
416 let info = validator.validate(query).unwrap();
417
418 assert!(info.max_depth >= 4);
419 }
420
421 #[test]
422 fn test_sensitive_field_detection() {
423 let validator = GraphQLValidator::default();
424 let query = "query { users { id name password } }";
425
426 let info = validator.validate(query).unwrap();
427 let analysis = validator.analyze_security(&info);
428
429 assert!(analysis
430 .potential_issues
431 .iter()
432 .any(|i| matches!(i.issue_type, SecurityIssueType::SensitiveFields)));
433 }
434
435 #[test]
436 fn test_variables() {
437 let validator = GraphQLValidator::default();
438 let query = "query GetUser($id: ID!) { user(id: $id) { name } }";
439
440 let info = validator.validate(query).unwrap();
441
442 assert!(info.has_variables);
443 assert!(info.variable_names.contains(&"id".to_string()));
444 assert_eq!(info.operation_name, Some("GetUser".to_string()));
445 }
446
447 #[test]
448 fn test_field_name_to_type() {
449 assert_eq!(field_name_to_type("users"), "User");
450 assert_eq!(field_name_to_type("orderItems"), "OrderItem");
451 assert_eq!(field_name_to_type("user"), "User");
452 }
453}