use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub thread_id: Option<String>,
pub checkpoint_id: Option<String>,
pub recursion_limit: usize,
pub metadata: HashMap<String, serde_json::Value>,
pub tags: Vec<String>,
}
impl Default for Config {
fn default() -> Self {
Self {
thread_id: None,
checkpoint_id: None,
recursion_limit: 25,
metadata: HashMap::new(),
tags: Vec::new(),
}
}
}
impl Config {
pub fn new() -> Self {
Self::default()
}
pub fn with_thread_id(mut self, thread_id: impl Into<String>) -> Self {
self.thread_id = Some(thread_id.into());
self
}
pub fn with_checkpoint_id(mut self, checkpoint_id: impl Into<String>) -> Self {
self.checkpoint_id = Some(checkpoint_id.into());
self
}
pub fn with_recursion_limit(mut self, limit: usize) -> Self {
self.recursion_limit = limit;
self
}
pub fn with_metadata(
mut self,
key: impl Into<String>,
value: impl Into<serde_json::Value>,
) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
pub fn ensure_thread_id(&mut self) -> &str {
if self.thread_id.is_none() {
self.thread_id = Some(uuid::Uuid::new_v4().to_string());
}
self.thread_id.as_ref().unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_builder() {
let config = Config::new()
.with_thread_id("test-thread")
.with_recursion_limit(100)
.with_metadata("key", "value")
.with_tag("test");
assert_eq!(config.thread_id.as_deref(), Some("test-thread"));
assert_eq!(config.recursion_limit, 100);
assert_eq!(config.metadata.len(), 1);
assert_eq!(config.tags.len(), 1);
}
#[test]
fn test_default_config() {
let config = Config::default();
assert_eq!(config.recursion_limit, 25);
assert!(config.thread_id.is_none());
assert!(config.metadata.is_empty());
}
#[test]
fn test_ensure_thread_id() {
let mut config = Config::new();
assert!(config.thread_id.is_none());
let thread_id = config.ensure_thread_id().to_string();
assert!(!thread_id.is_empty());
let thread_id2 = config.ensure_thread_id().to_string();
assert_eq!(thread_id, thread_id2);
}
#[test]
fn test_config_serialization() {
let config = Config::new()
.with_thread_id("test")
.with_recursion_limit(50);
let json = serde_json::to_string(&config).unwrap();
let deserialized: Config = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.thread_id, config.thread_id);
assert_eq!(deserialized.recursion_limit, config.recursion_limit);
}
}