1use std::sync::Arc;
40
41use base64::Engine;
42use chacha20poly1305::aead::{Aead, KeyInit};
43use chacha20poly1305::{Key, XChaCha20Poly1305, XNonce};
44use dashmap::DashMap;
45use indexmap::IndexMap;
46use objectiveai_sdk::mcp::Connection;
47use rand::RngCore;
48
49use crate::session::Session;
50
51pub type SessionPayload = IndexMap<String, IndexMap<String, String>>;
58
59const VERSION: u8 = 0x01;
64const NONCE_LEN: usize = 24; const TAG_LEN: usize = 16; #[derive(Debug)]
70pub struct SessionManager {
71 sessions: DashMap<String, Arc<Session>>,
72 key: [u8; 32],
76}
77
78impl SessionManager {
79 pub fn new(key: [u8; 32]) -> Self {
80 Self {
81 sessions: DashMap::new(),
82 key,
83 }
84 }
85
86 pub fn with_ephemeral_key() -> Self {
91 let mut key = [0u8; 32];
92 rand::rng().fill_bytes(&mut key);
93 Self::new(key)
94 }
95
96 pub fn add(
109 &self,
110 connections_with_headers: Vec<(Connection, IndexMap<String, String>)>,
111 ) -> String {
112 let payload = build_payload(&connections_with_headers);
113 let id = encrypt_and_encode(&payload, &self.key);
114 let connections: Vec<Connection> =
115 connections_with_headers.into_iter().map(|(c, _)| c).collect();
116 let by_name = build_by_name_map(connections);
117 self.sessions
118 .insert(id.clone(), Arc::new(Session::new(by_name, payload)));
119 id
120 }
121
122 pub fn get(&self, session_id: &str) -> Option<Arc<Session>> {
125 self.sessions.get(session_id).map(|e| e.value().clone())
126 }
127
128 pub fn remove(&self, session_id: &str) -> Option<Arc<Session>> {
138 self.sessions.remove(session_id).map(|(_, session)| session)
139 }
140
141 pub fn decode_session_id(&self, id: &str) -> Option<SessionPayload> {
145 decode_with_key(id, &self.key)
146 }
147
148 pub fn mint_id(&self, payload: &SessionPayload) -> String {
156 encrypt_and_encode(payload, &self.key)
157 }
158}
159
160fn build_payload(
169 pairs: &[(Connection, IndexMap<String, String>)],
170) -> SessionPayload {
171 let mut url_entries: Vec<(String, IndexMap<String, String>)> = pairs
173 .iter()
174 .map(|(c, headers)| {
175 let mut sorted: Vec<(&str, &str)> = headers
176 .iter()
177 .map(|(k, v)| (k.as_str(), v.as_str()))
178 .collect();
179 sorted.sort_by(|a, b| a.0.cmp(b.0));
180 let inner: IndexMap<String, String> = sorted
181 .into_iter()
182 .map(|(k, v)| (k.to_string(), v.to_string()))
183 .collect();
184 (c.url.clone(), inner)
185 })
186 .collect();
187 url_entries.sort_by(|a, b| a.0.cmp(&b.0));
188
189 let mut payload: SessionPayload = IndexMap::with_capacity(url_entries.len());
190 for (url, headers) in url_entries {
191 payload.insert(url, headers);
192 }
193 payload
194}
195
196fn encrypt_and_encode(payload: &SessionPayload, key: &[u8; 32]) -> String {
211 let plaintext =
212 serde_json::to_vec(payload).expect("SessionPayload serializes");
213
214 let mut hasher = blake3::Hasher::new_keyed(key);
219 hasher.update(&plaintext);
220 let mut nonce_bytes = [0u8; NONCE_LEN];
221 nonce_bytes.copy_from_slice(&hasher.finalize().as_bytes()[..NONCE_LEN]);
222
223 let cipher = XChaCha20Poly1305::new(Key::from_slice(key));
224 let nonce = XNonce::from_slice(&nonce_bytes);
225 let ciphertext_with_tag = cipher
226 .encrypt(nonce, plaintext.as_ref())
227 .expect("XChaCha20-Poly1305 encrypt is infallible for valid key/nonce");
228
229 let mut envelope = Vec::with_capacity(1 + NONCE_LEN + ciphertext_with_tag.len());
230 envelope.push(VERSION);
231 envelope.extend_from_slice(&nonce_bytes);
232 envelope.extend_from_slice(&ciphertext_with_tag);
233 base62_encode_bytes(&envelope)
234}
235
236fn decode_with_key(id: &str, key: &[u8; 32]) -> Option<SessionPayload> {
238 let envelope = base62_decode_bytes(id)?;
239 if envelope.len() < 1 + NONCE_LEN + TAG_LEN {
240 return None;
241 }
242 if envelope[0] != VERSION {
243 return None;
244 }
245 let nonce = XNonce::from_slice(&envelope[1..1 + NONCE_LEN]);
246 let ciphertext = &envelope[1 + NONCE_LEN..];
247 let cipher = XChaCha20Poly1305::new(Key::from_slice(key));
248 let plaintext = cipher.decrypt(nonce, ciphertext).ok()?;
249 serde_json::from_slice(&plaintext).ok()
250}
251
252pub fn parse_key_env(s: &str) -> Result<Option<[u8; 32]>, String> {
255 let trimmed = s.trim();
256 if trimmed.is_empty() {
257 return Ok(None);
258 }
259 let decoded = base64::engine::general_purpose::STANDARD
260 .decode(trimmed)
261 .map_err(|e| format!("MCP_ENCRYPTION_KEY: not valid base64: {e}"))?;
262 let key: [u8; 32] = decoded.try_into().map_err(|got: Vec<u8>| {
263 format!(
264 "MCP_ENCRYPTION_KEY: expected 32 bytes after base64-decode, got {}",
265 got.len(),
266 )
267 })?;
268 Ok(Some(key))
269}
270
271fn base62_encode_bytes(bytes: &[u8]) -> String {
278 if bytes.is_empty() {
279 return String::new();
280 }
281 const ALPHABET: &[u8; 62] =
282 b"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
283 let leading_zeros = bytes.iter().take_while(|b| **b == 0).count();
284 let mut digits: Vec<u8> = Vec::with_capacity(bytes.len() * 2);
285 let mut num: Vec<u32> = bytes[leading_zeros..].iter().map(|b| *b as u32).collect();
286 while !num.is_empty() {
287 let mut remainder: u32 = 0;
288 let mut next: Vec<u32> = Vec::with_capacity(num.len());
289 for &b in &num {
290 let acc = remainder * 256 + b;
291 let q = acc / 62;
292 remainder = acc % 62;
293 if !(next.is_empty() && q == 0) {
294 next.push(q);
295 }
296 }
297 digits.push(remainder as u8);
298 num = next;
299 }
300 let mut out = String::with_capacity(leading_zeros + digits.len());
301 for _ in 0..leading_zeros {
302 out.push(ALPHABET[0] as char);
303 }
304 for d in digits.into_iter().rev() {
305 out.push(ALPHABET[d as usize] as char);
306 }
307 out
308}
309
310fn base62_decode_bytes(s: &str) -> Option<Vec<u8>> {
311 if s.is_empty() {
312 return Some(Vec::new());
313 }
314 fn digit(c: char) -> Option<u32> {
315 match c {
316 '0'..='9' => Some(c as u32 - '0' as u32),
317 'a'..='z' => Some(c as u32 - 'a' as u32 + 10),
318 'A'..='Z' => Some(c as u32 - 'A' as u32 + 36),
319 _ => None,
320 }
321 }
322 let leading_zeros = s.chars().take_while(|c| *c == '0').count();
323 let mut num: Vec<u32> = Vec::with_capacity(s.len());
324 for c in s.chars().skip(leading_zeros) {
325 num.push(digit(c)?);
326 }
327 let mut bytes: Vec<u8> = Vec::new();
328 while !num.is_empty() {
329 let mut remainder: u32 = 0;
330 let mut next: Vec<u32> = Vec::with_capacity(num.len());
331 for &d in &num {
332 let acc = remainder * 62 + d;
333 let q = acc / 256;
334 remainder = acc % 256;
335 if !(next.is_empty() && q == 0) {
336 next.push(q);
337 }
338 }
339 bytes.push(remainder as u8);
340 num = next;
341 }
342 let mut out = vec![0u8; leading_zeros];
343 out.extend(bytes.into_iter().rev());
344 Some(out)
345}
346
347fn build_by_name_map(
348 connections: Vec<Connection>,
349) -> IndexMap<String, Connection> {
350 let mut name_counts: std::collections::HashMap<String, usize> =
353 std::collections::HashMap::new();
354 for c in &connections {
355 *name_counts
356 .entry(c.initialize_result.server_info.name.clone())
357 .or_insert(0) += 1;
358 }
359 let mut by_name: IndexMap<String, Connection> =
360 IndexMap::with_capacity(connections.len());
361 for (idx, connection) in connections.into_iter().enumerate() {
362 let raw = connection.initialize_result.server_info.name.clone();
363 let key = if name_counts.get(&raw).copied().unwrap_or(0) > 1 {
364 format!("{raw}_{idx}")
365 } else {
366 raw
367 };
368 if by_name.contains_key(&key) {
369 tracing::warn!(
370 key = %key,
371 "two upstreams produce the same prefix after disambiguation; later upstream wins",
372 );
373 }
374 by_name.insert(key, connection);
375 }
376 by_name
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 fn sample_payload() -> SessionPayload {
384 let mut p: SessionPayload = IndexMap::new();
385 let mut h_a: IndexMap<String, String> = IndexMap::new();
386 h_a.insert("Authorization".into(), "Bearer secret-A".into());
387 h_a.insert("Mcp-Session-Id".into(), "sid-A".into());
388 h_a.insert("X-Tenant".into(), "tenant-1".into());
389 p.insert("https://upstream-a.example/mcp".into(), h_a);
390 let mut h_b: IndexMap<String, String> = IndexMap::new();
391 h_b.insert("Mcp-Session-Id".into(), "sid-B".into());
392 p.insert("https://upstream-b.example/mcp".into(), h_b);
393 p
394 }
395
396 #[test]
397 fn base62_round_trip() {
398 for sample in [
399 &b""[..],
400 &b"a"[..],
401 &b"\x00\x01\x02"[..],
402 &b"hello world"[..],
403 br#"{"http://127.0.0.1:1234":"abc123"}"#,
404 &(0..=255u16).map(|b| b as u8).collect::<Vec<_>>()[..],
405 ] {
406 let encoded = base62_encode_bytes(sample);
407 assert!(encoded.bytes().all(|b| (0x21..=0x7E).contains(&b)));
408 let decoded = base62_decode_bytes(&encoded).expect("decode");
409 assert_eq!(decoded, sample, "round-trip failed for {sample:?}");
410 }
411 }
412
413 #[test]
414 fn encrypt_decrypt_round_trip() {
415 let key = [0x42u8; 32];
416 let payload = sample_payload();
417 let id = encrypt_and_encode(&payload, &key);
418 let decoded = decode_with_key(&id, &key).expect("decode under same key");
419 assert_eq!(decoded, payload);
420 }
421
422 #[test]
423 fn decode_with_wrong_key_returns_none() {
424 let key_a = [0x11u8; 32];
425 let key_b = [0x22u8; 32];
426 let id = encrypt_and_encode(&sample_payload(), &key_a);
427 assert!(decode_with_key(&id, &key_b).is_none());
428 }
429
430 #[test]
431 fn decode_garbage_returns_none() {
432 let key = [0x55u8; 32];
433 assert!(decode_with_key("ABCdef123", &key).is_none());
435 assert!(decode_with_key("", &key).is_none());
437 assert!(decode_with_key("0", &key).is_none());
439 }
440
441 #[test]
442 fn payload_roundtrip_preserves_canonical_order() {
443 let conn_a_url = "https://b.example/mcp".to_string();
446 let conn_b_url = "https://a.example/mcp".to_string();
447
448 let mut h_unsorted: IndexMap<String, String> = IndexMap::new();
449 h_unsorted.insert("Z-Header".into(), "z".into());
450 h_unsorted.insert("Authorization".into(), "Bearer".into());
451
452 let pairs_unsorted: Vec<(String, IndexMap<String, String>)> =
456 vec![(conn_a_url.clone(), h_unsorted.clone()), (conn_b_url.clone(), h_unsorted.clone())];
457
458 let mut payload: SessionPayload = IndexMap::new();
459 let mut url_entries: Vec<(String, IndexMap<String, String>)> = pairs_unsorted
460 .into_iter()
461 .map(|(url, headers)| {
462 let mut sorted: Vec<(&str, &str)> =
463 headers.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
464 sorted.sort_by(|a, b| a.0.cmp(b.0));
465 let inner: IndexMap<String, String> = sorted
466 .into_iter()
467 .map(|(k, v)| (k.to_string(), v.to_string()))
468 .collect();
469 (url, inner)
470 })
471 .collect();
472 url_entries.sort_by(|a, b| a.0.cmp(&b.0));
473 for (u, h) in url_entries {
474 payload.insert(u, h);
475 }
476
477 let urls: Vec<&String> = payload.keys().collect();
478 assert_eq!(urls, vec![&conn_b_url, &conn_a_url]); let inner = &payload[&conn_b_url];
480 let inner_keys: Vec<&String> = inner.keys().collect();
481 assert_eq!(inner_keys, vec!["Authorization", "Z-Header"]); }
483
484 #[test]
485 fn parse_key_env_round_trip() {
486 let key = [0xAAu8; 32];
487 let env = base64::engine::general_purpose::STANDARD.encode(key);
488 let parsed = parse_key_env(&env).expect("parse").expect("Some");
489 assert_eq!(parsed, key);
490
491 assert!(parse_key_env("").unwrap().is_none());
492 assert!(parse_key_env(" ").unwrap().is_none());
493 assert!(parse_key_env("not-base64!@#").is_err());
494 let short =
496 base64::engine::general_purpose::STANDARD.encode(&[0u8; 16][..]);
497 assert!(parse_key_env(&short).is_err());
498 }
499}