Skip to main content

objectiveai_mcp_proxy/
session_manager.rs

1//! Session registry.
2//!
3//! Maps session ids to [`Session`]s. A proxy session id is the base62
4//! encoding of an authenticated-encrypted, versioned envelope wrapping a
5//! JSON-serialized `IndexMap<upstream_url, IndexMap<header_name,
6//! header_value>>`. Each upstream's value is the full set of HTTP
7//! headers needed to reconnect: `Mcp-Session-Id`, `Authorization`, plus
8//! any custom headers the original initialize request supplied.
9//!
10//! Stable encoding: URLs sort alphabetically; headers within each
11//! per-URL map sort alphabetically, AND the AEAD nonce is derived
12//! deterministically from a BLAKE3 keyed hash of the canonical
13//! plaintext. So two requests with the same `{url → {header → value}}`
14//! content always encode to the *same* base62 id, byte-for-byte. That
15//! lets `handle_initialize`'s alive-in-memory branch hand the original
16//! id straight back to the caller — re-minting was previously producing
17//! a fresh ciphertext (random nonce) that didn't match any key in
18//! `state.sessions`, so the agent's next POST 404'd. Same payload now
19//! always lives at the same id. Authentication tag covers the version
20//! byte + nonce + ciphertext, so any tampering produces a decryption
21//! failure.
22//!
23//! Wire format (pre-base62):
24//! ```text
25//! [ 1B version (0x01) | 24B XChaCha20 nonce | ciphertext... | 16B Poly1305 tag ]
26//! ```
27//!
28//! Encryption uses one 256-bit key threaded in via
29//! [`SessionManager::new`]. Operators rotate by setting the new key in
30//! `MCP_ENCRYPTION_KEY` and restarting the proxy — every outstanding
31//! session id minted under the old key becomes invalid (a 401 on
32//! resume), which forces clients to re-initialize.
33//!
34//! All per-session dispatch (list, call, read) lives on [`Session`]
35//! itself; this file only cares about computing/minting ids, packing
36//! connections + their canonical headers into a [`Session`], and
37//! looking sessions back up.
38
39use std::sync::Arc;
40
41use base64::Engine;
42use chacha20poly1305::aead::{Aead, KeyInit};
43use chacha20poly1305::{Key, XChaCha20Poly1305, XNonce};
44use dashmap::DashMap;
45use indexmap::IndexMap;
46use objectiveai_sdk::mcp::Connection;
47use rand::RngCore;
48
49use crate::session::Session;
50
51/// Per-session encoded payload.
52///
53/// `connections` is `URL → header_map`, where the header_map is the
54/// full set of HTTP headers used to reconnect that upstream
55/// (`Mcp-Session-Id`, `Authorization`, custom `X-*`). The session id is
56/// uniform with every other header — there's no separate session-id
57/// field. URLs sort alphabetically when encoding for stable ids; the
58/// per-URL header map sorts the same way.
59///
60/// Agent identity (`X-OBJECTIVEAI-AGENT-*`) and routing keys
61/// (`X-OBJECTIVEAI-RESPONSE-ID`, `X-OBJECTIVEAI-RESPONSE-IDS`) are
62/// NOT in this payload — they live on `Session::transient_headers`
63/// in memory only, re-extracted from the inbound HeaderMap on every
64/// `initialize` and full-replaced. The session id encodes only the
65/// upstream-connection topology that needs to survive proxy restart.
66#[derive(Debug, Clone, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
67pub struct SessionPayload {
68    pub connections: IndexMap<String, IndexMap<String, String>>,
69}
70
71/// Current envelope version byte. Bumping this lets future shape
72/// changes be distinguished from old ids that happen to decrypt under
73/// the same key set. Decoders that see an unrecognized version return
74/// `None`.
75const VERSION: u8 = 0x01;
76const NONCE_LEN: usize = 24; // XChaCha20-Poly1305 nonce
77const TAG_LEN: usize = 16; // Poly1305 tag (handled internally by `Aead::encrypt`)
78
79/// Maps a session id to its [`Session`] state. Owns the encryption
80/// key used for minting and decoding ids.
81#[derive(Debug)]
82pub struct SessionManager {
83    sessions: DashMap<String, Arc<Session>>,
84    /// 256-bit AEAD key. Sessions minted under one key cannot be
85    /// decrypted by another — to rotate, set a new key on the
86    /// proxy and restart it; outstanding ids become 401s.
87    key: [u8; 32],
88}
89
90impl SessionManager {
91    pub fn new(key: [u8; 32]) -> Self {
92        Self {
93            sessions: DashMap::new(),
94            key,
95        }
96    }
97
98    /// Build a manager with a fresh random 256-bit key. Sessions
99    /// minted by the resulting manager only decode within the same
100    /// process — useful for tests and for operators who haven't yet
101    /// configured `MCP_ENCRYPTION_KEY`.
102    pub fn with_ephemeral_key() -> Self {
103        let mut key = [0u8; 32];
104        rand::rng().fill_bytes(&mut key);
105        Self::new(key)
106    }
107
108    /// Register a session whose id is computed from the per-upstream
109    /// header set. `connections_with_headers` carries each upstream's
110    /// live `Connection` plus the canonical header map that was used
111    /// to open it — `extra_headers` ∪ `Authorization` (if present)
112    /// ∪ `Mcp-Session-Id` (always — that's the upstream sid the
113    /// proxy must replay on resume).
114    ///
115    /// Returns the encoded session id. If the same upstream set is
116    /// re-registered with byte-identical headers, the returned id is
117    /// byte-identical too (modulo the random AEAD nonce, which makes
118    /// the ciphertext different each time — intentional, prevents
119    /// payload-recognition attacks).
120    pub fn add(
121        &self,
122        connections_with_headers: Vec<(Connection, IndexMap<String, String>)>,
123    ) -> String {
124        let payload = build_payload(&connections_with_headers);
125        let id = encrypt_and_encode(&payload, &self.key);
126        let connections: Vec<Connection> =
127            connections_with_headers.into_iter().map(|(c, _)| c).collect();
128        let by_name = build_by_name_map(connections);
129        self.sessions
130            .insert(id.clone(), Arc::new(Session::new(by_name, payload)));
131        id
132    }
133
134    /// Cheap clone-out of a [`Session`] — never holds a DashMap guard
135    /// across the await boundary.
136    pub fn get(&self, session_id: &str) -> Option<Arc<Session>> {
137        self.sessions.get(session_id).map(|e| e.value().clone())
138    }
139
140    /// Remove a session from the registry. Returns `Some(_)` if a session
141    /// was present, `None` if the id was unknown.
142    ///
143    /// Once every `Arc<Session>` to the removed session has dropped, the
144    /// session's `IndexMap<String, Connection>` drops, every `Connection`'s
145    /// `Drop` fires its upstream's wakeup signal, and each upstream's
146    /// listener task wakes to re-check liveness. The listener sees
147    /// `Arc::strong_count == 1` (only itself) and exits, which drops the
148    /// inner state and closes the upstream HTTP session.
149    pub fn remove(&self, session_id: &str) -> Option<Arc<Session>> {
150        self.sessions.remove(session_id).map(|(_, session)| session)
151    }
152
153    /// Decrypt an incoming session id back into the URL → header_map
154    /// payload it encodes. `None` on any decode failure (bad base62,
155    /// unknown version, AEAD failure, bad JSON, wrong shape).
156    pub fn decode_session_id(&self, id: &str) -> Option<SessionPayload> {
157        decode_with_key(id, &self.key)
158    }
159
160    /// Re-mint the encoded id for a payload that's already canonical —
161    /// used by the alive-in-memory branch in `handle_initialize` to
162    /// hand back the same id the client sent (the encrypt step
163    /// produces a different ciphertext each call due to the random
164    /// nonce, so technically the caller will see a fresh id, not the
165    /// byte-equal old one; the new id decrypts to the same payload
166    /// either way).
167    pub fn mint_id(&self, payload: &SessionPayload) -> String {
168        encrypt_and_encode(payload, &self.key)
169    }
170}
171
172/// Build a canonical (url-sorted, header-sorted) `SessionPayload`
173/// from a list of `(Connection, raw_header_map)` pairs.
174///
175/// The raw header map is normalized:
176///   - keys lowercased? **No** — HTTP headers are case-insensitive on
177///     the wire but we keep the casing the upstream sees. Sorting is
178///     done case-sensitively on the bytes; deterministic regardless.
179///   - sorted alphabetically.
180fn build_payload(
181    pairs: &[(Connection, IndexMap<String, String>)],
182) -> SessionPayload {
183    // Collect (url, sorted headers) pairs, then sort by URL.
184    let mut url_entries: Vec<(String, IndexMap<String, String>)> = pairs
185        .iter()
186        .map(|(c, headers)| {
187            let mut sorted: Vec<(&str, &str)> = headers
188                .iter()
189                .map(|(k, v)| (k.as_str(), v.as_str()))
190                .collect();
191            sorted.sort_by(|a, b| a.0.cmp(b.0));
192            let inner: IndexMap<String, String> = sorted
193                .into_iter()
194                .map(|(k, v)| (k.to_string(), v.to_string()))
195                .collect();
196            (c.url.clone(), inner)
197        })
198        .collect();
199    url_entries.sort_by(|a, b| a.0.cmp(&b.0));
200
201    let mut connections: IndexMap<String, IndexMap<String, String>> =
202        IndexMap::with_capacity(url_entries.len());
203    for (url, headers) in url_entries {
204        connections.insert(url, headers);
205    }
206
207    SessionPayload { connections }
208}
209
210/// JSON-serialize the payload, AEAD-encrypt with a *deterministic*
211/// nonce derived from a BLAKE3 keyed hash of the plaintext, prepend
212/// version + nonce, base62-encode the whole envelope.
213///
214/// Why deterministic: `handle_initialize`'s alive-in-memory branch
215/// needs to mint the same id the caller already holds, so the id
216/// remains a key in `state.sessions`. With a random nonce we'd
217/// generate a different ciphertext for the same payload every time.
218///
219/// Safety of nonce reuse: AEAD only breaks under nonce reuse when the
220/// SAME nonce is paired with TWO DIFFERENT plaintexts. Here the nonce
221/// is a function of the plaintext (and key), so distinct plaintexts
222/// get distinct nonces; identical plaintexts get identical nonces and
223/// identical ciphertexts, which is exactly what we want.
224fn encrypt_and_encode(payload: &SessionPayload, key: &[u8; 32]) -> String {
225    let plaintext =
226        serde_json::to_vec(payload).expect("SessionPayload serializes");
227
228    // Derive the 24-byte XChaCha20 nonce from BLAKE3(key, plaintext).
229    // BLAKE3's keyed hash is a PRF, so this is indistinguishable from
230    // random for any attacker who doesn't know `key`, but it's stable
231    // for the (key, plaintext) pair.
232    let mut hasher = blake3::Hasher::new_keyed(key);
233    hasher.update(&plaintext);
234    let mut nonce_bytes = [0u8; NONCE_LEN];
235    nonce_bytes.copy_from_slice(&hasher.finalize().as_bytes()[..NONCE_LEN]);
236
237    let cipher = XChaCha20Poly1305::new(Key::from_slice(key));
238    let nonce = XNonce::from_slice(&nonce_bytes);
239    let ciphertext_with_tag = cipher
240        .encrypt(nonce, plaintext.as_ref())
241        .expect("XChaCha20-Poly1305 encrypt is infallible for valid key/nonce");
242
243    let mut envelope = Vec::with_capacity(1 + NONCE_LEN + ciphertext_with_tag.len());
244    envelope.push(VERSION);
245    envelope.extend_from_slice(&nonce_bytes);
246    envelope.extend_from_slice(&ciphertext_with_tag);
247    base62_encode_bytes(&envelope)
248}
249
250/// Reverse of [`encrypt_and_encode`]. AEAD failure → `None`.
251fn decode_with_key(id: &str, key: &[u8; 32]) -> Option<SessionPayload> {
252    let envelope = base62_decode_bytes(id)?;
253    if envelope.len() < 1 + NONCE_LEN + TAG_LEN {
254        return None;
255    }
256    if envelope[0] != VERSION {
257        return None;
258    }
259    let nonce = XNonce::from_slice(&envelope[1..1 + NONCE_LEN]);
260    let ciphertext = &envelope[1 + NONCE_LEN..];
261    let cipher = XChaCha20Poly1305::new(Key::from_slice(key));
262    let plaintext = cipher.decrypt(nonce, ciphertext).ok()?;
263    serde_json::from_slice(&plaintext).ok()
264}
265
266/// Parse an `MCP_ENCRYPTION_KEY` env-var value: a single base64-encoded
267/// 32-byte key. Empty string → `None`. Malformed → `Err`.
268pub fn parse_key_env(s: &str) -> Result<Option<[u8; 32]>, String> {
269    let trimmed = s.trim();
270    if trimmed.is_empty() {
271        return Ok(None);
272    }
273    let decoded = base64::engine::general_purpose::STANDARD
274        .decode(trimmed)
275        .map_err(|e| format!("MCP_ENCRYPTION_KEY: not valid base64: {e}"))?;
276    let key: [u8; 32] = decoded.try_into().map_err(|got: Vec<u8>| {
277        format!(
278            "MCP_ENCRYPTION_KEY: expected 32 bytes after base64-decode, got {}",
279            got.len(),
280        )
281    })?;
282    Ok(Some(key))
283}
284
285/// Byte-level base62. The off-the-shelf `base62` crate only encodes
286/// `u128`s; we need variable-length input for our envelope. Encoding
287/// interprets the bytes as a big-endian unsigned big-integer and
288/// prints it in base62 with `0..9 a..z A..Z` digits; leading zero
289/// bytes are encoded as a `0` digit each so they survive the
290/// round-trip.
291fn base62_encode_bytes(bytes: &[u8]) -> String {
292    if bytes.is_empty() {
293        return String::new();
294    }
295    const ALPHABET: &[u8; 62] =
296        b"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
297    let leading_zeros = bytes.iter().take_while(|b| **b == 0).count();
298    let mut digits: Vec<u8> = Vec::with_capacity(bytes.len() * 2);
299    let mut num: Vec<u32> = bytes[leading_zeros..].iter().map(|b| *b as u32).collect();
300    while !num.is_empty() {
301        let mut remainder: u32 = 0;
302        let mut next: Vec<u32> = Vec::with_capacity(num.len());
303        for &b in &num {
304            let acc = remainder * 256 + b;
305            let q = acc / 62;
306            remainder = acc % 62;
307            if !(next.is_empty() && q == 0) {
308                next.push(q);
309            }
310        }
311        digits.push(remainder as u8);
312        num = next;
313    }
314    let mut out = String::with_capacity(leading_zeros + digits.len());
315    for _ in 0..leading_zeros {
316        out.push(ALPHABET[0] as char);
317    }
318    for d in digits.into_iter().rev() {
319        out.push(ALPHABET[d as usize] as char);
320    }
321    out
322}
323
324fn base62_decode_bytes(s: &str) -> Option<Vec<u8>> {
325    if s.is_empty() {
326        return Some(Vec::new());
327    }
328    fn digit(c: char) -> Option<u32> {
329        match c {
330            '0'..='9' => Some(c as u32 - '0' as u32),
331            'a'..='z' => Some(c as u32 - 'a' as u32 + 10),
332            'A'..='Z' => Some(c as u32 - 'A' as u32 + 36),
333            _ => None,
334        }
335    }
336    let leading_zeros = s.chars().take_while(|c| *c == '0').count();
337    let mut num: Vec<u32> = Vec::with_capacity(s.len());
338    for c in s.chars().skip(leading_zeros) {
339        num.push(digit(c)?);
340    }
341    let mut bytes: Vec<u8> = Vec::new();
342    while !num.is_empty() {
343        let mut remainder: u32 = 0;
344        let mut next: Vec<u32> = Vec::with_capacity(num.len());
345        for &d in &num {
346            let acc = remainder * 62 + d;
347            let q = acc / 256;
348            remainder = acc % 256;
349            if !(next.is_empty() && q == 0) {
350                next.push(q);
351            }
352        }
353        bytes.push(remainder as u8);
354        num = next;
355    }
356    let mut out = vec![0u8; leading_zeros];
357    out.extend(bytes.into_iter().rev());
358    Some(out)
359}
360
361fn build_by_name_map(
362    connections: Vec<Connection>,
363) -> IndexMap<String, Connection> {
364    // First pass: which names are duplicated? Anything that shows up
365    // more than once in the input gets the `_<index>` suffix.
366    let mut name_counts: std::collections::HashMap<String, usize> =
367        std::collections::HashMap::new();
368    for c in &connections {
369        *name_counts
370            .entry(c.initialize_result.server_info.name.clone())
371            .or_insert(0) += 1;
372    }
373    let mut by_name: IndexMap<String, Connection> =
374        IndexMap::with_capacity(connections.len());
375    for (idx, connection) in connections.into_iter().enumerate() {
376        let raw = connection.initialize_result.server_info.name.clone();
377        let key = if name_counts.get(&raw).copied().unwrap_or(0) > 1 {
378            format!("{raw}_{idx}")
379        } else {
380            raw
381        };
382        if by_name.contains_key(&key) {
383            tracing::warn!(
384                key = %key,
385                "two upstreams produce the same prefix after disambiguation; later upstream wins",
386            );
387        }
388        by_name.insert(key, connection);
389    }
390    by_name
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    fn sample_payload() -> SessionPayload {
398        let mut connections: IndexMap<String, IndexMap<String, String>> = IndexMap::new();
399        let mut h_a: IndexMap<String, String> = IndexMap::new();
400        h_a.insert("Authorization".into(), "Bearer secret-A".into());
401        h_a.insert("Mcp-Session-Id".into(), "sid-A".into());
402        h_a.insert("X-Tenant".into(), "tenant-1".into());
403        connections.insert("https://upstream-a.example/mcp".into(), h_a);
404        let mut h_b: IndexMap<String, String> = IndexMap::new();
405        h_b.insert("Mcp-Session-Id".into(), "sid-B".into());
406        connections.insert("https://upstream-b.example/mcp".into(), h_b);
407        SessionPayload { connections }
408    }
409
410    #[test]
411    fn base62_round_trip() {
412        for sample in [
413            &b""[..],
414            &b"a"[..],
415            &b"\x00\x01\x02"[..],
416            &b"hello world"[..],
417            br#"{"http://127.0.0.1:1234":"abc123"}"#,
418            &(0..=255u16).map(|b| b as u8).collect::<Vec<_>>()[..],
419        ] {
420            let encoded = base62_encode_bytes(sample);
421            assert!(encoded.bytes().all(|b| (0x21..=0x7E).contains(&b)));
422            let decoded = base62_decode_bytes(&encoded).expect("decode");
423            assert_eq!(decoded, sample, "round-trip failed for {sample:?}");
424        }
425    }
426
427    #[test]
428    fn encrypt_decrypt_round_trip() {
429        let key = [0x42u8; 32];
430        let payload = sample_payload();
431        let id = encrypt_and_encode(&payload, &key);
432        let decoded = decode_with_key(&id, &key).expect("decode under same key");
433        assert_eq!(decoded, payload);
434    }
435
436    #[test]
437    fn decode_with_wrong_key_returns_none() {
438        let key_a = [0x11u8; 32];
439        let key_b = [0x22u8; 32];
440        let id = encrypt_and_encode(&sample_payload(), &key_a);
441        assert!(decode_with_key(&id, &key_b).is_none());
442    }
443
444    #[test]
445    fn decode_garbage_returns_none() {
446        let key = [0x55u8; 32];
447        // Random base62 string, certainly not a valid envelope.
448        assert!(decode_with_key("ABCdef123", &key).is_none());
449        // Empty.
450        assert!(decode_with_key("", &key).is_none());
451        // Too short to even hold version + nonce + tag.
452        assert!(decode_with_key("0", &key).is_none());
453    }
454
455    #[test]
456    fn payload_roundtrip_preserves_canonical_order() {
457        // Build "the same" payload with shuffled URL and header order;
458        // after `build_payload` they should be equal byte-for-byte.
459        let conn_a_url = "https://b.example/mcp".to_string();
460        let conn_b_url = "https://a.example/mcp".to_string();
461
462        let mut h_unsorted: IndexMap<String, String> = IndexMap::new();
463        h_unsorted.insert("Z-Header".into(), "z".into());
464        h_unsorted.insert("Authorization".into(), "Bearer".into());
465
466        // We can't easily synthesize Connection without spinning a
467        // real server, so test build_payload's canonicalization
468        // through an inline helper that mirrors what add() builds.
469        let pairs_unsorted: Vec<(String, IndexMap<String, String>)> =
470            vec![(conn_a_url.clone(), h_unsorted.clone()), (conn_b_url.clone(), h_unsorted.clone())];
471
472        let mut connections: IndexMap<String, IndexMap<String, String>> = IndexMap::new();
473        let mut url_entries: Vec<(String, IndexMap<String, String>)> = pairs_unsorted
474            .into_iter()
475            .map(|(url, headers)| {
476                let mut sorted: Vec<(&str, &str)> =
477                    headers.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
478                sorted.sort_by(|a, b| a.0.cmp(b.0));
479                let inner: IndexMap<String, String> = sorted
480                    .into_iter()
481                    .map(|(k, v)| (k.to_string(), v.to_string()))
482                    .collect();
483                (url, inner)
484            })
485            .collect();
486        url_entries.sort_by(|a, b| a.0.cmp(&b.0));
487        for (u, h) in url_entries {
488            connections.insert(u, h);
489        }
490        let payload = SessionPayload { connections };
491
492        let urls: Vec<&String> = payload.connections.keys().collect();
493        assert_eq!(urls, vec![&conn_b_url, &conn_a_url]); // a.example before b.example
494        let inner = &payload.connections[&conn_b_url];
495        let inner_keys: Vec<&String> = inner.keys().collect();
496        assert_eq!(inner_keys, vec!["Authorization", "Z-Header"]); // alphabetical
497    }
498
499    #[test]
500    fn parse_key_env_round_trip() {
501        let key = [0xAAu8; 32];
502        let env = base64::engine::general_purpose::STANDARD.encode(key);
503        let parsed = parse_key_env(&env).expect("parse").expect("Some");
504        assert_eq!(parsed, key);
505
506        assert!(parse_key_env("").unwrap().is_none());
507        assert!(parse_key_env("   ").unwrap().is_none());
508        assert!(parse_key_env("not-base64!@#").is_err());
509        // Wrong-length payload (16 bytes after b64 decode):
510        let short =
511            base64::engine::general_purpose::STANDARD.encode(&[0u8; 16][..]);
512        assert!(parse_key_env(&short).is_err());
513    }
514}