use std::sync::OnceLock;
use schemars::{schema_for, JsonSchema};
use serde_json::Value;
use crate::cli::git::PrContent;
use crate::data::amendments::AmendmentFile;
use crate::data::check::AiCheckResponse;
fn schema_value<T: JsonSchema>(slot: &'static OnceLock<Value>) -> &'static Value {
slot.get_or_init(|| {
let schema = schema_for!(T);
serde_json::to_value(schema).unwrap_or(Value::Null)
})
}
pub fn amendment_file_schema() -> &'static Value {
static SLOT: OnceLock<Value> = OnceLock::new();
schema_value::<AmendmentFile>(&SLOT)
}
pub fn pr_content_schema() -> &'static Value {
static SLOT: OnceLock<Value> = OnceLock::new();
schema_value::<PrContent>(&SLOT)
}
pub fn check_response_schema() -> &'static Value {
static SLOT: OnceLock<Value> = OnceLock::new();
schema_value::<AiCheckResponse>(&SLOT)
}
#[cfg(test)]
type SchemaGetter = fn() -> &'static Value;
#[cfg(test)]
const ALL_SCHEMAS: &[(&str, SchemaGetter)] = &[
("amendment_file", amendment_file_schema),
("pr_content", pr_content_schema),
("check_response", check_response_schema),
];
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn all_schemas_are_non_null_objects() {
for (name, getter) in ALL_SCHEMAS {
let value = getter();
assert!(
value.is_object(),
"{name} schema should serialize to an object: {value}"
);
}
}
#[test]
fn schemas_are_cached() {
let amendment_first = amendment_file_schema();
let amendment_second = amendment_file_schema();
assert!(
std::ptr::eq(amendment_first, amendment_second),
"amendment_file_schema should return the same OnceLock value"
);
let pr_first = pr_content_schema();
let pr_second = pr_content_schema();
assert!(
std::ptr::eq(pr_first, pr_second),
"pr_content_schema should return the same OnceLock value"
);
let check_first = check_response_schema();
let check_second = check_response_schema();
assert!(
std::ptr::eq(check_first, check_second),
"check_response_schema should return the same OnceLock value"
);
}
#[test]
fn schemas_enforce_strict_objects() {
for (name, getter) in ALL_SCHEMAS {
let value = getter();
assert_strict_objects(value, name);
}
}
fn assert_strict_objects(value: &Value, name: &str) {
if let Some(map) = value.as_object() {
if map.contains_key("properties") {
let strict = map
.get("additionalProperties")
.and_then(Value::as_bool)
.is_some_and(|b| !b);
assert!(
strict,
"{name}: object subschema missing `additionalProperties: false`: {value}"
);
}
for (_, child) in map {
assert_strict_objects(child, name);
}
} else if let Some(arr) = value.as_array() {
for item in arr {
assert_strict_objects(item, name);
}
}
}
#[test]
fn pr_content_requires_title_and_description() {
let value = pr_content_schema();
let required = value
.get("required")
.and_then(Value::as_array)
.expect("pr_content schema should have a `required` array");
let names: Vec<&str> = required.iter().filter_map(Value::as_str).collect();
assert!(
names.contains(&"title"),
"missing title in required: {names:?}"
);
assert!(
names.contains(&"description"),
"missing description in required: {names:?}"
);
}
#[test]
fn amendment_file_required_fields() {
let value = amendment_file_schema();
let required = value
.get("required")
.and_then(Value::as_array)
.expect("amendment_file schema should have a `required` array");
let names: Vec<&str> = required.iter().filter_map(Value::as_str).collect();
assert_eq!(names, vec!["amendments"]);
}
#[test]
fn check_response_required_fields() {
let value = check_response_schema();
let required = value
.get("required")
.and_then(Value::as_array)
.expect("check_response schema should have a `required` array");
let names: Vec<&str> = required.iter().filter_map(Value::as_str).collect();
assert_eq!(names, vec!["checks"]);
}
}