use crate::error::Result;
use crate::spec_parser::SpecParser;
use openapiv3::{OpenAPI, ReferenceOr};
use serde::Serialize;
use std::collections::HashMap;
use std::path::Path;
use super::custom::CustomConformanceConfig;
#[derive(Debug, Serialize)]
pub struct RequestViolation {
pub check_name: String,
pub method: String,
pub path: String,
pub violation_type: String,
pub message: String,
}
pub fn validate_custom_checks(
spec: &OpenAPI,
custom_checks_file: &Path,
base_path: Option<&str>,
) -> Result<Vec<RequestViolation>> {
let config = CustomConformanceConfig::from_file(custom_checks_file)?;
let mut violations = Vec::new();
let spec_ops = build_spec_operation_map(spec);
for check in &config.custom_checks {
let check_path = check.path.split('?').next().unwrap_or(&check.path);
let spec_path = match find_matching_spec_path(check_path, &spec_ops, base_path) {
Some(p) => p,
None => {
violations.push(RequestViolation {
check_name: check.name.clone(),
method: check.method.clone(),
path: check.path.clone(),
violation_type: "unknown_path".to_string(),
message: format!(
"Path '{}' not found in OpenAPI spec (checked with base_path={:?})",
check_path, base_path
),
});
continue;
}
};
let path_item = match spec.paths.paths.get(&spec_path) {
Some(ReferenceOr::Item(item)) => item,
_ => continue,
};
let method_lower = check.method.to_lowercase();
let operation = match method_lower.as_str() {
"get" => path_item.get.as_ref(),
"post" => path_item.post.as_ref(),
"put" => path_item.put.as_ref(),
"delete" => path_item.delete.as_ref(),
"patch" => path_item.patch.as_ref(),
"head" => path_item.head.as_ref(),
"options" => path_item.options.as_ref(),
_ => None,
};
let operation = match operation {
Some(op) => op,
None => {
violations.push(RequestViolation {
check_name: check.name.clone(),
method: check.method.clone(),
path: check.path.clone(),
violation_type: "method_not_allowed".to_string(),
message: format!(
"Method '{}' not defined for path '{}' in the spec",
check.method, spec_path
),
});
continue;
}
};
if matches!(method_lower.as_str(), "post" | "put" | "patch") {
validate_request_body(
&check.name,
&check.method,
&check.path,
check.body.as_deref(),
operation,
spec,
&mut violations,
);
}
validate_parameters(
&check.name,
&check.method,
&check.path,
check_path,
&check.headers,
operation,
path_item,
spec,
&mut violations,
);
}
Ok(violations)
}
type SpecOperationMap = HashMap<String, Vec<String>>;
fn build_spec_operation_map(spec: &OpenAPI) -> SpecOperationMap {
let mut map = HashMap::new();
for (path, item_ref) in &spec.paths.paths {
if let ReferenceOr::Item(item) = item_ref {
let mut methods = Vec::new();
if item.get.is_some() {
methods.push("GET".to_string());
}
if item.post.is_some() {
methods.push("POST".to_string());
}
if item.put.is_some() {
methods.push("PUT".to_string());
}
if item.delete.is_some() {
methods.push("DELETE".to_string());
}
if item.patch.is_some() {
methods.push("PATCH".to_string());
}
if item.head.is_some() {
methods.push("HEAD".to_string());
}
if item.options.is_some() {
methods.push("OPTIONS".to_string());
}
map.insert(path.clone(), methods);
}
}
map
}
fn find_matching_spec_path(
check_path: &str,
spec_ops: &SpecOperationMap,
base_path: Option<&str>,
) -> Option<String> {
if spec_ops.contains_key(check_path) {
return Some(check_path.to_string());
}
if let Some(bp) = base_path {
let with_base = format!("{}{}", bp.trim_end_matches('/'), check_path);
if spec_ops.contains_key(&with_base) {
return Some(with_base);
}
}
for spec_path in spec_ops.keys() {
if path_matches_template(check_path, spec_path)
|| base_path
.map(|bp| {
let with_base = format!("{}{}", bp.trim_end_matches('/'), check_path);
path_matches_template(&with_base, spec_path)
})
.unwrap_or(false)
{
return Some(spec_path.clone());
}
}
None
}
fn path_matches_template(concrete: &str, template: &str) -> bool {
let concrete_parts: Vec<&str> = concrete.split('/').collect();
let template_parts: Vec<&str> = template.split('/').collect();
if concrete_parts.len() != template_parts.len() {
return false;
}
concrete_parts
.iter()
.zip(template_parts.iter())
.all(|(c, t)| t.starts_with('{') && t.ends_with('}') || c == t)
}
#[allow(clippy::too_many_arguments)]
fn validate_request_body(
check_name: &str,
method: &str,
path: &str,
body: Option<&str>,
operation: &openapiv3::Operation,
spec: &OpenAPI,
violations: &mut Vec<RequestViolation>,
) {
let request_body_ref = match &operation.request_body {
Some(rb) => rb,
None => {
return;
}
};
let request_body = match request_body_ref {
ReferenceOr::Item(rb) => rb,
ReferenceOr::Reference { reference } => {
let name = reference.strip_prefix("#/components/requestBodies/").unwrap_or(reference);
match spec.components.as_ref().and_then(|c| c.request_bodies.get(name)) {
Some(ReferenceOr::Item(rb)) => rb,
_ => return,
}
}
};
if request_body.required && body.is_none() {
violations.push(RequestViolation {
check_name: check_name.to_string(),
method: method.to_string(),
path: path.to_string(),
violation_type: "missing_required_body".to_string(),
message: "Spec requires a request body but none is provided in the check".to_string(),
});
return;
}
if let Some(body_str) = body {
let json_media = request_body.content.get("application/json").or_else(|| {
request_body.content.iter().find(|(k, _)| k.contains("json")).map(|(_, v)| v)
});
if let Some(media) = json_media {
if let Some(schema_ref) = &media.schema {
let root_schema = match schema_ref {
ReferenceOr::Item(s) => s.clone(),
ReferenceOr::Reference { reference } => {
let name =
reference.strip_prefix("#/components/schemas/").unwrap_or(reference);
match spec.components.as_ref().and_then(|c| c.schemas.get(name)) {
Some(ReferenceOr::Item(s)) => s.clone(),
_ => return,
}
}
};
match serde_json::from_str::<serde_json::Value>(body_str) {
Ok(body_value) => {
match mockforge_openapi::schema_ref_resolver::build_validator(
&root_schema,
spec,
) {
Ok(validator) => {
let errors: Vec<_> = validator.iter_errors(&body_value).collect();
for err in errors.iter().take(5) {
violations.push(RequestViolation {
check_name: check_name.to_string(),
method: method.to_string(),
path: path.to_string(),
violation_type: "body_schema_violation".to_string(),
message: format!(
"Request body schema violation at {}: {}",
err.instance_path, err
),
});
}
}
Err(_) => {
}
}
}
Err(e) => {
violations.push(RequestViolation {
check_name: check_name.to_string(),
method: method.to_string(),
path: path.to_string(),
violation_type: "body_not_json".to_string(),
message: format!("Request body is not valid JSON: {}", e),
});
}
}
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn validate_parameters(
check_name: &str,
method: &str,
path: &str,
check_path_no_query: &str,
check_headers: &HashMap<String, String>,
operation: &openapiv3::Operation,
path_item: &openapiv3::PathItem,
spec: &OpenAPI,
violations: &mut Vec<RequestViolation>,
) {
let mut all_params = Vec::new();
for p in &path_item.parameters {
if let Some(param) = resolve_parameter(p, spec) {
all_params.push(param);
}
}
for p in &operation.parameters {
if let Some(param) = resolve_parameter(p, spec) {
all_params.push(param);
}
}
for param in &all_params {
let param_data = match param {
openapiv3::Parameter::Query { parameter_data, .. } => {
if !parameter_data.required {
continue;
}
let has_param = check_path_no_query != path
&& path.contains(&format!("{}=", parameter_data.name));
if !has_param {
violations.push(RequestViolation {
check_name: check_name.to_string(),
method: method.to_string(),
path: path.to_string(),
violation_type: "missing_required_query_param".to_string(),
message: format!(
"Required query parameter '{}' is missing",
parameter_data.name
),
});
}
continue;
}
openapiv3::Parameter::Header { parameter_data, .. } => parameter_data,
openapiv3::Parameter::Path { parameter_data, .. } => {
let _ = parameter_data;
continue;
}
openapiv3::Parameter::Cookie { .. } => continue,
};
if param_data.required {
let has_header = check_headers.keys().any(|k| k.eq_ignore_ascii_case(¶m_data.name));
if !has_header {
violations.push(RequestViolation {
check_name: check_name.to_string(),
method: method.to_string(),
path: path.to_string(),
violation_type: "missing_required_header".to_string(),
message: format!("Required header parameter '{}' is missing", param_data.name),
});
}
}
}
}
fn resolve_parameter<'a>(
param_ref: &'a ReferenceOr<openapiv3::Parameter>,
spec: &'a OpenAPI,
) -> Option<&'a openapiv3::Parameter> {
match param_ref {
ReferenceOr::Item(p) => Some(p),
ReferenceOr::Reference { reference } => {
let name = reference.strip_prefix("#/components/parameters/")?;
match spec.components.as_ref()?.parameters.get(name)? {
ReferenceOr::Item(p) => Some(p),
_ => None,
}
}
}
}
#[allow(dead_code)]
fn resolve_schema_to_json(
schema_ref: &ReferenceOr<openapiv3::Schema>,
spec: &OpenAPI,
) -> Option<serde_json::Value> {
let schema = match schema_ref {
ReferenceOr::Item(s) => s,
ReferenceOr::Reference { reference } => {
let name = reference.strip_prefix("#/components/schemas/")?;
match spec.components.as_ref()?.schemas.get(name)? {
ReferenceOr::Item(s) => s,
_ => return None,
}
}
};
serde_json::to_value(schema).ok()
}
pub async fn run_request_validation(
spec_files: &[std::path::PathBuf],
custom_checks_file: Option<&Path>,
base_path: Option<&str>,
output_dir: &Path,
) -> Result<usize> {
let custom_file = match custom_checks_file {
Some(f) => f,
None => return Ok(0),
};
if spec_files.is_empty() {
return Ok(0);
}
let parser = SpecParser::from_file(&spec_files[0]).await?;
let spec = parser.spec();
let violations = validate_custom_checks(spec, custom_file, base_path)?;
if !violations.is_empty() {
let path = output_dir.join("conformance-request-violations.json");
if let Ok(json) = serde_json::to_string_pretty(&violations) {
let _ = std::fs::write(&path, json);
tracing::info!(
"Found {} request validation violation(s), saved to {}",
violations.len(),
path.display()
);
}
}
Ok(violations.len())
}
pub async fn validate_emitted_requests(
spec_files: &[std::path::PathBuf],
output_dir: &Path,
) -> Result<usize> {
validate_emitted_requests_with_base_path(spec_files, output_dir, None).await
}
pub async fn validate_emitted_requests_with_base_path(
spec_files: &[std::path::PathBuf],
output_dir: &Path,
base_path: Option<&str>,
) -> Result<usize> {
use serde_json::Value;
if spec_files.is_empty() {
return Ok(0);
}
let requests_path = output_dir.join("conformance-requests.json");
if !requests_path.exists() {
return Ok(0);
}
let bytes = match std::fs::read(&requests_path) {
Ok(b) => b,
Err(_) => return Ok(0),
};
let entries: Vec<Value> = match serde_json::from_slice(&bytes) {
Ok(v) => v,
Err(_) => return Ok(0),
};
if entries.is_empty() {
return Ok(0);
}
let parser = SpecParser::from_file(&spec_files[0]).await?;
let spec = parser.spec();
let spec_ops = build_spec_operation_map(spec);
let mut emitted_violations: Vec<RequestViolation> = Vec::new();
for entry in &entries {
let check = entry.get("check").and_then(|v| v.as_str()).unwrap_or("").to_string();
let req = match entry.get("request") {
Some(r) => r,
None => continue,
};
let method = req.get("method").and_then(|v| v.as_str()).unwrap_or("").to_uppercase();
let url = req.get("url").and_then(|v| v.as_str()).unwrap_or("").to_string();
if method.is_empty() || url.is_empty() {
continue;
}
let (path_only, query_string) = match url.find('?') {
Some(i) => (url[..i].to_string(), url[i + 1..].to_string()),
None => (url.clone(), String::new()),
};
let path_only = if let Some(stripped) = path_only.split_once("://") {
match stripped.1.find('/') {
Some(i) => stripped.1[i..].to_string(),
None => "/".to_string(),
}
} else {
path_only
};
let lookup_path = if let Some(bp) = base_path {
let bp = bp.trim_end_matches('/');
if !bp.is_empty() && path_only.starts_with(bp) {
let stripped = &path_only[bp.len()..];
if stripped.is_empty() {
"/".to_string()
} else {
stripped.to_string()
}
} else {
path_only.clone()
}
} else {
path_only.clone()
};
let spec_path = match find_matching_spec_path(&lookup_path, &spec_ops, None) {
Some(p) => p,
None => continue,
};
let path_item = match spec.paths.paths.get(&spec_path) {
Some(ReferenceOr::Item(item)) => item,
_ => continue,
};
let operation = match method.as_str() {
"GET" => path_item.get.as_ref(),
"POST" => path_item.post.as_ref(),
"PUT" => path_item.put.as_ref(),
"DELETE" => path_item.delete.as_ref(),
"PATCH" => path_item.patch.as_ref(),
"HEAD" => path_item.head.as_ref(),
"OPTIONS" => path_item.options.as_ref(),
_ => None,
};
let Some(operation) = operation else { continue };
let sent_query: HashMap<String, String> = query_string
.split('&')
.filter_map(|kv| {
let mut it = kv.splitn(2, '=');
let k = it.next()?.to_string();
let v = it.next().unwrap_or("").to_string();
if k.is_empty() {
None
} else {
Some((k, v))
}
})
.collect();
let path_params: HashMap<String, String> = {
let mut out = HashMap::new();
let concrete_parts: Vec<&str> = lookup_path.split('/').collect();
let template_parts: Vec<&str> = spec_path.split('/').collect();
if concrete_parts.len() == template_parts.len() {
for (c, t) in concrete_parts.iter().zip(template_parts.iter()) {
if t.starts_with('{') && t.ends_with('}') {
let name = &t[1..t.len() - 1];
out.insert(name.to_string(), (*c).to_string());
}
}
}
out
};
let mut all_params: Vec<&openapiv3::Parameter> = Vec::new();
for p in &path_item.parameters {
if let Some(param) = resolve_parameter(p, spec) {
all_params.push(param);
}
}
for p in &operation.parameters {
if let Some(param) = resolve_parameter(p, spec) {
all_params.push(param);
}
}
for param in &all_params {
let (loc_str, name, schema_ref) = match param {
openapiv3::Parameter::Query { parameter_data, .. } => {
let openapiv3::ParameterSchemaOrContent::Schema(sref) = ¶meter_data.format
else {
continue;
};
let Some(v) = sent_query.get(¶meter_data.name) else {
continue;
};
("query", ¶meter_data.name, (sref, v.clone()))
}
openapiv3::Parameter::Path { parameter_data, .. } => {
let openapiv3::ParameterSchemaOrContent::Schema(sref) = ¶meter_data.format
else {
continue;
};
let Some(v) = path_params.get(¶meter_data.name) else {
continue;
};
("path", ¶meter_data.name, (sref, v.clone()))
}
_ => continue,
};
let (schema_ref, value) = schema_ref;
let Some(schema) = schema_ref.as_item() else {
continue;
};
if let Some(msg) = check_value_against_schema(&value, schema) {
emitted_violations.push(RequestViolation {
check_name: check.clone(),
method: method.clone(),
path: url.clone(),
violation_type: format!("{}_value_mismatch", loc_str),
message: format!("{}.{}: {}", loc_str, name, msg),
});
}
}
let body_str = req.get("body").and_then(|v| v.as_str()).unwrap_or("");
if !body_str.is_empty() {
if let Ok(body_json) = serde_json::from_str::<serde_json::Value>(body_str) {
if let Some(req_body) = operation.request_body.as_ref().and_then(|r| r.as_item()) {
for (ct, media) in &req_body.content {
if !ct.contains("json") {
continue;
}
let Some(schema_ref) = &media.schema else {
continue;
};
let Some(schema) = schema_ref.as_item() else {
continue;
};
check_body_against_schema(
&check,
&method,
&url,
&body_json,
schema,
&mut emitted_violations,
);
}
}
}
}
}
let dst = output_dir.join("conformance-request-violations.json");
let mut all: Vec<Value> = if dst.exists() {
match std::fs::read(&dst) {
Ok(b) => serde_json::from_slice(&b).unwrap_or_default(),
Err(_) => Vec::new(),
}
} else {
Vec::new()
};
for v in &emitted_violations {
if let Ok(val) = serde_json::to_value(v) {
all.push(val);
}
}
if !all.is_empty() {
if let Ok(json) = serde_json::to_string_pretty(&all) {
let _ = std::fs::write(&dst, json);
tracing::info!(
"validate-requests: wrote {} entries to {} ({} from emitted requests)",
all.len(),
dst.display(),
emitted_violations.len()
);
}
}
let grouped_dst = output_dir.join("conformance-request-violations-by-request.json");
let grouped_value = group_violations_by_request(&all);
if let Ok(json) = serde_json::to_string_pretty(&grouped_value) {
let _ = std::fs::write(&grouped_dst, json);
}
Ok(emitted_violations.len())
}
fn group_violations_by_request(flat: &[serde_json::Value]) -> serde_json::Value {
use serde_json::{Map, Value};
let mut order: Vec<(String, String, String)> = Vec::new();
let mut grouped: std::collections::HashMap<(String, String, String), Vec<(String, String)>> =
std::collections::HashMap::new();
for v in flat {
let check = v.get("check_name").and_then(|x| x.as_str()).unwrap_or("").to_string();
let method = v.get("method").and_then(|x| x.as_str()).unwrap_or("").to_string();
let path = v.get("path").and_then(|x| x.as_str()).unwrap_or("").to_string();
let vt = v.get("violation_type").and_then(|x| x.as_str()).unwrap_or("").to_string();
let msg = v.get("message").and_then(|x| x.as_str()).unwrap_or("").to_string();
let key = (check, method, path);
if !grouped.contains_key(&key) {
order.push(key.clone());
}
grouped.entry(key).or_default().push((vt, msg));
}
let mut rows: Vec<Value> = Vec::with_capacity(order.len());
for key in &order {
let (check, method, path) = key;
let entries = grouped.get(key).cloned().unwrap_or_default();
let mut row = Map::new();
row.insert("check_name".into(), Value::String(check.clone()));
row.insert("method".into(), Value::String(method.clone()));
row.insert("path".into(), Value::String(path.clone()));
row.insert(
"violation_count".into(),
Value::Number(serde_json::Number::from(entries.len())),
);
for (i, (vt, msg)) in entries.iter().enumerate() {
let mut entry = Map::new();
entry.insert("violation_type".into(), Value::String(vt.clone()));
entry.insert("message".into(), Value::String(msg.clone()));
row.insert(format!("violation_{}", i + 1), Value::Object(entry));
}
rows.push(Value::Object(row));
}
Value::Array(rows)
}
fn check_body_against_schema(
check: &str,
method: &str,
url: &str,
body: &serde_json::Value,
schema: &openapiv3::Schema,
violations: &mut Vec<RequestViolation>,
) {
use openapiv3::{SchemaKind, Type};
let SchemaKind::Type(Type::Object(obj_type)) = &schema.schema_kind else {
return;
};
let Some(body_obj) = body.as_object() else {
return;
};
for required in &obj_type.required {
if !body_obj.contains_key(required) {
violations.push(RequestViolation {
check_name: check.to_string(),
method: method.to_string(),
path: url.to_string(),
violation_type: "body_missing_required".to_string(),
message: format!("body.{}: required field missing", required),
});
}
}
for (prop_name, prop_ref) in &obj_type.properties {
let Some(value) = body_obj.get(prop_name) else {
continue;
};
let Some(prop_schema) = prop_ref.as_item() else {
continue;
};
if let Some(value_str) = value.as_str() {
if let Some(msg) = check_value_against_schema(value_str, prop_schema) {
violations.push(RequestViolation {
check_name: check.to_string(),
method: method.to_string(),
path: url.to_string(),
violation_type: "body_value_mismatch".to_string(),
message: format!("body.{}: {}", prop_name, msg),
});
}
}
}
}
fn check_value_against_schema(value: &str, schema: &openapiv3::Schema) -> Option<String> {
use openapiv3::{SchemaKind, Type};
let SchemaKind::Type(t) = &schema.schema_kind else {
return None;
};
match t {
Type::String(s) => {
if !s.enumeration.is_empty() {
let allowed: Vec<String> = s.enumeration.iter().filter_map(|e| e.clone()).collect();
if !allowed.iter().any(|a| a == value) {
let quoted: Vec<String> =
allowed.iter().map(|a| format!("\"{}\"", a)).collect();
return Some(format!(
"value \"{}\" is not one of {}",
value,
quoted.join(" or ")
));
}
}
None
}
Type::Integer(_) => {
if value.parse::<i64>().is_err() {
Some(format!("value \"{}\" is not of type \"integer\"", value))
} else {
None
}
}
Type::Number(_) => {
if value.parse::<f64>().is_err() {
Some(format!("value \"{}\" is not of type \"number\"", value))
} else {
None
}
}
Type::Boolean(_) => match value {
"true" | "false" => None,
_ => Some(format!("value \"{}\" is not of type \"boolean\"", value)),
},
_ => None,
}
}