potato_agent/agents/
session.rs1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6#[derive(Debug, Clone)]
9pub struct SessionState {
10 inner: Arc<RwLock<HashMap<String, Value>>>,
11}
12
13impl Default for SessionState {
14 fn default() -> Self {
15 Self::new()
16 }
17}
18
19impl SessionState {
20 pub fn new() -> Self {
21 Self {
22 inner: Arc::new(RwLock::new(HashMap::new())),
23 }
24 }
25
26 pub fn get(&self, key: &str) -> Option<Value> {
27 self.inner
28 .read()
29 .unwrap_or_else(|e| e.into_inner())
30 .get(key)
31 .cloned()
32 }
33
34 pub fn set(&self, key: impl Into<String>, value: Value) {
35 self.inner
36 .write()
37 .unwrap_or_else(|e| e.into_inner())
38 .insert(key.into(), value);
39 }
40
41 pub fn remove(&self, key: &str) -> Option<Value> {
42 self.inner
43 .write()
44 .unwrap_or_else(|e| e.into_inner())
45 .remove(key)
46 }
47
48 pub fn snapshot(&self) -> HashMap<String, Value> {
49 self.inner.read().unwrap_or_else(|e| e.into_inner()).clone()
50 }
51
52 pub fn merge(&self, other: HashMap<String, Value>) {
54 let mut lock = self.inner.write().unwrap_or_else(|e| e.into_inner());
55 for (k, v) in other {
56 lock.insert(k, v);
57 }
58 }
59
60 pub fn merge_user_data(&self, other: HashMap<String, Value>) {
64 let mut lock = self.inner.write().unwrap_or_else(|e| e.into_inner());
65 for (k, v) in other {
66 if !k.starts_with("__") {
67 lock.insert(k, v);
68 }
69 }
70 }
71
72 const ANCESTOR_KEY: &'static str = "__ancestor_ids";
75
76 pub fn push_ancestor(&self, agent_id: &str) {
77 let mut lock = self.inner.write().unwrap_or_else(|e| e.into_inner());
78 let entry = lock
79 .entry(Self::ANCESTOR_KEY.to_string())
80 .or_insert_with(|| Value::Array(vec![]));
81 if let Value::Array(arr) = entry {
82 arr.push(Value::String(agent_id.to_string()));
83 }
84 }
85
86 pub fn pop_ancestor(&self) {
87 let mut lock = self.inner.write().unwrap_or_else(|e| e.into_inner());
88 if let Some(Value::Array(arr)) = lock.get_mut(Self::ANCESTOR_KEY) {
89 arr.pop();
90 }
91 }
92
93 pub fn is_ancestor(&self, agent_id: &str) -> bool {
94 let lock = self.inner.read().unwrap_or_else(|e| e.into_inner());
95 if let Some(Value::Array(arr)) = lock.get(Self::ANCESTOR_KEY) {
96 arr.iter().any(|v| v.as_str() == Some(agent_id))
97 } else {
98 false
99 }
100 }
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct SessionSnapshot(pub HashMap<String, Value>);
106
107impl From<&SessionState> for SessionSnapshot {
108 fn from(s: &SessionState) -> Self {
109 Self(s.snapshot())
110 }
111}
112
113impl From<SessionSnapshot> for SessionState {
114 fn from(snap: SessionSnapshot) -> Self {
115 let s = Self::new();
116 s.merge(snap.0);
117 s
118 }
119}