1use std::sync::Arc;
40
41use base64::Engine;
42use chacha20poly1305::aead::{Aead, KeyInit};
43use chacha20poly1305::{Key, XChaCha20Poly1305, XNonce};
44use dashmap::DashMap;
45use indexmap::IndexMap;
46use objectiveai_sdk::mcp::Connection;
47use rand::RngCore;
48
49use crate::session::Session;
50
51#[derive(Debug, Clone, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
67pub struct SessionPayload {
68 pub connections: IndexMap<String, IndexMap<String, String>>,
69}
70
71const VERSION: u8 = 0x01;
76const NONCE_LEN: usize = 24; const TAG_LEN: usize = 16; #[derive(Debug)]
82pub struct SessionManager {
83 sessions: DashMap<String, Arc<Session>>,
84 key: [u8; 32],
88}
89
90impl SessionManager {
91 pub fn new(key: [u8; 32]) -> Self {
92 Self {
93 sessions: DashMap::new(),
94 key,
95 }
96 }
97
98 pub fn with_ephemeral_key() -> Self {
103 let mut key = [0u8; 32];
104 rand::rng().fill_bytes(&mut key);
105 Self::new(key)
106 }
107
108 pub fn add(
121 &self,
122 connections_with_headers: Vec<(Connection, IndexMap<String, String>)>,
123 ) -> String {
124 let payload = build_payload(&connections_with_headers);
125 let id = encrypt_and_encode(&payload, &self.key);
126 let connections: Vec<Connection> =
127 connections_with_headers.into_iter().map(|(c, _)| c).collect();
128 let by_prefix = build_prefix_map(connections);
129 self.sessions
130 .insert(id.clone(), Arc::new(Session::new(by_prefix, payload)));
131 id
132 }
133
134 pub fn get(&self, session_id: &str) -> Option<Arc<Session>> {
137 self.sessions.get(session_id).map(|e| e.value().clone())
138 }
139
140 pub fn remove(&self, session_id: &str) -> Option<Arc<Session>> {
150 self.sessions.remove(session_id).map(|(_, session)| session)
151 }
152
153 pub fn decode_session_id(&self, id: &str) -> Option<SessionPayload> {
157 decode_with_key(id, &self.key)
158 }
159
160 pub fn mint_id(&self, payload: &SessionPayload) -> String {
168 encrypt_and_encode(payload, &self.key)
169 }
170}
171
172fn build_payload(
181 pairs: &[(Connection, IndexMap<String, String>)],
182) -> SessionPayload {
183 let mut url_entries: Vec<(String, IndexMap<String, String>)> = pairs
185 .iter()
186 .map(|(c, headers)| {
187 let mut sorted: Vec<(&str, &str)> = headers
188 .iter()
189 .map(|(k, v)| (k.as_str(), v.as_str()))
190 .collect();
191 sorted.sort_by(|a, b| a.0.cmp(b.0));
192 let inner: IndexMap<String, String> = sorted
193 .into_iter()
194 .map(|(k, v)| (k.to_string(), v.to_string()))
195 .collect();
196 (c.url.clone(), inner)
197 })
198 .collect();
199 url_entries.sort_by(|a, b| a.0.cmp(&b.0));
200
201 let mut connections: IndexMap<String, IndexMap<String, String>> =
202 IndexMap::with_capacity(url_entries.len());
203 for (url, headers) in url_entries {
204 connections.insert(url, headers);
205 }
206
207 SessionPayload { connections }
208}
209
210fn encrypt_and_encode(payload: &SessionPayload, key: &[u8; 32]) -> String {
225 let plaintext =
226 serde_json::to_vec(payload).expect("SessionPayload serializes");
227
228 let mut hasher = blake3::Hasher::new_keyed(key);
233 hasher.update(&plaintext);
234 let mut nonce_bytes = [0u8; NONCE_LEN];
235 nonce_bytes.copy_from_slice(&hasher.finalize().as_bytes()[..NONCE_LEN]);
236
237 let cipher = XChaCha20Poly1305::new(Key::from_slice(key));
238 let nonce = XNonce::from_slice(&nonce_bytes);
239 let ciphertext_with_tag = cipher
240 .encrypt(nonce, plaintext.as_ref())
241 .expect("XChaCha20-Poly1305 encrypt is infallible for valid key/nonce");
242
243 let mut envelope = Vec::with_capacity(1 + NONCE_LEN + ciphertext_with_tag.len());
244 envelope.push(VERSION);
245 envelope.extend_from_slice(&nonce_bytes);
246 envelope.extend_from_slice(&ciphertext_with_tag);
247 base62_encode_bytes(&envelope)
248}
249
250fn decode_with_key(id: &str, key: &[u8; 32]) -> Option<SessionPayload> {
252 let envelope = base62_decode_bytes(id)?;
253 if envelope.len() < 1 + NONCE_LEN + TAG_LEN {
254 return None;
255 }
256 if envelope[0] != VERSION {
257 return None;
258 }
259 let nonce = XNonce::from_slice(&envelope[1..1 + NONCE_LEN]);
260 let ciphertext = &envelope[1 + NONCE_LEN..];
261 let cipher = XChaCha20Poly1305::new(Key::from_slice(key));
262 let plaintext = cipher.decrypt(nonce, ciphertext).ok()?;
263 serde_json::from_slice(&plaintext).ok()
264}
265
266pub fn parse_key_env(s: &str) -> Result<Option<[u8; 32]>, String> {
269 let trimmed = s.trim();
270 if trimmed.is_empty() {
271 return Ok(None);
272 }
273 let decoded = base64::engine::general_purpose::STANDARD
274 .decode(trimmed)
275 .map_err(|e| format!("MCP_ENCRYPTION_KEY: not valid base64: {e}"))?;
276 let key: [u8; 32] = decoded.try_into().map_err(|got: Vec<u8>| {
277 format!(
278 "MCP_ENCRYPTION_KEY: expected 32 bytes after base64-decode, got {}",
279 got.len(),
280 )
281 })?;
282 Ok(Some(key))
283}
284
285fn base62_encode_bytes(bytes: &[u8]) -> String {
292 if bytes.is_empty() {
293 return String::new();
294 }
295 const ALPHABET: &[u8; 62] =
296 b"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
297 let leading_zeros = bytes.iter().take_while(|b| **b == 0).count();
298 let mut digits: Vec<u8> = Vec::with_capacity(bytes.len() * 2);
299 let mut num: Vec<u32> = bytes[leading_zeros..].iter().map(|b| *b as u32).collect();
300 while !num.is_empty() {
301 let mut remainder: u32 = 0;
302 let mut next: Vec<u32> = Vec::with_capacity(num.len());
303 for &b in &num {
304 let acc = remainder * 256 + b;
305 let q = acc / 62;
306 remainder = acc % 62;
307 if !(next.is_empty() && q == 0) {
308 next.push(q);
309 }
310 }
311 digits.push(remainder as u8);
312 num = next;
313 }
314 let mut out = String::with_capacity(leading_zeros + digits.len());
315 for _ in 0..leading_zeros {
316 out.push(ALPHABET[0] as char);
317 }
318 for d in digits.into_iter().rev() {
319 out.push(ALPHABET[d as usize] as char);
320 }
321 out
322}
323
324fn base62_decode_bytes(s: &str) -> Option<Vec<u8>> {
325 if s.is_empty() {
326 return Some(Vec::new());
327 }
328 fn digit(c: char) -> Option<u32> {
329 match c {
330 '0'..='9' => Some(c as u32 - '0' as u32),
331 'a'..='z' => Some(c as u32 - 'a' as u32 + 10),
332 'A'..='Z' => Some(c as u32 - 'A' as u32 + 36),
333 _ => None,
334 }
335 }
336 let leading_zeros = s.chars().take_while(|c| *c == '0').count();
337 let mut num: Vec<u32> = Vec::with_capacity(s.len());
338 for c in s.chars().skip(leading_zeros) {
339 num.push(digit(c)?);
340 }
341 let mut bytes: Vec<u8> = Vec::new();
342 while !num.is_empty() {
343 let mut remainder: u32 = 0;
344 let mut next: Vec<u32> = Vec::with_capacity(num.len());
345 for &d in &num {
346 let acc = remainder * 62 + d;
347 let q = acc / 256;
348 remainder = acc % 256;
349 if !(next.is_empty() && q == 0) {
350 next.push(q);
351 }
352 }
353 bytes.push(remainder as u8);
354 num = next;
355 }
356 let mut out = vec![0u8; leading_zeros];
357 out.extend(bytes.into_iter().rev());
358 Some(out)
359}
360
361fn normalize_prefix_token(s: &str) -> String {
366 s.replace(['_', '.'], "-")
367}
368
369fn build_prefix_map(
387 mut connections: Vec<Connection>,
388) -> IndexMap<String, Connection> {
389 connections.sort_by(|a, b| a.url.cmp(&b.url));
390 let n = connections.len();
391
392 let names: Vec<String> = connections
393 .iter()
394 .map(|c| normalize_prefix_token(&c.initialize_result.server_info.name))
395 .collect();
396 let versions: Vec<String> = connections
397 .iter()
398 .map(|c| normalize_prefix_token(&c.initialize_result.server_info.version))
399 .collect();
400
401 let prefix_at = |i: usize, tier: u8| -> String {
403 match tier {
404 1 => names[i].clone(),
405 2 => format!("{}-{}", names[i], versions[i]),
406 _ => format!("{}-{}-{}", names[i], versions[i], i),
407 }
408 };
409
410 let mut tier: Vec<u8> = vec![1; n];
411 loop {
412 let current: Vec<String> = (0..n).map(|i| prefix_at(i, tier[i])).collect();
413 let mut counts: std::collections::HashMap<&str, usize> =
414 std::collections::HashMap::new();
415 for p in ¤t {
416 *counts.entry(p.as_str()).or_insert(0) += 1;
417 }
418 let mut changed = false;
419 for i in 0..n {
420 if counts[current[i].as_str()] > 1 && tier[i] < 3 {
421 tier[i] += 1;
422 changed = true;
423 }
424 }
425 if !changed {
426 break;
427 }
428 }
429
430 let mut by_prefix: IndexMap<String, Connection> = IndexMap::with_capacity(n);
431 for (i, connection) in connections.into_iter().enumerate() {
432 let key = prefix_at(i, tier[i]);
433 debug_assert!(
436 !by_prefix.contains_key(&key),
437 "duplicate routing prefix after escalation: {key}",
438 );
439 by_prefix.insert(key, connection);
440 }
441 by_prefix
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447
448 fn sample_payload() -> SessionPayload {
449 let mut connections: IndexMap<String, IndexMap<String, String>> = IndexMap::new();
450 let mut h_a: IndexMap<String, String> = IndexMap::new();
451 h_a.insert("Authorization".into(), "Bearer secret-A".into());
452 h_a.insert("Mcp-Session-Id".into(), "sid-A".into());
453 h_a.insert("X-Tenant".into(), "tenant-1".into());
454 connections.insert("https://upstream-a.example/mcp".into(), h_a);
455 let mut h_b: IndexMap<String, String> = IndexMap::new();
456 h_b.insert("Mcp-Session-Id".into(), "sid-B".into());
457 connections.insert("https://upstream-b.example/mcp".into(), h_b);
458 SessionPayload { connections }
459 }
460
461 #[test]
462 fn base62_round_trip() {
463 for sample in [
464 &b""[..],
465 &b"a"[..],
466 &b"\x00\x01\x02"[..],
467 &b"hello world"[..],
468 br#"{"http://127.0.0.1:1234":"abc123"}"#,
469 &(0..=255u16).map(|b| b as u8).collect::<Vec<_>>()[..],
470 ] {
471 let encoded = base62_encode_bytes(sample);
472 assert!(encoded.bytes().all(|b| (0x21..=0x7E).contains(&b)));
473 let decoded = base62_decode_bytes(&encoded).expect("decode");
474 assert_eq!(decoded, sample, "round-trip failed for {sample:?}");
475 }
476 }
477
478 #[test]
479 fn encrypt_decrypt_round_trip() {
480 let key = [0x42u8; 32];
481 let payload = sample_payload();
482 let id = encrypt_and_encode(&payload, &key);
483 let decoded = decode_with_key(&id, &key).expect("decode under same key");
484 assert_eq!(decoded, payload);
485 }
486
487 #[test]
488 fn decode_with_wrong_key_returns_none() {
489 let key_a = [0x11u8; 32];
490 let key_b = [0x22u8; 32];
491 let id = encrypt_and_encode(&sample_payload(), &key_a);
492 assert!(decode_with_key(&id, &key_b).is_none());
493 }
494
495 #[test]
496 fn decode_garbage_returns_none() {
497 let key = [0x55u8; 32];
498 assert!(decode_with_key("ABCdef123", &key).is_none());
500 assert!(decode_with_key("", &key).is_none());
502 assert!(decode_with_key("0", &key).is_none());
504 }
505
506 #[test]
507 fn payload_roundtrip_preserves_canonical_order() {
508 let conn_a_url = "https://b.example/mcp".to_string();
511 let conn_b_url = "https://a.example/mcp".to_string();
512
513 let mut h_unsorted: IndexMap<String, String> = IndexMap::new();
514 h_unsorted.insert("Z-Header".into(), "z".into());
515 h_unsorted.insert("Authorization".into(), "Bearer".into());
516
517 let pairs_unsorted: Vec<(String, IndexMap<String, String>)> =
521 vec![(conn_a_url.clone(), h_unsorted.clone()), (conn_b_url.clone(), h_unsorted.clone())];
522
523 let mut connections: IndexMap<String, IndexMap<String, String>> = IndexMap::new();
524 let mut url_entries: Vec<(String, IndexMap<String, String>)> = pairs_unsorted
525 .into_iter()
526 .map(|(url, headers)| {
527 let mut sorted: Vec<(&str, &str)> =
528 headers.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
529 sorted.sort_by(|a, b| a.0.cmp(b.0));
530 let inner: IndexMap<String, String> = sorted
531 .into_iter()
532 .map(|(k, v)| (k.to_string(), v.to_string()))
533 .collect();
534 (url, inner)
535 })
536 .collect();
537 url_entries.sort_by(|a, b| a.0.cmp(&b.0));
538 for (u, h) in url_entries {
539 connections.insert(u, h);
540 }
541 let payload = SessionPayload { connections };
542
543 let urls: Vec<&String> = payload.connections.keys().collect();
544 assert_eq!(urls, vec![&conn_b_url, &conn_a_url]); let inner = &payload.connections[&conn_b_url];
546 let inner_keys: Vec<&String> = inner.keys().collect();
547 assert_eq!(inner_keys, vec!["Authorization", "Z-Header"]); }
549
550 #[test]
551 fn parse_key_env_round_trip() {
552 let key = [0xAAu8; 32];
553 let env = base64::engine::general_purpose::STANDARD.encode(key);
554 let parsed = parse_key_env(&env).expect("parse").expect("Some");
555 assert_eq!(parsed, key);
556
557 assert!(parse_key_env("").unwrap().is_none());
558 assert!(parse_key_env(" ").unwrap().is_none());
559 assert!(parse_key_env("not-base64!@#").is_err());
560 let short =
562 base64::engine::general_purpose::STANDARD.encode(&[0u8; 16][..]);
563 assert!(parse_key_env(&short).is_err());
564 }
565}