Skip to main content

construct/mcp_server/
session.rs

1//! Session state for the MCP daemon.
2//!
3//! Each connected client (one external CLI process) creates a session via
4//! `POST /session`, receives a `{ session_id, token }` pair, and attaches
5//! both headers on every JSON-RPC call. Sessions hold their own cwd and
6//! some future-proof scaffolding (allowed-tools filter, created-at stamp).
7//! Storage is a `tokio::sync::RwLock<HashMap<...>>` — Send+Sync, no deps.
8//!
9//! ## Session-wide progress broadcast (M4)
10//!
11//! In addition to the per-request SSE stream that `/mcp` tools/call uses to
12//! ship progress events back to the caller, we also publish each event onto
13//! a per-session `tokio::sync::broadcast` channel. Any subscriber holding
14//! the session token can tap that stream via `GET /session/<id>/events` to
15//! observe *every* Construct tool's progress for that session in real time —
16//! which is how the V2 Code tab surfaces "what Construct is doing right now"
17//! while an external CLI is mid tools/call.
18//!
19//! Broadcast capacity is small (64). Slow consumers simply miss frames
20//! (broadcast::Receiver returns `Lagged`); progress events are advisory,
21//! never load-bearing for correctness.
22
23use chrono::{DateTime, Utc};
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26use std::path::PathBuf;
27use std::sync::Arc;
28use tokio::sync::{RwLock, broadcast};
29use uuid::Uuid;
30
31/// Capacity of the per-session broadcast channel. 64 is enough for a burst of
32/// progress events without filling memory; slow subscribers will see `Lagged`
33/// and simply skip frames (acceptable — progress is advisory, not a log).
34const BROADCAST_CAPACITY: usize = 64;
35
36/// Session-wide progress event published to any subscriber of
37/// `/session/<id>/events`. Mirrors the per-request `notifications/progress`
38/// payload with additional `tool` + `timestamp` fields so subscribers can
39/// render "Notion — 4/10 at 10:20:33" without having to correlate tokens
40/// back to the originating request.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ProgressEvent {
43    pub token: u64,
44    pub progress: u64,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub total: Option<u64>,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub message: Option<String>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub tool: Option<String>,
51    /// RFC3339 timestamp (UTC) — frontend renders this as "Ns ago".
52    pub timestamp: String,
53}
54
55impl ProgressEvent {
56    /// Convenience constructor using the current wall clock.
57    pub fn new(
58        token: u64,
59        progress: u64,
60        total: Option<u64>,
61        message: Option<String>,
62        tool: Option<String>,
63    ) -> Self {
64        Self {
65            token,
66            progress,
67            total,
68            message,
69            tool,
70            timestamp: Utc::now().to_rfc3339(),
71        }
72    }
73}
74
75#[derive(Debug, Clone)]
76pub struct SessionState {
77    pub id: String,
78    pub token: String,
79    pub cwd: PathBuf,
80    pub label: Option<String>,
81    pub created_at: DateTime<Utc>,
82    /// Sender half of the per-session progress broadcast. Clone to subscribe.
83    pub events: broadcast::Sender<ProgressEvent>,
84}
85
86#[derive(Debug, Default)]
87pub struct SessionStore {
88    // session_id -> SessionState
89    inner: RwLock<HashMap<String, SessionState>>,
90}
91
92impl SessionStore {
93    pub fn new() -> Self {
94        Self::default()
95    }
96
97    pub async fn create(&self, cwd: PathBuf, label: Option<String>) -> SessionState {
98        let id = Uuid::new_v4().to_string();
99        let token = Uuid::new_v4().simple().to_string();
100        let (events_tx, _) = broadcast::channel(BROADCAST_CAPACITY);
101        let state = SessionState {
102            id: id.clone(),
103            token,
104            cwd,
105            label,
106            created_at: Utc::now(),
107            events: events_tx,
108        };
109        self.inner.write().await.insert(id, state.clone());
110        state
111    }
112
113    /// Return the session iff the `(session_id, token)` pair matches one on file.
114    pub async fn authenticate(&self, session_id: &str, token: &str) -> Option<SessionState> {
115        let guard = self.inner.read().await;
116        guard
117            .get(session_id)
118            .filter(|s| constant_time_eq(s.token.as_bytes(), token.as_bytes()))
119            .cloned()
120    }
121
122    /// Look up a session's broadcast sender by id (no auth). Used by the
123    /// `/session/<id>/events` handler after it has independently verified
124    /// the bearer token via `authenticate`.
125    pub async fn event_sender(&self, session_id: &str) -> Option<broadcast::Sender<ProgressEvent>> {
126        let guard = self.inner.read().await;
127        guard.get(session_id).map(|s| s.events.clone())
128    }
129
130    pub async fn len(&self) -> usize {
131        self.inner.read().await.len()
132    }
133}
134
135pub type SharedSessionStore = Arc<SessionStore>;
136
137fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
138    if a.len() != b.len() {
139        return false;
140    }
141    let mut diff: u8 = 0;
142    for (x, y) in a.iter().zip(b.iter()) {
143        diff |= x ^ y;
144    }
145    diff == 0
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[tokio::test]
153    async fn create_and_authenticate_happy_path() {
154        let store = SessionStore::new();
155        let sess = store
156            .create(PathBuf::from("/tmp"), Some("test".into()))
157            .await;
158        assert_eq!(store.len().await, 1);
159        let found = store.authenticate(&sess.id, &sess.token).await;
160        assert!(found.is_some());
161        assert_eq!(found.unwrap().cwd, PathBuf::from("/tmp"));
162    }
163
164    #[tokio::test]
165    async fn authenticate_rejects_wrong_token() {
166        let store = SessionStore::new();
167        let sess = store.create(PathBuf::from("/tmp"), None).await;
168        let found = store.authenticate(&sess.id, "not-the-token").await;
169        assert!(found.is_none());
170    }
171
172    #[tokio::test]
173    async fn authenticate_rejects_unknown_session() {
174        let store = SessionStore::new();
175        let _ = store.create(PathBuf::from("/tmp"), None).await;
176        let found = store.authenticate("not-a-session-id", "anything").await;
177        assert!(found.is_none());
178    }
179
180    #[tokio::test]
181    async fn session_broadcast_delivers_published_event() {
182        let store = SessionStore::new();
183        let sess = store.create(PathBuf::from("/tmp"), None).await;
184
185        // Subscribe BEFORE sending — broadcast drops messages with no live
186        // receivers.
187        let mut rx = sess.events.subscribe();
188
189        let ev = ProgressEvent::new(7, 1, Some(3), Some("hello".into()), Some("notion".into()));
190        sess.events.send(ev.clone()).expect("send ok");
191
192        let got = rx.recv().await.expect("recv ok");
193        assert_eq!(got.token, 7);
194        assert_eq!(got.progress, 1);
195        assert_eq!(got.total, Some(3));
196        assert_eq!(got.message.as_deref(), Some("hello"));
197        assert_eq!(got.tool.as_deref(), Some("notion"));
198    }
199
200    #[tokio::test]
201    async fn event_sender_lookup_returns_same_channel() {
202        let store = SessionStore::new();
203        let sess = store.create(PathBuf::from("/tmp"), None).await;
204        let tx = store.event_sender(&sess.id).await.expect("sender present");
205        let mut rx = tx.subscribe();
206        let ev = ProgressEvent::new(1, 1, None, None, None);
207        sess.events.send(ev).expect("send ok");
208        let got = rx.recv().await.expect("recv ok");
209        assert_eq!(got.token, 1);
210    }
211
212    #[tokio::test]
213    async fn event_sender_unknown_session_returns_none() {
214        let store = SessionStore::new();
215        assert!(store.event_sender("nope").await.is_none());
216    }
217
218    #[tokio::test]
219    async fn broadcast_with_no_subscribers_is_not_an_error_to_caller() {
220        // We surface this as "send may return Err, caller ignores it". This
221        // test documents the expected shape without treating it as fatal.
222        let store = SessionStore::new();
223        let sess = store.create(PathBuf::from("/tmp"), None).await;
224        let res = sess.events.send(ProgressEvent::new(0, 0, None, None, None));
225        // No subscribers → send returns Err(SendError). That's fine; the
226        // daemon's progress sink ignores this return value by design.
227        assert!(res.is_err());
228    }
229}