construct/mcp_server/
session.rs1use chrono::{DateTime, Utc};
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26use std::path::PathBuf;
27use std::sync::Arc;
28use tokio::sync::{RwLock, broadcast};
29use uuid::Uuid;
30
31const BROADCAST_CAPACITY: usize = 64;
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ProgressEvent {
43 pub token: u64,
44 pub progress: u64,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub total: Option<u64>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 pub message: Option<String>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 pub tool: Option<String>,
51 pub timestamp: String,
53}
54
55impl ProgressEvent {
56 pub fn new(
58 token: u64,
59 progress: u64,
60 total: Option<u64>,
61 message: Option<String>,
62 tool: Option<String>,
63 ) -> Self {
64 Self {
65 token,
66 progress,
67 total,
68 message,
69 tool,
70 timestamp: Utc::now().to_rfc3339(),
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
76pub struct SessionState {
77 pub id: String,
78 pub token: String,
79 pub cwd: PathBuf,
80 pub label: Option<String>,
81 pub created_at: DateTime<Utc>,
82 pub events: broadcast::Sender<ProgressEvent>,
84}
85
86#[derive(Debug, Default)]
87pub struct SessionStore {
88 inner: RwLock<HashMap<String, SessionState>>,
90}
91
92impl SessionStore {
93 pub fn new() -> Self {
94 Self::default()
95 }
96
97 pub async fn create(&self, cwd: PathBuf, label: Option<String>) -> SessionState {
98 let id = Uuid::new_v4().to_string();
99 let token = Uuid::new_v4().simple().to_string();
100 let (events_tx, _) = broadcast::channel(BROADCAST_CAPACITY);
101 let state = SessionState {
102 id: id.clone(),
103 token,
104 cwd,
105 label,
106 created_at: Utc::now(),
107 events: events_tx,
108 };
109 self.inner.write().await.insert(id, state.clone());
110 state
111 }
112
113 pub async fn authenticate(&self, session_id: &str, token: &str) -> Option<SessionState> {
115 let guard = self.inner.read().await;
116 guard
117 .get(session_id)
118 .filter(|s| constant_time_eq(s.token.as_bytes(), token.as_bytes()))
119 .cloned()
120 }
121
122 pub async fn event_sender(&self, session_id: &str) -> Option<broadcast::Sender<ProgressEvent>> {
126 let guard = self.inner.read().await;
127 guard.get(session_id).map(|s| s.events.clone())
128 }
129
130 pub async fn len(&self) -> usize {
131 self.inner.read().await.len()
132 }
133}
134
135pub type SharedSessionStore = Arc<SessionStore>;
136
137fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
138 if a.len() != b.len() {
139 return false;
140 }
141 let mut diff: u8 = 0;
142 for (x, y) in a.iter().zip(b.iter()) {
143 diff |= x ^ y;
144 }
145 diff == 0
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[tokio::test]
153 async fn create_and_authenticate_happy_path() {
154 let store = SessionStore::new();
155 let sess = store
156 .create(PathBuf::from("/tmp"), Some("test".into()))
157 .await;
158 assert_eq!(store.len().await, 1);
159 let found = store.authenticate(&sess.id, &sess.token).await;
160 assert!(found.is_some());
161 assert_eq!(found.unwrap().cwd, PathBuf::from("/tmp"));
162 }
163
164 #[tokio::test]
165 async fn authenticate_rejects_wrong_token() {
166 let store = SessionStore::new();
167 let sess = store.create(PathBuf::from("/tmp"), None).await;
168 let found = store.authenticate(&sess.id, "not-the-token").await;
169 assert!(found.is_none());
170 }
171
172 #[tokio::test]
173 async fn authenticate_rejects_unknown_session() {
174 let store = SessionStore::new();
175 let _ = store.create(PathBuf::from("/tmp"), None).await;
176 let found = store.authenticate("not-a-session-id", "anything").await;
177 assert!(found.is_none());
178 }
179
180 #[tokio::test]
181 async fn session_broadcast_delivers_published_event() {
182 let store = SessionStore::new();
183 let sess = store.create(PathBuf::from("/tmp"), None).await;
184
185 let mut rx = sess.events.subscribe();
188
189 let ev = ProgressEvent::new(7, 1, Some(3), Some("hello".into()), Some("notion".into()));
190 sess.events.send(ev.clone()).expect("send ok");
191
192 let got = rx.recv().await.expect("recv ok");
193 assert_eq!(got.token, 7);
194 assert_eq!(got.progress, 1);
195 assert_eq!(got.total, Some(3));
196 assert_eq!(got.message.as_deref(), Some("hello"));
197 assert_eq!(got.tool.as_deref(), Some("notion"));
198 }
199
200 #[tokio::test]
201 async fn event_sender_lookup_returns_same_channel() {
202 let store = SessionStore::new();
203 let sess = store.create(PathBuf::from("/tmp"), None).await;
204 let tx = store.event_sender(&sess.id).await.expect("sender present");
205 let mut rx = tx.subscribe();
206 let ev = ProgressEvent::new(1, 1, None, None, None);
207 sess.events.send(ev).expect("send ok");
208 let got = rx.recv().await.expect("recv ok");
209 assert_eq!(got.token, 1);
210 }
211
212 #[tokio::test]
213 async fn event_sender_unknown_session_returns_none() {
214 let store = SessionStore::new();
215 assert!(store.event_sender("nope").await.is_none());
216 }
217
218 #[tokio::test]
219 async fn broadcast_with_no_subscribers_is_not_an_error_to_caller() {
220 let store = SessionStore::new();
223 let sess = store.create(PathBuf::from("/tmp"), None).await;
224 let res = sess.events.send(ProgressEvent::new(0, 0, None, None, None));
225 assert!(res.is_err());
228 }
229}