use serde_json::Value;
use crate::{
deserializer::struct_coercer::{remove_accents, strip_punctuation},
error::{DeserializeError, ParseError, Result},
value::FlexValue,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MatchStrategy {
Exact,
Unaccented,
PunctuationStripped,
CaseInsensitive,
Substring,
Levenshtein,
}
#[derive(Debug, Clone)]
pub struct MatchResult {
pub variant: String,
pub strategy: MatchStrategy,
pub distance: usize,
}
#[derive(Debug, Clone)]
pub struct EnumVariant {
pub name: String,
pub description: Option<String>,
}
impl EnumVariant {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
description: None,
}
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
fn match_strings(&self) -> Vec<String> {
match &self.description {
Some(desc) if !desc.trim().is_empty() => {
vec![
self.name.clone(),
desc.clone(),
format!("{}: {}", self.name, desc),
]
}
_ => vec![self.name.clone()],
}
}
}
#[derive(Debug, Clone)]
pub struct EnumMatcher {
variants: Vec<EnumVariant>,
}
impl EnumMatcher {
pub fn new() -> Self {
Self {
variants: Vec::new(),
}
}
pub fn variant(mut self, variant: EnumVariant) -> Self {
self.variants.push(variant);
self
}
pub fn match_string(&self, input: &str) -> Result<String> {
self.match_string_detailed(input)
.map(|result| result.variant)
}
pub fn match_string_detailed(&self, input: &str) -> Result<MatchResult> {
let input = input.trim();
let candidates: Vec<(&str, Vec<String>)> = self
.variants
.iter()
.map(|v| (v.name.as_str(), v.match_strings()))
.collect();
if let Some(matched) = self.try_exact_match(input, &candidates) {
return Ok(MatchResult {
variant: matched.to_string(),
strategy: MatchStrategy::Exact,
distance: 0,
});
}
if let Some(matched) = self.try_unaccented_match(input, &candidates) {
return Ok(MatchResult {
variant: matched.to_string(),
strategy: MatchStrategy::Unaccented,
distance: 0,
});
}
let stripped_input = strip_punctuation(input);
let stripped_candidates: Vec<(&str, Vec<String>)> = candidates
.iter()
.map(|(name, values)| {
let stripped_values = values.iter().map(|v| strip_punctuation(v)).collect();
(*name, stripped_values)
})
.collect();
if let Some(matched) = self.try_exact_match(&stripped_input, &stripped_candidates) {
return Ok(MatchResult {
variant: matched.to_string(),
strategy: MatchStrategy::PunctuationStripped,
distance: 0,
});
}
let lowercase_input = stripped_input.to_lowercase();
let lowercase_candidates: Vec<(&str, Vec<String>)> = stripped_candidates
.iter()
.map(|(name, values)| {
let lowercase_values = values.iter().map(|v| v.to_lowercase()).collect();
(*name, lowercase_values)
})
.collect();
if let Some(matched) = self.try_exact_match(&lowercase_input, &lowercase_candidates) {
return Ok(MatchResult {
variant: matched.to_string(),
strategy: MatchStrategy::CaseInsensitive,
distance: 0,
});
}
if let Some(matched) = self.try_substring_match(&lowercase_input, &lowercase_candidates) {
return Ok(MatchResult {
variant: matched.to_string(),
strategy: MatchStrategy::Substring,
distance: 0,
});
}
if let Some((matched, distance)) =
self.try_edit_distance_match(&lowercase_input, &lowercase_candidates)
{
return Ok(MatchResult {
variant: matched.to_string(),
strategy: MatchStrategy::Levenshtein,
distance,
});
}
let suggestion = self.find_closest_suggestion(&lowercase_input, &lowercase_candidates);
Err(ParseError::DeserializeFailed(
DeserializeError::UnknownVariant {
enum_name: "enum".to_string(),
variant: input.to_string(),
suggestion,
},
))
}
fn find_closest_suggestion(
&self,
input: &str,
lowercase_candidates: &[(&str, Vec<String>)],
) -> Option<String> {
let mut best_match: Option<(&str, usize)> = None;
let max_distance = (input.len() / 2).max(2);
for (variant_name, match_strings) in lowercase_candidates {
for match_str in match_strings {
let distance = levenshtein_distance(input, match_str);
if distance <= max_distance {
if let Some((_, best_dist)) = best_match {
if distance < best_dist {
best_match = Some((variant_name, distance));
}
} else {
best_match = Some((variant_name, distance));
}
}
}
}
best_match.map(|(name, _)| name.to_string())
}
fn try_exact_match<'a>(
&self,
input: &str,
candidates: &'a [(&'a str, Vec<String>)],
) -> Option<&'a str> {
for (variant_name, match_strings) in candidates {
if match_strings.iter().any(|s| s == input) {
return Some(variant_name);
}
}
None
}
fn try_unaccented_match<'a>(
&self,
input: &str,
candidates: &'a [(&'a str, Vec<String>)],
) -> Option<&'a str> {
let unaccented_input = remove_accents(input);
for (variant_name, match_strings) in candidates {
if match_strings
.iter()
.any(|s| remove_accents(s) == unaccented_input)
{
return Some(variant_name);
}
}
None
}
fn try_substring_match<'a>(
&self,
input: &str,
candidates: &'a [(&'a str, Vec<String>)],
) -> Option<&'a str> {
let mut all_matches: Vec<(usize, usize, usize, &'a str)> = Vec::new();
for (variant_name, match_strings) in candidates {
for match_str in match_strings {
for (start_idx, _) in input.match_indices(match_str.as_str()) {
let end_idx = start_idx + match_str.len();
all_matches.push((start_idx, end_idx, match_str.len(), variant_name));
}
}
}
if !all_matches.is_empty() {
all_matches.sort_by(|a, b| b.2.cmp(&a.2));
return Some(all_matches[0].3);
}
let mut reverse_matches: Vec<(&'a str, usize)> = Vec::new();
for (variant_name, match_strings) in candidates {
for match_str in match_strings {
if match_str.contains(input) {
reverse_matches.push((variant_name, match_str.len()));
}
}
}
if !reverse_matches.is_empty() {
reverse_matches.sort_by(|a, b| a.1.cmp(&b.1));
return Some(reverse_matches[0].0);
}
None
}
fn try_edit_distance_match<'a>(
&self,
input: &str,
candidates: &'a [(&'a str, Vec<String>)],
) -> Option<(&'a str, usize)> {
let mut best_match: Option<&'a str> = None;
let mut best_distance = usize::MAX;
for (variant_name, match_strings) in candidates {
for match_str in match_strings {
let distance = levenshtein_distance(input, match_str);
if distance < best_distance {
best_distance = distance;
best_match = Some(variant_name);
}
}
}
let threshold = if input.is_empty() { 0 } else { input.len() / 3 };
if best_distance <= threshold {
best_match.map(|m| (m, best_distance))
} else {
None
}
}
}
impl Default for EnumMatcher {
fn default() -> Self {
Self::new()
}
}
pub fn levenshtein_distance(s1: &str, s2: &str) -> usize {
let len1 = s1.chars().count();
let len2 = s2.chars().count();
if len1 == 0 {
return len2;
}
if len2 == 0 {
return len1;
}
let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
for (i, row) in matrix.iter_mut().enumerate().take(len1 + 1) {
row[0] = i;
}
for j in 0..=len2 {
matrix[0][j] = j;
}
let s1_chars: Vec<char> = s1.chars().collect();
let s2_chars: Vec<char> = s2.chars().collect();
for (i, c1) in s1_chars.iter().enumerate() {
for (j, c2) in s2_chars.iter().enumerate() {
let cost = if c1 == c2 { 0 } else { 1 };
matrix[i + 1][j + 1] = std::cmp::min(
std::cmp::min(
matrix[i][j + 1] + 1, matrix[i + 1][j] + 1, ),
matrix[i][j] + cost, );
}
}
matrix[len1][len2]
}
pub fn match_enum_variant(value: &FlexValue, matcher: &EnumMatcher) -> Result<String> {
match &value.value {
Value::String(s) => matcher.match_string(s),
Value::Number(n) => {
matcher.match_string(&n.to_string())
}
Value::Bool(b) => {
matcher.match_string(&b.to_string())
}
_ => Err(ParseError::DeserializeFailed(
DeserializeError::type_mismatch("string", "non-string"),
)),
}
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
use crate::value::Source;
#[test]
fn test_levenshtein_distance() {
assert_eq!(levenshtein_distance("", ""), 0);
assert_eq!(levenshtein_distance("hello", "hello"), 0);
assert_eq!(levenshtein_distance("hello", "hallo"), 1);
assert_eq!(levenshtein_distance("hello", "help"), 2);
assert_eq!(levenshtein_distance("kitten", "sitting"), 3);
assert_eq!(levenshtein_distance("saturday", "sunday"), 3);
}
#[test]
fn test_enum_matcher_exact_match() {
let matcher = EnumMatcher::new()
.variant(EnumVariant::new("Success"))
.variant(EnumVariant::new("Error"))
.variant(EnumVariant::new("Pending"));
assert_eq!(matcher.match_string("Success").unwrap(), "Success");
assert_eq!(matcher.match_string("Error").unwrap(), "Error");
assert_eq!(matcher.match_string("Pending").unwrap(), "Pending");
}
#[test]
fn test_enum_matcher_case_insensitive() {
let matcher = EnumMatcher::new()
.variant(EnumVariant::new("Success"))
.variant(EnumVariant::new("Error"));
assert_eq!(matcher.match_string("success").unwrap(), "Success");
assert_eq!(matcher.match_string("SUCCESS").unwrap(), "Success");
assert_eq!(matcher.match_string("error").unwrap(), "Error");
assert_eq!(matcher.match_string("ERROR").unwrap(), "Error");
}
#[test]
fn test_enum_matcher_with_description() {
let matcher = EnumMatcher::new()
.variant(EnumVariant::new("Active").with_description("Currently active"))
.variant(EnumVariant::new("Inactive").with_description("Not active"));
assert_eq!(matcher.match_string("Active").unwrap(), "Active");
assert_eq!(matcher.match_string("Currently active").unwrap(), "Active");
assert_eq!(
matcher.match_string("Active: Currently active").unwrap(),
"Active"
);
}
#[test]
fn test_enum_matcher_punctuation_stripping() {
let matcher = EnumMatcher::new()
.variant(EnumVariant::new("InProgress"))
.variant(EnumVariant::new("Completed"));
assert_eq!(matcher.match_string("In-Progress").unwrap(), "InProgress");
assert_eq!(matcher.match_string("in_progress").unwrap(), "InProgress");
}
#[test]
fn test_enum_matcher_substring() {
let matcher = EnumMatcher::new()
.variant(EnumVariant::new("Processing"))
.variant(EnumVariant::new("Completed"));
assert_eq!(
matcher.match_string("Currently Processing").unwrap(),
"Processing"
);
assert_eq!(
matcher.match_string("Task Completed successfully").unwrap(),
"Completed"
);
}
#[test]
fn test_enum_matcher_edit_distance() {
let matcher = EnumMatcher::new()
.variant(EnumVariant::new("Success"))
.variant(EnumVariant::new("Failure"));
assert_eq!(matcher.match_string("Succes").unwrap(), "Success"); assert_eq!(matcher.match_string("Sucess").unwrap(), "Success"); assert_eq!(matcher.match_string("Failur").unwrap(), "Failure"); }
#[test]
fn test_enum_matcher_no_match() {
let matcher = EnumMatcher::new()
.variant(EnumVariant::new("Success"))
.variant(EnumVariant::new("Error"));
let result = matcher.match_string("RandomValue");
assert!(result.is_err());
}
#[test]
fn test_enum_matcher_accents() {
let matcher = EnumMatcher::new()
.variant(EnumVariant::new("Café"))
.variant(EnumVariant::new("Naïve"));
assert_eq!(matcher.match_string("Cafe").unwrap(), "Café");
assert_eq!(matcher.match_string("Naive").unwrap(), "Naïve");
}
#[test]
fn test_match_enum_variant_from_flex_value() {
let matcher = EnumMatcher::new()
.variant(EnumVariant::new("Success"))
.variant(EnumVariant::new("Error"));
let value = FlexValue::new(json!("Success"), Source::Direct);
assert_eq!(match_enum_variant(&value, &matcher).unwrap(), "Success");
let value = FlexValue::new(json!("success"), Source::Direct);
assert_eq!(match_enum_variant(&value, &matcher).unwrap(), "Success");
}
#[test]
fn test_match_string_detailed_strategies() {
let matcher = EnumMatcher::new()
.variant(EnumVariant::new("Success"))
.variant(EnumVariant::new("Error"))
.variant(EnumVariant::new("Naïve"));
let result = matcher.match_string_detailed("Success").unwrap();
assert_eq!(result.variant, "Success");
assert_eq!(result.strategy, MatchStrategy::Exact);
assert_eq!(result.distance, 0);
let result = matcher.match_string_detailed("success").unwrap();
assert_eq!(result.variant, "Success");
assert_eq!(result.strategy, MatchStrategy::CaseInsensitive);
assert_eq!(result.distance, 0);
let result = matcher.match_string_detailed("Naive").unwrap();
assert_eq!(result.variant, "Naïve");
assert_eq!(result.strategy, MatchStrategy::Unaccented);
assert_eq!(result.distance, 0);
let result = matcher.match_string_detailed("Suc.cess").unwrap();
assert_eq!(result.variant, "Success");
assert_eq!(result.strategy, MatchStrategy::PunctuationStripped);
assert_eq!(result.distance, 0);
let result = matcher.match_string_detailed("Succ").unwrap();
assert_eq!(result.variant, "Success");
assert_eq!(result.strategy, MatchStrategy::Substring);
assert_eq!(result.distance, 0);
let result = matcher.match_string_detailed("Succss").unwrap();
assert_eq!(result.variant, "Success");
assert_eq!(result.strategy, MatchStrategy::Levenshtein);
assert!(result.distance > 0);
}
}