use std::borrow::Cow;
use schemars::{JsonSchema, Schema, SchemaGenerator};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Deserializer};
pub(crate) struct Lenient<T>(pub T);
const NEAR_MISS_MAX_DISTANCE: usize = 2;
impl<T: JsonSchema> JsonSchema for Lenient<T> {
fn inline_schema() -> bool {
T::inline_schema()
}
fn schema_name() -> Cow<'static, str> {
T::schema_name()
}
fn schema_id() -> Cow<'static, str> {
T::schema_id()
}
fn json_schema(generator: &mut SchemaGenerator) -> Schema {
T::json_schema(generator)
}
fn _schemars_private_non_optional_json_schema(generator: &mut SchemaGenerator) -> Schema {
T::_schemars_private_non_optional_json_schema(generator)
}
fn _schemars_private_is_option() -> bool {
T::_schemars_private_is_option()
}
}
impl<'de, T> Deserialize<'de> for Lenient<T>
where
T: DeserializeOwned + JsonSchema,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Error as _;
let value = serde_json::Value::deserialize(deserializer)?;
match serde_json::from_value::<T>(value.clone()) {
Ok(inner) => Ok(Lenient(inner)),
Err(error) => Err(D::Error::custom(build_message::<T>(&error, &value))),
}
}
}
fn build_message<T: JsonSchema>(error: &serde_json::Error, provided: &serde_json::Value) -> String {
let (expected, required) = expected_fields::<T>();
let mut message = format!("{error}");
if !expected.is_empty() {
message.push_str(". expected fields: ");
message.push_str(&join_fields(&expected, &required));
}
if let Some(suggestion) = did_you_mean(provided, &expected) {
message.push_str(&format!(". did you mean `{suggestion}`?"));
}
message
}
fn expected_fields<T: JsonSchema>() -> (Vec<String>, Vec<String>) {
let schema = schemars::schema_for!(T);
let Some(object) = schema.as_object() else {
return (Vec::new(), Vec::new());
};
let properties = object
.get("properties")
.and_then(|v| v.as_object())
.map(|map| map.keys().cloned().collect::<Vec<_>>())
.unwrap_or_default();
let required = object
.get("required")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(str::to_owned))
.collect::<Vec<_>>()
})
.unwrap_or_default();
(properties, required)
}
fn join_fields(expected: &[String], required: &[String]) -> String {
expected
.iter()
.map(|field| {
if required.iter().any(|r| r == field) {
format!("`{field}` (required)")
} else {
format!("`{field}`")
}
})
.collect::<Vec<_>>()
.join(", ")
}
fn did_you_mean(provided: &serde_json::Value, expected: &[String]) -> Option<String> {
let provided_keys = provided.as_object()?;
let mut best: Option<(usize, &String)> = None;
for key in provided_keys.keys() {
if expected.iter().any(|field| field == key) {
continue;
}
for field in expected {
let distance = levenshtein(key, field);
if distance <= NEAR_MISS_MAX_DISTANCE
&& best.is_none_or(|(best_dist, _)| distance < best_dist)
{
best = Some((distance, field));
}
}
}
best.map(|(_, field)| field.clone())
}
fn levenshtein(a: &str, b: &str) -> usize {
let a: Vec<char> = a.chars().collect();
let b: Vec<char> = b.chars().collect();
let mut previous: Vec<usize> = (0..=b.len()).collect();
for (i, ca) in a.iter().enumerate() {
let mut diagonal = previous[0];
previous[0] = i + 1;
for (j, cb) in b.iter().enumerate() {
let substitution_cost = usize::from(ca != cb);
let next = (previous[j + 1] + 1)
.min(previous[j] + 1)
.min(diagonal + substitution_cost);
diagonal = previous[j + 1];
previous[j + 1] = next;
}
}
previous[b.len()]
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mcp::types::WorkspaceGrepParams;
#[test]
fn levenshtein_matches_known_distances() {
assert_eq!(levenshtein("pattern", "pattern"), 0);
assert_eq!(levenshtein("patern", "pattern"), 1);
assert_eq!(levenshtein("pattrn", "pattern"), 1);
assert_eq!(levenshtein("", "abc"), 3);
}
#[test]
fn valid_params_deserialize_to_inner() {
let value = serde_json::json!({ "pattern": "x" });
let lenient: Lenient<WorkspaceGrepParams> =
serde_json::from_value(value).expect("valid params should deserialize");
assert_eq!(lenient.0.pattern, "x");
assert!(lenient.0.include_context);
assert_eq!(lenient.0.limit, None);
}
#[test]
fn missing_required_field_names_expected_and_suggests() {
let value = serde_json::json!({ "patern": "x" });
let message = match serde_json::from_value::<Lenient<WorkspaceGrepParams>>(value) {
Ok(_) => panic!("missing `pattern` must fail"),
Err(error) => error.to_string(),
};
assert!(
message.contains("pattern"),
"message should name the expected `pattern` field, got: {message}"
);
assert!(
message.contains("did you mean `pattern`?"),
"message should suggest the near-miss field, got: {message}"
);
}
#[test]
fn aliased_param_name_deserializes_through_lenient() {
let value = serde_json::json!({ "query": "needle" });
let lenient: Lenient<WorkspaceGrepParams> =
serde_json::from_value(value).expect("aliased `query` should deserialize");
assert_eq!(lenient.0.pattern, "needle");
}
#[test]
fn missing_required_field_without_near_miss_still_lists_fields() {
let value = serde_json::json!({ "completely_unrelated_key": "x" });
let message = match serde_json::from_value::<Lenient<WorkspaceGrepParams>>(value) {
Ok(_) => panic!("missing `pattern` must fail"),
Err(error) => error.to_string(),
};
assert!(
message.contains("`pattern` (required)"),
"message should list `pattern` as required, got: {message}"
);
}
}