use std::cell::RefCell;
use std::collections::HashSet;
use std::rc::Rc;
use crate::{constraints::ConstraintResults, error::Result, value::FlexValue};
pub const DEFAULT_MAX_DEPTH: usize = 100;
#[derive(Debug)]
struct SharedState {
visited_for_strict: HashSet<(String, FlexValue)>,
visited_for_lenient: HashSet<(String, FlexValue)>,
constraints: ConstraintResults,
transformations: Vec<crate::value::Transformation>,
}
#[derive(Debug, Clone)]
pub struct CoercionContext {
shared: Rc<RefCell<SharedState>>,
depth: usize,
max_depth: usize,
scope: Vec<String>,
}
impl CoercionContext {
pub fn new() -> Self {
Self {
shared: Rc::new(RefCell::new(SharedState {
visited_for_strict: HashSet::new(),
visited_for_lenient: HashSet::new(),
constraints: ConstraintResults::new(),
transformations: Vec::new(),
})),
depth: 0,
max_depth: DEFAULT_MAX_DEPTH,
scope: vec!["<root>".to_string()],
}
}
pub fn with_max_depth(max_depth: usize) -> Self {
Self {
shared: Rc::new(RefCell::new(SharedState {
visited_for_strict: HashSet::new(),
visited_for_lenient: HashSet::new(),
constraints: ConstraintResults::new(),
transformations: Vec::new(),
})),
depth: 0,
max_depth,
scope: vec!["<root>".to_string()],
}
}
pub fn enter_scope(&self, name: &str) -> Self {
let mut new_scope = self.scope.clone();
new_scope.push(name.to_string());
Self {
shared: Rc::clone(&self.shared),
depth: self.depth,
max_depth: self.max_depth,
scope: new_scope,
}
}
pub fn scope_path(&self) -> String {
self.scope.join(".")
}
pub fn scope(&self) -> &[String] {
&self.scope
}
pub fn add_constraint(&mut self, result: crate::constraints::ConstraintResult) {
self.shared.borrow_mut().constraints.add(result);
}
pub fn constraints(&self) -> ConstraintResults {
self.shared.borrow().constraints.clone()
}
pub fn all_asserts_passed(&self) -> bool {
self.shared.borrow().constraints.all_asserts_passed()
}
pub fn failing_asserts(&self) -> Vec<crate::constraints::ConstraintResult> {
self.shared
.borrow()
.constraints
.failing_asserts()
.into_iter()
.cloned()
.collect()
}
pub fn add_transformation(&mut self, transformation: crate::value::Transformation) {
self.shared
.borrow_mut()
.transformations
.push(transformation);
}
pub fn transformations(&self) -> Vec<crate::value::Transformation> {
self.shared.borrow().transformations.clone()
}
pub fn take_transformations(&mut self) -> Vec<crate::value::Transformation> {
std::mem::take(&mut self.shared.borrow_mut().transformations)
}
pub fn check_can_enter_strict(&self, type_name: &str, value: &FlexValue) -> Result<()> {
if self.depth >= self.max_depth {
return Err(crate::error::ParseError::DeserializeFailed(
crate::error::DeserializeError::DepthLimitExceeded {
depth: self.depth,
max_depth: self.max_depth,
},
));
}
let pair = (type_name.to_string(), value.clone());
if self.shared.borrow().visited_for_strict.contains(&pair) {
return Err(crate::error::ParseError::DeserializeFailed(
crate::error::DeserializeError::CircularReference {
type_name: type_name.to_string(),
},
));
}
Ok(())
}
pub fn with_visited_strict(&self, type_name: &str, value: &FlexValue) -> Self {
self.shared
.borrow_mut()
.visited_for_strict
.insert((type_name.to_string(), value.clone()));
Self {
shared: Rc::clone(&self.shared),
depth: self.depth + 1,
max_depth: self.max_depth,
scope: self.scope.clone(),
}
}
pub fn check_can_enter_lenient(&self, type_name: &str, value: &FlexValue) -> Result<()> {
if self.depth >= self.max_depth {
return Err(crate::error::ParseError::DeserializeFailed(
crate::error::DeserializeError::DepthLimitExceeded {
depth: self.depth,
max_depth: self.max_depth,
},
));
}
let pair = (type_name.to_string(), value.clone());
if self.shared.borrow().visited_for_lenient.contains(&pair) {
return Err(crate::error::ParseError::DeserializeFailed(
crate::error::DeserializeError::CircularReference {
type_name: type_name.to_string(),
},
));
}
Ok(())
}
pub fn with_visited_lenient(&self, type_name: &str, value: &FlexValue) -> Self {
self.shared
.borrow_mut()
.visited_for_lenient
.insert((type_name.to_string(), value.clone()));
Self {
shared: Rc::clone(&self.shared),
depth: self.depth + 1,
max_depth: self.max_depth,
scope: self.scope.clone(),
}
}
pub const fn depth(&self) -> usize {
self.depth
}
}
impl Default for CoercionContext {
fn default() -> Self {
Self::new()
}
}
pub trait LlmDeserialize: Sized {
fn try_deserialize(_value: &FlexValue, _ctx: &mut CoercionContext) -> Option<Self> {
None
}
fn deserialize(value: &FlexValue, ctx: &mut CoercionContext) -> Result<Self>;
fn type_name() -> &'static str {
std::any::type_name::<Self>()
}
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
use crate::value::{FlexValue, Source};
#[test]
fn test_context_depth_limit() {
let ctx = CoercionContext::with_max_depth(3);
let v1 = FlexValue::new(json!(1), Source::Direct);
let v2 = FlexValue::new(json!(2), Source::Direct);
let v3 = FlexValue::new(json!(3), Source::Direct);
let v4 = FlexValue::new(json!(4), Source::Direct);
let ctx1 = ctx.with_visited_strict("T1", &v1);
assert_eq!(ctx1.depth(), 1);
let ctx2 = ctx1.with_visited_strict("T2", &v2);
assert_eq!(ctx2.depth(), 2);
let ctx3 = ctx2.with_visited_strict("T3", &v3);
assert_eq!(ctx3.depth(), 3);
let result = ctx3.check_can_enter_strict("T4", &v4);
assert!(result.is_err());
}
#[test]
fn test_context_circular_detection_strict() {
let ctx = CoercionContext::new();
let value = FlexValue::new(json!({"recursive": true}), Source::Direct);
let ctx1 = ctx.with_visited_strict("Node", &value);
let result = ctx1.check_can_enter_strict("Node", &value);
assert!(result.is_err());
}
#[test]
fn test_context_circular_detection_lenient() {
let ctx = CoercionContext::new();
let value = FlexValue::new(json!({"recursive": true}), Source::Direct);
let ctx1 = ctx.with_visited_lenient("Node", &value);
let result = ctx1.check_can_enter_lenient("Node", &value);
assert!(result.is_err());
}
#[test]
fn test_context_cloning() {
let ctx = CoercionContext::new();
let value = FlexValue::new(json!(1), Source::Direct);
let ctx1 = ctx.with_visited_strict("T", &value);
assert_eq!(ctx1.depth(), 1);
assert_eq!(ctx.depth(), 0);
let ctx2 = ctx.with_visited_strict("T", &value);
assert_eq!(ctx2.depth(), 1);
}
#[test]
fn test_separate_strict_lenient_tracking() {
let ctx = CoercionContext::new();
let value = FlexValue::new(json!(1), Source::Direct);
let ctx_strict = ctx.with_visited_strict("T", &value);
let ctx_lenient = ctx.with_visited_lenient("T", &value);
assert_eq!(ctx_strict.depth(), 1);
assert_eq!(ctx_lenient.depth(), 1);
}
}