use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use echo_core::error::ReactError;
use echo_core::llm::types::Message;
fn deep_merge_values(
target: &mut serde_json::Map<String, Value>,
source: &serde_json::Map<String, Value>,
) {
for (k, v) in source {
if let Some(existing) = target.get_mut(k)
&& let (Some(obj_a), Some(obj_b)) = (existing.as_object_mut(), v.as_object())
{
deep_merge_values(obj_a, obj_b);
continue;
}
target.insert(k.clone(), v.clone());
}
}
pub type StateResult<T> = std::result::Result<T, StateError>;
#[derive(Debug)]
pub enum StateError {
Serialize(String),
LockPoisoned(String),
}
impl std::fmt::Display for StateError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StateError::Serialize(e) => write!(f, "序列化失败: {e}"),
StateError::LockPoisoned(e) => write!(f, "锁中毒: {e}"),
}
}
}
impl std::error::Error for StateError {}
impl From<StateError> for ReactError {
fn from(e: StateError) -> Self {
ReactError::Other(e.to_string())
}
}
#[derive(Clone)]
pub struct SharedState {
inner: Arc<RwLock<StateInner>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct StateInner {
pub values: HashMap<String, Value>,
pub messages: Vec<Message>,
#[serde(default)]
pub current_node: Option<String>,
}
impl SharedState {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(StateInner::default())),
}
}
pub fn from_values(values: HashMap<String, Value>) -> Self {
Self {
inner: Arc::new(RwLock::new(StateInner {
values,
messages: Vec::new(),
current_node: None,
})),
}
}
pub fn from_snapshot(snapshot: &str) -> std::result::Result<Self, serde_json::Error> {
let inner: StateInner = serde_json::from_str(snapshot)?;
Ok(Self {
inner: Arc::new(RwLock::new(inner)),
})
}
pub fn set<T: Serialize>(&self, key: impl Into<String>, value: T) -> StateResult<()> {
let key = key.into();
let v = serde_json::to_value(value).map_err(|e| StateError::Serialize(e.to_string()))?;
let mut inner = self
.inner
.write()
.map_err(|e| StateError::LockPoisoned(e.to_string()))?;
inner.values.insert(key, v);
drop(inner);
Ok(())
}
pub fn set_best_effort<T: Serialize>(&self, key: impl Into<String>, value: T) -> Option<()> {
self.set(key, value).ok()
}
pub fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
let Ok(inner) = self.inner.read() else {
return None;
};
inner
.values
.get(key)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
pub fn get_raw(&self, key: &str) -> Option<Value> {
self.inner.read().ok()?.values.get(key).cloned()
}
pub fn contains(&self, key: &str) -> bool {
self.inner
.read()
.map(|inner| inner.values.contains_key(key))
.unwrap_or(false)
}
pub fn remove(&self, key: &str) -> Option<Value> {
self.inner.write().ok()?.values.remove(key)
}
pub fn keys(&self) -> Vec<String> {
self.inner
.read()
.map(|inner| inner.values.keys().cloned().collect())
.unwrap_or_default()
}
pub fn push_message(&self, msg: Message) -> StateResult<()> {
let mut inner = self
.inner
.write()
.map_err(|e| StateError::LockPoisoned(e.to_string()))?;
inner.messages.push(msg);
drop(inner);
Ok(())
}
pub fn messages(&self) -> Vec<Message> {
self.inner
.read()
.map(|inner| inner.messages.clone())
.unwrap_or_default()
}
pub fn message_count(&self) -> usize {
self.inner
.read()
.map(|inner| inner.messages.len())
.unwrap_or(0)
}
pub fn clear_messages(&self) -> StateResult<()> {
let mut inner = self
.inner
.write()
.map_err(|e| StateError::LockPoisoned(e.to_string()))?;
inner.messages.clear();
drop(inner);
Ok(())
}
pub(crate) fn set_current_node(&self, node: impl Into<String>) {
if let Ok(mut inner) = self.inner.write() {
inner.current_node = Some(node.into());
}
}
pub fn current_node(&self) -> Option<String> {
self.inner
.read()
.ok()
.and_then(|inner| inner.current_node.clone())
}
pub fn fork(&self) -> StateResult<Self> {
let inner = self
.inner
.read()
.map_err(|e| StateError::LockPoisoned(e.to_string()))?
.clone();
Ok(Self {
inner: Arc::new(RwLock::new(inner)),
})
}
pub fn snapshot(&self) -> StateResult<String> {
let inner = self
.inner
.read()
.map_err(|e| StateError::LockPoisoned(e.to_string()))?;
serde_json::to_string_pretty(&*inner).map_err(|e| StateError::Serialize(e.to_string()))
}
pub fn snapshot_unwrap(&self) -> String {
self.snapshot()
.unwrap_or_else(|e| panic!("SharedState::snapshot_unwrap: {e}"))
}
pub fn to_json_value(&self) -> StateResult<serde_json::Value> {
let inner = self
.inner
.read()
.map_err(|e| StateError::LockPoisoned(e.to_string()))?;
serde_json::to_value(&*inner).map_err(|e| StateError::Serialize(e.to_string()))
}
pub fn to_json(&self) -> serde_json::Value {
self.to_json_value()
.unwrap_or_else(|e| panic!("SharedState::to_json: {e}"))
}
pub fn from_json(json: &serde_json::Value) -> std::result::Result<Self, serde_json::Error> {
let inner: StateInner = serde_json::from_value(json.clone())?;
Ok(Self {
inner: Arc::new(RwLock::new(inner)),
})
}
pub fn merge(&self, other: &SharedState) -> StateResult<()> {
let other_inner = self
.inner
.read()
.map_err(|e| StateError::LockPoisoned(e.to_string()))?;
drop(other_inner);
let self_inner = self
.inner
.read()
.map_err(|e| StateError::LockPoisoned(e.to_string()))?;
drop(self_inner);
let other_lock = other
.inner
.read()
.map_err(|e| StateError::LockPoisoned(e.to_string()))?;
let mut self_lock = self
.inner
.write()
.map_err(|e| StateError::LockPoisoned(e.to_string()))?;
for (k, v) in &other_lock.values {
self_lock
.values
.entry(k.clone())
.or_insert_with(|| v.clone());
}
drop(other_lock);
drop(self_lock);
Ok(())
}
pub fn merge_overwrite(&self, other: &SharedState) -> StateResult<()> {
let other_lock = other
.inner
.read()
.map_err(|e| StateError::LockPoisoned(e.to_string()))?;
let mut self_lock = self
.inner
.write()
.map_err(|e| StateError::LockPoisoned(e.to_string()))?;
for (k, v) in &other_lock.values {
self_lock.values.insert(k.clone(), v.clone());
}
drop(other_lock);
drop(self_lock);
Ok(())
}
pub fn deep_merge(&self, other: &SharedState) -> StateResult<()> {
let other_lock = other
.inner
.read()
.map_err(|e| StateError::LockPoisoned(e.to_string()))?;
let mut self_lock = self
.inner
.write()
.map_err(|e| StateError::LockPoisoned(e.to_string()))?;
for (k, other_val) in &other_lock.values {
if let Some(self_val) = self_lock.values.get(k) {
if let (Some(self_obj), Some(other_obj)) =
(self_val.as_object(), other_val.as_object())
{
let mut merged = self_obj.clone();
deep_merge_values(&mut merged, other_obj);
self_lock.values.insert(k.clone(), Value::Object(merged));
continue;
}
}
self_lock.values.insert(k.clone(), other_val.clone());
}
drop(other_lock);
drop(self_lock);
Ok(())
}
}
impl Default for SharedState {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for SharedState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.inner.read() {
Ok(inner) => f
.debug_struct("SharedState")
.field("keys", &inner.values.keys().collect::<Vec<_>>())
.field("messages", &inner.messages.len())
.field("current_node", &inner.current_node)
.finish(),
Err(_) => f
.debug_struct("SharedState")
.field("error", &"lock poisoned")
.finish(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_set_get_typed() {
let state = SharedState::new();
state.set("count", 42i64).unwrap();
state.set("name", "echo").unwrap();
state.set("tags", vec!["a", "b"]).unwrap();
assert_eq!(state.get::<i64>("count"), Some(42));
assert_eq!(state.get::<String>("name"), Some("echo".to_string()));
assert_eq!(
state.get::<Vec<String>>("tags"),
Some(vec!["a".to_string(), "b".to_string()])
);
assert_eq!(state.get::<i64>("missing"), None);
}
#[test]
fn test_contains_remove() {
let state = SharedState::new();
state.set("x", 1).unwrap();
assert!(state.contains("x"));
assert!(!state.contains("y"));
state.remove("x");
assert!(!state.contains("x"));
}
#[test]
fn test_messages() {
let state = SharedState::new();
state
.push_message(Message::user("hello".to_string()))
.unwrap();
state
.push_message(Message::assistant("hi".to_string()))
.unwrap();
assert_eq!(state.message_count(), 2);
let msgs = state.messages();
assert_eq!(msgs[0].role, "user");
assert_eq!(msgs[1].role, "assistant");
}
#[test]
fn test_snapshot_restore() {
let state = SharedState::new();
state.set("key", "value").unwrap();
state
.push_message(Message::user("hello".to_string()))
.unwrap();
let snap = state.snapshot().unwrap();
let restored = SharedState::from_snapshot(&snap).unwrap();
assert_eq!(restored.get::<String>("key"), Some("value".to_string()));
assert_eq!(restored.message_count(), 1);
}
#[test]
fn test_merge() {
let a = SharedState::new();
a.set("x", 1).unwrap();
a.set("shared", "from_a").unwrap();
let b = SharedState::new();
b.set("y", 2).unwrap();
b.set("shared", "from_b").unwrap();
a.merge(&b).unwrap();
assert_eq!(a.get::<i64>("x"), Some(1));
assert_eq!(a.get::<i64>("y"), Some(2));
assert_eq!(a.get::<String>("shared"), Some("from_a".to_string())); }
#[test]
fn test_merge_overwrite() {
let a = SharedState::new();
a.set("shared", "from_a").unwrap();
let b = SharedState::new();
b.set("shared", "from_b").unwrap();
a.merge_overwrite(&b).unwrap();
assert_eq!(a.get::<String>("shared"), Some("from_b".to_string()));
}
#[test]
fn test_clone_shares_state() {
let state = SharedState::new();
let cloned = state.clone();
state.set("x", 42).unwrap();
assert_eq!(cloned.get::<i64>("x"), Some(42)); }
#[test]
fn test_from_values() {
let mut vals = HashMap::new();
vals.insert("key".to_string(), serde_json::json!("value"));
let state = SharedState::from_values(vals);
assert_eq!(state.get::<String>("key"), Some("value".to_string()));
}
}