Skip to main content

macp_storage/
registry.rs

1use macp_core::session::Session;
2use std::collections::HashMap;
3use std::fs;
4use std::path::{Path, PathBuf};
5use tokio::sync::RwLock;
6
7#[derive(serde::Serialize, serde::Deserialize)]
8pub struct PersistedRoot {
9    pub uri: String,
10    pub name: String,
11}
12
13#[derive(serde::Serialize, serde::Deserialize)]
14pub struct PersistedSession {
15    #[serde(default = "default_schema_version")]
16    pub schema_version: u32,
17    pub session_id: String,
18    pub state: macp_core::session::SessionState,
19    pub ttl_expiry: i64,
20    #[serde(default)]
21    pub ttl_ms: i64,
22    pub started_at_unix_ms: i64,
23    pub resolution: Option<Vec<u8>>,
24    pub mode: String,
25    pub mode_state: Vec<u8>,
26    pub participants: Vec<String>,
27    pub seen_message_ids: Vec<String>,
28    pub intent: String,
29    pub mode_version: String,
30    pub configuration_version: String,
31    pub policy_version: String,
32    #[serde(default)]
33    pub context_id: String,
34    #[serde(default)]
35    pub extensions: HashMap<String, Vec<u8>>,
36    pub roots: Vec<PersistedRoot>,
37    pub initiator_sender: String,
38    #[serde(default)]
39    pub policy_definition: Option<macp_core::policy::PolicyDefinition>,
40    #[serde(default)]
41    pub suspended_at_ms: Option<i64>,
42    #[serde(default)]
43    pub accumulated_suspended_ms: i64,
44}
45
46fn default_schema_version() -> u32 {
47    2
48}
49
50impl From<&Session> for PersistedSession {
51    fn from(session: &Session) -> Self {
52        Self {
53            schema_version: 2,
54            session_id: session.session_id.clone(),
55            state: session.state.clone(),
56            ttl_expiry: session.ttl_expiry,
57            ttl_ms: session.ttl_ms,
58            started_at_unix_ms: session.started_at_unix_ms,
59            resolution: session.resolution.clone(),
60            mode: session.mode.clone(),
61            mode_state: session.mode_state.clone(),
62            participants: session.participants.clone(),
63            seen_message_ids: session.seen_message_ids.iter().cloned().collect(),
64            intent: session.intent.clone(),
65            mode_version: session.mode_version.clone(),
66            configuration_version: session.configuration_version.clone(),
67            policy_version: session.policy_version.clone(),
68            context_id: session.context_id.clone(),
69            extensions: session.extensions.clone(),
70            roots: session
71                .roots
72                .iter()
73                .map(|root| PersistedRoot {
74                    uri: root.uri.clone(),
75                    name: root.name.clone(),
76                })
77                .collect(),
78            initiator_sender: session.initiator_sender.clone(),
79            policy_definition: session.policy_definition.clone(),
80            suspended_at_ms: session.suspended_at_ms,
81            accumulated_suspended_ms: session.accumulated_suspended_ms,
82        }
83    }
84}
85
86impl From<PersistedSession> for Session {
87    fn from(session: PersistedSession) -> Self {
88        let ttl_ms = if session.ttl_ms > 0 {
89            session.ttl_ms
90        } else {
91            // Backward compatibility: compute from absolute timestamps
92            session
93                .ttl_expiry
94                .saturating_sub(session.started_at_unix_ms)
95        };
96        Self {
97            session_id: session.session_id,
98            state: session.state,
99            ttl_expiry: session.ttl_expiry,
100            ttl_ms,
101            started_at_unix_ms: session.started_at_unix_ms,
102            resolution: session.resolution,
103            mode: session.mode,
104            mode_state: session.mode_state,
105            participants: session.participants,
106            seen_message_ids: session.seen_message_ids.into_iter().collect(),
107            intent: session.intent,
108            mode_version: session.mode_version,
109            configuration_version: session.configuration_version,
110            policy_version: session.policy_version,
111            context_id: session.context_id,
112            extensions: session.extensions,
113            roots: session
114                .roots
115                .into_iter()
116                .map(|root| macp_pb::pb::Root {
117                    uri: root.uri,
118                    name: root.name,
119                })
120                .collect(),
121            initiator_sender: session.initiator_sender,
122            participant_message_counts: std::collections::HashMap::new(),
123            participant_last_seen: std::collections::HashMap::new(),
124            policy_definition: session.policy_definition,
125            suspended_at_ms: session.suspended_at_ms,
126            accumulated_suspended_ms: session.accumulated_suspended_ms,
127        }
128    }
129}
130
131pub struct SessionRegistry {
132    pub sessions: RwLock<HashMap<String, Session>>,
133    persistence_path: Option<PathBuf>,
134}
135
136impl Default for SessionRegistry {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142impl SessionRegistry {
143    pub fn new() -> Self {
144        Self {
145            sessions: RwLock::new(HashMap::new()),
146            persistence_path: None,
147        }
148    }
149
150    pub fn with_persistence<P: AsRef<Path>>(dir: P) -> std::io::Result<Self> {
151        let dir = dir.as_ref().to_path_buf();
152        fs::create_dir_all(&dir)?;
153        let path = dir.join("sessions.json");
154        let sessions = Self::load_sessions(&path)?;
155        Ok(Self {
156            sessions: RwLock::new(sessions),
157            persistence_path: Some(path),
158        })
159    }
160
161    fn load_sessions(path: &Path) -> std::io::Result<HashMap<String, Session>> {
162        if !path.exists() {
163            return Ok(HashMap::new());
164        }
165        let bytes = fs::read(path)?;
166        let persisted: HashMap<String, PersistedSession> = match serde_json::from_slice(&bytes) {
167            Ok(v) => v,
168            Err(e) => {
169                eprintln!("warning: failed to deserialize sessions from {}: {e}; starting with empty state", path.display());
170                HashMap::new()
171            }
172        };
173        Ok(persisted
174            .into_iter()
175            .map(|(id, session)| (id, session.into()))
176            .collect())
177    }
178
179    fn persist_map(path: &Path, sessions: &HashMap<String, Session>) -> std::io::Result<()> {
180        let persisted: HashMap<String, PersistedSession> = sessions
181            .iter()
182            .map(|(id, session)| (id.clone(), PersistedSession::from(session)))
183            .collect();
184        let bytes = serde_json::to_vec_pretty(&persisted)?;
185        let tmp_path = path.with_extension("json.tmp");
186        fs::write(&tmp_path, bytes)?;
187        fs::rename(&tmp_path, path)
188    }
189
190    pub(crate) async fn persist_locked(
191        &self,
192        sessions: &HashMap<String, Session>,
193    ) -> std::io::Result<()> {
194        if let Some(path) = &self.persistence_path {
195            Self::persist_map(path, sessions)?;
196        }
197        Ok(())
198    }
199
200    pub async fn persist_snapshot(&self) -> std::io::Result<()> {
201        let guard = self.sessions.read().await;
202        self.persist_locked(&guard).await
203    }
204
205    pub async fn get_session(&self, session_id: &str) -> Option<Session> {
206        let guard = self.sessions.read().await;
207        guard.get(session_id).cloned()
208    }
209
210    pub async fn get_all_sessions(&self) -> Vec<Session> {
211        let guard = self.sessions.read().await;
212        guard.values().cloned().collect()
213    }
214
215    pub async fn insert_recovered_session(&self, session_id: String, session: Session) {
216        let mut guard = self.sessions.write().await;
217        guard.insert(session_id, session);
218        let _ = self.persist_locked(&guard).await;
219    }
220
221    pub async fn count_open_sessions_for_initiator(&self, sender: &str) -> usize {
222        let now = chrono::Utc::now().timestamp_millis();
223        let guard = self.sessions.read().await;
224        guard
225            .values()
226            .filter(|session| {
227                session.initiator_sender == sender
228                    && session.state == macp_core::session::SessionState::Open
229                    && now <= session.ttl_expiry
230            })
231            .count()
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use macp_core::session::{Session, SessionState};
239    use std::collections::HashSet;
240    use std::time::{SystemTime, UNIX_EPOCH};
241
242    fn sample_session(id: &str) -> Session {
243        Session {
244            session_id: id.into(),
245            state: SessionState::Open,
246            ttl_expiry: 10,
247            ttl_ms: 9,
248            started_at_unix_ms: 1,
249            resolution: None,
250            mode: "macp.mode.decision.v1".into(),
251            mode_state: vec![1, 2, 3],
252            participants: vec!["alice".into()],
253            seen_message_ids: HashSet::from(["m1".into()]),
254            intent: "intent".into(),
255            mode_version: "1.0.0".into(),
256            configuration_version: "cfg".into(),
257            policy_version: "pol".into(),
258            context_id: "test-ctx".to_string(),
259            extensions: std::collections::HashMap::new(),
260            roots: vec![macp_pb::pb::Root {
261                uri: "root://1".into(),
262                name: "r1".into(),
263            }],
264            initiator_sender: "alice".into(),
265            participant_message_counts: std::collections::HashMap::new(),
266            participant_last_seen: std::collections::HashMap::new(),
267            policy_definition: None,
268            suspended_at_ms: None,
269            accumulated_suspended_ms: 0,
270        }
271    }
272
273    #[tokio::test]
274    async fn expired_sessions_not_counted_against_limit() {
275        let registry = SessionRegistry::new();
276        let now = chrono::Utc::now().timestamp_millis();
277        // Insert a session with TTL already expired
278        let mut expired = sample_session("expired-s1");
279        expired.initiator_sender = "agent://alice".into();
280        expired.ttl_expiry = now - 1000; // expired 1 second ago
281        expired.state = SessionState::Open; // still Open but TTL is past
282        registry
283            .insert_recovered_session("expired-s1".into(), expired)
284            .await;
285
286        // Should not count the expired-but-open session
287        let count = registry
288            .count_open_sessions_for_initiator("agent://alice")
289            .await;
290        assert_eq!(count, 0);
291
292        // Insert a session that is still valid
293        let mut active = sample_session("active-s1");
294        active.initiator_sender = "agent://alice".into();
295        active.ttl_expiry = now + 60_000; // expires in 60s
296        active.state = SessionState::Open;
297        registry
298            .insert_recovered_session("active-s1".into(), active)
299            .await;
300
301        let count = registry
302            .count_open_sessions_for_initiator("agent://alice")
303            .await;
304        assert_eq!(count, 1);
305    }
306
307    #[tokio::test]
308    async fn persistent_registry_round_trip() {
309        let base = std::env::temp_dir().join(format!(
310            "macp-registry-test-{}",
311            SystemTime::now()
312                .duration_since(UNIX_EPOCH)
313                .unwrap()
314                .as_nanos()
315        ));
316
317        let registry = SessionRegistry::with_persistence(&base).unwrap();
318        registry
319            .insert_recovered_session("s1".into(), sample_session("s1"))
320            .await;
321
322        let reopened = SessionRegistry::with_persistence(&base).unwrap();
323        let session = reopened.get_session("s1").await.unwrap();
324        assert_eq!(session.mode, "macp.mode.decision.v1");
325        assert_eq!(session.mode_version, "1.0.0");
326        assert!(session.seen_message_ids.contains("m1"));
327    }
328}