1use macp_core::session::Session;
2use std::collections::HashMap;
3use std::fs;
4use std::path::{Path, PathBuf};
5use tokio::sync::RwLock;
6
7#[derive(serde::Serialize, serde::Deserialize)]
8pub struct PersistedRoot {
9 pub uri: String,
10 pub name: String,
11}
12
13#[derive(serde::Serialize, serde::Deserialize)]
14pub struct PersistedSession {
15 #[serde(default = "default_schema_version")]
16 pub schema_version: u32,
17 pub session_id: String,
18 pub state: macp_core::session::SessionState,
19 pub ttl_expiry: i64,
20 #[serde(default)]
21 pub ttl_ms: i64,
22 pub started_at_unix_ms: i64,
23 pub resolution: Option<Vec<u8>>,
24 pub mode: String,
25 pub mode_state: Vec<u8>,
26 pub participants: Vec<String>,
27 pub seen_message_ids: Vec<String>,
28 pub intent: String,
29 pub mode_version: String,
30 pub configuration_version: String,
31 pub policy_version: String,
32 #[serde(default)]
33 pub context_id: String,
34 #[serde(default)]
35 pub extensions: HashMap<String, Vec<u8>>,
36 pub roots: Vec<PersistedRoot>,
37 pub initiator_sender: String,
38 #[serde(default)]
39 pub policy_definition: Option<macp_core::policy::PolicyDefinition>,
40 #[serde(default)]
41 pub suspended_at_ms: Option<i64>,
42 #[serde(default)]
43 pub accumulated_suspended_ms: i64,
44}
45
46fn default_schema_version() -> u32 {
47 2
48}
49
50impl From<&Session> for PersistedSession {
51 fn from(session: &Session) -> Self {
52 Self {
53 schema_version: 2,
54 session_id: session.session_id.clone(),
55 state: session.state.clone(),
56 ttl_expiry: session.ttl_expiry,
57 ttl_ms: session.ttl_ms,
58 started_at_unix_ms: session.started_at_unix_ms,
59 resolution: session.resolution.clone(),
60 mode: session.mode.clone(),
61 mode_state: session.mode_state.clone(),
62 participants: session.participants.clone(),
63 seen_message_ids: session.seen_message_ids.iter().cloned().collect(),
64 intent: session.intent.clone(),
65 mode_version: session.mode_version.clone(),
66 configuration_version: session.configuration_version.clone(),
67 policy_version: session.policy_version.clone(),
68 context_id: session.context_id.clone(),
69 extensions: session.extensions.clone(),
70 roots: session
71 .roots
72 .iter()
73 .map(|root| PersistedRoot {
74 uri: root.uri.clone(),
75 name: root.name.clone(),
76 })
77 .collect(),
78 initiator_sender: session.initiator_sender.clone(),
79 policy_definition: session.policy_definition.clone(),
80 suspended_at_ms: session.suspended_at_ms,
81 accumulated_suspended_ms: session.accumulated_suspended_ms,
82 }
83 }
84}
85
86impl From<PersistedSession> for Session {
87 fn from(session: PersistedSession) -> Self {
88 let ttl_ms = if session.ttl_ms > 0 {
89 session.ttl_ms
90 } else {
91 session
93 .ttl_expiry
94 .saturating_sub(session.started_at_unix_ms)
95 };
96 Self {
97 session_id: session.session_id,
98 state: session.state,
99 ttl_expiry: session.ttl_expiry,
100 ttl_ms,
101 started_at_unix_ms: session.started_at_unix_ms,
102 resolution: session.resolution,
103 mode: session.mode,
104 mode_state: session.mode_state,
105 participants: session.participants,
106 seen_message_ids: session.seen_message_ids.into_iter().collect(),
107 intent: session.intent,
108 mode_version: session.mode_version,
109 configuration_version: session.configuration_version,
110 policy_version: session.policy_version,
111 context_id: session.context_id,
112 extensions: session.extensions,
113 roots: session
114 .roots
115 .into_iter()
116 .map(|root| macp_pb::pb::Root {
117 uri: root.uri,
118 name: root.name,
119 })
120 .collect(),
121 initiator_sender: session.initiator_sender,
122 participant_message_counts: std::collections::HashMap::new(),
123 participant_last_seen: std::collections::HashMap::new(),
124 policy_definition: session.policy_definition,
125 suspended_at_ms: session.suspended_at_ms,
126 accumulated_suspended_ms: session.accumulated_suspended_ms,
127 }
128 }
129}
130
131pub struct SessionRegistry {
132 pub sessions: RwLock<HashMap<String, Session>>,
133 persistence_path: Option<PathBuf>,
134}
135
136impl Default for SessionRegistry {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142impl SessionRegistry {
143 pub fn new() -> Self {
144 Self {
145 sessions: RwLock::new(HashMap::new()),
146 persistence_path: None,
147 }
148 }
149
150 pub fn with_persistence<P: AsRef<Path>>(dir: P) -> std::io::Result<Self> {
151 let dir = dir.as_ref().to_path_buf();
152 fs::create_dir_all(&dir)?;
153 let path = dir.join("sessions.json");
154 let sessions = Self::load_sessions(&path)?;
155 Ok(Self {
156 sessions: RwLock::new(sessions),
157 persistence_path: Some(path),
158 })
159 }
160
161 fn load_sessions(path: &Path) -> std::io::Result<HashMap<String, Session>> {
162 if !path.exists() {
163 return Ok(HashMap::new());
164 }
165 let bytes = fs::read(path)?;
166 let persisted: HashMap<String, PersistedSession> = match serde_json::from_slice(&bytes) {
167 Ok(v) => v,
168 Err(e) => {
169 eprintln!("warning: failed to deserialize sessions from {}: {e}; starting with empty state", path.display());
170 HashMap::new()
171 }
172 };
173 Ok(persisted
174 .into_iter()
175 .map(|(id, session)| (id, session.into()))
176 .collect())
177 }
178
179 fn persist_map(path: &Path, sessions: &HashMap<String, Session>) -> std::io::Result<()> {
180 let persisted: HashMap<String, PersistedSession> = sessions
181 .iter()
182 .map(|(id, session)| (id.clone(), PersistedSession::from(session)))
183 .collect();
184 let bytes = serde_json::to_vec_pretty(&persisted)?;
185 let tmp_path = path.with_extension("json.tmp");
186 fs::write(&tmp_path, bytes)?;
187 fs::rename(&tmp_path, path)
188 }
189
190 pub(crate) async fn persist_locked(
191 &self,
192 sessions: &HashMap<String, Session>,
193 ) -> std::io::Result<()> {
194 if let Some(path) = &self.persistence_path {
195 Self::persist_map(path, sessions)?;
196 }
197 Ok(())
198 }
199
200 pub async fn persist_snapshot(&self) -> std::io::Result<()> {
201 let guard = self.sessions.read().await;
202 self.persist_locked(&guard).await
203 }
204
205 pub async fn get_session(&self, session_id: &str) -> Option<Session> {
206 let guard = self.sessions.read().await;
207 guard.get(session_id).cloned()
208 }
209
210 pub async fn get_all_sessions(&self) -> Vec<Session> {
211 let guard = self.sessions.read().await;
212 guard.values().cloned().collect()
213 }
214
215 pub async fn insert_recovered_session(&self, session_id: String, session: Session) {
216 let mut guard = self.sessions.write().await;
217 guard.insert(session_id, session);
218 let _ = self.persist_locked(&guard).await;
219 }
220
221 pub async fn count_open_sessions_for_initiator(&self, sender: &str) -> usize {
222 let now = chrono::Utc::now().timestamp_millis();
223 let guard = self.sessions.read().await;
224 guard
225 .values()
226 .filter(|session| {
227 session.initiator_sender == sender
228 && session.state == macp_core::session::SessionState::Open
229 && now <= session.ttl_expiry
230 })
231 .count()
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use macp_core::session::{Session, SessionState};
239 use std::collections::HashSet;
240 use std::time::{SystemTime, UNIX_EPOCH};
241
242 fn sample_session(id: &str) -> Session {
243 Session {
244 session_id: id.into(),
245 state: SessionState::Open,
246 ttl_expiry: 10,
247 ttl_ms: 9,
248 started_at_unix_ms: 1,
249 resolution: None,
250 mode: "macp.mode.decision.v1".into(),
251 mode_state: vec![1, 2, 3],
252 participants: vec!["alice".into()],
253 seen_message_ids: HashSet::from(["m1".into()]),
254 intent: "intent".into(),
255 mode_version: "1.0.0".into(),
256 configuration_version: "cfg".into(),
257 policy_version: "pol".into(),
258 context_id: "test-ctx".to_string(),
259 extensions: std::collections::HashMap::new(),
260 roots: vec![macp_pb::pb::Root {
261 uri: "root://1".into(),
262 name: "r1".into(),
263 }],
264 initiator_sender: "alice".into(),
265 participant_message_counts: std::collections::HashMap::new(),
266 participant_last_seen: std::collections::HashMap::new(),
267 policy_definition: None,
268 suspended_at_ms: None,
269 accumulated_suspended_ms: 0,
270 }
271 }
272
273 #[tokio::test]
274 async fn expired_sessions_not_counted_against_limit() {
275 let registry = SessionRegistry::new();
276 let now = chrono::Utc::now().timestamp_millis();
277 let mut expired = sample_session("expired-s1");
279 expired.initiator_sender = "agent://alice".into();
280 expired.ttl_expiry = now - 1000; expired.state = SessionState::Open; registry
283 .insert_recovered_session("expired-s1".into(), expired)
284 .await;
285
286 let count = registry
288 .count_open_sessions_for_initiator("agent://alice")
289 .await;
290 assert_eq!(count, 0);
291
292 let mut active = sample_session("active-s1");
294 active.initiator_sender = "agent://alice".into();
295 active.ttl_expiry = now + 60_000; active.state = SessionState::Open;
297 registry
298 .insert_recovered_session("active-s1".into(), active)
299 .await;
300
301 let count = registry
302 .count_open_sessions_for_initiator("agent://alice")
303 .await;
304 assert_eq!(count, 1);
305 }
306
307 #[tokio::test]
308 async fn persistent_registry_round_trip() {
309 let base = std::env::temp_dir().join(format!(
310 "macp-registry-test-{}",
311 SystemTime::now()
312 .duration_since(UNIX_EPOCH)
313 .unwrap()
314 .as_nanos()
315 ));
316
317 let registry = SessionRegistry::with_persistence(&base).unwrap();
318 registry
319 .insert_recovered_session("s1".into(), sample_session("s1"))
320 .await;
321
322 let reopened = SessionRegistry::with_persistence(&base).unwrap();
323 let session = reopened.get_session("s1").await.unwrap();
324 assert_eq!(session.mode, "macp.mode.decision.v1");
325 assert_eq!(session.mode_version, "1.0.0");
326 assert!(session.seen_message_ids.contains("m1"));
327 }
328}