use crate::error::Result;
use crate::spec_parser::SpecParser;
use openapiv3::{OpenAPI, ReferenceOr};
use serde::Serialize;
use std::collections::HashMap;
use std::path::Path;
use super::custom::CustomConformanceConfig;
#[derive(Debug, Serialize)]
pub struct RequestViolation {
pub check_name: String,
pub method: String,
pub path: String,
pub violation_type: String,
pub message: String,
}
pub fn validate_custom_checks(
spec: &OpenAPI,
custom_checks_file: &Path,
base_path: Option<&str>,
) -> Result<Vec<RequestViolation>> {
let config = CustomConformanceConfig::from_file(custom_checks_file)?;
let mut violations = Vec::new();
let spec_ops = build_spec_operation_map(spec);
for check in &config.custom_checks {
let check_path = check.path.split('?').next().unwrap_or(&check.path);
let spec_path = match find_matching_spec_path(check_path, &spec_ops, base_path) {
Some(p) => p,
None => {
violations.push(RequestViolation {
check_name: check.name.clone(),
method: check.method.clone(),
path: check.path.clone(),
violation_type: "unknown_path".to_string(),
message: format!(
"Path '{}' not found in OpenAPI spec (checked with base_path={:?})",
check_path, base_path
),
});
continue;
}
};
let path_item = match spec.paths.paths.get(&spec_path) {
Some(ReferenceOr::Item(item)) => item,
_ => continue,
};
let method_lower = check.method.to_lowercase();
let operation = match method_lower.as_str() {
"get" => path_item.get.as_ref(),
"post" => path_item.post.as_ref(),
"put" => path_item.put.as_ref(),
"delete" => path_item.delete.as_ref(),
"patch" => path_item.patch.as_ref(),
"head" => path_item.head.as_ref(),
"options" => path_item.options.as_ref(),
_ => None,
};
let operation = match operation {
Some(op) => op,
None => {
violations.push(RequestViolation {
check_name: check.name.clone(),
method: check.method.clone(),
path: check.path.clone(),
violation_type: "method_not_allowed".to_string(),
message: format!(
"Method '{}' not defined for path '{}' in the spec",
check.method, spec_path
),
});
continue;
}
};
if matches!(method_lower.as_str(), "post" | "put" | "patch") {
validate_request_body(
&check.name,
&check.method,
&check.path,
check.body.as_deref(),
operation,
spec,
&mut violations,
);
}
validate_parameters(
&check.name,
&check.method,
&check.path,
check_path,
&check.headers,
operation,
path_item,
spec,
&mut violations,
);
}
Ok(violations)
}
type SpecOperationMap = HashMap<String, Vec<String>>;
fn build_spec_operation_map(spec: &OpenAPI) -> SpecOperationMap {
let mut map = HashMap::new();
for (path, item_ref) in &spec.paths.paths {
if let ReferenceOr::Item(item) = item_ref {
let mut methods = Vec::new();
if item.get.is_some() {
methods.push("GET".to_string());
}
if item.post.is_some() {
methods.push("POST".to_string());
}
if item.put.is_some() {
methods.push("PUT".to_string());
}
if item.delete.is_some() {
methods.push("DELETE".to_string());
}
if item.patch.is_some() {
methods.push("PATCH".to_string());
}
if item.head.is_some() {
methods.push("HEAD".to_string());
}
if item.options.is_some() {
methods.push("OPTIONS".to_string());
}
map.insert(path.clone(), methods);
}
}
map
}
fn find_matching_spec_path(
check_path: &str,
spec_ops: &SpecOperationMap,
base_path: Option<&str>,
) -> Option<String> {
if spec_ops.contains_key(check_path) {
return Some(check_path.to_string());
}
if let Some(bp) = base_path {
let with_base = format!("{}{}", bp.trim_end_matches('/'), check_path);
if spec_ops.contains_key(&with_base) {
return Some(with_base);
}
}
for spec_path in spec_ops.keys() {
if path_matches_template(check_path, spec_path)
|| base_path
.map(|bp| {
let with_base = format!("{}{}", bp.trim_end_matches('/'), check_path);
path_matches_template(&with_base, spec_path)
})
.unwrap_or(false)
{
return Some(spec_path.clone());
}
}
None
}
fn path_matches_template(concrete: &str, template: &str) -> bool {
let concrete_parts: Vec<&str> = concrete.split('/').collect();
let template_parts: Vec<&str> = template.split('/').collect();
if concrete_parts.len() != template_parts.len() {
return false;
}
concrete_parts
.iter()
.zip(template_parts.iter())
.all(|(c, t)| t.starts_with('{') && t.ends_with('}') || c == t)
}
#[allow(clippy::too_many_arguments)]
fn validate_request_body(
check_name: &str,
method: &str,
path: &str,
body: Option<&str>,
operation: &openapiv3::Operation,
spec: &OpenAPI,
violations: &mut Vec<RequestViolation>,
) {
let request_body_ref = match &operation.request_body {
Some(rb) => rb,
None => {
return;
}
};
let request_body = match request_body_ref {
ReferenceOr::Item(rb) => rb,
ReferenceOr::Reference { reference } => {
let name = reference.strip_prefix("#/components/requestBodies/").unwrap_or(reference);
match spec.components.as_ref().and_then(|c| c.request_bodies.get(name)) {
Some(ReferenceOr::Item(rb)) => rb,
_ => return,
}
}
};
if request_body.required && body.is_none() {
violations.push(RequestViolation {
check_name: check_name.to_string(),
method: method.to_string(),
path: path.to_string(),
violation_type: "missing_required_body".to_string(),
message: "Spec requires a request body but none is provided in the check".to_string(),
});
return;
}
if let Some(body_str) = body {
let json_media = request_body.content.get("application/json").or_else(|| {
request_body.content.iter().find(|(k, _)| k.contains("json")).map(|(_, v)| v)
});
if let Some(media) = json_media {
if let Some(schema_ref) = &media.schema {
let root_schema = match schema_ref {
ReferenceOr::Item(s) => s.clone(),
ReferenceOr::Reference { reference } => {
let name =
reference.strip_prefix("#/components/schemas/").unwrap_or(reference);
match spec.components.as_ref().and_then(|c| c.schemas.get(name)) {
Some(ReferenceOr::Item(s)) => s.clone(),
_ => return,
}
}
};
match serde_json::from_str::<serde_json::Value>(body_str) {
Ok(body_value) => {
match mockforge_openapi::schema_ref_resolver::build_validator(
&root_schema,
spec,
) {
Ok(validator) => {
let errors: Vec<_> = validator.iter_errors(&body_value).collect();
for err in errors.iter().take(5) {
violations.push(RequestViolation {
check_name: check_name.to_string(),
method: method.to_string(),
path: path.to_string(),
violation_type: "body_schema_violation".to_string(),
message: format!(
"Request body schema violation at {}: {}",
err.instance_path, err
),
});
}
}
Err(_) => {
}
}
}
Err(e) => {
violations.push(RequestViolation {
check_name: check_name.to_string(),
method: method.to_string(),
path: path.to_string(),
violation_type: "body_not_json".to_string(),
message: format!("Request body is not valid JSON: {}", e),
});
}
}
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn validate_parameters(
check_name: &str,
method: &str,
path: &str,
check_path_no_query: &str,
check_headers: &HashMap<String, String>,
operation: &openapiv3::Operation,
path_item: &openapiv3::PathItem,
spec: &OpenAPI,
violations: &mut Vec<RequestViolation>,
) {
let mut all_params = Vec::new();
for p in &path_item.parameters {
if let Some(param) = resolve_parameter(p, spec) {
all_params.push(param);
}
}
for p in &operation.parameters {
if let Some(param) = resolve_parameter(p, spec) {
all_params.push(param);
}
}
for param in &all_params {
let param_data = match param {
openapiv3::Parameter::Query { parameter_data, .. } => {
if !parameter_data.required {
continue;
}
let has_param = check_path_no_query != path
&& path.contains(&format!("{}=", parameter_data.name));
if !has_param {
violations.push(RequestViolation {
check_name: check_name.to_string(),
method: method.to_string(),
path: path.to_string(),
violation_type: "missing_required_query_param".to_string(),
message: format!(
"Required query parameter '{}' is missing",
parameter_data.name
),
});
}
continue;
}
openapiv3::Parameter::Header { parameter_data, .. } => parameter_data,
openapiv3::Parameter::Path { parameter_data, .. } => {
let _ = parameter_data;
continue;
}
openapiv3::Parameter::Cookie { .. } => continue,
};
if param_data.required {
let has_header = check_headers.keys().any(|k| k.eq_ignore_ascii_case(¶m_data.name));
if !has_header {
violations.push(RequestViolation {
check_name: check_name.to_string(),
method: method.to_string(),
path: path.to_string(),
violation_type: "missing_required_header".to_string(),
message: format!("Required header parameter '{}' is missing", param_data.name),
});
}
}
}
}
fn resolve_parameter<'a>(
param_ref: &'a ReferenceOr<openapiv3::Parameter>,
spec: &'a OpenAPI,
) -> Option<&'a openapiv3::Parameter> {
match param_ref {
ReferenceOr::Item(p) => Some(p),
ReferenceOr::Reference { reference } => {
let name = reference.strip_prefix("#/components/parameters/")?;
match spec.components.as_ref()?.parameters.get(name)? {
ReferenceOr::Item(p) => Some(p),
_ => None,
}
}
}
}
#[allow(dead_code)]
fn resolve_schema_to_json(
schema_ref: &ReferenceOr<openapiv3::Schema>,
spec: &OpenAPI,
) -> Option<serde_json::Value> {
let schema = match schema_ref {
ReferenceOr::Item(s) => s,
ReferenceOr::Reference { reference } => {
let name = reference.strip_prefix("#/components/schemas/")?;
match spec.components.as_ref()?.schemas.get(name)? {
ReferenceOr::Item(s) => s,
_ => return None,
}
}
};
serde_json::to_value(schema).ok()
}
pub async fn run_request_validation(
spec_files: &[std::path::PathBuf],
custom_checks_file: Option<&Path>,
base_path: Option<&str>,
output_dir: &Path,
) -> Result<usize> {
let custom_file = match custom_checks_file {
Some(f) => f,
None => return Ok(0),
};
if spec_files.is_empty() {
return Ok(0);
}
let parser = SpecParser::from_file(&spec_files[0]).await?;
let spec = parser.spec();
let violations = validate_custom_checks(spec, custom_file, base_path)?;
if !violations.is_empty() {
let path = output_dir.join("conformance-request-violations.json");
if let Ok(json) = serde_json::to_string_pretty(&violations) {
let _ = std::fs::write(&path, json);
tracing::info!(
"Found {} request validation violation(s), saved to {}",
violations.len(),
path.display()
);
}
}
Ok(violations.len())
}
pub async fn validate_emitted_requests(
spec_files: &[std::path::PathBuf],
output_dir: &Path,
) -> Result<usize> {
validate_emitted_requests_with_base_path(spec_files, output_dir, None).await
}
pub async fn validate_emitted_requests_with_base_path(
spec_files: &[std::path::PathBuf],
output_dir: &Path,
base_path: Option<&str>,
) -> Result<usize> {
use serde_json::Value;
if spec_files.is_empty() {
return Ok(0);
}
let requests_path = output_dir.join("conformance-requests.json");
let self_test_jsonl_path = output_dir.join("conformance-self-test-requests.jsonl");
let entries: Vec<Value> = if requests_path.exists() {
let bytes = match std::fs::read(&requests_path) {
Ok(b) => b,
Err(_) => return Ok(0),
};
match serde_json::from_slice(&bytes) {
Ok(v) => v,
Err(_) => return Ok(0),
}
} else if self_test_jsonl_path.exists() {
let bytes = match std::fs::read(&self_test_jsonl_path) {
Ok(b) => b,
Err(_) => return Ok(0),
};
let text = String::from_utf8_lossy(&bytes);
text.lines()
.filter(|l| !l.is_empty())
.filter_map(|l| serde_json::from_str::<Value>(l).ok())
.map(|case| {
let label = case.get("label").and_then(|v| v.as_str()).unwrap_or("").to_string();
let method = case.get("method").and_then(|v| v.as_str()).unwrap_or("").to_string();
let url = case.get("url").and_then(|v| v.as_str()).unwrap_or("").to_string();
let body = case.get("request_body").cloned().unwrap_or(Value::Null);
let mut req = serde_json::Map::new();
req.insert("method".into(), Value::String(method));
req.insert("url".into(), Value::String(url));
req.insert(
"body".into(),
match body {
Value::String(s) => Value::String(s),
Value::Null => Value::String(String::new()),
other => other,
},
);
let mut out = serde_json::Map::new();
out.insert("check".into(), Value::String(label));
out.insert("request".into(), Value::Object(req));
Value::Object(out)
})
.collect()
} else {
return Ok(0);
};
if entries.is_empty() {
return Ok(0);
}
let parser = SpecParser::from_file(&spec_files[0]).await?;
let spec = parser.spec();
let spec_ops = build_spec_operation_map(spec);
let mut emitted_violations: Vec<RequestViolation> = Vec::new();
for entry in &entries {
let check = entry.get("check").and_then(|v| v.as_str()).unwrap_or("").to_string();
let req = match entry.get("request") {
Some(r) => r,
None => continue,
};
let method = req.get("method").and_then(|v| v.as_str()).unwrap_or("").to_uppercase();
let url = req.get("url").and_then(|v| v.as_str()).unwrap_or("").to_string();
if method.is_empty() || url.is_empty() {
continue;
}
let (path_only, query_string) = match url.find('?') {
Some(i) => (url[..i].to_string(), url[i + 1..].to_string()),
None => (url.clone(), String::new()),
};
let path_only = if let Some(stripped) = path_only.split_once("://") {
match stripped.1.find('/') {
Some(i) => stripped.1[i..].to_string(),
None => "/".to_string(),
}
} else {
path_only
};
let lookup_path = if let Some(bp) = base_path {
let bp = bp.trim_end_matches('/');
if !bp.is_empty() && path_only.starts_with(bp) {
let stripped = &path_only[bp.len()..];
if stripped.is_empty() {
"/".to_string()
} else {
stripped.to_string()
}
} else {
path_only.clone()
}
} else {
path_only.clone()
};
let spec_path = match find_matching_spec_path(&lookup_path, &spec_ops, None) {
Some(p) => p,
None => continue,
};
let path_item = match spec.paths.paths.get(&spec_path) {
Some(ReferenceOr::Item(item)) => item,
_ => continue,
};
let operation = match method.as_str() {
"GET" => path_item.get.as_ref(),
"POST" => path_item.post.as_ref(),
"PUT" => path_item.put.as_ref(),
"DELETE" => path_item.delete.as_ref(),
"PATCH" => path_item.patch.as_ref(),
"HEAD" => path_item.head.as_ref(),
"OPTIONS" => path_item.options.as_ref(),
_ => None,
};
let Some(operation) = operation else { continue };
let sent_query: HashMap<String, String> = query_string
.split('&')
.filter_map(|kv| {
let mut it = kv.splitn(2, '=');
let k = it.next()?.to_string();
let v = it.next().unwrap_or("").to_string();
if k.is_empty() {
None
} else {
Some((k, v))
}
})
.collect();
let path_params: HashMap<String, String> = {
let mut out = HashMap::new();
let concrete_parts: Vec<&str> = lookup_path.split('/').collect();
let template_parts: Vec<&str> = spec_path.split('/').collect();
if concrete_parts.len() == template_parts.len() {
for (c, t) in concrete_parts.iter().zip(template_parts.iter()) {
if t.starts_with('{') && t.ends_with('}') {
let name = &t[1..t.len() - 1];
out.insert(name.to_string(), (*c).to_string());
}
}
}
out
};
let mut all_params: Vec<&openapiv3::Parameter> = Vec::new();
for p in &path_item.parameters {
if let Some(param) = resolve_parameter(p, spec) {
all_params.push(param);
}
}
for p in &operation.parameters {
if let Some(param) = resolve_parameter(p, spec) {
all_params.push(param);
}
}
for param in &all_params {
let (loc_str, name, schema_ref) = match param {
openapiv3::Parameter::Query { parameter_data, .. } => {
let openapiv3::ParameterSchemaOrContent::Schema(sref) = ¶meter_data.format
else {
continue;
};
let Some(v) = sent_query.get(¶meter_data.name) else {
continue;
};
("query", ¶meter_data.name, (sref, v.clone()))
}
openapiv3::Parameter::Path { parameter_data, .. } => {
let openapiv3::ParameterSchemaOrContent::Schema(sref) = ¶meter_data.format
else {
continue;
};
let Some(v) = path_params.get(¶meter_data.name) else {
continue;
};
("path", ¶meter_data.name, (sref, v.clone()))
}
_ => continue,
};
let (schema_ref, value) = schema_ref;
let Some(schema) = schema_ref.as_item() else {
continue;
};
if let Some(msg) = check_value_against_schema(&value, schema) {
emitted_violations.push(RequestViolation {
check_name: check.clone(),
method: method.clone(),
path: url.clone(),
violation_type: format!("{}_value_mismatch", loc_str),
message: format!("{}.{}: {}", loc_str, name, msg),
});
}
}
let body_str = req.get("body").and_then(|v| v.as_str()).unwrap_or("");
if !body_str.is_empty() {
if let Ok(body_json) = serde_json::from_str::<serde_json::Value>(body_str) {
if let Some(req_body) = operation.request_body.as_ref().and_then(|r| r.as_item()) {
for (ct, media) in &req_body.content {
if !ct.contains("json") {
continue;
}
let Some(schema_ref) = &media.schema else {
continue;
};
let Some(schema) = schema_ref.as_item() else {
continue;
};
check_body_against_schema(
&check,
&method,
&url,
&body_json,
schema,
&mut emitted_violations,
);
}
}
}
}
}
let dst = output_dir.join("conformance-request-violations.json");
let mut all: Vec<Value> = if dst.exists() {
match std::fs::read(&dst) {
Ok(b) => serde_json::from_slice(&b).unwrap_or_default(),
Err(_) => Vec::new(),
}
} else {
Vec::new()
};
for v in &emitted_violations {
if let Ok(val) = serde_json::to_value(v) {
all.push(val);
}
}
{
let mut seen: std::collections::HashSet<(String, String, String, String, String)> =
std::collections::HashSet::new();
all.retain(|v| {
let f = |k: &str| v.get(k).and_then(|x| x.as_str()).unwrap_or("").to_string();
seen.insert((
f("check_name"),
f("method"),
f("path"),
f("violation_type"),
f("message"),
))
});
}
if !all.is_empty() {
if let Ok(json) = serde_json::to_string_pretty(&all) {
let _ = std::fs::write(&dst, json);
tracing::info!(
"validate-requests: wrote {} entries to {} ({} from emitted requests)",
all.len(),
dst.display(),
emitted_violations.len()
);
}
}
let grouped_dst = output_dir.join("conformance-request-violations-by-request.json");
let grouped_value = group_violations_by_request(&all);
if let Ok(json) = serde_json::to_string_pretty(&grouped_value) {
let _ = std::fs::write(&grouped_dst, json);
}
let drill_dst = output_dir.join("conformance-request-violations-by-probe.json");
let drill_value = group_violations_by_probe(&all);
if let Ok(json) = serde_json::to_string_pretty(&drill_value) {
let _ = std::fs::write(&drill_dst, json);
}
Ok(emitted_violations.len())
}
fn group_violations_by_probe(flat: &[serde_json::Value]) -> serde_json::Value {
use serde_json::{Map, Value};
let mut by_probe_order: Vec<(String, String, String)> = Vec::new();
let mut by_probe: std::collections::HashMap<(String, String, String), Vec<(String, String)>> =
std::collections::HashMap::new();
let mut seen_in_probe: std::collections::HashSet<(String, String, String, String)> =
std::collections::HashSet::new();
for v in flat {
let check = v.get("check_name").and_then(|x| x.as_str()).unwrap_or("").to_string();
let method = v.get("method").and_then(|x| x.as_str()).unwrap_or("").to_string();
let path = v.get("path").and_then(|x| x.as_str()).unwrap_or("").to_string();
let vt = v.get("violation_type").and_then(|x| x.as_str()).unwrap_or("").to_string();
let msg = v.get("message").and_then(|x| x.as_str()).unwrap_or("").to_string();
let key = (check.clone(), method.clone(), path.clone());
if !by_probe.contains_key(&key) {
by_probe_order.push(key.clone());
}
if seen_in_probe.insert((check, method, path, format!("{vt}\u{0}{msg}"))) {
by_probe.entry(key).or_default().push((vt, msg));
}
}
by_probe_order.sort_by(|a, b| a.1.cmp(&b.1).then(a.2.cmp(&b.2)).then(a.0.cmp(&b.0)));
let mut rows: Vec<Value> = Vec::with_capacity(by_probe_order.len());
for key in &by_probe_order {
let (check, method, path) = key;
let entries = by_probe.get(key).cloned().unwrap_or_default();
let mut row = Map::new();
row.insert("check_name".into(), Value::String(check.clone()));
row.insert("method".into(), Value::String(method.clone()));
row.insert("path".into(), Value::String(path.clone()));
row.insert(
"violation_count".into(),
Value::Number(serde_json::Number::from(entries.len())),
);
for (i, (vt, msg)) in entries.iter().enumerate() {
let mut entry = Map::new();
entry.insert("violation_type".into(), Value::String(vt.clone()));
entry.insert("message".into(), Value::String(msg.clone()));
row.insert(format!("violation_{}", i + 1), Value::Object(entry));
}
rows.push(Value::Object(row));
}
Value::Array(rows)
}
fn group_violations_by_request(flat: &[serde_json::Value]) -> serde_json::Value {
use serde_json::{Map, Value};
let mut order: Vec<(String, String)> = Vec::new();
let mut checks_by_key: std::collections::HashMap<(String, String), Vec<String>> =
std::collections::HashMap::new();
let mut viols_by_key: std::collections::HashMap<(String, String), Vec<(String, String)>> =
std::collections::HashMap::new();
let mut seen_check: std::collections::HashSet<(String, String, String)> =
std::collections::HashSet::new();
let mut seen_viol: std::collections::HashSet<(String, String, String)> =
std::collections::HashSet::new();
for v in flat {
let check = v.get("check_name").and_then(|x| x.as_str()).unwrap_or("").to_string();
let method = v.get("method").and_then(|x| x.as_str()).unwrap_or("").to_string();
let path = v.get("path").and_then(|x| x.as_str()).unwrap_or("").to_string();
let vt = v.get("violation_type").and_then(|x| x.as_str()).unwrap_or("").to_string();
let msg = v.get("message").and_then(|x| x.as_str()).unwrap_or("").to_string();
let key = (method.clone(), path.clone());
if !checks_by_key.contains_key(&key) && !viols_by_key.contains_key(&key) {
order.push(key.clone());
}
if !check.is_empty() && seen_check.insert((method.clone(), path.clone(), check.clone())) {
checks_by_key.entry(key.clone()).or_default().push(check);
}
if seen_viol.insert((method.clone(), path.clone(), format!("{vt}\u{0}{msg}"))) {
viols_by_key.entry(key).or_default().push((vt, msg));
}
}
let mut rows: Vec<Value> = Vec::with_capacity(order.len());
for key in &order {
let (method, path) = key;
let checks = checks_by_key.get(key).cloned().unwrap_or_default();
let viols = viols_by_key.get(key).cloned().unwrap_or_default();
let mut row = Map::new();
row.insert(
"checks".into(),
Value::Array(checks.iter().map(|s| Value::String(s.clone())).collect()),
);
let dominant_prefix: &str = viols
.first()
.map(|(vt, _)| {
if vt.starts_with("query_") {
"param:query"
} else if vt.starts_with("body_") {
"body:"
} else if vt.starts_with("path_") {
"param:path"
} else if vt.starts_with("header_") {
"param:header"
} else {
""
}
})
.unwrap_or("");
let best_check = if !dominant_prefix.is_empty() {
checks
.iter()
.find(|c| c.starts_with(dominant_prefix))
.cloned()
.or_else(|| checks.first().cloned())
.unwrap_or_default()
} else {
checks.first().cloned().unwrap_or_default()
};
row.insert("check_name".into(), Value::String(best_check));
row.insert("method".into(), Value::String(method.clone()));
row.insert("path".into(), Value::String(path.clone()));
row.insert("violation_count".into(), Value::Number(serde_json::Number::from(viols.len())));
for (i, (vt, msg)) in viols.iter().enumerate() {
let mut entry = Map::new();
entry.insert("violation_type".into(), Value::String(vt.clone()));
entry.insert("message".into(), Value::String(msg.clone()));
row.insert(format!("violation_{}", i + 1), Value::Object(entry));
}
rows.push(Value::Object(row));
}
Value::Array(rows)
}
fn check_body_against_schema(
check: &str,
method: &str,
url: &str,
body: &serde_json::Value,
schema: &openapiv3::Schema,
violations: &mut Vec<RequestViolation>,
) {
use openapiv3::{SchemaKind, Type};
let SchemaKind::Type(Type::Object(obj_type)) = &schema.schema_kind else {
return;
};
let Some(body_obj) = body.as_object() else {
return;
};
for required in &obj_type.required {
if !body_obj.contains_key(required) {
violations.push(RequestViolation {
check_name: check.to_string(),
method: method.to_string(),
path: url.to_string(),
violation_type: "body_missing_required".to_string(),
message: format!("body.{}: required field missing", required),
});
}
}
for (prop_name, prop_ref) in &obj_type.properties {
let Some(value) = body_obj.get(prop_name) else {
continue;
};
let Some(prop_schema) = prop_ref.as_item() else {
continue;
};
if let Some(value_str) = value.as_str() {
if let Some(msg) = check_value_against_schema(value_str, prop_schema) {
violations.push(RequestViolation {
check_name: check.to_string(),
method: method.to_string(),
path: url.to_string(),
violation_type: "body_value_mismatch".to_string(),
message: format!("body.{}: {}", prop_name, msg),
});
}
}
}
}
fn check_value_against_schema(value: &str, schema: &openapiv3::Schema) -> Option<String> {
use openapiv3::{SchemaKind, Type};
let SchemaKind::Type(t) = &schema.schema_kind else {
return None;
};
match t {
Type::String(s) => {
if !s.enumeration.is_empty() {
let allowed: Vec<String> = s.enumeration.iter().filter_map(|e| e.clone()).collect();
if !allowed.iter().any(|a| a == value) {
let quoted: Vec<String> =
allowed.iter().map(|a| format!("\"{}\"", a)).collect();
return Some(format!(
"value \"{}\" is not one of {}",
value,
quoted.join(" or ")
));
}
}
None
}
Type::Integer(_) => {
if value.parse::<i64>().is_err() {
Some(format!("value \"{}\" is not of type \"integer\"", value))
} else {
None
}
}
Type::Number(_) => {
if value.parse::<f64>().is_err() {
Some(format!("value \"{}\" is not of type \"number\"", value))
} else {
None
}
}
Type::Boolean(_) => match value {
"true" | "false" => None,
_ => Some(format!("value \"{}\" is not of type \"boolean\"", value)),
},
_ => None,
}
}
#[cfg(test)]
mod grouping_tests {
use super::{group_violations_by_probe, group_violations_by_request};
use serde_json::json;
fn viol(check: &str, method: &str, path: &str, vt: &str, msg: &str) -> serde_json::Value {
json!({
"check_name": check,
"method": method,
"path": path,
"violation_type": vt,
"message": msg,
})
}
#[test]
fn by_request_unions_all_checks_for_a_url() {
let path = "https://host/v1/organizations?alt=test-value&prettyPrint=test-value";
let flat = vec![
viol(
"request-body:type-mismatch:billingType",
"POST",
path,
"body_type_mismatch",
"body.billingType: expected string",
),
viol(
"owasp:ldap-injection",
"POST",
path,
"query_value_mismatch",
"query.alt: value \"test-value\" is not one of \"json\" or \"media\"",
),
viol(
"owasp:ldap-injection",
"POST",
path,
"query_value_mismatch",
"query.prettyPrint: value \"test-value\" is not of type \"boolean\"",
),
];
let out = group_violations_by_request(&flat);
let rows = out.as_array().expect("array");
assert_eq!(rows.len(), 1, "expected a single by-request row per URL");
let row = &rows[0];
assert_eq!(row["violation_count"], 3);
let checks: Vec<&str> =
row["checks"].as_array().unwrap().iter().map(|c| c.as_str().unwrap()).collect();
assert!(checks.contains(&"owasp:ldap-injection"), "owasp check must appear: {checks:?}");
assert!(
checks.iter().any(|c| c.starts_with("request-body:")),
"body check must appear: {checks:?}"
);
}
#[test]
fn by_probe_dedups_repeated_iterations() {
let path = "https://host/v1/organizations?alt=test-value";
let mut flat = Vec::new();
for _ in 0..22 {
flat.push(viol(
"owasp:ldap-injection",
"POST",
path,
"query_value_mismatch",
"query.alt: value \"test-value\" is not one of \"json\" or \"media\"",
));
}
let out = group_violations_by_probe(&flat);
let rows = out.as_array().expect("array");
assert_eq!(rows.len(), 1, "one probe row");
assert_eq!(rows[0]["violation_count"], 1, "22 identical iterations collapse to 1");
assert!(rows[0].get("violation_1").is_some());
assert!(rows[0].get("violation_2").is_none(), "no duplicate violation_2");
}
#[test]
fn by_request_dedups_repeated_iterations() {
let path = "https://host/v1/widgets";
let mut flat = Vec::new();
for _ in 0..22 {
flat.push(viol(
"request-body:type-mismatch:name",
"POST",
path,
"body_type_mismatch",
"body.name: expected string",
));
}
let out = group_violations_by_request(&flat);
let rows = out.as_array().unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0]["violation_count"], 1, "duplicate iterations collapse");
let checks = rows[0]["checks"].as_array().unwrap();
assert_eq!(checks.len(), 1, "the same check listed once");
}
#[test]
fn by_request_keeps_distinct_urls_separate() {
let flat = vec![
viol("c1", "POST", "https://host/a", "body_type_mismatch", "a"),
viol("c2", "GET", "https://host/b", "query_value_mismatch", "b"),
];
let out = group_violations_by_request(&flat);
assert_eq!(out.as_array().unwrap().len(), 2);
}
}