Skip to main content

mockforge_bench/conformance/
custom.rs

1//! Custom conformance test authoring via YAML
2//!
3//! Allows users to define additional conformance checks beyond the built-in
4//! OpenAPI 3.0.0 feature set. Custom checks are grouped under a "Custom"
5//! category in the conformance report.
6
7use crate::error::{BenchError, Result};
8use serde::Deserialize;
9use std::path::Path;
10
11/// Top-level YAML configuration for custom conformance checks
12#[derive(Debug, Deserialize)]
13pub struct CustomConformanceConfig {
14    /// List of custom checks to run
15    pub custom_checks: Vec<CustomCheck>,
16}
17
18/// A single custom conformance check
19#[derive(Debug, Deserialize)]
20pub struct CustomCheck {
21    /// Check name (should start with "custom:" for report aggregation)
22    pub name: String,
23    /// Request path (e.g., "/api/users")
24    pub path: String,
25    /// HTTP method (GET, POST, PUT, DELETE, etc.)
26    pub method: String,
27    /// Expected HTTP status code
28    pub expected_status: u16,
29    /// Optional request body (JSON string)
30    #[serde(default)]
31    pub body: Option<String>,
32    /// Optional expected response headers (name -> regex pattern)
33    #[serde(default)]
34    pub expected_headers: std::collections::HashMap<String, String>,
35    /// Optional expected body fields with type validation
36    #[serde(default)]
37    pub expected_body_fields: Vec<ExpectedBodyField>,
38    /// Optional request headers
39    #[serde(default)]
40    pub headers: std::collections::HashMap<String, String>,
41}
42
43/// Expected field in the response body with type checking
44#[derive(Debug, Deserialize)]
45pub struct ExpectedBodyField {
46    /// Field name in the JSON response
47    pub name: String,
48    /// Expected JSON type: "string", "integer", "number", "boolean", "array", "object"
49    #[serde(rename = "type")]
50    pub field_type: String,
51}
52
53impl CustomConformanceConfig {
54    /// Parse a custom conformance config from a YAML file
55    pub fn from_file(path: &Path) -> Result<Self> {
56        let content = std::fs::read_to_string(path).map_err(|e| {
57            BenchError::Other(format!(
58                "Failed to read custom conformance file '{}': {}",
59                path.display(),
60                e
61            ))
62        })?;
63        serde_yaml::from_str(&content).map_err(|e| {
64            BenchError::Other(format!(
65                "Failed to parse custom conformance YAML '{}': {}",
66                path.display(),
67                e
68            ))
69        })
70    }
71
72    /// Generate a k6 `group('Custom', ...)` block for all custom checks.
73    ///
74    /// `base_url` is the JS expression for the base URL (e.g., `"BASE_URL"`).
75    /// `custom_headers` are additional headers to inject into every request.
76    pub fn generate_k6_group(&self, base_url: &str, custom_headers: &[(String, String)]) -> String {
77        self.generate_k6_group_with_options(base_url, custom_headers, false)
78    }
79
80    /// Generate a k6 `group('Custom', ...)` block for all custom checks.
81    /// When `export_requests` is true, emits `__captureExchange` calls after each request.
82    pub fn generate_k6_group_with_options(
83        &self,
84        base_url: &str,
85        custom_headers: &[(String, String)],
86        export_requests: bool,
87    ) -> String {
88        let mut script = String::with_capacity(4096);
89        script.push_str("  group('Custom', function () {\n");
90
91        for check in &self.custom_checks {
92            script.push_str("    {\n");
93
94            // Build headers object
95            let mut all_headers: Vec<(String, String)> = Vec::new();
96            // Add check-specific headers
97            for (k, v) in &check.headers {
98                all_headers.push((k.clone(), v.clone()));
99            }
100            // Add global custom headers (check-specific take priority)
101            for (k, v) in custom_headers {
102                if !check.headers.contains_key(k) {
103                    all_headers.push((k.clone(), v.clone()));
104                }
105            }
106            // If posting JSON body, add Content-Type
107            if check.body.is_some()
108                && !all_headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("content-type"))
109            {
110                all_headers.push(("Content-Type".to_string(), "application/json".to_string()));
111            }
112
113            let headers_js = if all_headers.is_empty() {
114                "{}".to_string()
115            } else {
116                let entries: Vec<String> = all_headers
117                    .iter()
118                    .map(|(k, v)| format!("'{}': '{}'", k, v.replace('\'', "\\'")))
119                    .collect();
120                format!("{{ {} }}", entries.join(", "))
121            };
122
123            let method = check.method.to_uppercase();
124            let url = format!("${{{}}}{}", base_url, check.path);
125            let escaped_name = check.name.replace('\'', "\\'");
126
127            match method.as_str() {
128                "GET" | "HEAD" | "OPTIONS" | "DELETE" => {
129                    let k6_method = match method.as_str() {
130                        "DELETE" => "del",
131                        other => &other.to_lowercase(),
132                    };
133                    if all_headers.is_empty() {
134                        script
135                            .push_str(&format!("      let res = http.{}(`{}`);\n", k6_method, url));
136                    } else {
137                        script.push_str(&format!(
138                            "      let res = http.{}(`{}`, {{ headers: {} }});\n",
139                            k6_method, url, headers_js
140                        ));
141                    }
142                }
143                _ => {
144                    // POST, PUT, PATCH
145                    let k6_method = method.to_lowercase();
146                    let body_expr = match &check.body {
147                        Some(b) => format!("'{}'", b.replace('\'', "\\'")),
148                        None => "null".to_string(),
149                    };
150                    script.push_str(&format!(
151                        "      let res = http.{}(`{}`, {}, {{ headers: {} }});\n",
152                        k6_method, url, body_expr, headers_js
153                    ));
154                }
155            }
156
157            // Capture request/response when --export-requests is enabled
158            if export_requests {
159                script.push_str(&format!(
160                    "      if (typeof __captureExchange === 'function') __captureExchange('{}', res);\n",
161                    escaped_name
162                ));
163            }
164
165            // Status check with failure detail capture
166            script.push_str(&format!(
167                "      {{ let ok = check(res, {{ '{}': (r) => r.status === {} }}); if (!ok) __captureFailure('{}', res, 'status === {}'); }}\n",
168                escaped_name, check.expected_status, escaped_name, check.expected_status
169            ));
170
171            // Header checks with failure detail capture.
172            // k6 canonicalizes response header names (e.g. `X-XSS-Protection` ->
173            // `X-Xss-Protection`), so match header names case-insensitively.
174            for (header_name, pattern) in &check.expected_headers {
175                let header_check_name = format!("{}:header:{}", escaped_name, header_name);
176                let escaped_pattern = pattern.replace('\\', "\\\\").replace('\'', "\\'");
177                let header_lower = header_name.to_lowercase();
178                script.push_str(&format!(
179                    "      {{ let ok = check(res, {{ '{}': (r) => {{ const _hk = Object.keys(r.headers || {{}}).find(k => k.toLowerCase() === '{}'); return new RegExp('{}').test(_hk ? r.headers[_hk] : ''); }} }}); if (!ok) __captureFailure('{}', res, 'header {} matches /{}/ '); }}\n",
180                    header_check_name,
181                    header_lower,
182                    escaped_pattern,
183                    header_check_name,
184                    header_name,
185                    escaped_pattern
186                ));
187            }
188
189            // Body field checks
190            for field in &check.expected_body_fields {
191                let field_check_name =
192                    format!("{}:body:{}:{}", escaped_name, field.name, field.field_type);
193                // Generate JS expression to access the field value, supporting
194                // nested paths like "results.name" and "items[].id"
195                let accessor = generate_field_accessor(&field.name);
196                let type_check = match field.field_type.as_str() {
197                    "string" => format!("typeof ({}) === 'string'", accessor),
198                    "integer" => format!("Number.isInteger({})", accessor),
199                    "number" => format!("typeof ({}) === 'number'", accessor),
200                    "boolean" => format!("typeof ({}) === 'boolean'", accessor),
201                    "array" => format!("Array.isArray({})", accessor),
202                    "object" => format!(
203                        "typeof ({}) === 'object' && !Array.isArray({})",
204                        accessor, accessor
205                    ),
206                    _ => format!("({}) !== undefined", accessor),
207                };
208                script.push_str(&format!(
209                    "      {{ let ok = check(res, {{ '{}': (r) => {{ try {{ return {}; }} catch(e) {{ return false; }} }} }}); if (!ok) __captureFailure('{}', res, 'body field {} is {}'); }}\n",
210                    field_check_name, type_check, field_check_name, field.name, field.field_type
211                ));
212            }
213
214            script.push_str("    }\n");
215        }
216
217        script.push_str("  });\n\n");
218        script
219    }
220}
221
222/// Generate a JavaScript expression to access a field in a parsed JSON body.
223///
224/// Supports three path formats:
225/// - Simple key: `"name"` → `JSON.parse(r.body)['name']`
226/// - Dot-notation: `"config.enabled"` → `JSON.parse(r.body)['config']['enabled']`
227/// - Array bracket: `"items[].id"` → `JSON.parse(r.body)['items'][0]['id']`
228fn generate_field_accessor(field_name: &str) -> String {
229    // Split on dots, handling [] array notation
230    let parts: Vec<&str> = field_name.split('.').collect();
231    let mut expr = String::from("JSON.parse(r.body)");
232
233    for part in &parts {
234        if let Some(arr_name) = part.strip_suffix("[]") {
235            // Array field — access the array then index first element
236            expr.push_str(&format!("['{}'][0]", arr_name));
237        } else {
238            expr.push_str(&format!("['{}']", part));
239        }
240    }
241
242    expr
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_parse_custom_yaml() {
251        let yaml = r#"
252custom_checks:
253  - name: "custom:pets-returns-200"
254    path: /pets
255    method: GET
256    expected_status: 200
257  - name: "custom:create-product"
258    path: /api/products
259    method: POST
260    expected_status: 201
261    body: '{"sku": "TEST-001", "name": "Test"}'
262    expected_body_fields:
263      - name: id
264        type: integer
265    expected_headers:
266      content-type: "application/json"
267"#;
268        let config: CustomConformanceConfig = serde_yaml::from_str(yaml).unwrap();
269        assert_eq!(config.custom_checks.len(), 2);
270        assert_eq!(config.custom_checks[0].name, "custom:pets-returns-200");
271        assert_eq!(config.custom_checks[0].expected_status, 200);
272        assert_eq!(config.custom_checks[1].expected_body_fields.len(), 1);
273        assert_eq!(config.custom_checks[1].expected_body_fields[0].name, "id");
274        assert_eq!(config.custom_checks[1].expected_body_fields[0].field_type, "integer");
275    }
276
277    #[test]
278    fn test_generate_k6_group_get() {
279        let config = CustomConformanceConfig {
280            custom_checks: vec![CustomCheck {
281                name: "custom:test-get".to_string(),
282                path: "/api/test".to_string(),
283                method: "GET".to_string(),
284                expected_status: 200,
285                body: None,
286                expected_headers: std::collections::HashMap::new(),
287                expected_body_fields: vec![],
288                headers: std::collections::HashMap::new(),
289            }],
290        };
291
292        let script = config.generate_k6_group("BASE_URL", &[]);
293        assert!(script.contains("group('Custom'"));
294        assert!(script.contains("http.get(`${BASE_URL}/api/test`)"));
295        assert!(script.contains("'custom:test-get': (r) => r.status === 200"));
296    }
297
298    #[test]
299    fn test_generate_k6_group_post_with_body() {
300        let config = CustomConformanceConfig {
301            custom_checks: vec![CustomCheck {
302                name: "custom:create".to_string(),
303                path: "/api/items".to_string(),
304                method: "POST".to_string(),
305                expected_status: 201,
306                body: Some(r#"{"name": "test"}"#.to_string()),
307                expected_headers: std::collections::HashMap::new(),
308                expected_body_fields: vec![ExpectedBodyField {
309                    name: "id".to_string(),
310                    field_type: "integer".to_string(),
311                }],
312                headers: std::collections::HashMap::new(),
313            }],
314        };
315
316        let script = config.generate_k6_group("BASE_URL", &[]);
317        assert!(script.contains("http.post("));
318        assert!(script.contains("'custom:create': (r) => r.status === 201"));
319        assert!(script.contains("custom:create:body:id:integer"));
320        assert!(script.contains("Number.isInteger"));
321    }
322
323    #[test]
324    fn test_generate_k6_group_with_header_checks() {
325        let mut expected_headers = std::collections::HashMap::new();
326        expected_headers.insert("content-type".to_string(), "application/json".to_string());
327
328        let config = CustomConformanceConfig {
329            custom_checks: vec![CustomCheck {
330                name: "custom:header-check".to_string(),
331                path: "/api/test".to_string(),
332                method: "GET".to_string(),
333                expected_status: 200,
334                body: None,
335                expected_headers,
336                expected_body_fields: vec![],
337                headers: std::collections::HashMap::new(),
338            }],
339        };
340
341        let script = config.generate_k6_group("BASE_URL", &[]);
342        assert!(script.contains("custom:header-check:header:content-type"));
343        assert!(script.contains("new RegExp('application/json')"));
344    }
345
346    #[test]
347    fn test_generate_k6_group_with_custom_headers() {
348        let config = CustomConformanceConfig {
349            custom_checks: vec![CustomCheck {
350                name: "custom:auth-test".to_string(),
351                path: "/api/secure".to_string(),
352                method: "GET".to_string(),
353                expected_status: 200,
354                body: None,
355                expected_headers: std::collections::HashMap::new(),
356                expected_body_fields: vec![],
357                headers: std::collections::HashMap::new(),
358            }],
359        };
360
361        let custom_headers = vec![("Authorization".to_string(), "Bearer token123".to_string())];
362        let script = config.generate_k6_group("BASE_URL", &custom_headers);
363        assert!(script.contains("'Authorization': 'Bearer token123'"));
364    }
365
366    #[test]
367    fn test_failure_capture_emitted() {
368        let config = CustomConformanceConfig {
369            custom_checks: vec![CustomCheck {
370                name: "custom:capture-test".to_string(),
371                path: "/api/test".to_string(),
372                method: "GET".to_string(),
373                expected_status: 200,
374                body: None,
375                expected_headers: {
376                    let mut m = std::collections::HashMap::new();
377                    m.insert("X-Rate-Limit".to_string(), ".*".to_string());
378                    m
379                },
380                expected_body_fields: vec![ExpectedBodyField {
381                    name: "id".to_string(),
382                    field_type: "integer".to_string(),
383                }],
384                headers: std::collections::HashMap::new(),
385            }],
386        };
387
388        let script = config.generate_k6_group("BASE_URL", &[]);
389        // Status check should call __captureFailure on failure
390        assert!(
391            script.contains("__captureFailure('custom:capture-test', res, 'status === 200')"),
392            "Status check should emit __captureFailure"
393        );
394        // Header check should call __captureFailure on failure
395        assert!(
396            script.contains("__captureFailure('custom:capture-test:header:X-Rate-Limit'"),
397            "Header check should emit __captureFailure"
398        );
399        // Body field check should call __captureFailure on failure
400        assert!(
401            script.contains("__captureFailure('custom:capture-test:body:id:integer'"),
402            "Body field check should emit __captureFailure"
403        );
404    }
405
406    #[test]
407    fn test_from_file_nonexistent() {
408        let result = CustomConformanceConfig::from_file(Path::new("/nonexistent/file.yaml"));
409        assert!(result.is_err());
410        let err = result.unwrap_err().to_string();
411        assert!(err.contains("Failed to read custom conformance file"));
412    }
413
414    #[test]
415    fn test_generate_k6_group_delete() {
416        let config = CustomConformanceConfig {
417            custom_checks: vec![CustomCheck {
418                name: "custom:delete-item".to_string(),
419                path: "/api/items/1".to_string(),
420                method: "DELETE".to_string(),
421                expected_status: 204,
422                body: None,
423                expected_headers: std::collections::HashMap::new(),
424                expected_body_fields: vec![],
425                headers: std::collections::HashMap::new(),
426            }],
427        };
428
429        let script = config.generate_k6_group("BASE_URL", &[]);
430        assert!(script.contains("http.del("));
431        assert!(script.contains("r.status === 204"));
432    }
433
434    #[test]
435    fn test_field_accessor_simple() {
436        assert_eq!(generate_field_accessor("name"), "JSON.parse(r.body)['name']");
437    }
438
439    #[test]
440    fn test_field_accessor_nested_dot() {
441        assert_eq!(
442            generate_field_accessor("config.enabled"),
443            "JSON.parse(r.body)['config']['enabled']"
444        );
445    }
446
447    #[test]
448    fn test_field_accessor_array_bracket() {
449        assert_eq!(generate_field_accessor("items[].id"), "JSON.parse(r.body)['items'][0]['id']");
450    }
451
452    #[test]
453    fn test_field_accessor_deep_nested() {
454        assert_eq!(generate_field_accessor("a.b.c"), "JSON.parse(r.body)['a']['b']['c']");
455    }
456
457    #[test]
458    fn test_generate_k6_nested_body_fields() {
459        let config = CustomConformanceConfig {
460            custom_checks: vec![CustomCheck {
461                name: "custom:nested".to_string(),
462                path: "/api/data".to_string(),
463                method: "GET".to_string(),
464                expected_status: 200,
465                body: None,
466                expected_headers: std::collections::HashMap::new(),
467                expected_body_fields: vec![
468                    ExpectedBodyField {
469                        name: "count".to_string(),
470                        field_type: "integer".to_string(),
471                    },
472                    ExpectedBodyField {
473                        name: "results[].name".to_string(),
474                        field_type: "string".to_string(),
475                    },
476                ],
477                headers: std::collections::HashMap::new(),
478            }],
479        };
480
481        let script = config.generate_k6_group("BASE_URL", &[]);
482        // Simple field should use direct bracket access
483        assert!(script.contains("JSON.parse(r.body)['count']"));
484        // Nested array field should use [0] for array traversal
485        assert!(script.contains("JSON.parse(r.body)['results'][0]['name']"));
486    }
487}