Skip to main content

mockforge_bench/conformance/
custom.rs

1//! Custom conformance test authoring via YAML
2//!
3//! Allows users to define additional conformance checks beyond the built-in
4//! OpenAPI 3.0.0 feature set. Custom checks are grouped under a "Custom"
5//! category in the conformance report.
6
7use crate::error::{BenchError, Result};
8use serde::Deserialize;
9use std::path::Path;
10
11/// Top-level YAML configuration for custom conformance checks
12#[derive(Debug, Deserialize)]
13pub struct CustomConformanceConfig {
14    /// List of custom checks to run
15    pub custom_checks: Vec<CustomCheck>,
16}
17
18/// A single custom conformance check
19#[derive(Debug, Deserialize)]
20pub struct CustomCheck {
21    /// Check name (should start with "custom:" for report aggregation)
22    pub name: String,
23    /// Request path (e.g., "/api/users")
24    pub path: String,
25    /// HTTP method (GET, POST, PUT, DELETE, etc.)
26    pub method: String,
27    /// Expected HTTP status code
28    pub expected_status: u16,
29    /// Optional request body (JSON string)
30    #[serde(default)]
31    pub body: Option<String>,
32    /// Optional expected response headers (name -> regex pattern)
33    #[serde(default)]
34    pub expected_headers: std::collections::HashMap<String, String>,
35    /// Optional expected body fields with type validation
36    #[serde(default)]
37    pub expected_body_fields: Vec<ExpectedBodyField>,
38    /// Optional request headers
39    #[serde(default)]
40    pub headers: std::collections::HashMap<String, String>,
41}
42
43/// Expected field in the response body with type checking
44#[derive(Debug, Deserialize)]
45pub struct ExpectedBodyField {
46    /// Field name in the JSON response
47    pub name: String,
48    /// Expected JSON type: "string", "integer", "number", "boolean", "array", "object"
49    #[serde(rename = "type")]
50    pub field_type: String,
51}
52
53impl CustomConformanceConfig {
54    /// Parse a custom conformance config from a YAML file
55    pub fn from_file(path: &Path) -> Result<Self> {
56        let content = std::fs::read_to_string(path).map_err(|e| {
57            BenchError::Other(format!(
58                "Failed to read custom conformance file '{}': {}",
59                path.display(),
60                e
61            ))
62        })?;
63        serde_yaml::from_str(&content).map_err(|e| {
64            BenchError::Other(format!(
65                "Failed to parse custom conformance YAML '{}': {}",
66                path.display(),
67                e
68            ))
69        })
70    }
71
72    /// Generate a k6 `group('Custom', ...)` block for all custom checks.
73    ///
74    /// `base_url` is the JS expression for the base URL (e.g., `"BASE_URL"`).
75    /// `custom_headers` are additional headers to inject into every request.
76    pub fn generate_k6_group(&self, base_url: &str, custom_headers: &[(String, String)]) -> String {
77        let mut script = String::with_capacity(4096);
78        script.push_str("  group('Custom', function () {\n");
79
80        for check in &self.custom_checks {
81            script.push_str("    {\n");
82
83            // Build headers object
84            let mut all_headers: Vec<(String, String)> = Vec::new();
85            // Add check-specific headers
86            for (k, v) in &check.headers {
87                all_headers.push((k.clone(), v.clone()));
88            }
89            // Add global custom headers (check-specific take priority)
90            for (k, v) in custom_headers {
91                if !check.headers.contains_key(k) {
92                    all_headers.push((k.clone(), v.clone()));
93                }
94            }
95            // If posting JSON body, add Content-Type
96            if check.body.is_some()
97                && !all_headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("content-type"))
98            {
99                all_headers.push(("Content-Type".to_string(), "application/json".to_string()));
100            }
101
102            let headers_js = if all_headers.is_empty() {
103                "{}".to_string()
104            } else {
105                let entries: Vec<String> = all_headers
106                    .iter()
107                    .map(|(k, v)| format!("'{}': '{}'", k, v.replace('\'', "\\'")))
108                    .collect();
109                format!("{{ {} }}", entries.join(", "))
110            };
111
112            let method = check.method.to_uppercase();
113            let url = format!("${{{}}}{}", base_url, check.path);
114            let escaped_name = check.name.replace('\'', "\\'");
115
116            match method.as_str() {
117                "GET" | "HEAD" | "OPTIONS" | "DELETE" => {
118                    let k6_method = match method.as_str() {
119                        "DELETE" => "del",
120                        other => &other.to_lowercase(),
121                    };
122                    if all_headers.is_empty() {
123                        script
124                            .push_str(&format!("      let res = http.{}(`{}`);\n", k6_method, url));
125                    } else {
126                        script.push_str(&format!(
127                            "      let res = http.{}(`{}`, {{ headers: {} }});\n",
128                            k6_method, url, headers_js
129                        ));
130                    }
131                }
132                _ => {
133                    // POST, PUT, PATCH
134                    let k6_method = method.to_lowercase();
135                    let body_expr = match &check.body {
136                        Some(b) => format!("'{}'", b.replace('\'', "\\'")),
137                        None => "null".to_string(),
138                    };
139                    script.push_str(&format!(
140                        "      let res = http.{}(`{}`, {}, {{ headers: {} }});\n",
141                        k6_method, url, body_expr, headers_js
142                    ));
143                }
144            }
145
146            // Status check with failure detail capture
147            script.push_str(&format!(
148                "      {{ let ok = check(res, {{ '{}': (r) => r.status === {} }}); if (!ok) __captureFailure('{}', res, 'status === {}'); }}\n",
149                escaped_name, check.expected_status, escaped_name, check.expected_status
150            ));
151
152            // Header checks with failure detail capture.
153            // k6 canonicalizes response header names (e.g. `X-XSS-Protection` ->
154            // `X-Xss-Protection`), so match header names case-insensitively.
155            for (header_name, pattern) in &check.expected_headers {
156                let header_check_name = format!("{}:header:{}", escaped_name, header_name);
157                let escaped_pattern = pattern.replace('\\', "\\\\").replace('\'', "\\'");
158                let header_lower = header_name.to_lowercase();
159                script.push_str(&format!(
160                    "      {{ let ok = check(res, {{ '{}': (r) => {{ const _hk = Object.keys(r.headers || {{}}).find(k => k.toLowerCase() === '{}'); return new RegExp('{}').test(_hk ? r.headers[_hk] : ''); }} }}); if (!ok) __captureFailure('{}', res, 'header {} matches /{}/ '); }}\n",
161                    header_check_name,
162                    header_lower,
163                    escaped_pattern,
164                    header_check_name,
165                    header_name,
166                    escaped_pattern
167                ));
168            }
169
170            // Body field checks
171            for field in &check.expected_body_fields {
172                let field_check_name =
173                    format!("{}:body:{}:{}", escaped_name, field.name, field.field_type);
174                // Generate JS expression to access the field value, supporting
175                // nested paths like "results.name" and "items[].id"
176                let accessor = generate_field_accessor(&field.name);
177                let type_check = match field.field_type.as_str() {
178                    "string" => format!("typeof ({}) === 'string'", accessor),
179                    "integer" => format!("Number.isInteger({})", accessor),
180                    "number" => format!("typeof ({}) === 'number'", accessor),
181                    "boolean" => format!("typeof ({}) === 'boolean'", accessor),
182                    "array" => format!("Array.isArray({})", accessor),
183                    "object" => format!(
184                        "typeof ({}) === 'object' && !Array.isArray({})",
185                        accessor, accessor
186                    ),
187                    _ => format!("({}) !== undefined", accessor),
188                };
189                script.push_str(&format!(
190                    "      {{ let ok = check(res, {{ '{}': (r) => {{ try {{ return {}; }} catch(e) {{ return false; }} }} }}); if (!ok) __captureFailure('{}', res, 'body field {} is {}'); }}\n",
191                    field_check_name, type_check, field_check_name, field.name, field.field_type
192                ));
193            }
194
195            script.push_str("    }\n");
196        }
197
198        script.push_str("  });\n\n");
199        script
200    }
201}
202
203/// Generate a JavaScript expression to access a field in a parsed JSON body.
204///
205/// Supports three path formats:
206/// - Simple key: `"name"` → `JSON.parse(r.body)['name']`
207/// - Dot-notation: `"config.enabled"` → `JSON.parse(r.body)['config']['enabled']`
208/// - Array bracket: `"items[].id"` → `JSON.parse(r.body)['items'][0]['id']`
209fn generate_field_accessor(field_name: &str) -> String {
210    // Split on dots, handling [] array notation
211    let parts: Vec<&str> = field_name.split('.').collect();
212    let mut expr = String::from("JSON.parse(r.body)");
213
214    for part in &parts {
215        if let Some(arr_name) = part.strip_suffix("[]") {
216            // Array field — access the array then index first element
217            expr.push_str(&format!("['{}'][0]", arr_name));
218        } else {
219            expr.push_str(&format!("['{}']", part));
220        }
221    }
222
223    expr
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn test_parse_custom_yaml() {
232        let yaml = r#"
233custom_checks:
234  - name: "custom:pets-returns-200"
235    path: /pets
236    method: GET
237    expected_status: 200
238  - name: "custom:create-product"
239    path: /api/products
240    method: POST
241    expected_status: 201
242    body: '{"sku": "TEST-001", "name": "Test"}'
243    expected_body_fields:
244      - name: id
245        type: integer
246    expected_headers:
247      content-type: "application/json"
248"#;
249        let config: CustomConformanceConfig = serde_yaml::from_str(yaml).unwrap();
250        assert_eq!(config.custom_checks.len(), 2);
251        assert_eq!(config.custom_checks[0].name, "custom:pets-returns-200");
252        assert_eq!(config.custom_checks[0].expected_status, 200);
253        assert_eq!(config.custom_checks[1].expected_body_fields.len(), 1);
254        assert_eq!(config.custom_checks[1].expected_body_fields[0].name, "id");
255        assert_eq!(config.custom_checks[1].expected_body_fields[0].field_type, "integer");
256    }
257
258    #[test]
259    fn test_generate_k6_group_get() {
260        let config = CustomConformanceConfig {
261            custom_checks: vec![CustomCheck {
262                name: "custom:test-get".to_string(),
263                path: "/api/test".to_string(),
264                method: "GET".to_string(),
265                expected_status: 200,
266                body: None,
267                expected_headers: std::collections::HashMap::new(),
268                expected_body_fields: vec![],
269                headers: std::collections::HashMap::new(),
270            }],
271        };
272
273        let script = config.generate_k6_group("BASE_URL", &[]);
274        assert!(script.contains("group('Custom'"));
275        assert!(script.contains("http.get(`${BASE_URL}/api/test`)"));
276        assert!(script.contains("'custom:test-get': (r) => r.status === 200"));
277    }
278
279    #[test]
280    fn test_generate_k6_group_post_with_body() {
281        let config = CustomConformanceConfig {
282            custom_checks: vec![CustomCheck {
283                name: "custom:create".to_string(),
284                path: "/api/items".to_string(),
285                method: "POST".to_string(),
286                expected_status: 201,
287                body: Some(r#"{"name": "test"}"#.to_string()),
288                expected_headers: std::collections::HashMap::new(),
289                expected_body_fields: vec![ExpectedBodyField {
290                    name: "id".to_string(),
291                    field_type: "integer".to_string(),
292                }],
293                headers: std::collections::HashMap::new(),
294            }],
295        };
296
297        let script = config.generate_k6_group("BASE_URL", &[]);
298        assert!(script.contains("http.post("));
299        assert!(script.contains("'custom:create': (r) => r.status === 201"));
300        assert!(script.contains("custom:create:body:id:integer"));
301        assert!(script.contains("Number.isInteger"));
302    }
303
304    #[test]
305    fn test_generate_k6_group_with_header_checks() {
306        let mut expected_headers = std::collections::HashMap::new();
307        expected_headers.insert("content-type".to_string(), "application/json".to_string());
308
309        let config = CustomConformanceConfig {
310            custom_checks: vec![CustomCheck {
311                name: "custom:header-check".to_string(),
312                path: "/api/test".to_string(),
313                method: "GET".to_string(),
314                expected_status: 200,
315                body: None,
316                expected_headers,
317                expected_body_fields: vec![],
318                headers: std::collections::HashMap::new(),
319            }],
320        };
321
322        let script = config.generate_k6_group("BASE_URL", &[]);
323        assert!(script.contains("custom:header-check:header:content-type"));
324        assert!(script.contains("new RegExp('application/json')"));
325    }
326
327    #[test]
328    fn test_generate_k6_group_with_custom_headers() {
329        let config = CustomConformanceConfig {
330            custom_checks: vec![CustomCheck {
331                name: "custom:auth-test".to_string(),
332                path: "/api/secure".to_string(),
333                method: "GET".to_string(),
334                expected_status: 200,
335                body: None,
336                expected_headers: std::collections::HashMap::new(),
337                expected_body_fields: vec![],
338                headers: std::collections::HashMap::new(),
339            }],
340        };
341
342        let custom_headers = vec![("Authorization".to_string(), "Bearer token123".to_string())];
343        let script = config.generate_k6_group("BASE_URL", &custom_headers);
344        assert!(script.contains("'Authorization': 'Bearer token123'"));
345    }
346
347    #[test]
348    fn test_failure_capture_emitted() {
349        let config = CustomConformanceConfig {
350            custom_checks: vec![CustomCheck {
351                name: "custom:capture-test".to_string(),
352                path: "/api/test".to_string(),
353                method: "GET".to_string(),
354                expected_status: 200,
355                body: None,
356                expected_headers: {
357                    let mut m = std::collections::HashMap::new();
358                    m.insert("X-Rate-Limit".to_string(), ".*".to_string());
359                    m
360                },
361                expected_body_fields: vec![ExpectedBodyField {
362                    name: "id".to_string(),
363                    field_type: "integer".to_string(),
364                }],
365                headers: std::collections::HashMap::new(),
366            }],
367        };
368
369        let script = config.generate_k6_group("BASE_URL", &[]);
370        // Status check should call __captureFailure on failure
371        assert!(
372            script.contains("__captureFailure('custom:capture-test', res, 'status === 200')"),
373            "Status check should emit __captureFailure"
374        );
375        // Header check should call __captureFailure on failure
376        assert!(
377            script.contains("__captureFailure('custom:capture-test:header:X-Rate-Limit'"),
378            "Header check should emit __captureFailure"
379        );
380        // Body field check should call __captureFailure on failure
381        assert!(
382            script.contains("__captureFailure('custom:capture-test:body:id:integer'"),
383            "Body field check should emit __captureFailure"
384        );
385    }
386
387    #[test]
388    fn test_from_file_nonexistent() {
389        let result = CustomConformanceConfig::from_file(Path::new("/nonexistent/file.yaml"));
390        assert!(result.is_err());
391        let err = result.unwrap_err().to_string();
392        assert!(err.contains("Failed to read custom conformance file"));
393    }
394
395    #[test]
396    fn test_generate_k6_group_delete() {
397        let config = CustomConformanceConfig {
398            custom_checks: vec![CustomCheck {
399                name: "custom:delete-item".to_string(),
400                path: "/api/items/1".to_string(),
401                method: "DELETE".to_string(),
402                expected_status: 204,
403                body: None,
404                expected_headers: std::collections::HashMap::new(),
405                expected_body_fields: vec![],
406                headers: std::collections::HashMap::new(),
407            }],
408        };
409
410        let script = config.generate_k6_group("BASE_URL", &[]);
411        assert!(script.contains("http.del("));
412        assert!(script.contains("r.status === 204"));
413    }
414
415    #[test]
416    fn test_field_accessor_simple() {
417        assert_eq!(generate_field_accessor("name"), "JSON.parse(r.body)['name']");
418    }
419
420    #[test]
421    fn test_field_accessor_nested_dot() {
422        assert_eq!(
423            generate_field_accessor("config.enabled"),
424            "JSON.parse(r.body)['config']['enabled']"
425        );
426    }
427
428    #[test]
429    fn test_field_accessor_array_bracket() {
430        assert_eq!(generate_field_accessor("items[].id"), "JSON.parse(r.body)['items'][0]['id']");
431    }
432
433    #[test]
434    fn test_field_accessor_deep_nested() {
435        assert_eq!(generate_field_accessor("a.b.c"), "JSON.parse(r.body)['a']['b']['c']");
436    }
437
438    #[test]
439    fn test_generate_k6_nested_body_fields() {
440        let config = CustomConformanceConfig {
441            custom_checks: vec![CustomCheck {
442                name: "custom:nested".to_string(),
443                path: "/api/data".to_string(),
444                method: "GET".to_string(),
445                expected_status: 200,
446                body: None,
447                expected_headers: std::collections::HashMap::new(),
448                expected_body_fields: vec![
449                    ExpectedBodyField {
450                        name: "count".to_string(),
451                        field_type: "integer".to_string(),
452                    },
453                    ExpectedBodyField {
454                        name: "results[].name".to_string(),
455                        field_type: "string".to_string(),
456                    },
457                ],
458                headers: std::collections::HashMap::new(),
459            }],
460        };
461
462        let script = config.generate_k6_group("BASE_URL", &[]);
463        // Simple field should use direct bracket access
464        assert!(script.contains("JSON.parse(r.body)['count']"));
465        // Nested array field should use [0] for array traversal
466        assert!(script.contains("JSON.parse(r.body)['results'][0]['name']"));
467    }
468}