1use crate::types::{
4 CodeType, Complexity, SecurityAnalysis, SecurityIssue, SecurityIssueType, ValidationError,
5};
6use graphql_parser::query::{Definition, Document, OperationDefinition, Selection, SelectionSet};
7use std::collections::HashSet;
8
9#[derive(Debug, Clone, Default)]
11pub struct GraphQLQueryInfo {
12 pub operation_type: GraphQLOperationType,
14
15 pub operation_name: Option<String>,
17
18 pub root_fields: Vec<String>,
20
21 pub types_accessed: HashSet<String>,
23
24 pub fields_accessed: HashSet<String>,
26
27 pub has_variables: bool,
29
30 pub variable_names: Vec<String>,
32
33 pub max_depth: usize,
35
36 pub has_fragments: bool,
38
39 pub fragment_names: Vec<String>,
41
42 pub has_introspection: bool,
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
48pub enum GraphQLOperationType {
49 #[default]
50 Query,
51 Mutation,
52 Subscription,
53}
54
55impl GraphQLOperationType {
56 pub fn is_read_only(&self) -> bool {
58 matches!(self, GraphQLOperationType::Query)
59 }
60}
61
62pub struct GraphQLValidator {
64 sensitive_fields: Vec<String>,
66
67 max_depth: usize,
69
70 max_complexity: usize,
72}
73
74impl Default for GraphQLValidator {
75 fn default() -> Self {
76 Self {
77 sensitive_fields: vec![
78 "password".into(),
79 "ssn".into(),
80 "socialSecurityNumber".into(),
81 "creditCard".into(),
82 "creditCardNumber".into(),
83 "apiKey".into(),
84 "secret".into(),
85 "token".into(),
86 ],
87 max_depth: 10,
88 max_complexity: 100,
89 }
90 }
91}
92
93impl GraphQLValidator {
94 pub fn new(sensitive_fields: Vec<String>, max_depth: usize, max_complexity: usize) -> Self {
96 Self {
97 sensitive_fields,
98 max_depth,
99 max_complexity,
100 }
101 }
102
103 pub fn validate(&self, query: &str) -> Result<GraphQLQueryInfo, ValidationError> {
105 let document = graphql_parser::parse_query::<&str>(query).map_err(|e| {
107 ValidationError::ParseError {
108 message: e.to_string(),
109 line: 0,
110 column: 0,
111 }
112 })?;
113
114 let info = self.extract_query_info(&document)?;
116
117 if info.max_depth > self.max_depth {
119 return Err(ValidationError::SecurityError {
120 message: format!(
121 "Query depth {} exceeds maximum allowed depth {}",
122 info.max_depth, self.max_depth
123 ),
124 issue: SecurityIssueType::DeepNesting,
125 });
126 }
127
128 Ok(info)
129 }
130
131 fn extract_query_info<'a>(
133 &self,
134 document: &'a Document<'a, &'a str>,
135 ) -> Result<GraphQLQueryInfo, ValidationError> {
136 let mut info = GraphQLQueryInfo::default();
137 let mut found_operation = false;
138
139 for definition in &document.definitions {
140 match definition {
141 Definition::Operation(op) => {
142 if found_operation {
143 return Err(ValidationError::ParseError {
144 message: "Multiple operations not supported".into(),
145 line: 0,
146 column: 0,
147 });
148 }
149 found_operation = true;
150 self.extract_operation_info(op, &mut info)?;
151 }
152 Definition::Fragment(frag) => {
153 info.has_fragments = true;
154 info.fragment_names.push(frag.name.to_string());
155 }
156 }
157 }
158
159 if !found_operation {
160 return Err(ValidationError::ParseError {
161 message: "No operation found in query".into(),
162 line: 0,
163 column: 0,
164 });
165 }
166
167 Ok(info)
168 }
169
170 fn extract_operation_info<'a>(
172 &self,
173 op: &'a OperationDefinition<'a, &'a str>,
174 info: &mut GraphQLQueryInfo,
175 ) -> Result<(), ValidationError> {
176 match op {
177 OperationDefinition::Query(q) => {
178 info.operation_type = GraphQLOperationType::Query;
179 info.operation_name = q.name.map(|s| s.to_string());
180 info.has_variables = !q.variable_definitions.is_empty();
181 info.variable_names = q
182 .variable_definitions
183 .iter()
184 .map(|v| v.name.to_string())
185 .collect();
186 self.extract_selection_set(&q.selection_set, info, 1)?;
187 }
188 OperationDefinition::Mutation(m) => {
189 info.operation_type = GraphQLOperationType::Mutation;
190 info.operation_name = m.name.map(|s| s.to_string());
191 info.has_variables = !m.variable_definitions.is_empty();
192 info.variable_names = m
193 .variable_definitions
194 .iter()
195 .map(|v| v.name.to_string())
196 .collect();
197 self.extract_selection_set(&m.selection_set, info, 1)?;
198 }
199 OperationDefinition::Subscription(s) => {
200 info.operation_type = GraphQLOperationType::Subscription;
201 info.operation_name = s.name.map(|s| s.to_string());
202 info.has_variables = !s.variable_definitions.is_empty();
203 info.variable_names = s
204 .variable_definitions
205 .iter()
206 .map(|v| v.name.to_string())
207 .collect();
208 self.extract_selection_set(&s.selection_set, info, 1)?;
209 }
210 OperationDefinition::SelectionSet(ss) => {
211 info.operation_type = GraphQLOperationType::Query;
213 self.extract_selection_set(ss, info, 1)?;
214 }
215 }
216 Ok(())
217 }
218
219 fn extract_selection_set<'a>(
221 &self,
222 selection_set: &'a SelectionSet<'a, &'a str>,
223 info: &mut GraphQLQueryInfo,
224 depth: usize,
225 ) -> Result<(), ValidationError> {
226 info.max_depth = info.max_depth.max(depth);
227
228 for selection in &selection_set.items {
229 match selection {
230 Selection::Field(field) => {
231 let field_name = field.name.to_string();
232
233 if field_name.starts_with("__") {
235 info.has_introspection = true;
236 }
237
238 if depth == 1 {
240 info.root_fields.push(field_name.clone());
241 }
242
243 info.fields_accessed.insert(field_name.clone());
245
246 if depth == 1 {
249 let type_name = field_name_to_type(&field_name);
250 info.types_accessed.insert(type_name);
251 }
252
253 self.extract_selection_set(&field.selection_set, info, depth + 1)?;
255 }
256 Selection::FragmentSpread(spread) => {
257 info.fragment_names.push(spread.fragment_name.to_string());
258 }
259 Selection::InlineFragment(inline) => {
260 if let Some(type_cond) = &inline.type_condition {
261 info.types_accessed.insert(type_cond.to_string());
262 }
263 self.extract_selection_set(&inline.selection_set, info, depth + 1)?;
264 }
265 }
266 }
267
268 Ok(())
269 }
270
271 pub fn analyze_security(&self, info: &GraphQLQueryInfo) -> SecurityAnalysis {
273 let mut analysis = SecurityAnalysis {
274 is_read_only: info.operation_type.is_read_only(),
275 tables_accessed: info.types_accessed.clone(),
276 fields_accessed: info.fields_accessed.clone(),
277 has_aggregation: false,
278 has_subqueries: info.max_depth > 3,
279 estimated_complexity: self.estimate_complexity(info),
280 potential_issues: Vec::new(),
281 estimated_rows: None,
282 };
283
284 for field in &info.fields_accessed {
286 let field_lower = field.to_lowercase();
287 if self
288 .sensitive_fields
289 .iter()
290 .any(|s| field_lower.contains(&s.to_lowercase()))
291 {
292 analysis.potential_issues.push(SecurityIssue::new(
293 SecurityIssueType::SensitiveFields,
294 format!("Query accesses potentially sensitive field: {}", field),
295 ));
296 }
297 }
298
299 if info.max_depth > 5 {
301 analysis.potential_issues.push(SecurityIssue::new(
302 SecurityIssueType::DeepNesting,
303 format!("Query has deep nesting (depth: {})", info.max_depth),
304 ));
305 }
306
307 if matches!(analysis.estimated_complexity, Complexity::High) {
309 analysis.potential_issues.push(SecurityIssue::new(
310 SecurityIssueType::HighComplexity,
311 "Query has high complexity",
312 ));
313 }
314
315 analysis
316 }
317
318 fn estimate_complexity(&self, info: &GraphQLQueryInfo) -> Complexity {
320 let field_count = info.fields_accessed.len();
321 let type_count = info.types_accessed.len();
322 let depth = info.max_depth;
323
324 let complexity_score = field_count + (type_count * 2) + (depth * depth);
326
327 if complexity_score > self.max_complexity {
328 Complexity::High
329 } else if complexity_score > self.max_complexity / 2 {
330 Complexity::Medium
331 } else {
332 Complexity::Low
333 }
334 }
335
336 pub fn to_code_type(&self, info: &GraphQLQueryInfo) -> CodeType {
338 match info.operation_type {
339 GraphQLOperationType::Query => CodeType::GraphQLQuery,
340 GraphQLOperationType::Mutation => CodeType::GraphQLMutation,
341 GraphQLOperationType::Subscription => CodeType::GraphQLQuery, }
343 }
344}
345
346pub(crate) fn field_name_to_type(field_name: &str) -> String {
350 let singular = if field_name.ends_with('s') && field_name.len() > 1 {
352 &field_name[..field_name.len() - 1]
353 } else {
354 field_name
355 };
356
357 let mut chars: Vec<char> = singular.chars().collect();
359 if let Some(first) = chars.first_mut() {
360 *first = first.to_uppercase().next().unwrap_or(*first);
361 }
362 chars.into_iter().collect()
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn test_simple_query_parsing() {
371 let validator = GraphQLValidator::default();
372 let query = "query { users { id name email } }";
373
374 let info = validator.validate(query).unwrap();
375
376 assert_eq!(info.operation_type, GraphQLOperationType::Query);
377 assert!(info.root_fields.contains(&"users".to_string()));
378 assert!(info.fields_accessed.contains("id"));
379 assert!(info.fields_accessed.contains("name"));
380 assert!(info.fields_accessed.contains("email"));
381 }
382
383 #[test]
384 fn test_mutation_detection() {
385 let validator = GraphQLValidator::default();
386 let query = "mutation { createUser(name: \"test\") { id } }";
387
388 let info = validator.validate(query).unwrap();
389
390 assert_eq!(info.operation_type, GraphQLOperationType::Mutation);
391 assert!(!info.operation_type.is_read_only());
392 }
393
394 #[test]
395 fn test_nested_query() {
396 let validator = GraphQLValidator::default();
397 let query = r#"
398 query {
399 users {
400 id
401 orders {
402 id
403 items {
404 product {
405 name
406 }
407 }
408 }
409 }
410 }
411 "#;
412
413 let info = validator.validate(query).unwrap();
414
415 assert!(info.max_depth >= 4);
416 }
417
418 #[test]
419 fn test_sensitive_field_detection() {
420 let validator = GraphQLValidator::default();
421 let query = "query { users { id name password } }";
422
423 let info = validator.validate(query).unwrap();
424 let analysis = validator.analyze_security(&info);
425
426 assert!(analysis
427 .potential_issues
428 .iter()
429 .any(|i| matches!(i.issue_type, SecurityIssueType::SensitiveFields)));
430 }
431
432 #[test]
433 fn test_variables() {
434 let validator = GraphQLValidator::default();
435 let query = "query GetUser($id: ID!) { user(id: $id) { name } }";
436
437 let info = validator.validate(query).unwrap();
438
439 assert!(info.has_variables);
440 assert!(info.variable_names.contains(&"id".to_string()));
441 assert_eq!(info.operation_name, Some("GetUser".to_string()));
442 }
443
444 #[test]
445 fn test_field_name_to_type() {
446 assert_eq!(field_name_to_type("users"), "User");
447 assert_eq!(field_name_to_type("orderItems"), "OrderItem");
448 assert_eq!(field_name_to_type("user"), "User");
449 }
450}