use std::sync::Arc;
use base64::Engine;
use chacha20poly1305::aead::{Aead, KeyInit};
use chacha20poly1305::{Key, XChaCha20Poly1305, XNonce};
use dashmap::DashMap;
use indexmap::IndexMap;
use objectiveai_sdk::mcp::Connection;
use rand::RngCore;
use crate::session::Session;
pub type SessionPayload = IndexMap<String, IndexMap<String, String>>;
const VERSION: u8 = 0x01;
const NONCE_LEN: usize = 24; const TAG_LEN: usize = 16;
#[derive(Debug)]
pub struct SessionManager {
sessions: DashMap<String, Arc<Session>>,
key: [u8; 32],
}
impl SessionManager {
pub fn new(key: [u8; 32]) -> Self {
Self {
sessions: DashMap::new(),
key,
}
}
pub fn with_ephemeral_key() -> Self {
let mut key = [0u8; 32];
rand::rng().fill_bytes(&mut key);
Self::new(key)
}
pub fn add(
&self,
connections_with_headers: Vec<(Connection, IndexMap<String, String>)>,
) -> String {
let payload = build_payload(&connections_with_headers);
let id = encrypt_and_encode(&payload, &self.key);
let connections: Vec<Connection> =
connections_with_headers.into_iter().map(|(c, _)| c).collect();
let by_name = build_by_name_map(connections);
self.sessions
.insert(id.clone(), Arc::new(Session::new(by_name, payload)));
id
}
pub fn get(&self, session_id: &str) -> Option<Arc<Session>> {
self.sessions.get(session_id).map(|e| e.value().clone())
}
pub fn remove(&self, session_id: &str) -> Option<Arc<Session>> {
self.sessions.remove(session_id).map(|(_, session)| session)
}
pub fn decode_session_id(&self, id: &str) -> Option<SessionPayload> {
decode_with_key(id, &self.key)
}
pub fn mint_id(&self, payload: &SessionPayload) -> String {
encrypt_and_encode(payload, &self.key)
}
}
fn build_payload(
pairs: &[(Connection, IndexMap<String, String>)],
) -> SessionPayload {
let mut url_entries: Vec<(String, IndexMap<String, String>)> = pairs
.iter()
.map(|(c, headers)| {
let mut sorted: Vec<(&str, &str)> = headers
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect();
sorted.sort_by(|a, b| a.0.cmp(b.0));
let inner: IndexMap<String, String> = sorted
.into_iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
(c.url.clone(), inner)
})
.collect();
url_entries.sort_by(|a, b| a.0.cmp(&b.0));
let mut payload: SessionPayload = IndexMap::with_capacity(url_entries.len());
for (url, headers) in url_entries {
payload.insert(url, headers);
}
payload
}
fn encrypt_and_encode(payload: &SessionPayload, key: &[u8; 32]) -> String {
let plaintext =
serde_json::to_vec(payload).expect("SessionPayload serializes");
let mut hasher = blake3::Hasher::new_keyed(key);
hasher.update(&plaintext);
let mut nonce_bytes = [0u8; NONCE_LEN];
nonce_bytes.copy_from_slice(&hasher.finalize().as_bytes()[..NONCE_LEN]);
let cipher = XChaCha20Poly1305::new(Key::from_slice(key));
let nonce = XNonce::from_slice(&nonce_bytes);
let ciphertext_with_tag = cipher
.encrypt(nonce, plaintext.as_ref())
.expect("XChaCha20-Poly1305 encrypt is infallible for valid key/nonce");
let mut envelope = Vec::with_capacity(1 + NONCE_LEN + ciphertext_with_tag.len());
envelope.push(VERSION);
envelope.extend_from_slice(&nonce_bytes);
envelope.extend_from_slice(&ciphertext_with_tag);
base62_encode_bytes(&envelope)
}
fn decode_with_key(id: &str, key: &[u8; 32]) -> Option<SessionPayload> {
let envelope = base62_decode_bytes(id)?;
if envelope.len() < 1 + NONCE_LEN + TAG_LEN {
return None;
}
if envelope[0] != VERSION {
return None;
}
let nonce = XNonce::from_slice(&envelope[1..1 + NONCE_LEN]);
let ciphertext = &envelope[1 + NONCE_LEN..];
let cipher = XChaCha20Poly1305::new(Key::from_slice(key));
let plaintext = cipher.decrypt(nonce, ciphertext).ok()?;
serde_json::from_slice(&plaintext).ok()
}
pub fn parse_key_env(s: &str) -> Result<Option<[u8; 32]>, String> {
let trimmed = s.trim();
if trimmed.is_empty() {
return Ok(None);
}
let decoded = base64::engine::general_purpose::STANDARD
.decode(trimmed)
.map_err(|e| format!("MCP_ENCRYPTION_KEY: not valid base64: {e}"))?;
let key: [u8; 32] = decoded.try_into().map_err(|got: Vec<u8>| {
format!(
"MCP_ENCRYPTION_KEY: expected 32 bytes after base64-decode, got {}",
got.len(),
)
})?;
Ok(Some(key))
}
fn base62_encode_bytes(bytes: &[u8]) -> String {
if bytes.is_empty() {
return String::new();
}
const ALPHABET: &[u8; 62] =
b"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
let leading_zeros = bytes.iter().take_while(|b| **b == 0).count();
let mut digits: Vec<u8> = Vec::with_capacity(bytes.len() * 2);
let mut num: Vec<u32> = bytes[leading_zeros..].iter().map(|b| *b as u32).collect();
while !num.is_empty() {
let mut remainder: u32 = 0;
let mut next: Vec<u32> = Vec::with_capacity(num.len());
for &b in &num {
let acc = remainder * 256 + b;
let q = acc / 62;
remainder = acc % 62;
if !(next.is_empty() && q == 0) {
next.push(q);
}
}
digits.push(remainder as u8);
num = next;
}
let mut out = String::with_capacity(leading_zeros + digits.len());
for _ in 0..leading_zeros {
out.push(ALPHABET[0] as char);
}
for d in digits.into_iter().rev() {
out.push(ALPHABET[d as usize] as char);
}
out
}
fn base62_decode_bytes(s: &str) -> Option<Vec<u8>> {
if s.is_empty() {
return Some(Vec::new());
}
fn digit(c: char) -> Option<u32> {
match c {
'0'..='9' => Some(c as u32 - '0' as u32),
'a'..='z' => Some(c as u32 - 'a' as u32 + 10),
'A'..='Z' => Some(c as u32 - 'A' as u32 + 36),
_ => None,
}
}
let leading_zeros = s.chars().take_while(|c| *c == '0').count();
let mut num: Vec<u32> = Vec::with_capacity(s.len());
for c in s.chars().skip(leading_zeros) {
num.push(digit(c)?);
}
let mut bytes: Vec<u8> = Vec::new();
while !num.is_empty() {
let mut remainder: u32 = 0;
let mut next: Vec<u32> = Vec::with_capacity(num.len());
for &d in &num {
let acc = remainder * 62 + d;
let q = acc / 256;
remainder = acc % 256;
if !(next.is_empty() && q == 0) {
next.push(q);
}
}
bytes.push(remainder as u8);
num = next;
}
let mut out = vec![0u8; leading_zeros];
out.extend(bytes.into_iter().rev());
Some(out)
}
fn build_by_name_map(
connections: Vec<Connection>,
) -> IndexMap<String, Connection> {
let mut name_counts: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
for c in &connections {
*name_counts
.entry(c.initialize_result.server_info.name.clone())
.or_insert(0) += 1;
}
let mut by_name: IndexMap<String, Connection> =
IndexMap::with_capacity(connections.len());
for (idx, connection) in connections.into_iter().enumerate() {
let raw = connection.initialize_result.server_info.name.clone();
let key = if name_counts.get(&raw).copied().unwrap_or(0) > 1 {
format!("{raw}_{idx}")
} else {
raw
};
if by_name.contains_key(&key) {
tracing::warn!(
key = %key,
"two upstreams produce the same prefix after disambiguation; later upstream wins",
);
}
by_name.insert(key, connection);
}
by_name
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_payload() -> SessionPayload {
let mut p: SessionPayload = IndexMap::new();
let mut h_a: IndexMap<String, String> = IndexMap::new();
h_a.insert("Authorization".into(), "Bearer secret-A".into());
h_a.insert("Mcp-Session-Id".into(), "sid-A".into());
h_a.insert("X-Tenant".into(), "tenant-1".into());
p.insert("https://upstream-a.example/mcp".into(), h_a);
let mut h_b: IndexMap<String, String> = IndexMap::new();
h_b.insert("Mcp-Session-Id".into(), "sid-B".into());
p.insert("https://upstream-b.example/mcp".into(), h_b);
p
}
#[test]
fn base62_round_trip() {
for sample in [
&b""[..],
&b"a"[..],
&b"\x00\x01\x02"[..],
&b"hello world"[..],
br#"{"http://127.0.0.1:1234":"abc123"}"#,
&(0..=255u16).map(|b| b as u8).collect::<Vec<_>>()[..],
] {
let encoded = base62_encode_bytes(sample);
assert!(encoded.bytes().all(|b| (0x21..=0x7E).contains(&b)));
let decoded = base62_decode_bytes(&encoded).expect("decode");
assert_eq!(decoded, sample, "round-trip failed for {sample:?}");
}
}
#[test]
fn encrypt_decrypt_round_trip() {
let key = [0x42u8; 32];
let payload = sample_payload();
let id = encrypt_and_encode(&payload, &key);
let decoded = decode_with_key(&id, &key).expect("decode under same key");
assert_eq!(decoded, payload);
}
#[test]
fn decode_with_wrong_key_returns_none() {
let key_a = [0x11u8; 32];
let key_b = [0x22u8; 32];
let id = encrypt_and_encode(&sample_payload(), &key_a);
assert!(decode_with_key(&id, &key_b).is_none());
}
#[test]
fn decode_garbage_returns_none() {
let key = [0x55u8; 32];
assert!(decode_with_key("ABCdef123", &key).is_none());
assert!(decode_with_key("", &key).is_none());
assert!(decode_with_key("0", &key).is_none());
}
#[test]
fn payload_roundtrip_preserves_canonical_order() {
let conn_a_url = "https://b.example/mcp".to_string();
let conn_b_url = "https://a.example/mcp".to_string();
let mut h_unsorted: IndexMap<String, String> = IndexMap::new();
h_unsorted.insert("Z-Header".into(), "z".into());
h_unsorted.insert("Authorization".into(), "Bearer".into());
let pairs_unsorted: Vec<(String, IndexMap<String, String>)> =
vec![(conn_a_url.clone(), h_unsorted.clone()), (conn_b_url.clone(), h_unsorted.clone())];
let mut payload: SessionPayload = IndexMap::new();
let mut url_entries: Vec<(String, IndexMap<String, String>)> = pairs_unsorted
.into_iter()
.map(|(url, headers)| {
let mut sorted: Vec<(&str, &str)> =
headers.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
sorted.sort_by(|a, b| a.0.cmp(b.0));
let inner: IndexMap<String, String> = sorted
.into_iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
(url, inner)
})
.collect();
url_entries.sort_by(|a, b| a.0.cmp(&b.0));
for (u, h) in url_entries {
payload.insert(u, h);
}
let urls: Vec<&String> = payload.keys().collect();
assert_eq!(urls, vec![&conn_b_url, &conn_a_url]); let inner = &payload[&conn_b_url];
let inner_keys: Vec<&String> = inner.keys().collect();
assert_eq!(inner_keys, vec!["Authorization", "Z-Header"]); }
#[test]
fn parse_key_env_round_trip() {
let key = [0xAAu8; 32];
let env = base64::engine::general_purpose::STANDARD.encode(key);
let parsed = parse_key_env(&env).expect("parse").expect("Some");
assert_eq!(parsed, key);
assert!(parse_key_env("").unwrap().is_none());
assert!(parse_key_env(" ").unwrap().is_none());
assert!(parse_key_env("not-base64!@#").is_err());
let short =
base64::engine::general_purpose::STANDARD.encode(&[0u8; 16][..]);
assert!(parse_key_env(&short).is_err());
}
}