use crate::{
deserializer::traits::{CoercionContext, LlmDeserialize},
error::{DeserializeError, ParseError, Result},
value::{FlexValue, Transformation},
};
#[derive(Debug, Clone)]
pub struct UnionMatch<T> {
pub value: T,
pub score: u32,
pub transformations: Vec<Transformation>,
}
pub struct UnionDeserializer<T> {
_phantom: std::marker::PhantomData<T>,
}
impl<T> UnionDeserializer<T> {
pub fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
pub fn try_all<V1, V2>(
&self,
value: &FlexValue,
ctx: &mut CoercionContext,
) -> Vec<UnionMatch<T>>
where
V1: LlmDeserialize + Into<T>,
V2: LlmDeserialize + Into<T>,
{
let mut matches = Vec::new();
let mut ctx1 = ctx.clone();
if let Some(v1) = V1::try_deserialize(value, &mut ctx1) {
matches.push(UnionMatch {
value: v1.into(),
score: 0, transformations: ctx1.transformations().to_vec(),
});
}
let mut ctx2 = ctx.clone();
if let Some(v2) = V2::try_deserialize(value, &mut ctx2) {
matches.push(UnionMatch {
value: v2.into(),
score: 0,
transformations: ctx2.transformations().to_vec(),
});
}
if matches.is_empty() {
let mut ctx1 = ctx.clone();
if let Ok(v1) = V1::deserialize(value, &mut ctx1) {
let score = calculate_score_from_context(&ctx1);
matches.push(UnionMatch {
value: v1.into(),
score,
transformations: ctx1.transformations().to_vec(),
});
}
let mut ctx2 = ctx.clone();
if let Ok(v2) = V2::deserialize(value, &mut ctx2) {
let score = calculate_score_from_context(&ctx2);
matches.push(UnionMatch {
value: v2.into(),
score,
transformations: ctx2.transformations().to_vec(),
});
}
}
matches
}
pub fn deserialize<V1, V2>(&self, value: &FlexValue, ctx: &mut CoercionContext) -> Result<T>
where
V1: LlmDeserialize + Into<T>,
V2: LlmDeserialize + Into<T>,
{
let mut matches = self.try_all::<V1, V2>(value, ctx);
if matches.is_empty() {
return Err(ParseError::DeserializeFailed(DeserializeError::Custom(
"No union variant matched".to_string(),
)));
}
if matches.len() == 1 {
let winning_match = matches.remove(0);
let union_transformation = Transformation::UnionMatch {
index: 0, candidates: vec![
std::any::type_name::<V1>().to_string(),
std::any::type_name::<V2>().to_string(),
],
};
ctx.add_transformation(union_transformation);
for transformation in winning_match.transformations {
ctx.add_transformation(transformation);
}
return Ok(winning_match.value);
}
matches.sort_by(|a, b| {
match a.score.cmp(&b.score) {
std::cmp::Ordering::Equal => {
apply_union_heuristics(a, b)
}
ordering => ordering,
}
});
let winning_match = matches.remove(0);
let union_transformation = Transformation::UnionMatch {
index: 0, candidates: vec![
std::any::type_name::<V1>().to_string(),
std::any::type_name::<V2>().to_string(),
],
};
ctx.add_transformation(union_transformation);
for transformation in winning_match.transformations {
ctx.add_transformation(transformation);
}
Ok(winning_match.value)
}
}
impl<T> Default for UnionDeserializer<T> {
fn default() -> Self {
Self::new()
}
}
fn calculate_score_from_context(ctx: &CoercionContext) -> u32 {
ctx.transformations().iter().map(|t| t.penalty()).sum()
}
fn apply_union_heuristics<T>(a: &UnionMatch<T>, b: &UnionMatch<T>) -> std::cmp::Ordering {
let a_is_list = is_list_transformation(&a.transformations);
let b_is_list = is_list_transformation(&b.transformations);
if a_is_list && b_is_list {
let a_has_single_to_array = a
.transformations
.iter()
.any(|t| matches!(t, Transformation::SingleToArray));
let b_has_single_to_array = b
.transformations
.iter()
.any(|t| matches!(t, Transformation::SingleToArray));
match (a_has_single_to_array, b_has_single_to_array) {
(true, false) => return std::cmp::Ordering::Greater, (false, true) => return std::cmp::Ordering::Less, _ => {}
}
let a_has_markdown = a
.transformations
.iter()
.any(|t| matches!(t, Transformation::ObjectFromMarkdown { .. }));
let b_has_markdown = b
.transformations
.iter()
.any(|t| matches!(t, Transformation::ObjectFromMarkdown { .. }));
match (a_has_markdown, b_has_markdown) {
(true, false) => return std::cmp::Ordering::Greater, (false, true) => return std::cmp::Ordering::Less, _ => {}
}
let a_error_count = count_array_errors(&a.transformations);
let b_error_count = count_array_errors(&b.transformations);
match a_error_count.cmp(&b_error_count) {
std::cmp::Ordering::Equal => {}
ordering => return ordering,
}
}
let a_is_implied_single = a
.transformations
.iter()
.any(|t| matches!(t, Transformation::ImpliedKey { .. }));
let b_is_implied_single = b
.transformations
.iter()
.any(|t| matches!(t, Transformation::ImpliedKey { .. }));
match (a_is_implied_single, b_is_implied_single) {
(true, false) => return std::cmp::Ordering::Greater, (false, true) => return std::cmp::Ordering::Less, _ => {}
}
let a_is_all_defaults = is_all_defaults(&a.transformations);
let b_is_all_defaults = is_all_defaults(&b.transformations);
match (a_is_all_defaults, b_is_all_defaults) {
(true, false) => return std::cmp::Ordering::Greater, (false, true) => return std::cmp::Ordering::Less, _ => {}
}
let a_is_json_to_string = a
.transformations
.iter()
.any(|t| matches!(t, Transformation::JsonToString { .. }));
let b_is_json_to_string = b
.transformations
.iter()
.any(|t| matches!(t, Transformation::JsonToString { .. }));
match (a_is_json_to_string, b_is_json_to_string) {
(true, false) => return std::cmp::Ordering::Greater, (false, true) => return std::cmp::Ordering::Less, _ => {}
}
std::cmp::Ordering::Equal
}
fn is_list_transformation(transformations: &[Transformation]) -> bool {
transformations.iter().any(|t| {
matches!(
t,
Transformation::SingleToArray | Transformation::ArrayItemParseError { .. }
)
})
}
fn count_array_errors(transformations: &[Transformation]) -> usize {
transformations
.iter()
.filter(|t| matches!(t, Transformation::ArrayItemParseError { .. }))
.count()
}
fn is_all_defaults(transformations: &[Transformation]) -> bool {
let has_defaults = transformations
.iter()
.any(|t| matches!(t, Transformation::DefaultValueInserted { .. }));
let has_real_values = transformations.iter().any(|t| {
matches!(
t,
Transformation::StringToNumber { .. }
| Transformation::FloatToInt { .. }
| Transformation::FieldNameCaseChanged { .. }
)
});
has_defaults && !has_real_values
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
use crate::value::Source;
#[derive(Debug, Clone, PartialEq)]
struct StringVariant(String);
impl LlmDeserialize for StringVariant {
fn try_deserialize(value: &FlexValue, _ctx: &mut CoercionContext) -> Option<Self> {
match &value.value {
serde_json::Value::String(s) => Some(StringVariant(s.clone())),
_ => None,
}
}
fn deserialize(value: &FlexValue, ctx: &mut CoercionContext) -> Result<Self> {
Self::try_deserialize(value, ctx).ok_or_else(|| {
ParseError::DeserializeFailed(DeserializeError::type_mismatch(
"string",
"non-string",
))
})
}
}
#[derive(Debug, Clone, PartialEq)]
struct IntVariant(i64);
impl LlmDeserialize for IntVariant {
fn try_deserialize(value: &FlexValue, _ctx: &mut CoercionContext) -> Option<Self> {
match &value.value {
serde_json::Value::Number(n) => n.as_i64().map(IntVariant),
_ => None,
}
}
fn deserialize(value: &FlexValue, _ctx: &mut CoercionContext) -> Result<Self> {
match &value.value {
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
Ok(IntVariant(i))
} else {
Err(ParseError::DeserializeFailed(
DeserializeError::type_mismatch("integer", "non-integer number"),
))
}
}
serde_json::Value::String(s) => s.parse::<i64>().map(IntVariant).map_err(|_| {
ParseError::DeserializeFailed(DeserializeError::type_mismatch(
"integer",
"unparseable string",
))
}),
_ => Err(ParseError::DeserializeFailed(
DeserializeError::type_mismatch("integer", "non-numeric"),
)),
}
}
}
#[derive(Debug, Clone, PartialEq)]
enum StringOrInt {
String(StringVariant),
Int(IntVariant),
}
impl From<StringVariant> for StringOrInt {
fn from(v: StringVariant) -> Self {
StringOrInt::String(v)
}
}
impl From<IntVariant> for StringOrInt {
fn from(v: IntVariant) -> Self {
StringOrInt::Int(v)
}
}
#[test]
fn test_union_string_match() {
let json = json!("hello");
let value = FlexValue::new(json, Source::Direct);
let mut ctx = CoercionContext::new();
let deserializer = UnionDeserializer::<StringOrInt>::new();
let result = deserializer.deserialize::<StringVariant, IntVariant>(&value, &mut ctx);
assert!(result.is_ok());
let union = result.unwrap();
assert!(matches!(union, StringOrInt::String(_)));
}
#[test]
fn test_union_int_match() {
let json = json!(42);
let value = FlexValue::new(json, Source::Direct);
let mut ctx = CoercionContext::new();
let deserializer = UnionDeserializer::<StringOrInt>::new();
let result = deserializer.deserialize::<StringVariant, IntVariant>(&value, &mut ctx);
assert!(result.is_ok());
let union = result.unwrap();
assert!(matches!(union, StringOrInt::Int(_)));
}
#[test]
fn test_union_string_coercion_to_int() {
let json = json!("42");
let value = FlexValue::new(json, Source::Direct);
let mut ctx = CoercionContext::new();
let deserializer = UnionDeserializer::<StringOrInt>::new();
let result = deserializer.deserialize::<StringVariant, IntVariant>(&value, &mut ctx);
assert!(result.is_ok());
let union = result.unwrap();
assert!(matches!(union, StringOrInt::String(_)));
}
#[test]
fn test_union_no_match() {
let json = json!(null);
let value = FlexValue::new(json, Source::Direct);
let mut ctx = CoercionContext::new();
let deserializer = UnionDeserializer::<StringOrInt>::new();
let result = deserializer.deserialize::<StringVariant, IntVariant>(&value, &mut ctx);
assert!(result.is_err());
}
#[test]
fn test_calculate_score() {
use crate::scoring::score_candidate;
let mut value = FlexValue::new(json!("42"), Source::Direct);
assert_eq!(score_candidate(&value), 0);
value.add_transformation(Transformation::StringToNumber {
original: "42".to_string(),
});
assert!(score_candidate(&value) > 0);
}
#[test]
fn test_is_all_defaults() {
let transformations = vec![Transformation::DefaultValueInserted {
field: "age".to_string(),
}];
assert!(is_all_defaults(&transformations));
let transformations2 = vec![
Transformation::DefaultValueInserted {
field: "age".to_string(),
},
Transformation::StringToNumber {
original: "42".to_string(),
},
];
assert!(!is_all_defaults(&transformations2));
}
}