1use std::sync::Arc;
40
41use base64::Engine;
42use chacha20poly1305::aead::{Aead, KeyInit};
43use chacha20poly1305::{Key, XChaCha20Poly1305, XNonce};
44use dashmap::DashMap;
45use indexmap::IndexMap;
46use objectiveai_sdk::mcp::Connection;
47use rand::RngCore;
48
49use crate::session::Session;
50
51#[derive(Debug, Clone, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
67pub struct SessionPayload {
68 pub connections: IndexMap<String, IndexMap<String, String>>,
69}
70
71const VERSION: u8 = 0x01;
76const NONCE_LEN: usize = 24; const TAG_LEN: usize = 16; #[derive(Debug)]
82pub struct SessionManager {
83 sessions: DashMap<String, Arc<Session>>,
84 key: [u8; 32],
88}
89
90impl SessionManager {
91 pub fn new(key: [u8; 32]) -> Self {
92 Self {
93 sessions: DashMap::new(),
94 key,
95 }
96 }
97
98 pub fn with_ephemeral_key() -> Self {
103 let mut key = [0u8; 32];
104 rand::rng().fill_bytes(&mut key);
105 Self::new(key)
106 }
107
108 pub fn add(
121 &self,
122 connections_with_headers: Vec<(Connection, IndexMap<String, String>)>,
123 ) -> String {
124 let payload = build_payload(&connections_with_headers);
125 let id = encrypt_and_encode(&payload, &self.key);
126 let connections: Vec<Connection> =
127 connections_with_headers.into_iter().map(|(c, _)| c).collect();
128 let by_name = build_by_name_map(connections);
129 self.sessions
130 .insert(id.clone(), Arc::new(Session::new(by_name, payload)));
131 id
132 }
133
134 pub fn get(&self, session_id: &str) -> Option<Arc<Session>> {
137 self.sessions.get(session_id).map(|e| e.value().clone())
138 }
139
140 pub fn remove(&self, session_id: &str) -> Option<Arc<Session>> {
150 self.sessions.remove(session_id).map(|(_, session)| session)
151 }
152
153 pub fn decode_session_id(&self, id: &str) -> Option<SessionPayload> {
157 decode_with_key(id, &self.key)
158 }
159
160 pub fn mint_id(&self, payload: &SessionPayload) -> String {
168 encrypt_and_encode(payload, &self.key)
169 }
170}
171
172fn build_payload(
181 pairs: &[(Connection, IndexMap<String, String>)],
182) -> SessionPayload {
183 let mut url_entries: Vec<(String, IndexMap<String, String>)> = pairs
185 .iter()
186 .map(|(c, headers)| {
187 let mut sorted: Vec<(&str, &str)> = headers
188 .iter()
189 .map(|(k, v)| (k.as_str(), v.as_str()))
190 .collect();
191 sorted.sort_by(|a, b| a.0.cmp(b.0));
192 let inner: IndexMap<String, String> = sorted
193 .into_iter()
194 .map(|(k, v)| (k.to_string(), v.to_string()))
195 .collect();
196 (c.url.clone(), inner)
197 })
198 .collect();
199 url_entries.sort_by(|a, b| a.0.cmp(&b.0));
200
201 let mut connections: IndexMap<String, IndexMap<String, String>> =
202 IndexMap::with_capacity(url_entries.len());
203 for (url, headers) in url_entries {
204 connections.insert(url, headers);
205 }
206
207 SessionPayload { connections }
208}
209
210fn encrypt_and_encode(payload: &SessionPayload, key: &[u8; 32]) -> String {
225 let plaintext =
226 serde_json::to_vec(payload).expect("SessionPayload serializes");
227
228 let mut hasher = blake3::Hasher::new_keyed(key);
233 hasher.update(&plaintext);
234 let mut nonce_bytes = [0u8; NONCE_LEN];
235 nonce_bytes.copy_from_slice(&hasher.finalize().as_bytes()[..NONCE_LEN]);
236
237 let cipher = XChaCha20Poly1305::new(Key::from_slice(key));
238 let nonce = XNonce::from_slice(&nonce_bytes);
239 let ciphertext_with_tag = cipher
240 .encrypt(nonce, plaintext.as_ref())
241 .expect("XChaCha20-Poly1305 encrypt is infallible for valid key/nonce");
242
243 let mut envelope = Vec::with_capacity(1 + NONCE_LEN + ciphertext_with_tag.len());
244 envelope.push(VERSION);
245 envelope.extend_from_slice(&nonce_bytes);
246 envelope.extend_from_slice(&ciphertext_with_tag);
247 base62_encode_bytes(&envelope)
248}
249
250fn decode_with_key(id: &str, key: &[u8; 32]) -> Option<SessionPayload> {
252 let envelope = base62_decode_bytes(id)?;
253 if envelope.len() < 1 + NONCE_LEN + TAG_LEN {
254 return None;
255 }
256 if envelope[0] != VERSION {
257 return None;
258 }
259 let nonce = XNonce::from_slice(&envelope[1..1 + NONCE_LEN]);
260 let ciphertext = &envelope[1 + NONCE_LEN..];
261 let cipher = XChaCha20Poly1305::new(Key::from_slice(key));
262 let plaintext = cipher.decrypt(nonce, ciphertext).ok()?;
263 serde_json::from_slice(&plaintext).ok()
264}
265
266pub fn parse_key_env(s: &str) -> Result<Option<[u8; 32]>, String> {
269 let trimmed = s.trim();
270 if trimmed.is_empty() {
271 return Ok(None);
272 }
273 let decoded = base64::engine::general_purpose::STANDARD
274 .decode(trimmed)
275 .map_err(|e| format!("MCP_ENCRYPTION_KEY: not valid base64: {e}"))?;
276 let key: [u8; 32] = decoded.try_into().map_err(|got: Vec<u8>| {
277 format!(
278 "MCP_ENCRYPTION_KEY: expected 32 bytes after base64-decode, got {}",
279 got.len(),
280 )
281 })?;
282 Ok(Some(key))
283}
284
285fn base62_encode_bytes(bytes: &[u8]) -> String {
292 if bytes.is_empty() {
293 return String::new();
294 }
295 const ALPHABET: &[u8; 62] =
296 b"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
297 let leading_zeros = bytes.iter().take_while(|b| **b == 0).count();
298 let mut digits: Vec<u8> = Vec::with_capacity(bytes.len() * 2);
299 let mut num: Vec<u32> = bytes[leading_zeros..].iter().map(|b| *b as u32).collect();
300 while !num.is_empty() {
301 let mut remainder: u32 = 0;
302 let mut next: Vec<u32> = Vec::with_capacity(num.len());
303 for &b in &num {
304 let acc = remainder * 256 + b;
305 let q = acc / 62;
306 remainder = acc % 62;
307 if !(next.is_empty() && q == 0) {
308 next.push(q);
309 }
310 }
311 digits.push(remainder as u8);
312 num = next;
313 }
314 let mut out = String::with_capacity(leading_zeros + digits.len());
315 for _ in 0..leading_zeros {
316 out.push(ALPHABET[0] as char);
317 }
318 for d in digits.into_iter().rev() {
319 out.push(ALPHABET[d as usize] as char);
320 }
321 out
322}
323
324fn base62_decode_bytes(s: &str) -> Option<Vec<u8>> {
325 if s.is_empty() {
326 return Some(Vec::new());
327 }
328 fn digit(c: char) -> Option<u32> {
329 match c {
330 '0'..='9' => Some(c as u32 - '0' as u32),
331 'a'..='z' => Some(c as u32 - 'a' as u32 + 10),
332 'A'..='Z' => Some(c as u32 - 'A' as u32 + 36),
333 _ => None,
334 }
335 }
336 let leading_zeros = s.chars().take_while(|c| *c == '0').count();
337 let mut num: Vec<u32> = Vec::with_capacity(s.len());
338 for c in s.chars().skip(leading_zeros) {
339 num.push(digit(c)?);
340 }
341 let mut bytes: Vec<u8> = Vec::new();
342 while !num.is_empty() {
343 let mut remainder: u32 = 0;
344 let mut next: Vec<u32> = Vec::with_capacity(num.len());
345 for &d in &num {
346 let acc = remainder * 62 + d;
347 let q = acc / 256;
348 remainder = acc % 256;
349 if !(next.is_empty() && q == 0) {
350 next.push(q);
351 }
352 }
353 bytes.push(remainder as u8);
354 num = next;
355 }
356 let mut out = vec![0u8; leading_zeros];
357 out.extend(bytes.into_iter().rev());
358 Some(out)
359}
360
361fn build_by_name_map(
362 connections: Vec<Connection>,
363) -> IndexMap<String, Connection> {
364 let mut name_counts: std::collections::HashMap<String, usize> =
367 std::collections::HashMap::new();
368 for c in &connections {
369 *name_counts
370 .entry(c.initialize_result.server_info.name.clone())
371 .or_insert(0) += 1;
372 }
373 let mut by_name: IndexMap<String, Connection> =
374 IndexMap::with_capacity(connections.len());
375 for (idx, connection) in connections.into_iter().enumerate() {
376 let raw = connection.initialize_result.server_info.name.clone();
377 let key = if name_counts.get(&raw).copied().unwrap_or(0) > 1 {
378 format!("{raw}_{idx}")
379 } else {
380 raw
381 };
382 if by_name.contains_key(&key) {
383 tracing::warn!(
384 key = %key,
385 "two upstreams produce the same prefix after disambiguation; later upstream wins",
386 );
387 }
388 by_name.insert(key, connection);
389 }
390 by_name
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 fn sample_payload() -> SessionPayload {
398 let mut connections: IndexMap<String, IndexMap<String, String>> = IndexMap::new();
399 let mut h_a: IndexMap<String, String> = IndexMap::new();
400 h_a.insert("Authorization".into(), "Bearer secret-A".into());
401 h_a.insert("Mcp-Session-Id".into(), "sid-A".into());
402 h_a.insert("X-Tenant".into(), "tenant-1".into());
403 connections.insert("https://upstream-a.example/mcp".into(), h_a);
404 let mut h_b: IndexMap<String, String> = IndexMap::new();
405 h_b.insert("Mcp-Session-Id".into(), "sid-B".into());
406 connections.insert("https://upstream-b.example/mcp".into(), h_b);
407 SessionPayload { connections }
408 }
409
410 #[test]
411 fn base62_round_trip() {
412 for sample in [
413 &b""[..],
414 &b"a"[..],
415 &b"\x00\x01\x02"[..],
416 &b"hello world"[..],
417 br#"{"http://127.0.0.1:1234":"abc123"}"#,
418 &(0..=255u16).map(|b| b as u8).collect::<Vec<_>>()[..],
419 ] {
420 let encoded = base62_encode_bytes(sample);
421 assert!(encoded.bytes().all(|b| (0x21..=0x7E).contains(&b)));
422 let decoded = base62_decode_bytes(&encoded).expect("decode");
423 assert_eq!(decoded, sample, "round-trip failed for {sample:?}");
424 }
425 }
426
427 #[test]
428 fn encrypt_decrypt_round_trip() {
429 let key = [0x42u8; 32];
430 let payload = sample_payload();
431 let id = encrypt_and_encode(&payload, &key);
432 let decoded = decode_with_key(&id, &key).expect("decode under same key");
433 assert_eq!(decoded, payload);
434 }
435
436 #[test]
437 fn decode_with_wrong_key_returns_none() {
438 let key_a = [0x11u8; 32];
439 let key_b = [0x22u8; 32];
440 let id = encrypt_and_encode(&sample_payload(), &key_a);
441 assert!(decode_with_key(&id, &key_b).is_none());
442 }
443
444 #[test]
445 fn decode_garbage_returns_none() {
446 let key = [0x55u8; 32];
447 assert!(decode_with_key("ABCdef123", &key).is_none());
449 assert!(decode_with_key("", &key).is_none());
451 assert!(decode_with_key("0", &key).is_none());
453 }
454
455 #[test]
456 fn payload_roundtrip_preserves_canonical_order() {
457 let conn_a_url = "https://b.example/mcp".to_string();
460 let conn_b_url = "https://a.example/mcp".to_string();
461
462 let mut h_unsorted: IndexMap<String, String> = IndexMap::new();
463 h_unsorted.insert("Z-Header".into(), "z".into());
464 h_unsorted.insert("Authorization".into(), "Bearer".into());
465
466 let pairs_unsorted: Vec<(String, IndexMap<String, String>)> =
470 vec![(conn_a_url.clone(), h_unsorted.clone()), (conn_b_url.clone(), h_unsorted.clone())];
471
472 let mut connections: IndexMap<String, IndexMap<String, String>> = IndexMap::new();
473 let mut url_entries: Vec<(String, IndexMap<String, String>)> = pairs_unsorted
474 .into_iter()
475 .map(|(url, headers)| {
476 let mut sorted: Vec<(&str, &str)> =
477 headers.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
478 sorted.sort_by(|a, b| a.0.cmp(b.0));
479 let inner: IndexMap<String, String> = sorted
480 .into_iter()
481 .map(|(k, v)| (k.to_string(), v.to_string()))
482 .collect();
483 (url, inner)
484 })
485 .collect();
486 url_entries.sort_by(|a, b| a.0.cmp(&b.0));
487 for (u, h) in url_entries {
488 connections.insert(u, h);
489 }
490 let payload = SessionPayload { connections };
491
492 let urls: Vec<&String> = payload.connections.keys().collect();
493 assert_eq!(urls, vec![&conn_b_url, &conn_a_url]); let inner = &payload.connections[&conn_b_url];
495 let inner_keys: Vec<&String> = inner.keys().collect();
496 assert_eq!(inner_keys, vec!["Authorization", "Z-Header"]); }
498
499 #[test]
500 fn parse_key_env_round_trip() {
501 let key = [0xAAu8; 32];
502 let env = base64::engine::general_purpose::STANDARD.encode(key);
503 let parsed = parse_key_env(&env).expect("parse").expect("Some");
504 assert_eq!(parsed, key);
505
506 assert!(parse_key_env("").unwrap().is_none());
507 assert!(parse_key_env(" ").unwrap().is_none());
508 assert!(parse_key_env("not-base64!@#").is_err());
509 let short =
511 base64::engine::general_purpose::STANDARD.encode(&[0u8; 16][..]);
512 assert!(parse_key_env(&short).is_err());
513 }
514}