Skip to main content

potato_agent/agents/
session.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6/// Shared mutable state across agents in a run.
7/// Clone is cheap (Arc clone). Per-request in Axum handlers.
8#[derive(Debug, Clone)]
9pub struct SessionState {
10    inner: Arc<RwLock<HashMap<String, Value>>>,
11}
12
13impl Default for SessionState {
14    fn default() -> Self {
15        Self::new()
16    }
17}
18
19impl SessionState {
20    pub fn new() -> Self {
21        Self {
22            inner: Arc::new(RwLock::new(HashMap::new())),
23        }
24    }
25
26    pub fn get(&self, key: &str) -> Option<Value> {
27        self.inner
28            .read()
29            .unwrap_or_else(|e| e.into_inner())
30            .get(key)
31            .cloned()
32    }
33
34    pub fn set(&self, key: impl Into<String>, value: Value) {
35        self.inner
36            .write()
37            .unwrap_or_else(|e| e.into_inner())
38            .insert(key.into(), value);
39    }
40
41    pub fn remove(&self, key: &str) -> Option<Value> {
42        self.inner
43            .write()
44            .unwrap_or_else(|e| e.into_inner())
45            .remove(key)
46    }
47
48    pub fn snapshot(&self) -> HashMap<String, Value> {
49        self.inner.read().unwrap_or_else(|e| e.into_inner()).clone()
50    }
51
52    /// Merge another snapshot into this session (later values win).
53    pub fn merge(&self, other: HashMap<String, Value>) {
54        let mut lock = self.inner.write().unwrap_or_else(|e| e.into_inner());
55        for (k, v) in other {
56            lock.insert(k, v);
57        }
58    }
59
60    /// Merge a child snapshot into this session, skipping `__`-prefixed system keys.
61    /// Use this when merging child-agent sessions to prevent children from overwriting
62    /// system keys such as `__ancestor_ids`.
63    pub fn merge_user_data(&self, other: HashMap<String, Value>) {
64        let mut lock = self.inner.write().unwrap_or_else(|e| e.into_inner());
65        for (k, v) in other {
66            if !k.starts_with("__") {
67                lock.insert(k, v);
68            }
69        }
70    }
71
72    // ── ancestor tracking (circular call prevention) ──────────────────────
73
74    const ANCESTOR_KEY: &'static str = "__ancestor_ids";
75
76    pub fn push_ancestor(&self, agent_id: &str) {
77        let mut lock = self.inner.write().unwrap_or_else(|e| e.into_inner());
78        let entry = lock
79            .entry(Self::ANCESTOR_KEY.to_string())
80            .or_insert_with(|| Value::Array(vec![]));
81        if let Value::Array(arr) = entry {
82            arr.push(Value::String(agent_id.to_string()));
83        }
84    }
85
86    pub fn pop_ancestor(&self) {
87        let mut lock = self.inner.write().unwrap_or_else(|e| e.into_inner());
88        if let Some(Value::Array(arr)) = lock.get_mut(Self::ANCESTOR_KEY) {
89            arr.pop();
90        }
91    }
92
93    pub fn is_ancestor(&self, agent_id: &str) -> bool {
94        let lock = self.inner.read().unwrap_or_else(|e| e.into_inner());
95        if let Some(Value::Array(arr)) = lock.get(Self::ANCESTOR_KEY) {
96            arr.iter().any(|v| v.as_str() == Some(agent_id))
97        } else {
98            false
99        }
100    }
101}
102
103/// Serializable snapshot of session state — for storage between HTTP requests.
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct SessionSnapshot(pub HashMap<String, Value>);
106
107impl From<&SessionState> for SessionSnapshot {
108    fn from(s: &SessionState) -> Self {
109        Self(s.snapshot())
110    }
111}
112
113impl From<SessionSnapshot> for SessionState {
114    fn from(snap: SessionSnapshot) -> Self {
115        let s = Self::new();
116        s.merge(snap.0);
117        s
118    }
119}