use std::any::{Any, TypeId, type_name};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use rmcp::model::JsonObject;
use schemars::JsonSchema;
use schemars::Schema;
use schemars::generate::SchemaSettings;
use serde_json::Value;
thread_local! {
static SCHEMA_CACHE: Mutex<HashMap<(TypeId, bool), Arc<JsonObject>>> = Mutex::new(HashMap::new());
}
#[must_use]
pub fn input_schema<T: JsonSchema + Any>(pinned: bool) -> Arc<JsonObject> {
SCHEMA_CACHE.with(|cache| {
cache
.lock()
.expect("schema cache poisoned")
.entry((TypeId::of::<T>(), pinned))
.or_insert_with(|| Arc::new(build::<T>(pinned)))
.clone()
})
}
#[must_use]
pub fn output_schema<T: JsonSchema + Any>() -> Arc<JsonObject> {
let schema = input_schema::<T>(false);
match schema.get("type") {
Some(Value::String(t)) if t == "object" => schema,
other => panic!(
"Invalid output schema for type `{}`: root `type` must be \"object\", got {:?}",
type_name::<T>(),
other,
),
}
}
fn build<T: JsonSchema>(pinned: bool) -> JsonObject {
let value = SchemaSettings::draft2020_12()
.with(|s| {
s.inline_subschemas = true;
s.meta_schema = None;
})
.with_transform(strip_root_metadata)
.with_transform(move |schema: &mut Schema| {
if pinned {
strip_root_database(schema);
}
})
.into_generator()
.into_root_schema_for::<T>()
.to_value();
let Value::Object(object) = value else {
panic!("schema for `{}` did not produce a JSON object", type_name::<T>());
};
object
}
fn strip_root_metadata(schema: &mut Schema) {
if let Some(object) = schema.as_object_mut() {
object.remove("title");
object.remove("description");
}
}
fn strip_root_database(schema: &mut Schema) {
let Some(object) = schema.as_object_mut() else {
return;
};
if let Some(Value::Object(properties)) = object.get_mut("properties") {
properties.remove("database");
}
if let Some(Value::Array(required)) = object.get_mut("required") {
required.retain(|value| value.as_str() != Some("database"));
if required.is_empty() {
object.remove("required");
}
}
}
#[cfg(test)]
mod tests {
use super::{build, input_schema, output_schema};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Deserialize, Serialize, JsonSchema)]
#[schemars(title = "FixtureTitle", description = "Fixture root description.")]
struct Fixture {
name: String,
nested: Nested,
}
#[derive(Deserialize, Serialize, JsonSchema)]
struct Nested {
value: u32,
}
#[derive(Deserialize, Serialize, JsonSchema)]
struct OtherFixture {
value: u32,
}
fn contains_key(value: &Value, key: &str) -> bool {
match value {
Value::Object(map) => map.contains_key(key) || map.values().any(|v| contains_key(v, key)),
Value::Array(items) => items.iter().any(|v| contains_key(v, key)),
_ => false,
}
}
#[derive(Deserialize, Serialize, JsonSchema)]
struct PinnedFixture {
query: String,
#[serde(default)]
database: Option<String>,
}
#[test]
fn input_schema_strips_dollar_schema_title_and_description() {
let schema = input_schema::<Fixture>(false);
assert!(!schema.contains_key("$schema"), "root $schema not stripped: {schema:?}");
assert!(!schema.contains_key("title"), "root title not stripped: {schema:?}");
assert!(
!schema.contains_key("description"),
"root description not stripped: {schema:?}"
);
assert_eq!(schema.get("type"), Some(&Value::String("object".into())));
}
#[test]
fn input_schema_inlines_nested_subschemas() {
let schema = input_schema::<Fixture>(false);
let value = Value::Object((*schema).clone());
assert!(!contains_key(&value, "$defs"), "$defs not inlined: {value}");
assert!(!contains_key(&value, "$ref"), "$ref not inlined: {value}");
}
#[test]
fn input_schema_caches_by_type_and_pinned() {
let first = input_schema::<Fixture>(false);
let second = input_schema::<Fixture>(false);
assert!(
std::sync::Arc::ptr_eq(&first, &second),
"same (type, pinned) should return cached Arc"
);
let other = input_schema::<OtherFixture>(false);
assert!(
!std::sync::Arc::ptr_eq(&first, &other),
"different types must not share cache entry"
);
let pinned = input_schema::<PinnedFixture>(true);
let unpinned = input_schema::<PinnedFixture>(false);
assert!(
!std::sync::Arc::ptr_eq(&pinned, &unpinned),
"same type with different pinned flags must not share cache entry"
);
}
#[test]
fn output_schema_accepts_object_root() {
let schema = output_schema::<Fixture>();
assert_eq!(schema.get("type"), Some(&Value::String("object".into())));
let again = output_schema::<Fixture>();
assert!(std::sync::Arc::ptr_eq(&schema, &again));
}
#[test]
#[should_panic(expected = "root `type` must be \"object\"")]
fn output_schema_panics_on_non_object_root() {
let _ = output_schema::<u32>();
}
#[test]
fn build_preserves_properties() {
let schema = build::<Fixture>(false);
let properties = schema
.get("properties")
.and_then(Value::as_object)
.expect("properties survive generation");
assert!(properties.contains_key("name"));
assert!(properties.contains_key("nested"));
let name = properties.get("name").and_then(Value::as_object).unwrap();
assert_eq!(
name.get("description").and_then(Value::as_str),
Some("Doc-comment kept on the property — must survive schema generation."),
"per-property descriptions must survive schema generation"
);
}
#[test]
fn pinned_input_schema_strips_database_property_and_required() {
let unpinned = input_schema::<PinnedFixture>(false);
let unpinned_props = unpinned.get("properties").and_then(Value::as_object).unwrap();
assert!(unpinned_props.contains_key("database"));
assert!(unpinned_props.contains_key("query"));
let pinned = input_schema::<PinnedFixture>(true);
let pinned_props = pinned.get("properties").and_then(Value::as_object).unwrap();
assert!(
!pinned_props.contains_key("database"),
"pinned schema must drop `database`: {pinned:?}"
);
assert!(
pinned_props.contains_key("query"),
"pinned schema preserves other props"
);
let required = pinned.get("required").and_then(Value::as_array);
if let Some(required) = required {
assert!(
!required.iter().any(|v| v.as_str() == Some("database")),
"pinned schema must drop `database` from required: {required:?}"
);
}
}
}