Skip to main content

mockforge_bench/conformance/
custom.rs

1//! Custom conformance test authoring via YAML
2//!
3//! Allows users to define additional conformance checks beyond the built-in
4//! OpenAPI 3.0.0 feature set. Custom checks are grouped under a "Custom"
5//! category in the conformance report.
6
7use crate::error::{BenchError, Result};
8use serde::Deserialize;
9use std::path::Path;
10
11/// Top-level YAML configuration for custom conformance checks
12#[derive(Debug, Deserialize)]
13pub struct CustomConformanceConfig {
14    /// List of custom checks to run
15    pub custom_checks: Vec<CustomCheck>,
16    /// Round 38 (#79) — Srikanth on 0.3.182. Repeat the entire
17    /// `custom_checks` sequence N times so a "log in, do work,
18    /// log out" chain can be exercised under load. The
19    /// `${var:...}` / `${cookie:...}` substitution context is
20    /// reset at the start of each iteration; values captured in
21    /// iteration K are NOT visible to iteration K+1. Defaults to 1.
22    #[serde(default = "default_iterations")]
23    pub chain_iterations: u32,
24}
25
26fn default_iterations() -> u32 {
27    1
28}
29
30/// A single custom conformance check
31#[derive(Debug, Deserialize)]
32pub struct CustomCheck {
33    /// Check name (should start with "custom:" for report aggregation)
34    pub name: String,
35    /// Request path (e.g., "/api/users")
36    pub path: String,
37    /// HTTP method (GET, POST, PUT, DELETE, etc.)
38    pub method: String,
39    /// Expected HTTP status code
40    pub expected_status: u16,
41    /// Optional request body (JSON string)
42    #[serde(default)]
43    pub body: Option<String>,
44    /// Optional expected response headers (name -> regex pattern)
45    #[serde(default)]
46    pub expected_headers: std::collections::HashMap<String, String>,
47    /// Optional expected body fields with type validation
48    #[serde(default)]
49    pub expected_body_fields: Vec<ExpectedBodyField>,
50    /// Optional request headers
51    #[serde(default)]
52    pub headers: std::collections::HashMap<String, String>,
53
54    /// Round 38 (#79) — file upload support. When set, the request
55    /// is sent as `multipart/form-data` with one part per file. Each
56    /// file's bytes come from a local path (so the YAML can name a
57    /// `.exe`, `.jpg`, `.json`, `.docx`, `.xml`, etc. without
58    /// embedding the bytes). `body` wins over `upload`/`uploads`.
59    #[serde(default)]
60    pub upload: Option<UploadFile>,
61    #[serde(default)]
62    pub uploads: Vec<UploadFile>,
63
64    /// Round 38 (#79) — capture values from the response into the
65    /// chain context so subsequent checks can reference them via
66    /// `${var:NAME}`, `${cookie:NAME}`, `${header:NAME}` in path /
67    /// headers / body.
68    #[serde(default)]
69    pub extract: ExtractRules,
70
71    /// Round 38 (#79) — repeat the check N times within an
72    /// iteration. `mode: parallel` fires N concurrent requests
73    /// (Srikanth's Sequence 1: "Use that cookie and csrf token in 16
74    /// subsequent requests that should be sent in parallel").
75    /// `mode: sequential` runs them one after another (Sequence 2).
76    #[serde(default)]
77    pub repeat: Repeat,
78}
79
80/// Expected field in the response body with type checking
81#[derive(Debug, Deserialize)]
82pub struct ExpectedBodyField {
83    /// Field name in the JSON response
84    pub name: String,
85    /// Expected JSON type: "string", "integer", "number", "boolean", "array", "object"
86    #[serde(rename = "type")]
87    pub field_type: String,
88}
89
90/// Round 38 (#79) — a single file to upload as a multipart form part.
91#[derive(Debug, Clone, Deserialize)]
92pub struct UploadFile {
93    /// Local path to the file; bytes are read at request time.
94    pub path: String,
95    /// `Content-Type` for this part. Common values:
96    /// `application/octet-stream`, `image/jpeg`, `application/json`,
97    /// `application/xml`.
98    #[serde(default = "default_upload_content_type")]
99    pub content_type: String,
100    /// Multipart form field name. Defaults to `"file"`.
101    #[serde(default = "default_upload_field_name")]
102    pub field_name: String,
103    /// Filename announced to the server. Defaults to the basename
104    /// of `path`.
105    #[serde(default)]
106    pub filename: Option<String>,
107}
108
109fn default_upload_content_type() -> String {
110    "application/octet-stream".to_string()
111}
112fn default_upload_field_name() -> String {
113    "file".to_string()
114}
115
116/// Round 38 (#79) — what to capture from a check's response.
117#[derive(Debug, Clone, Default, Deserialize)]
118pub struct ExtractRules {
119    /// Cookie names to capture from `Set-Cookie`. Stored under
120    /// `${cookie:NAME}`.
121    #[serde(default)]
122    pub cookies: Vec<String>,
123    /// Response headers to capture (var_name -> header_name). Header
124    /// name is case-insensitive. Stored under `${var:VAR_NAME}`.
125    #[serde(default)]
126    pub headers: std::collections::HashMap<String, String>,
127    /// JSON body fields by simple dotted path. Stored under
128    /// `${var:VAR_NAME}`.
129    #[serde(default)]
130    pub body_fields: std::collections::HashMap<String, String>,
131}
132
133impl ExtractRules {
134    pub fn is_empty(&self) -> bool {
135        self.cookies.is_empty() && self.headers.is_empty() && self.body_fields.is_empty()
136    }
137}
138
139/// Round 38 (#79) — repeat semantics for a single custom check.
140#[derive(Debug, Clone, Deserialize)]
141pub struct Repeat {
142    #[serde(default = "default_repeat_count")]
143    pub count: u32,
144    #[serde(default)]
145    pub mode: RepeatMode,
146}
147
148impl Default for Repeat {
149    fn default() -> Self {
150        Self {
151            count: 1,
152            mode: RepeatMode::default(),
153        }
154    }
155}
156
157impl Repeat {
158    pub fn is_default(&self) -> bool {
159        self.count == 1 && matches!(self.mode, RepeatMode::Sequential)
160    }
161}
162
163fn default_repeat_count() -> u32 {
164    1
165}
166
167/// Round 38 (#79) — sequential vs parallel repeat.
168#[derive(Debug, Clone, Default, Deserialize, PartialEq)]
169#[serde(rename_all = "lowercase")]
170pub enum RepeatMode {
171    #[default]
172    Sequential,
173    Parallel,
174}
175
176impl CustomConformanceConfig {
177    /// Parse a custom conformance config from a YAML file
178    pub fn from_file(path: &Path) -> Result<Self> {
179        let content = std::fs::read_to_string(path).map_err(|e| {
180            BenchError::Other(format!(
181                "Failed to read custom conformance file '{}': {}",
182                path.display(),
183                e
184            ))
185        })?;
186        serde_yaml::from_str(&content).map_err(|e| {
187            BenchError::Other(format!(
188                "Failed to parse custom conformance YAML '{}': {}",
189                path.display(),
190                e
191            ))
192        })
193    }
194
195    /// Generate a k6 `group('Custom', ...)` block for all custom checks.
196    ///
197    /// `base_url` is the JS expression for the base URL (e.g., `"BASE_URL"`).
198    /// `custom_headers` are additional headers to inject into every request.
199    pub fn generate_k6_group(&self, base_url: &str, custom_headers: &[(String, String)]) -> String {
200        self.generate_k6_group_with_options(base_url, custom_headers, false)
201    }
202
203    /// Generate a k6 `group('Custom', ...)` block for all custom checks.
204    /// When `export_requests` is true, emits `__captureExchange` calls after each request.
205    pub fn generate_k6_group_with_options(
206        &self,
207        base_url: &str,
208        custom_headers: &[(String, String)],
209        export_requests: bool,
210    ) -> String {
211        let mut script = String::with_capacity(4096);
212        script.push_str("  group('Custom', function () {\n");
213
214        for check in &self.custom_checks {
215            script.push_str("    {\n");
216
217            // Build headers object
218            let mut all_headers: Vec<(String, String)> = Vec::new();
219            // Add check-specific headers
220            for (k, v) in &check.headers {
221                all_headers.push((k.clone(), v.clone()));
222            }
223            // Add global custom headers (check-specific take priority)
224            for (k, v) in custom_headers {
225                if !check.headers.contains_key(k) {
226                    all_headers.push((k.clone(), v.clone()));
227                }
228            }
229            // If posting JSON body, add Content-Type
230            if check.body.is_some()
231                && !all_headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("content-type"))
232            {
233                all_headers.push(("Content-Type".to_string(), "application/json".to_string()));
234            }
235
236            let headers_js = if all_headers.is_empty() {
237                "{}".to_string()
238            } else {
239                let entries: Vec<String> = all_headers
240                    .iter()
241                    .map(|(k, v)| format!("'{}': '{}'", k, v.replace('\'', "\\'")))
242                    .collect();
243                format!("{{ {} }}", entries.join(", "))
244            };
245
246            let method = check.method.to_uppercase();
247            let url = format!("${{{}}}{}", base_url, check.path);
248            let escaped_name = check.name.replace('\'', "\\'");
249
250            match method.as_str() {
251                "GET" | "HEAD" | "OPTIONS" | "DELETE" => {
252                    let k6_method = match method.as_str() {
253                        "DELETE" => "del",
254                        other => &other.to_lowercase(),
255                    };
256                    if all_headers.is_empty() {
257                        script
258                            .push_str(&format!("      let res = http.{}(`{}`);\n", k6_method, url));
259                    } else {
260                        script.push_str(&format!(
261                            "      let res = http.{}(`{}`, {{ headers: {} }});\n",
262                            k6_method, url, headers_js
263                        ));
264                    }
265                }
266                _ => {
267                    // POST, PUT, PATCH
268                    let k6_method = method.to_lowercase();
269                    let body_expr = match &check.body {
270                        Some(b) => format!(
271                            "'{}'",
272                            b.replace('\\', "\\\\")
273                                .replace('\'', "\\'")
274                                .replace('\n', "\\n")
275                                .replace('\r', "\\r")
276                                .replace('\t', "\\t")
277                        ),
278                        None => "null".to_string(),
279                    };
280                    script.push_str(&format!(
281                        "      let res = http.{}(`{}`, {}, {{ headers: {} }});\n",
282                        k6_method, url, body_expr, headers_js
283                    ));
284                }
285            }
286
287            // Capture request/response when --export-requests is enabled
288            if export_requests {
289                script.push_str(&format!(
290                    "      if (typeof __captureExchange === 'function') __captureExchange('{}', res);\n",
291                    escaped_name
292                ));
293            }
294
295            // Status check with failure detail capture
296            script.push_str(&format!(
297                "      {{ let ok = check(res, {{ '{}': (r) => r.status === {} }}); if (!ok) __captureFailure('{}', res, 'status === {}'); }}\n",
298                escaped_name, check.expected_status, escaped_name, check.expected_status
299            ));
300
301            // Header checks with failure detail capture.
302            // k6 canonicalizes response header names (e.g. `X-XSS-Protection` ->
303            // `X-Xss-Protection`), so match header names case-insensitively.
304            for (header_name, pattern) in &check.expected_headers {
305                let header_check_name = format!("{}:header:{}", escaped_name, header_name);
306                let escaped_pattern = pattern.replace('\\', "\\\\").replace('\'', "\\'");
307                let header_lower = header_name.to_lowercase();
308                script.push_str(&format!(
309                    "      {{ let ok = check(res, {{ '{}': (r) => {{ const _hk = Object.keys(r.headers || {{}}).find(k => k.toLowerCase() === '{}'); return new RegExp('{}').test(_hk ? r.headers[_hk] : ''); }} }}); if (!ok) __captureFailure('{}', res, 'header {} matches /{}/ '); }}\n",
310                    header_check_name,
311                    header_lower,
312                    escaped_pattern,
313                    header_check_name,
314                    header_name,
315                    escaped_pattern
316                ));
317            }
318
319            // Body field checks
320            for field in &check.expected_body_fields {
321                let field_check_name =
322                    format!("{}:body:{}:{}", escaped_name, field.name, field.field_type);
323                // Generate JS expression to access the field value, supporting
324                // nested paths like "results.name" and "items[].id"
325                let accessor = generate_field_accessor(&field.name);
326                let type_check = match field.field_type.as_str() {
327                    "string" => format!("typeof ({}) === 'string'", accessor),
328                    "integer" => format!("Number.isInteger({})", accessor),
329                    "number" => format!("typeof ({}) === 'number'", accessor),
330                    "boolean" => format!("typeof ({}) === 'boolean'", accessor),
331                    "array" => format!("Array.isArray({})", accessor),
332                    "object" => format!(
333                        "typeof ({}) === 'object' && !Array.isArray({})",
334                        accessor, accessor
335                    ),
336                    _ => format!("({}) !== undefined", accessor),
337                };
338                script.push_str(&format!(
339                    "      {{ let ok = check(res, {{ '{}': (r) => {{ try {{ return {}; }} catch(e) {{ return false; }} }} }}); if (!ok) __captureFailure('{}', res, 'body field {} is {}'); }}\n",
340                    field_check_name, type_check, field_check_name, field.name, field.field_type
341                ));
342            }
343
344            script.push_str("    }\n");
345        }
346
347        script.push_str("  });\n\n");
348        script
349    }
350}
351
352/// Generate a JavaScript expression to access a field in a parsed JSON body.
353///
354/// Supports three path formats:
355/// - Simple key: `"name"` → `JSON.parse(r.body)['name']`
356/// - Dot-notation: `"config.enabled"` → `JSON.parse(r.body)['config']['enabled']`
357/// - Array bracket: `"items[].id"` → `JSON.parse(r.body)['items'][0]['id']`
358fn generate_field_accessor(field_name: &str) -> String {
359    // Split on dots, handling [] array notation
360    let parts: Vec<&str> = field_name.split('.').collect();
361    let mut expr = String::from("JSON.parse(r.body)");
362
363    for part in &parts {
364        if let Some(arr_name) = part.strip_suffix("[]") {
365            // Array field — access the array then index first element
366            expr.push_str(&format!("['{}'][0]", arr_name));
367        } else {
368            expr.push_str(&format!("['{}']", part));
369        }
370    }
371
372    expr
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[test]
380    fn test_parse_custom_yaml() {
381        let yaml = r#"
382custom_checks:
383  - name: "custom:pets-returns-200"
384    path: /pets
385    method: GET
386    expected_status: 200
387  - name: "custom:create-product"
388    path: /api/products
389    method: POST
390    expected_status: 201
391    body: '{"sku": "TEST-001", "name": "Test"}'
392    expected_body_fields:
393      - name: id
394        type: integer
395    expected_headers:
396      content-type: "application/json"
397"#;
398        let config: CustomConformanceConfig = serde_yaml::from_str(yaml).unwrap();
399        assert_eq!(config.custom_checks.len(), 2);
400        assert_eq!(config.custom_checks[0].name, "custom:pets-returns-200");
401        assert_eq!(config.custom_checks[0].expected_status, 200);
402        assert_eq!(config.custom_checks[1].expected_body_fields.len(), 1);
403        assert_eq!(config.custom_checks[1].expected_body_fields[0].name, "id");
404        assert_eq!(config.custom_checks[1].expected_body_fields[0].field_type, "integer");
405    }
406
407    #[test]
408    fn test_generate_k6_group_get() {
409        let config = CustomConformanceConfig {
410            custom_checks: vec![CustomCheck {
411                name: "custom:test-get".to_string(),
412                path: "/api/test".to_string(),
413                method: "GET".to_string(),
414                expected_status: 200,
415                body: None,
416                expected_headers: std::collections::HashMap::new(),
417                expected_body_fields: vec![],
418                headers: std::collections::HashMap::new(),
419                upload: None,
420                uploads: vec![],
421                extract: ExtractRules::default(),
422                repeat: Repeat::default(),
423            }],
424            chain_iterations: 1,
425        };
426
427        let script = config.generate_k6_group("BASE_URL", &[]);
428        assert!(script.contains("group('Custom'"));
429        assert!(script.contains("http.get(`${BASE_URL}/api/test`)"));
430        assert!(script.contains("'custom:test-get': (r) => r.status === 200"));
431    }
432
433    #[test]
434    fn test_generate_k6_group_post_with_body() {
435        let config = CustomConformanceConfig {
436            custom_checks: vec![CustomCheck {
437                name: "custom:create".to_string(),
438                path: "/api/items".to_string(),
439                method: "POST".to_string(),
440                expected_status: 201,
441                body: Some(r#"{"name": "test"}"#.to_string()),
442                expected_headers: std::collections::HashMap::new(),
443                expected_body_fields: vec![ExpectedBodyField {
444                    name: "id".to_string(),
445                    field_type: "integer".to_string(),
446                }],
447                headers: std::collections::HashMap::new(),
448                upload: None,
449                uploads: vec![],
450                extract: ExtractRules::default(),
451                repeat: Repeat::default(),
452            }],
453            chain_iterations: 1,
454        };
455
456        let script = config.generate_k6_group("BASE_URL", &[]);
457        assert!(script.contains("http.post("));
458        assert!(script.contains("'custom:create': (r) => r.status === 201"));
459        assert!(script.contains("custom:create:body:id:integer"));
460        assert!(script.contains("Number.isInteger"));
461    }
462
463    #[test]
464    fn test_generate_k6_group_with_header_checks() {
465        let mut expected_headers = std::collections::HashMap::new();
466        expected_headers.insert("content-type".to_string(), "application/json".to_string());
467
468        let config = CustomConformanceConfig {
469            custom_checks: vec![CustomCheck {
470                name: "custom:header-check".to_string(),
471                path: "/api/test".to_string(),
472                method: "GET".to_string(),
473                expected_status: 200,
474                body: None,
475                expected_headers,
476                expected_body_fields: vec![],
477                headers: std::collections::HashMap::new(),
478                upload: None,
479                uploads: vec![],
480                extract: ExtractRules::default(),
481                repeat: Repeat::default(),
482            }],
483            chain_iterations: 1,
484        };
485
486        let script = config.generate_k6_group("BASE_URL", &[]);
487        assert!(script.contains("custom:header-check:header:content-type"));
488        assert!(script.contains("new RegExp('application/json')"));
489    }
490
491    #[test]
492    fn test_generate_k6_group_with_custom_headers() {
493        let config = CustomConformanceConfig {
494            custom_checks: vec![CustomCheck {
495                name: "custom:auth-test".to_string(),
496                path: "/api/secure".to_string(),
497                method: "GET".to_string(),
498                expected_status: 200,
499                body: None,
500                expected_headers: std::collections::HashMap::new(),
501                expected_body_fields: vec![],
502                headers: std::collections::HashMap::new(),
503                upload: None,
504                uploads: vec![],
505                extract: ExtractRules::default(),
506                repeat: Repeat::default(),
507            }],
508            chain_iterations: 1,
509        };
510
511        let custom_headers = vec![("Authorization".to_string(), "Bearer token123".to_string())];
512        let script = config.generate_k6_group("BASE_URL", &custom_headers);
513        assert!(script.contains("'Authorization': 'Bearer token123'"));
514    }
515
516    #[test]
517    fn test_failure_capture_emitted() {
518        let config = CustomConformanceConfig {
519            custom_checks: vec![CustomCheck {
520                name: "custom:capture-test".to_string(),
521                path: "/api/test".to_string(),
522                method: "GET".to_string(),
523                expected_status: 200,
524                body: None,
525                expected_headers: {
526                    let mut m = std::collections::HashMap::new();
527                    m.insert("X-Rate-Limit".to_string(), ".*".to_string());
528                    m
529                },
530                expected_body_fields: vec![ExpectedBodyField {
531                    name: "id".to_string(),
532                    field_type: "integer".to_string(),
533                }],
534                headers: std::collections::HashMap::new(),
535                upload: None,
536                uploads: vec![],
537                extract: ExtractRules::default(),
538                repeat: Repeat::default(),
539            }],
540            chain_iterations: 1,
541        };
542
543        let script = config.generate_k6_group("BASE_URL", &[]);
544        // Status check should call __captureFailure on failure
545        assert!(
546            script.contains("__captureFailure('custom:capture-test', res, 'status === 200')"),
547            "Status check should emit __captureFailure"
548        );
549        // Header check should call __captureFailure on failure
550        assert!(
551            script.contains("__captureFailure('custom:capture-test:header:X-Rate-Limit'"),
552            "Header check should emit __captureFailure"
553        );
554        // Body field check should call __captureFailure on failure
555        assert!(
556            script.contains("__captureFailure('custom:capture-test:body:id:integer'"),
557            "Body field check should emit __captureFailure"
558        );
559    }
560
561    #[test]
562    fn test_from_file_nonexistent() {
563        let result = CustomConformanceConfig::from_file(Path::new("/nonexistent/file.yaml"));
564        assert!(result.is_err());
565        let err = result.unwrap_err().to_string();
566        assert!(err.contains("Failed to read custom conformance file"));
567    }
568
569    #[test]
570    fn test_generate_k6_group_delete() {
571        let config = CustomConformanceConfig {
572            custom_checks: vec![CustomCheck {
573                name: "custom:delete-item".to_string(),
574                path: "/api/items/1".to_string(),
575                method: "DELETE".to_string(),
576                expected_status: 204,
577                body: None,
578                expected_headers: std::collections::HashMap::new(),
579                expected_body_fields: vec![],
580                headers: std::collections::HashMap::new(),
581                upload: None,
582                uploads: vec![],
583                extract: ExtractRules::default(),
584                repeat: Repeat::default(),
585            }],
586            chain_iterations: 1,
587        };
588
589        let script = config.generate_k6_group("BASE_URL", &[]);
590        assert!(script.contains("http.del("));
591        assert!(script.contains("r.status === 204"));
592    }
593
594    #[test]
595    fn test_field_accessor_simple() {
596        assert_eq!(generate_field_accessor("name"), "JSON.parse(r.body)['name']");
597    }
598
599    #[test]
600    fn test_field_accessor_nested_dot() {
601        assert_eq!(
602            generate_field_accessor("config.enabled"),
603            "JSON.parse(r.body)['config']['enabled']"
604        );
605    }
606
607    #[test]
608    fn test_field_accessor_array_bracket() {
609        assert_eq!(generate_field_accessor("items[].id"), "JSON.parse(r.body)['items'][0]['id']");
610    }
611
612    #[test]
613    fn test_field_accessor_deep_nested() {
614        assert_eq!(generate_field_accessor("a.b.c"), "JSON.parse(r.body)['a']['b']['c']");
615    }
616
617    #[test]
618    fn test_generate_k6_nested_body_fields() {
619        let config = CustomConformanceConfig {
620            custom_checks: vec![CustomCheck {
621                name: "custom:nested".to_string(),
622                path: "/api/data".to_string(),
623                method: "GET".to_string(),
624                expected_status: 200,
625                body: None,
626                expected_headers: std::collections::HashMap::new(),
627                expected_body_fields: vec![
628                    ExpectedBodyField {
629                        name: "count".to_string(),
630                        field_type: "integer".to_string(),
631                    },
632                    ExpectedBodyField {
633                        name: "results[].name".to_string(),
634                        field_type: "string".to_string(),
635                    },
636                ],
637                headers: std::collections::HashMap::new(),
638                upload: None,
639                uploads: vec![],
640                extract: ExtractRules::default(),
641                repeat: Repeat::default(),
642            }],
643            chain_iterations: 1,
644        };
645
646        let script = config.generate_k6_group("BASE_URL", &[]);
647        // Simple field should use direct bracket access
648        assert!(script.contains("JSON.parse(r.body)['count']"));
649        // Nested array field should use [0] for array traversal
650        assert!(script.contains("JSON.parse(r.body)['results'][0]['name']"));
651    }
652}