Skip to main content

mockforge_bench/conformance/
custom.rs

1//! Custom conformance test authoring via YAML
2//!
3//! Allows users to define additional conformance checks beyond the built-in
4//! OpenAPI 3.0.0 feature set. Custom checks are grouped under a "Custom"
5//! category in the conformance report.
6
7use crate::error::{BenchError, Result};
8use serde::Deserialize;
9use std::path::Path;
10
11/// Top-level YAML configuration for custom conformance checks
12#[derive(Debug, Deserialize)]
13pub struct CustomConformanceConfig {
14    /// List of custom checks to run
15    pub custom_checks: Vec<CustomCheck>,
16}
17
18/// A single custom conformance check
19#[derive(Debug, Deserialize)]
20pub struct CustomCheck {
21    /// Check name (should start with "custom:" for report aggregation)
22    pub name: String,
23    /// Request path (e.g., "/api/users")
24    pub path: String,
25    /// HTTP method (GET, POST, PUT, DELETE, etc.)
26    pub method: String,
27    /// Expected HTTP status code
28    pub expected_status: u16,
29    /// Optional request body (JSON string)
30    #[serde(default)]
31    pub body: Option<String>,
32    /// Optional expected response headers (name -> regex pattern)
33    #[serde(default)]
34    pub expected_headers: std::collections::HashMap<String, String>,
35    /// Optional expected body fields with type validation
36    #[serde(default)]
37    pub expected_body_fields: Vec<ExpectedBodyField>,
38    /// Optional request headers
39    #[serde(default)]
40    pub headers: std::collections::HashMap<String, String>,
41}
42
43/// Expected field in the response body with type checking
44#[derive(Debug, Deserialize)]
45pub struct ExpectedBodyField {
46    /// Field name in the JSON response
47    pub name: String,
48    /// Expected JSON type: "string", "integer", "number", "boolean", "array", "object"
49    #[serde(rename = "type")]
50    pub field_type: String,
51}
52
53impl CustomConformanceConfig {
54    /// Parse a custom conformance config from a YAML file
55    pub fn from_file(path: &Path) -> Result<Self> {
56        let content = std::fs::read_to_string(path).map_err(|e| {
57            BenchError::Other(format!(
58                "Failed to read custom conformance file '{}': {}",
59                path.display(),
60                e
61            ))
62        })?;
63        serde_yaml::from_str(&content).map_err(|e| {
64            BenchError::Other(format!(
65                "Failed to parse custom conformance YAML '{}': {}",
66                path.display(),
67                e
68            ))
69        })
70    }
71
72    /// Generate a k6 `group('Custom', ...)` block for all custom checks.
73    ///
74    /// `base_url` is the JS expression for the base URL (e.g., `"BASE_URL"`).
75    /// `custom_headers` are additional headers to inject into every request.
76    pub fn generate_k6_group(&self, base_url: &str, custom_headers: &[(String, String)]) -> String {
77        self.generate_k6_group_with_options(base_url, custom_headers, false)
78    }
79
80    /// Generate a k6 `group('Custom', ...)` block for all custom checks.
81    /// When `export_requests` is true, emits `__captureExchange` calls after each request.
82    pub fn generate_k6_group_with_options(
83        &self,
84        base_url: &str,
85        custom_headers: &[(String, String)],
86        export_requests: bool,
87    ) -> String {
88        let mut script = String::with_capacity(4096);
89        script.push_str("  group('Custom', function () {\n");
90
91        for check in &self.custom_checks {
92            script.push_str("    {\n");
93
94            // Build headers object
95            let mut all_headers: Vec<(String, String)> = Vec::new();
96            // Add check-specific headers
97            for (k, v) in &check.headers {
98                all_headers.push((k.clone(), v.clone()));
99            }
100            // Add global custom headers (check-specific take priority)
101            for (k, v) in custom_headers {
102                if !check.headers.contains_key(k) {
103                    all_headers.push((k.clone(), v.clone()));
104                }
105            }
106            // If posting JSON body, add Content-Type
107            if check.body.is_some()
108                && !all_headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("content-type"))
109            {
110                all_headers.push(("Content-Type".to_string(), "application/json".to_string()));
111            }
112
113            let headers_js = if all_headers.is_empty() {
114                "{}".to_string()
115            } else {
116                let entries: Vec<String> = all_headers
117                    .iter()
118                    .map(|(k, v)| format!("'{}': '{}'", k, v.replace('\'', "\\'")))
119                    .collect();
120                format!("{{ {} }}", entries.join(", "))
121            };
122
123            let method = check.method.to_uppercase();
124            let url = format!("${{{}}}{}", base_url, check.path);
125            let escaped_name = check.name.replace('\'', "\\'");
126
127            match method.as_str() {
128                "GET" | "HEAD" | "OPTIONS" | "DELETE" => {
129                    let k6_method = match method.as_str() {
130                        "DELETE" => "del",
131                        other => &other.to_lowercase(),
132                    };
133                    if all_headers.is_empty() {
134                        script
135                            .push_str(&format!("      let res = http.{}(`{}`);\n", k6_method, url));
136                    } else {
137                        script.push_str(&format!(
138                            "      let res = http.{}(`{}`, {{ headers: {} }});\n",
139                            k6_method, url, headers_js
140                        ));
141                    }
142                }
143                _ => {
144                    // POST, PUT, PATCH
145                    let k6_method = method.to_lowercase();
146                    let body_expr = match &check.body {
147                        Some(b) => format!(
148                            "'{}'",
149                            b.replace('\\', "\\\\")
150                                .replace('\'', "\\'")
151                                .replace('\n', "\\n")
152                                .replace('\r', "\\r")
153                                .replace('\t', "\\t")
154                        ),
155                        None => "null".to_string(),
156                    };
157                    script.push_str(&format!(
158                        "      let res = http.{}(`{}`, {}, {{ headers: {} }});\n",
159                        k6_method, url, body_expr, headers_js
160                    ));
161                }
162            }
163
164            // Capture request/response when --export-requests is enabled
165            if export_requests {
166                script.push_str(&format!(
167                    "      if (typeof __captureExchange === 'function') __captureExchange('{}', res);\n",
168                    escaped_name
169                ));
170            }
171
172            // Status check with failure detail capture
173            script.push_str(&format!(
174                "      {{ let ok = check(res, {{ '{}': (r) => r.status === {} }}); if (!ok) __captureFailure('{}', res, 'status === {}'); }}\n",
175                escaped_name, check.expected_status, escaped_name, check.expected_status
176            ));
177
178            // Header checks with failure detail capture.
179            // k6 canonicalizes response header names (e.g. `X-XSS-Protection` ->
180            // `X-Xss-Protection`), so match header names case-insensitively.
181            for (header_name, pattern) in &check.expected_headers {
182                let header_check_name = format!("{}:header:{}", escaped_name, header_name);
183                let escaped_pattern = pattern.replace('\\', "\\\\").replace('\'', "\\'");
184                let header_lower = header_name.to_lowercase();
185                script.push_str(&format!(
186                    "      {{ let ok = check(res, {{ '{}': (r) => {{ const _hk = Object.keys(r.headers || {{}}).find(k => k.toLowerCase() === '{}'); return new RegExp('{}').test(_hk ? r.headers[_hk] : ''); }} }}); if (!ok) __captureFailure('{}', res, 'header {} matches /{}/ '); }}\n",
187                    header_check_name,
188                    header_lower,
189                    escaped_pattern,
190                    header_check_name,
191                    header_name,
192                    escaped_pattern
193                ));
194            }
195
196            // Body field checks
197            for field in &check.expected_body_fields {
198                let field_check_name =
199                    format!("{}:body:{}:{}", escaped_name, field.name, field.field_type);
200                // Generate JS expression to access the field value, supporting
201                // nested paths like "results.name" and "items[].id"
202                let accessor = generate_field_accessor(&field.name);
203                let type_check = match field.field_type.as_str() {
204                    "string" => format!("typeof ({}) === 'string'", accessor),
205                    "integer" => format!("Number.isInteger({})", accessor),
206                    "number" => format!("typeof ({}) === 'number'", accessor),
207                    "boolean" => format!("typeof ({}) === 'boolean'", accessor),
208                    "array" => format!("Array.isArray({})", accessor),
209                    "object" => format!(
210                        "typeof ({}) === 'object' && !Array.isArray({})",
211                        accessor, accessor
212                    ),
213                    _ => format!("({}) !== undefined", accessor),
214                };
215                script.push_str(&format!(
216                    "      {{ let ok = check(res, {{ '{}': (r) => {{ try {{ return {}; }} catch(e) {{ return false; }} }} }}); if (!ok) __captureFailure('{}', res, 'body field {} is {}'); }}\n",
217                    field_check_name, type_check, field_check_name, field.name, field.field_type
218                ));
219            }
220
221            script.push_str("    }\n");
222        }
223
224        script.push_str("  });\n\n");
225        script
226    }
227}
228
229/// Generate a JavaScript expression to access a field in a parsed JSON body.
230///
231/// Supports three path formats:
232/// - Simple key: `"name"` → `JSON.parse(r.body)['name']`
233/// - Dot-notation: `"config.enabled"` → `JSON.parse(r.body)['config']['enabled']`
234/// - Array bracket: `"items[].id"` → `JSON.parse(r.body)['items'][0]['id']`
235fn generate_field_accessor(field_name: &str) -> String {
236    // Split on dots, handling [] array notation
237    let parts: Vec<&str> = field_name.split('.').collect();
238    let mut expr = String::from("JSON.parse(r.body)");
239
240    for part in &parts {
241        if let Some(arr_name) = part.strip_suffix("[]") {
242            // Array field — access the array then index first element
243            expr.push_str(&format!("['{}'][0]", arr_name));
244        } else {
245            expr.push_str(&format!("['{}']", part));
246        }
247    }
248
249    expr
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_parse_custom_yaml() {
258        let yaml = r#"
259custom_checks:
260  - name: "custom:pets-returns-200"
261    path: /pets
262    method: GET
263    expected_status: 200
264  - name: "custom:create-product"
265    path: /api/products
266    method: POST
267    expected_status: 201
268    body: '{"sku": "TEST-001", "name": "Test"}'
269    expected_body_fields:
270      - name: id
271        type: integer
272    expected_headers:
273      content-type: "application/json"
274"#;
275        let config: CustomConformanceConfig = serde_yaml::from_str(yaml).unwrap();
276        assert_eq!(config.custom_checks.len(), 2);
277        assert_eq!(config.custom_checks[0].name, "custom:pets-returns-200");
278        assert_eq!(config.custom_checks[0].expected_status, 200);
279        assert_eq!(config.custom_checks[1].expected_body_fields.len(), 1);
280        assert_eq!(config.custom_checks[1].expected_body_fields[0].name, "id");
281        assert_eq!(config.custom_checks[1].expected_body_fields[0].field_type, "integer");
282    }
283
284    #[test]
285    fn test_generate_k6_group_get() {
286        let config = CustomConformanceConfig {
287            custom_checks: vec![CustomCheck {
288                name: "custom:test-get".to_string(),
289                path: "/api/test".to_string(),
290                method: "GET".to_string(),
291                expected_status: 200,
292                body: None,
293                expected_headers: std::collections::HashMap::new(),
294                expected_body_fields: vec![],
295                headers: std::collections::HashMap::new(),
296            }],
297        };
298
299        let script = config.generate_k6_group("BASE_URL", &[]);
300        assert!(script.contains("group('Custom'"));
301        assert!(script.contains("http.get(`${BASE_URL}/api/test`)"));
302        assert!(script.contains("'custom:test-get': (r) => r.status === 200"));
303    }
304
305    #[test]
306    fn test_generate_k6_group_post_with_body() {
307        let config = CustomConformanceConfig {
308            custom_checks: vec![CustomCheck {
309                name: "custom:create".to_string(),
310                path: "/api/items".to_string(),
311                method: "POST".to_string(),
312                expected_status: 201,
313                body: Some(r#"{"name": "test"}"#.to_string()),
314                expected_headers: std::collections::HashMap::new(),
315                expected_body_fields: vec![ExpectedBodyField {
316                    name: "id".to_string(),
317                    field_type: "integer".to_string(),
318                }],
319                headers: std::collections::HashMap::new(),
320            }],
321        };
322
323        let script = config.generate_k6_group("BASE_URL", &[]);
324        assert!(script.contains("http.post("));
325        assert!(script.contains("'custom:create': (r) => r.status === 201"));
326        assert!(script.contains("custom:create:body:id:integer"));
327        assert!(script.contains("Number.isInteger"));
328    }
329
330    #[test]
331    fn test_generate_k6_group_with_header_checks() {
332        let mut expected_headers = std::collections::HashMap::new();
333        expected_headers.insert("content-type".to_string(), "application/json".to_string());
334
335        let config = CustomConformanceConfig {
336            custom_checks: vec![CustomCheck {
337                name: "custom:header-check".to_string(),
338                path: "/api/test".to_string(),
339                method: "GET".to_string(),
340                expected_status: 200,
341                body: None,
342                expected_headers,
343                expected_body_fields: vec![],
344                headers: std::collections::HashMap::new(),
345            }],
346        };
347
348        let script = config.generate_k6_group("BASE_URL", &[]);
349        assert!(script.contains("custom:header-check:header:content-type"));
350        assert!(script.contains("new RegExp('application/json')"));
351    }
352
353    #[test]
354    fn test_generate_k6_group_with_custom_headers() {
355        let config = CustomConformanceConfig {
356            custom_checks: vec![CustomCheck {
357                name: "custom:auth-test".to_string(),
358                path: "/api/secure".to_string(),
359                method: "GET".to_string(),
360                expected_status: 200,
361                body: None,
362                expected_headers: std::collections::HashMap::new(),
363                expected_body_fields: vec![],
364                headers: std::collections::HashMap::new(),
365            }],
366        };
367
368        let custom_headers = vec![("Authorization".to_string(), "Bearer token123".to_string())];
369        let script = config.generate_k6_group("BASE_URL", &custom_headers);
370        assert!(script.contains("'Authorization': 'Bearer token123'"));
371    }
372
373    #[test]
374    fn test_failure_capture_emitted() {
375        let config = CustomConformanceConfig {
376            custom_checks: vec![CustomCheck {
377                name: "custom:capture-test".to_string(),
378                path: "/api/test".to_string(),
379                method: "GET".to_string(),
380                expected_status: 200,
381                body: None,
382                expected_headers: {
383                    let mut m = std::collections::HashMap::new();
384                    m.insert("X-Rate-Limit".to_string(), ".*".to_string());
385                    m
386                },
387                expected_body_fields: vec![ExpectedBodyField {
388                    name: "id".to_string(),
389                    field_type: "integer".to_string(),
390                }],
391                headers: std::collections::HashMap::new(),
392            }],
393        };
394
395        let script = config.generate_k6_group("BASE_URL", &[]);
396        // Status check should call __captureFailure on failure
397        assert!(
398            script.contains("__captureFailure('custom:capture-test', res, 'status === 200')"),
399            "Status check should emit __captureFailure"
400        );
401        // Header check should call __captureFailure on failure
402        assert!(
403            script.contains("__captureFailure('custom:capture-test:header:X-Rate-Limit'"),
404            "Header check should emit __captureFailure"
405        );
406        // Body field check should call __captureFailure on failure
407        assert!(
408            script.contains("__captureFailure('custom:capture-test:body:id:integer'"),
409            "Body field check should emit __captureFailure"
410        );
411    }
412
413    #[test]
414    fn test_from_file_nonexistent() {
415        let result = CustomConformanceConfig::from_file(Path::new("/nonexistent/file.yaml"));
416        assert!(result.is_err());
417        let err = result.unwrap_err().to_string();
418        assert!(err.contains("Failed to read custom conformance file"));
419    }
420
421    #[test]
422    fn test_generate_k6_group_delete() {
423        let config = CustomConformanceConfig {
424            custom_checks: vec![CustomCheck {
425                name: "custom:delete-item".to_string(),
426                path: "/api/items/1".to_string(),
427                method: "DELETE".to_string(),
428                expected_status: 204,
429                body: None,
430                expected_headers: std::collections::HashMap::new(),
431                expected_body_fields: vec![],
432                headers: std::collections::HashMap::new(),
433            }],
434        };
435
436        let script = config.generate_k6_group("BASE_URL", &[]);
437        assert!(script.contains("http.del("));
438        assert!(script.contains("r.status === 204"));
439    }
440
441    #[test]
442    fn test_field_accessor_simple() {
443        assert_eq!(generate_field_accessor("name"), "JSON.parse(r.body)['name']");
444    }
445
446    #[test]
447    fn test_field_accessor_nested_dot() {
448        assert_eq!(
449            generate_field_accessor("config.enabled"),
450            "JSON.parse(r.body)['config']['enabled']"
451        );
452    }
453
454    #[test]
455    fn test_field_accessor_array_bracket() {
456        assert_eq!(generate_field_accessor("items[].id"), "JSON.parse(r.body)['items'][0]['id']");
457    }
458
459    #[test]
460    fn test_field_accessor_deep_nested() {
461        assert_eq!(generate_field_accessor("a.b.c"), "JSON.parse(r.body)['a']['b']['c']");
462    }
463
464    #[test]
465    fn test_generate_k6_nested_body_fields() {
466        let config = CustomConformanceConfig {
467            custom_checks: vec![CustomCheck {
468                name: "custom:nested".to_string(),
469                path: "/api/data".to_string(),
470                method: "GET".to_string(),
471                expected_status: 200,
472                body: None,
473                expected_headers: std::collections::HashMap::new(),
474                expected_body_fields: vec![
475                    ExpectedBodyField {
476                        name: "count".to_string(),
477                        field_type: "integer".to_string(),
478                    },
479                    ExpectedBodyField {
480                        name: "results[].name".to_string(),
481                        field_type: "string".to_string(),
482                    },
483                ],
484                headers: std::collections::HashMap::new(),
485            }],
486        };
487
488        let script = config.generate_k6_group("BASE_URL", &[]);
489        // Simple field should use direct bracket access
490        assert!(script.contains("JSON.parse(r.body)['count']"));
491        // Nested array field should use [0] for array traversal
492        assert!(script.contains("JSON.parse(r.body)['results'][0]['name']"));
493    }
494}