Skip to main content

objectiveai_mcp_proxy/
session_manager.rs

1//! Session registry.
2//!
3//! Maps session ids to [`Session`]s. A proxy session id is the base62
4//! encoding of an authenticated-encrypted, versioned envelope wrapping a
5//! JSON-serialized `IndexMap<upstream_url, IndexMap<header_name,
6//! header_value>>`. Each upstream's value is the full set of HTTP
7//! headers needed to reconnect: `Mcp-Session-Id`, `Authorization`, plus
8//! any custom headers the original initialize request supplied.
9//!
10//! Stable encoding: URLs sort alphabetically; headers within each
11//! per-URL map sort alphabetically, AND the AEAD nonce is derived
12//! deterministically from a BLAKE3 keyed hash of the canonical
13//! plaintext. So two requests with the same `{url → {header → value}}`
14//! content always encode to the *same* base62 id, byte-for-byte. That
15//! lets `handle_initialize`'s alive-in-memory branch hand the original
16//! id straight back to the caller — re-minting was previously producing
17//! a fresh ciphertext (random nonce) that didn't match any key in
18//! `state.sessions`, so the agent's next POST 404'd. Same payload now
19//! always lives at the same id. Authentication tag covers the version
20//! byte + nonce + ciphertext, so any tampering produces a decryption
21//! failure.
22//!
23//! Wire format (pre-base62):
24//! ```text
25//! [ 1B version (0x01) | 24B XChaCha20 nonce | ciphertext... | 16B Poly1305 tag ]
26//! ```
27//!
28//! Encryption uses one 256-bit key threaded in via
29//! [`SessionManager::new`]. Operators rotate by setting the new key in
30//! `MCP_ENCRYPTION_KEY` and restarting the proxy — every outstanding
31//! session id minted under the old key becomes invalid (a 401 on
32//! resume), which forces clients to re-initialize.
33//!
34//! All per-session dispatch (list, call, read) lives on [`Session`]
35//! itself; this file only cares about computing/minting ids, packing
36//! connections + their canonical headers into a [`Session`], and
37//! looking sessions back up.
38
39use std::sync::Arc;
40
41use base64::Engine;
42use chacha20poly1305::aead::{Aead, KeyInit};
43use chacha20poly1305::{Key, XChaCha20Poly1305, XNonce};
44use dashmap::DashMap;
45use indexmap::IndexMap;
46use objectiveai_sdk::mcp::Connection;
47use rand::RngCore;
48
49use crate::session::Session;
50
51/// Per-session encoded payload: `URL → header_map`. The header_map is
52/// the full set of HTTP headers used to reconnect this upstream
53/// (`Mcp-Session-Id`, `Authorization`, custom `X-*`). The session id is
54/// uniform with every other header — there's no separate session-id
55/// field. URLs sort alphabetically when encoding for stable ids; the
56/// per-URL header map sorts the same way.
57pub type SessionPayload = IndexMap<String, IndexMap<String, String>>;
58
59/// Current envelope version byte. Bumping this lets future shape
60/// changes be distinguished from old ids that happen to decrypt under
61/// the same key set. Decoders that see an unrecognized version return
62/// `None`.
63const VERSION: u8 = 0x01;
64const NONCE_LEN: usize = 24; // XChaCha20-Poly1305 nonce
65const TAG_LEN: usize = 16; // Poly1305 tag (handled internally by `Aead::encrypt`)
66
67/// Maps a session id to its [`Session`] state. Owns the encryption
68/// key used for minting and decoding ids.
69#[derive(Debug)]
70pub struct SessionManager {
71    sessions: DashMap<String, Arc<Session>>,
72    /// 256-bit AEAD key. Sessions minted under one key cannot be
73    /// decrypted by another — to rotate, set a new key on the
74    /// proxy and restart it; outstanding ids become 401s.
75    key: [u8; 32],
76}
77
78impl SessionManager {
79    pub fn new(key: [u8; 32]) -> Self {
80        Self {
81            sessions: DashMap::new(),
82            key,
83        }
84    }
85
86    /// Build a manager with a fresh random 256-bit key. Sessions
87    /// minted by the resulting manager only decode within the same
88    /// process — useful for tests and for operators who haven't yet
89    /// configured `MCP_ENCRYPTION_KEY`.
90    pub fn with_ephemeral_key() -> Self {
91        let mut key = [0u8; 32];
92        rand::rng().fill_bytes(&mut key);
93        Self::new(key)
94    }
95
96    /// Register a session whose id is computed from the per-upstream
97    /// header set. `connections_with_headers` carries each upstream's
98    /// live `Connection` plus the canonical header map that was used
99    /// to open it — `extra_headers` ∪ `Authorization` (if present)
100    /// ∪ `Mcp-Session-Id` (always — that's the upstream sid the
101    /// proxy must replay on resume).
102    ///
103    /// Returns the encoded session id. If the same upstream set is
104    /// re-registered with byte-identical headers, the returned id is
105    /// byte-identical too (modulo the random AEAD nonce, which makes
106    /// the ciphertext different each time — intentional, prevents
107    /// payload-recognition attacks).
108    pub fn add(
109        &self,
110        connections_with_headers: Vec<(Connection, IndexMap<String, String>)>,
111    ) -> String {
112        let payload = build_payload(&connections_with_headers);
113        let id = encrypt_and_encode(&payload, &self.key);
114        let connections: Vec<Connection> =
115            connections_with_headers.into_iter().map(|(c, _)| c).collect();
116        let by_name = build_by_name_map(connections);
117        self.sessions
118            .insert(id.clone(), Arc::new(Session::new(by_name, payload)));
119        id
120    }
121
122    /// Cheap clone-out of a [`Session`] — never holds a DashMap guard
123    /// across the await boundary.
124    pub fn get(&self, session_id: &str) -> Option<Arc<Session>> {
125        self.sessions.get(session_id).map(|e| e.value().clone())
126    }
127
128    /// Remove a session from the registry. Returns `Some(_)` if a session
129    /// was present, `None` if the id was unknown.
130    ///
131    /// Once every `Arc<Session>` to the removed session has dropped, the
132    /// session's `IndexMap<String, Connection>` drops, every `Connection`'s
133    /// `Drop` fires its upstream's wakeup signal, and each upstream's
134    /// listener task wakes to re-check liveness. The listener sees
135    /// `Arc::strong_count == 1` (only itself) and exits, which drops the
136    /// inner state and closes the upstream HTTP session.
137    pub fn remove(&self, session_id: &str) -> Option<Arc<Session>> {
138        self.sessions.remove(session_id).map(|(_, session)| session)
139    }
140
141    /// Decrypt an incoming session id back into the URL → header_map
142    /// payload it encodes. `None` on any decode failure (bad base62,
143    /// unknown version, AEAD failure, bad JSON, wrong shape).
144    pub fn decode_session_id(&self, id: &str) -> Option<SessionPayload> {
145        decode_with_key(id, &self.key)
146    }
147
148    /// Re-mint the encoded id for a payload that's already canonical —
149    /// used by the alive-in-memory branch in `handle_initialize` to
150    /// hand back the same id the client sent (the encrypt step
151    /// produces a different ciphertext each call due to the random
152    /// nonce, so technically the caller will see a fresh id, not the
153    /// byte-equal old one; the new id decrypts to the same payload
154    /// either way).
155    pub fn mint_id(&self, payload: &SessionPayload) -> String {
156        encrypt_and_encode(payload, &self.key)
157    }
158}
159
160/// Build a canonical (url-sorted, header-sorted) `SessionPayload`
161/// from a list of `(Connection, raw_header_map)` pairs.
162///
163/// The raw header map is normalized:
164///   - keys lowercased? **No** — HTTP headers are case-insensitive on
165///     the wire but we keep the casing the upstream sees. Sorting is
166///     done case-sensitively on the bytes; deterministic regardless.
167///   - sorted alphabetically.
168fn build_payload(
169    pairs: &[(Connection, IndexMap<String, String>)],
170) -> SessionPayload {
171    // Collect (url, sorted headers) pairs, then sort by URL.
172    let mut url_entries: Vec<(String, IndexMap<String, String>)> = pairs
173        .iter()
174        .map(|(c, headers)| {
175            let mut sorted: Vec<(&str, &str)> = headers
176                .iter()
177                .map(|(k, v)| (k.as_str(), v.as_str()))
178                .collect();
179            sorted.sort_by(|a, b| a.0.cmp(b.0));
180            let inner: IndexMap<String, String> = sorted
181                .into_iter()
182                .map(|(k, v)| (k.to_string(), v.to_string()))
183                .collect();
184            (c.url.clone(), inner)
185        })
186        .collect();
187    url_entries.sort_by(|a, b| a.0.cmp(&b.0));
188
189    let mut payload: SessionPayload = IndexMap::with_capacity(url_entries.len());
190    for (url, headers) in url_entries {
191        payload.insert(url, headers);
192    }
193    payload
194}
195
196/// JSON-serialize the payload, AEAD-encrypt with a *deterministic*
197/// nonce derived from a BLAKE3 keyed hash of the plaintext, prepend
198/// version + nonce, base62-encode the whole envelope.
199///
200/// Why deterministic: `handle_initialize`'s alive-in-memory branch
201/// needs to mint the same id the caller already holds, so the id
202/// remains a key in `state.sessions`. With a random nonce we'd
203/// generate a different ciphertext for the same payload every time.
204///
205/// Safety of nonce reuse: AEAD only breaks under nonce reuse when the
206/// SAME nonce is paired with TWO DIFFERENT plaintexts. Here the nonce
207/// is a function of the plaintext (and key), so distinct plaintexts
208/// get distinct nonces; identical plaintexts get identical nonces and
209/// identical ciphertexts, which is exactly what we want.
210fn encrypt_and_encode(payload: &SessionPayload, key: &[u8; 32]) -> String {
211    let plaintext =
212        serde_json::to_vec(payload).expect("SessionPayload serializes");
213
214    // Derive the 24-byte XChaCha20 nonce from BLAKE3(key, plaintext).
215    // BLAKE3's keyed hash is a PRF, so this is indistinguishable from
216    // random for any attacker who doesn't know `key`, but it's stable
217    // for the (key, plaintext) pair.
218    let mut hasher = blake3::Hasher::new_keyed(key);
219    hasher.update(&plaintext);
220    let mut nonce_bytes = [0u8; NONCE_LEN];
221    nonce_bytes.copy_from_slice(&hasher.finalize().as_bytes()[..NONCE_LEN]);
222
223    let cipher = XChaCha20Poly1305::new(Key::from_slice(key));
224    let nonce = XNonce::from_slice(&nonce_bytes);
225    let ciphertext_with_tag = cipher
226        .encrypt(nonce, plaintext.as_ref())
227        .expect("XChaCha20-Poly1305 encrypt is infallible for valid key/nonce");
228
229    let mut envelope = Vec::with_capacity(1 + NONCE_LEN + ciphertext_with_tag.len());
230    envelope.push(VERSION);
231    envelope.extend_from_slice(&nonce_bytes);
232    envelope.extend_from_slice(&ciphertext_with_tag);
233    base62_encode_bytes(&envelope)
234}
235
236/// Reverse of [`encrypt_and_encode`]. AEAD failure → `None`.
237fn decode_with_key(id: &str, key: &[u8; 32]) -> Option<SessionPayload> {
238    let envelope = base62_decode_bytes(id)?;
239    if envelope.len() < 1 + NONCE_LEN + TAG_LEN {
240        return None;
241    }
242    if envelope[0] != VERSION {
243        return None;
244    }
245    let nonce = XNonce::from_slice(&envelope[1..1 + NONCE_LEN]);
246    let ciphertext = &envelope[1 + NONCE_LEN..];
247    let cipher = XChaCha20Poly1305::new(Key::from_slice(key));
248    let plaintext = cipher.decrypt(nonce, ciphertext).ok()?;
249    serde_json::from_slice(&plaintext).ok()
250}
251
252/// Parse an `MCP_ENCRYPTION_KEY` env-var value: a single base64-encoded
253/// 32-byte key. Empty string → `None`. Malformed → `Err`.
254pub fn parse_key_env(s: &str) -> Result<Option<[u8; 32]>, String> {
255    let trimmed = s.trim();
256    if trimmed.is_empty() {
257        return Ok(None);
258    }
259    let decoded = base64::engine::general_purpose::STANDARD
260        .decode(trimmed)
261        .map_err(|e| format!("MCP_ENCRYPTION_KEY: not valid base64: {e}"))?;
262    let key: [u8; 32] = decoded.try_into().map_err(|got: Vec<u8>| {
263        format!(
264            "MCP_ENCRYPTION_KEY: expected 32 bytes after base64-decode, got {}",
265            got.len(),
266        )
267    })?;
268    Ok(Some(key))
269}
270
271/// Byte-level base62. The off-the-shelf `base62` crate only encodes
272/// `u128`s; we need variable-length input for our envelope. Encoding
273/// interprets the bytes as a big-endian unsigned big-integer and
274/// prints it in base62 with `0..9 a..z A..Z` digits; leading zero
275/// bytes are encoded as a `0` digit each so they survive the
276/// round-trip.
277fn base62_encode_bytes(bytes: &[u8]) -> String {
278    if bytes.is_empty() {
279        return String::new();
280    }
281    const ALPHABET: &[u8; 62] =
282        b"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
283    let leading_zeros = bytes.iter().take_while(|b| **b == 0).count();
284    let mut digits: Vec<u8> = Vec::with_capacity(bytes.len() * 2);
285    let mut num: Vec<u32> = bytes[leading_zeros..].iter().map(|b| *b as u32).collect();
286    while !num.is_empty() {
287        let mut remainder: u32 = 0;
288        let mut next: Vec<u32> = Vec::with_capacity(num.len());
289        for &b in &num {
290            let acc = remainder * 256 + b;
291            let q = acc / 62;
292            remainder = acc % 62;
293            if !(next.is_empty() && q == 0) {
294                next.push(q);
295            }
296        }
297        digits.push(remainder as u8);
298        num = next;
299    }
300    let mut out = String::with_capacity(leading_zeros + digits.len());
301    for _ in 0..leading_zeros {
302        out.push(ALPHABET[0] as char);
303    }
304    for d in digits.into_iter().rev() {
305        out.push(ALPHABET[d as usize] as char);
306    }
307    out
308}
309
310fn base62_decode_bytes(s: &str) -> Option<Vec<u8>> {
311    if s.is_empty() {
312        return Some(Vec::new());
313    }
314    fn digit(c: char) -> Option<u32> {
315        match c {
316            '0'..='9' => Some(c as u32 - '0' as u32),
317            'a'..='z' => Some(c as u32 - 'a' as u32 + 10),
318            'A'..='Z' => Some(c as u32 - 'A' as u32 + 36),
319            _ => None,
320        }
321    }
322    let leading_zeros = s.chars().take_while(|c| *c == '0').count();
323    let mut num: Vec<u32> = Vec::with_capacity(s.len());
324    for c in s.chars().skip(leading_zeros) {
325        num.push(digit(c)?);
326    }
327    let mut bytes: Vec<u8> = Vec::new();
328    while !num.is_empty() {
329        let mut remainder: u32 = 0;
330        let mut next: Vec<u32> = Vec::with_capacity(num.len());
331        for &d in &num {
332            let acc = remainder * 62 + d;
333            let q = acc / 256;
334            remainder = acc % 256;
335            if !(next.is_empty() && q == 0) {
336                next.push(q);
337            }
338        }
339        bytes.push(remainder as u8);
340        num = next;
341    }
342    let mut out = vec![0u8; leading_zeros];
343    out.extend(bytes.into_iter().rev());
344    Some(out)
345}
346
347fn build_by_name_map(
348    connections: Vec<Connection>,
349) -> IndexMap<String, Connection> {
350    // First pass: which names are duplicated? Anything that shows up
351    // more than once in the input gets the `_<index>` suffix.
352    let mut name_counts: std::collections::HashMap<String, usize> =
353        std::collections::HashMap::new();
354    for c in &connections {
355        *name_counts
356            .entry(c.initialize_result.server_info.name.clone())
357            .or_insert(0) += 1;
358    }
359    let mut by_name: IndexMap<String, Connection> =
360        IndexMap::with_capacity(connections.len());
361    for (idx, connection) in connections.into_iter().enumerate() {
362        let raw = connection.initialize_result.server_info.name.clone();
363        let key = if name_counts.get(&raw).copied().unwrap_or(0) > 1 {
364            format!("{raw}_{idx}")
365        } else {
366            raw
367        };
368        if by_name.contains_key(&key) {
369            tracing::warn!(
370                key = %key,
371                "two upstreams produce the same prefix after disambiguation; later upstream wins",
372            );
373        }
374        by_name.insert(key, connection);
375    }
376    by_name
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    fn sample_payload() -> SessionPayload {
384        let mut p: SessionPayload = IndexMap::new();
385        let mut h_a: IndexMap<String, String> = IndexMap::new();
386        h_a.insert("Authorization".into(), "Bearer secret-A".into());
387        h_a.insert("Mcp-Session-Id".into(), "sid-A".into());
388        h_a.insert("X-Tenant".into(), "tenant-1".into());
389        p.insert("https://upstream-a.example/mcp".into(), h_a);
390        let mut h_b: IndexMap<String, String> = IndexMap::new();
391        h_b.insert("Mcp-Session-Id".into(), "sid-B".into());
392        p.insert("https://upstream-b.example/mcp".into(), h_b);
393        p
394    }
395
396    #[test]
397    fn base62_round_trip() {
398        for sample in [
399            &b""[..],
400            &b"a"[..],
401            &b"\x00\x01\x02"[..],
402            &b"hello world"[..],
403            br#"{"http://127.0.0.1:1234":"abc123"}"#,
404            &(0..=255u16).map(|b| b as u8).collect::<Vec<_>>()[..],
405        ] {
406            let encoded = base62_encode_bytes(sample);
407            assert!(encoded.bytes().all(|b| (0x21..=0x7E).contains(&b)));
408            let decoded = base62_decode_bytes(&encoded).expect("decode");
409            assert_eq!(decoded, sample, "round-trip failed for {sample:?}");
410        }
411    }
412
413    #[test]
414    fn encrypt_decrypt_round_trip() {
415        let key = [0x42u8; 32];
416        let payload = sample_payload();
417        let id = encrypt_and_encode(&payload, &key);
418        let decoded = decode_with_key(&id, &key).expect("decode under same key");
419        assert_eq!(decoded, payload);
420    }
421
422    #[test]
423    fn decode_with_wrong_key_returns_none() {
424        let key_a = [0x11u8; 32];
425        let key_b = [0x22u8; 32];
426        let id = encrypt_and_encode(&sample_payload(), &key_a);
427        assert!(decode_with_key(&id, &key_b).is_none());
428    }
429
430    #[test]
431    fn decode_garbage_returns_none() {
432        let key = [0x55u8; 32];
433        // Random base62 string, certainly not a valid envelope.
434        assert!(decode_with_key("ABCdef123", &key).is_none());
435        // Empty.
436        assert!(decode_with_key("", &key).is_none());
437        // Too short to even hold version + nonce + tag.
438        assert!(decode_with_key("0", &key).is_none());
439    }
440
441    #[test]
442    fn payload_roundtrip_preserves_canonical_order() {
443        // Build "the same" payload with shuffled URL and header order;
444        // after `build_payload` they should be equal byte-for-byte.
445        let conn_a_url = "https://b.example/mcp".to_string();
446        let conn_b_url = "https://a.example/mcp".to_string();
447
448        let mut h_unsorted: IndexMap<String, String> = IndexMap::new();
449        h_unsorted.insert("Z-Header".into(), "z".into());
450        h_unsorted.insert("Authorization".into(), "Bearer".into());
451
452        // We can't easily synthesize Connection without spinning a
453        // real server, so test build_payload's canonicalization
454        // through an inline helper that mirrors what add() builds.
455        let pairs_unsorted: Vec<(String, IndexMap<String, String>)> =
456            vec![(conn_a_url.clone(), h_unsorted.clone()), (conn_b_url.clone(), h_unsorted.clone())];
457
458        let mut payload: SessionPayload = IndexMap::new();
459        let mut url_entries: Vec<(String, IndexMap<String, String>)> = pairs_unsorted
460            .into_iter()
461            .map(|(url, headers)| {
462                let mut sorted: Vec<(&str, &str)> =
463                    headers.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
464                sorted.sort_by(|a, b| a.0.cmp(b.0));
465                let inner: IndexMap<String, String> = sorted
466                    .into_iter()
467                    .map(|(k, v)| (k.to_string(), v.to_string()))
468                    .collect();
469                (url, inner)
470            })
471            .collect();
472        url_entries.sort_by(|a, b| a.0.cmp(&b.0));
473        for (u, h) in url_entries {
474            payload.insert(u, h);
475        }
476
477        let urls: Vec<&String> = payload.keys().collect();
478        assert_eq!(urls, vec![&conn_b_url, &conn_a_url]); // a.example before b.example
479        let inner = &payload[&conn_b_url];
480        let inner_keys: Vec<&String> = inner.keys().collect();
481        assert_eq!(inner_keys, vec!["Authorization", "Z-Header"]); // alphabetical
482    }
483
484    #[test]
485    fn parse_key_env_round_trip() {
486        let key = [0xAAu8; 32];
487        let env = base64::engine::general_purpose::STANDARD.encode(key);
488        let parsed = parse_key_env(&env).expect("parse").expect("Some");
489        assert_eq!(parsed, key);
490
491        assert!(parse_key_env("").unwrap().is_none());
492        assert!(parse_key_env("   ").unwrap().is_none());
493        assert!(parse_key_env("not-base64!@#").is_err());
494        // Wrong-length payload (16 bytes after b64 decode):
495        let short =
496            base64::engine::general_purpose::STANDARD.encode(&[0u8; 16][..]);
497        assert!(parse_key_env(&short).is_err());
498    }
499}