1use std::{
9 collections::HashMap,
10 fmt,
11 sync::{Arc, Mutex},
12 time::{Duration, SystemTime},
13};
14
15use axum::http::{HeaderMap, HeaderName};
16use serde_json::Value;
17use thiserror::Error;
18use uuid::Uuid;
19
20use crate::config::{SessionConfig, SessionFallbackScope};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum SessionScope {
25 Agent,
27 Request,
29}
30
31impl SessionScope {
32 pub fn as_str(self) -> &'static str {
34 match self {
35 Self::Agent => "agent",
36 Self::Request => "request",
37 }
38 }
39}
40
41impl fmt::Display for SessionScope {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 f.write_str(self.as_str())
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum SessionExpirationReason {
51 IdleTtl,
52 MaxTtl,
53 MaxRequests,
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct SessionContext {
59 pub session_key: String,
60 pub model_id: String,
61 pub agent_session_id: String,
62 pub scope: SessionScope,
63 pub created_at: SystemTime,
64 pub last_used_at: SystemTime,
65 pub expires_at: SystemTime,
66 pub request_count: u64,
67 pub attested_model_public_key: Option<String>,
68 pub attestation_tee_provider: Option<String>,
69 pub attestation_tdx_debug: Option<bool>,
70 pub attestation_nvidia_verified: Option<String>,
71 pub verified_at: Option<SystemTime>,
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
76pub struct SessionResolution {
77 pub session: SessionContext,
78 pub created: bool,
79 pub replaced_expired: Option<SessionExpirationReason>,
80}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
84pub struct AttestedModelState {
85 pub model_public_key: String,
86 pub tee_provider: Option<String>,
87 pub tdx_debug: Option<bool>,
88 pub nvidia_verified: String,
89 pub verified_at: SystemTime,
90}
91
92#[derive(Debug, Clone, Copy)]
94pub struct SessionRequest<'a> {
95 pub model_id: &'a str,
96 pub headers: &'a HeaderMap,
97 pub body: Option<&'a Value>,
98}
99
100impl<'a> SessionRequest<'a> {
101 pub fn new(model_id: &'a str, headers: &'a HeaderMap) -> Self {
103 Self {
104 model_id,
105 headers,
106 body: None,
107 }
108 }
109
110 pub fn with_body(mut self, body: &'a Value) -> Self {
112 self.body = Some(body);
113 self
114 }
115}
116
117#[derive(Debug, Clone)]
119pub struct SessionManager {
120 config: SessionConfig,
121 sessions: Arc<Mutex<HashMap<String, SessionContext>>>,
122 agent_fallback_session_id: Arc<str>,
123}
124
125impl SessionManager {
126 pub fn new(config: SessionConfig) -> Self {
128 Self {
129 config,
130 sessions: Arc::new(Mutex::new(HashMap::new())),
131 agent_fallback_session_id: Arc::from(Uuid::new_v4().to_string()),
132 }
133 }
134
135 pub fn get_or_create(
138 &self,
139 request: SessionRequest<'_>,
140 ) -> Result<SessionResolution, SessionError> {
141 self.get_or_create_at(request, SystemTime::now())
142 }
143
144 pub fn get_or_create_at(
146 &self,
147 request: SessionRequest<'_>,
148 now: SystemTime,
149 ) -> Result<SessionResolution, SessionError> {
150 if request.model_id.trim().is_empty() {
151 return Err(SessionError::InvalidModelId);
152 }
153
154 let resolved = self.resolve_identifier(request)?;
155 let session_key = session_key(request.model_id, &resolved.agent_session_id);
156 let mut sessions = self.lock_sessions();
157 let replaced_expired = match sessions.get(&session_key) {
158 Some(existing) => self.expiration_reason(existing, now),
159 None => None,
160 };
161
162 if replaced_expired.is_some() {
163 sessions.remove(&session_key);
164 }
165
166 if let Some(existing) = sessions.get_mut(&session_key) {
167 existing.request_count += 1;
168 existing.last_used_at = now;
169 return Ok(SessionResolution {
170 session: existing.clone(),
171 created: false,
172 replaced_expired: None,
173 });
174 }
175
176 let context = SessionContext::new(
177 request.model_id,
178 resolved.agent_session_id,
179 resolved.scope,
180 now,
181 &self.config,
182 );
183 sessions.insert(session_key, context.clone());
184
185 Ok(SessionResolution {
186 session: context,
187 created: true,
188 replaced_expired,
189 })
190 }
191
192 pub fn set_attested_model_state(
194 &self,
195 session_key: &str,
196 state: AttestedModelState,
197 ) -> Result<SessionContext, SessionError> {
198 self.set_attested_model_state_at(session_key, state, SystemTime::now())
199 }
200
201 pub fn set_attested_model_state_at(
203 &self,
204 session_key: &str,
205 state: AttestedModelState,
206 now: SystemTime,
207 ) -> Result<SessionContext, SessionError> {
208 let mut sessions = self.lock_sessions();
209 let expired = sessions
210 .get(session_key)
211 .and_then(|session| self.expiration_reason(session, now));
212
213 if let Some(reason) = expired {
214 sessions.remove(session_key);
215 return Err(SessionError::SessionExpired { reason });
216 }
217
218 let session =
219 sessions
220 .get_mut(session_key)
221 .ok_or_else(|| SessionError::SessionNotFound {
222 session_key: session_key.to_owned(),
223 })?;
224 session.attested_model_public_key = Some(state.model_public_key);
225 session.attestation_tee_provider = state.tee_provider;
226 session.attestation_tdx_debug = state.tdx_debug;
227 session.attestation_nvidia_verified = Some(state.nvidia_verified);
228 session.verified_at = Some(state.verified_at);
229
230 Ok(session.clone())
231 }
232
233 pub fn cleanup_expired(&self) -> usize {
235 self.cleanup_expired_at(SystemTime::now())
236 }
237
238 pub fn cleanup_expired_at(&self, now: SystemTime) -> usize {
240 let mut sessions = self.lock_sessions();
241 let before = sessions.len();
242 sessions.retain(|_, session| self.expiration_reason(session, now).is_none());
243 before - sessions.len()
244 }
245
246 pub fn len(&self) -> usize {
248 self.lock_sessions().len()
249 }
250
251 pub fn is_empty(&self) -> bool {
253 self.len() == 0
254 }
255
256 fn resolve_identifier(
258 &self,
259 request: SessionRequest<'_>,
260 ) -> Result<ResolvedSessionIdentifier, SessionError> {
261 if let Some(value) = self.explicit_identifier(&request)? {
262 return Ok(ResolvedSessionIdentifier::agent(value));
263 }
264
265 match self.config.fallback_scope {
266 SessionFallbackScope::Agent => Ok(ResolvedSessionIdentifier::agent(
267 self.agent_fallback_session_id.to_string(),
268 )),
269 SessionFallbackScope::Request => Ok(ResolvedSessionIdentifier {
270 agent_session_id: Uuid::new_v4().to_string(),
271 scope: SessionScope::Request,
272 }),
273 SessionFallbackScope::Disabled => Err(SessionError::MissingSessionIdentifier),
274 }
275 }
276
277 fn explicit_identifier(
279 &self,
280 request: &SessionRequest<'_>,
281 ) -> Result<Option<String>, SessionError> {
282 if let Some(value) =
283 header_identifier(request.headers, &self.config.headers.incoming_session_id)?
284 {
285 return Ok(Some(value));
286 }
287
288 Ok(metadata_identifier(request.body, "session_id"))
289 }
290
291 fn expiration_reason(
293 &self,
294 session: &SessionContext,
295 now: SystemTime,
296 ) -> Option<SessionExpirationReason> {
297 if session.request_count >= self.config.max_requests {
298 return Some(SessionExpirationReason::MaxRequests);
299 }
300
301 if now >= session.expires_at {
302 return Some(SessionExpirationReason::MaxTtl);
303 }
304
305 if elapsed_since(session.last_used_at, now) >= self.config.idle_ttl {
306 return Some(SessionExpirationReason::IdleTtl);
307 }
308
309 None
310 }
311
312 fn lock_sessions(&self) -> std::sync::MutexGuard<'_, HashMap<String, SessionContext>> {
314 self.sessions
315 .lock()
316 .unwrap_or_else(std::sync::PoisonError::into_inner)
317 }
318}
319
320#[derive(Debug, Clone)]
322struct ResolvedSessionIdentifier {
323 agent_session_id: String,
324 scope: SessionScope,
325}
326
327impl ResolvedSessionIdentifier {
328 fn agent(agent_session_id: String) -> Self {
330 Self {
331 agent_session_id,
332 scope: SessionScope::Agent,
333 }
334 }
335}
336
337impl SessionContext {
338 fn new(
340 model_id: &str,
341 agent_session_id: String,
342 scope: SessionScope,
343 now: SystemTime,
344 config: &SessionConfig,
345 ) -> Self {
346 let session_key = session_key(model_id, &agent_session_id);
347 Self {
348 session_key,
349 model_id: model_id.to_owned(),
350 agent_session_id,
351 scope,
352 created_at: now,
353 last_used_at: now,
354 expires_at: now + config.max_ttl,
355 request_count: 1,
356 attested_model_public_key: None,
357 attestation_tee_provider: None,
358 attestation_tdx_debug: None,
359 attestation_nvidia_verified: None,
360 verified_at: None,
361 }
362 }
363}
364
365#[derive(Debug, Error, PartialEq, Eq)]
367pub enum SessionError {
368 #[error("request model id must not be empty")]
369 InvalidModelId,
370 #[error("request does not include a session identifier and session fallback is disabled")]
371 MissingSessionIdentifier,
372 #[error("configured session header name {header:?} is invalid")]
373 InvalidHeaderName { header: String },
374 #[error("session header {header} contains non-UTF-8 data")]
375 InvalidHeaderValue { header: String },
376 #[error("session {session_key} was not found")]
377 SessionNotFound { session_key: String },
378 #[error("session expired before attestation state could be stored: {reason:?}")]
379 SessionExpired { reason: SessionExpirationReason },
380}
381
382fn header_identifier(
384 headers: &HeaderMap,
385 configured_name: &str,
386) -> Result<Option<String>, SessionError> {
387 let name = HeaderName::from_bytes(configured_name.as_bytes()).map_err(|_| {
388 SessionError::InvalidHeaderName {
389 header: configured_name.to_owned(),
390 }
391 })?;
392
393 let Some(value) = headers.get(&name) else {
394 return Ok(None);
395 };
396 let value = value
397 .to_str()
398 .map_err(|_| SessionError::InvalidHeaderValue {
399 header: configured_name.to_owned(),
400 })?;
401 Ok(non_empty_string(value))
402}
403
404fn metadata_identifier(body: Option<&Value>, key: &str) -> Option<String> {
406 body.and_then(|body| body.get("metadata"))
407 .and_then(|metadata| metadata.get(key))
408 .and_then(Value::as_str)
409 .and_then(non_empty_string)
410}
411
412fn non_empty_string(value: &str) -> Option<String> {
414 let trimmed = value.trim();
415 (!trimmed.is_empty()).then(|| trimmed.to_owned())
416}
417
418fn session_key(model_id: &str, agent_session_id: &str) -> String {
420 format!("{model_id}:{agent_session_id}")
421}
422
423fn elapsed_since(start: SystemTime, now: SystemTime) -> Duration {
425 now.duration_since(start).unwrap_or(Duration::ZERO)
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use axum::http::HeaderValue;
432 use serde_json::json;
433
434 fn test_config() -> SessionConfig {
435 SessionConfig {
436 idle_ttl: Duration::from_secs(10),
437 max_ttl: Duration::from_secs(30),
438 max_requests: 3,
439 fallback_scope: SessionFallbackScope::Request,
440 headers: Default::default(),
441 }
442 }
443
444 fn manager() -> SessionManager {
445 SessionManager::new(test_config())
446 }
447
448 fn now(seconds: u64) -> SystemTime {
449 SystemTime::UNIX_EPOCH + Duration::from_secs(seconds)
450 }
451
452 fn request<'a>(model_id: &'a str, headers: &'a HeaderMap) -> SessionRequest<'a> {
453 SessionRequest::new(model_id, headers)
454 }
455
456 #[test]
457 fn creates_new_agent_session_from_incoming_session_id_header() {
458 let manager = manager();
459 let mut headers = HeaderMap::new();
460 headers.insert(
461 "X-Venice-Proxy-Session-Id",
462 HeaderValue::from_static("chat-1"),
463 );
464
465 let resolved = manager
466 .get_or_create_at(request("model-a", &headers), now(0))
467 .expect("session should resolve");
468
469 assert!(resolved.created);
470 assert_eq!(resolved.replaced_expired, None);
471 assert_eq!(resolved.session.session_key, "model-a:chat-1");
472 assert_eq!(resolved.session.model_id, "model-a");
473 assert_eq!(resolved.session.agent_session_id, "chat-1");
474 assert_eq!(resolved.session.scope, SessionScope::Agent);
475 assert_eq!(resolved.session.request_count, 1);
476 }
477
478 #[test]
479 fn reuses_existing_session_from_configured_header() {
480 let mut config = test_config();
481 config.headers.incoming_session_id = "X-Custom-Session-Id".to_owned();
482 let manager = SessionManager::new(config);
483 let mut headers = HeaderMap::new();
484 headers.insert(
485 "X-Custom-Session-Id",
486 HeaderValue::from_static("configured-chat"),
487 );
488
489 let first = manager
490 .get_or_create_at(request("model-a", &headers), now(0))
491 .expect("first request should create");
492 let second = manager
493 .get_or_create_at(request("model-a", &headers), now(5))
494 .expect("second request should reuse");
495
496 assert!(first.created);
497 assert!(!second.created);
498 assert_eq!(second.session.session_key, first.session.session_key);
499 assert_eq!(second.session.request_count, 2);
500 assert_eq!(second.session.last_used_at, now(5));
501 assert_eq!(manager.len(), 1);
502 }
503
504 #[test]
505 fn configured_header_wins_over_metadata() {
506 let manager = manager();
507 let mut headers = HeaderMap::new();
508 headers.insert(
509 "X-Venice-Proxy-Session-Id",
510 HeaderValue::from_static("header-session"),
511 );
512 let body = json!({ "metadata": { "session_id": "body-session" } });
513
514 let resolved = manager
515 .get_or_create_at(
516 SessionRequest::new("model-a", &headers).with_body(&body),
517 now(0),
518 )
519 .expect("session should resolve");
520
521 assert_eq!(resolved.session.session_key, "model-a:header-session");
522 }
523
524 #[test]
525 fn metadata_session_id_is_used_when_headers_are_missing() {
526 let manager = manager();
527 let headers = HeaderMap::new();
528 let body = json!({ "metadata": { "session_id": "metadata-session" } });
529
530 let resolved = manager
531 .get_or_create_at(
532 SessionRequest::new("model-a", &headers).with_body(&body),
533 now(0),
534 )
535 .expect("session should resolve");
536
537 assert_eq!(resolved.session.session_key, "model-a:metadata-session");
538 assert_eq!(resolved.session.scope, SessionScope::Agent);
539 }
540
541 #[test]
542 fn idle_ttl_expiration_discards_old_session_and_creates_fresh_one() {
543 let manager = manager();
544 let mut headers = HeaderMap::new();
545 headers.insert(
546 "X-Venice-Proxy-Session-Id",
547 HeaderValue::from_static("chat-1"),
548 );
549
550 let first = manager
551 .get_or_create_at(request("model-a", &headers), now(0))
552 .expect("first request should create");
553 let second = manager
554 .get_or_create_at(request("model-a", &headers), now(10))
555 .expect("idle-expired request should recreate");
556
557 assert!(second.created);
558 assert_eq!(
559 second.replaced_expired,
560 Some(SessionExpirationReason::IdleTtl)
561 );
562 assert_eq!(second.session.session_key, first.session.session_key);
563 assert_eq!(second.session.request_count, 1);
564 assert_eq!(second.session.created_at, now(10));
565 }
566
567 #[test]
568 fn max_ttl_expiration_discards_old_session_and_creates_fresh_one() {
569 let mut config = test_config();
570 config.idle_ttl = Duration::from_secs(20);
571 config.max_ttl = Duration::from_secs(30);
572 let manager = SessionManager::new(config);
573 let mut headers = HeaderMap::new();
574 headers.insert(
575 "X-Venice-Proxy-Session-Id",
576 HeaderValue::from_static("chat-1"),
577 );
578
579 let first = manager
580 .get_or_create_at(request("model-a", &headers), now(0))
581 .expect("first request should create");
582 manager
583 .get_or_create_at(request("model-a", &headers), now(15))
584 .expect("within idle ttl should reuse");
585 let third = manager
586 .get_or_create_at(request("model-a", &headers), now(30))
587 .expect("max-ttl-expired request should recreate");
588
589 assert!(third.created);
590 assert_eq!(
591 third.replaced_expired,
592 Some(SessionExpirationReason::MaxTtl)
593 );
594 assert_eq!(third.session.session_key, first.session.session_key);
595 assert_eq!(third.session.request_count, 1);
596 assert_eq!(third.session.created_at, now(30));
597 }
598
599 #[test]
600 fn max_request_expiration_discards_old_session_and_creates_fresh_one() {
601 let manager = manager();
602 let mut headers = HeaderMap::new();
603 headers.insert(
604 "X-Venice-Proxy-Session-Id",
605 HeaderValue::from_static("chat-1"),
606 );
607
608 manager
609 .get_or_create_at(request("model-a", &headers), now(0))
610 .expect("first request should create");
611 manager
612 .get_or_create_at(request("model-a", &headers), now(1))
613 .expect("second request should reuse");
614 let third = manager
615 .get_or_create_at(request("model-a", &headers), now(2))
616 .expect("third request should reuse and reach max");
617 let fourth = manager
618 .get_or_create_at(request("model-a", &headers), now(3))
619 .expect("fourth request should recreate");
620
621 assert!(!third.created);
622 assert_eq!(third.session.request_count, 3);
623 assert!(fourth.created);
624 assert_eq!(
625 fourth.replaced_expired,
626 Some(SessionExpirationReason::MaxRequests)
627 );
628 assert_eq!(fourth.session.request_count, 1);
629 }
630
631 #[test]
632 fn request_fallback_creates_distinct_request_scoped_sessions() {
633 let manager = manager();
634 let headers = HeaderMap::new();
635
636 let first = manager
637 .get_or_create_at(request("model-a", &headers), now(0))
638 .expect("fallback should create");
639 let second = manager
640 .get_or_create_at(request("model-a", &headers), now(1))
641 .expect("fallback should create again");
642
643 assert!(first.created);
644 assert!(second.created);
645 assert_eq!(first.session.scope, SessionScope::Request);
646 assert_eq!(second.session.scope, SessionScope::Request);
647 assert_ne!(
648 first.session.agent_session_id,
649 second.session.agent_session_id
650 );
651 assert_eq!(manager.len(), 2);
652 }
653
654 #[test]
655 fn agent_fallback_reuses_generated_agent_scoped_session() {
656 let mut config = test_config();
657 config.fallback_scope = SessionFallbackScope::Agent;
658 let manager = SessionManager::new(config);
659 let headers = HeaderMap::new();
660
661 let first = manager
662 .get_or_create_at(request("model-a", &headers), now(0))
663 .expect("fallback should create");
664 let second = manager
665 .get_or_create_at(request("model-a", &headers), now(1))
666 .expect("fallback should reuse");
667
668 assert!(first.created);
669 assert!(!second.created);
670 assert_eq!(first.session.scope, SessionScope::Agent);
671 assert_eq!(
672 first.session.agent_session_id,
673 second.session.agent_session_id
674 );
675 assert_eq!(second.session.request_count, 2);
676 }
677
678 #[test]
679 fn disabled_fallback_returns_clear_error_without_creating_session() {
680 let mut config = test_config();
681 config.fallback_scope = SessionFallbackScope::Disabled;
682 let manager = SessionManager::new(config);
683 let headers = HeaderMap::new();
684
685 let error = manager
686 .get_or_create_at(request("model-a", &headers), now(0))
687 .expect_err("missing session id should fail when fallback is disabled");
688
689 assert_eq!(error, SessionError::MissingSessionIdentifier);
690 assert_eq!(
691 error.to_string(),
692 "request does not include a session identifier and session fallback is disabled"
693 );
694 assert!(manager.is_empty());
695 }
696
697 #[test]
698 fn cleanup_removes_expired_sessions_and_keeps_valid_sessions() {
699 let manager = manager();
700 let mut headers_a = HeaderMap::new();
701 headers_a.insert(
702 "X-Venice-Proxy-Session-Id",
703 HeaderValue::from_static("chat-a"),
704 );
705 let mut headers_b = HeaderMap::new();
706 headers_b.insert(
707 "X-Venice-Proxy-Session-Id",
708 HeaderValue::from_static("chat-b"),
709 );
710
711 manager
712 .get_or_create_at(request("model-a", &headers_a), now(0))
713 .expect("session a should create");
714 manager
715 .get_or_create_at(request("model-a", &headers_b), now(15))
716 .expect("session b should create");
717
718 let removed = manager.cleanup_expired_at(now(20));
719
720 assert_eq!(removed, 1);
721 assert_eq!(manager.len(), 1);
722 let reused_b = manager
723 .get_or_create_at(request("model-a", &headers_b), now(21))
724 .expect("session b should remain valid");
725 assert!(!reused_b.created);
726 }
727
728 #[test]
729 fn stores_attested_model_state_on_existing_unexpired_session() {
730 let manager = manager();
731 let mut headers = HeaderMap::new();
732 headers.insert(
733 "X-Venice-Proxy-Session-Id",
734 HeaderValue::from_static("chat-1"),
735 );
736 let session = manager
737 .get_or_create_at(request("model-a", &headers), now(0))
738 .expect("session should create")
739 .session;
740
741 let updated = manager
742 .set_attested_model_state_at(
743 &session.session_key,
744 AttestedModelState {
745 model_public_key: "model-public-key".to_owned(),
746 tee_provider: Some("phala".to_owned()),
747 tdx_debug: Some(false),
748 nvidia_verified: "ignored".to_owned(),
749 verified_at: now(1),
750 },
751 now(1),
752 )
753 .expect("attestation state should update");
754
755 assert_eq!(
756 updated.attested_model_public_key.as_deref(),
757 Some("model-public-key")
758 );
759 assert_eq!(updated.attestation_tee_provider.as_deref(), Some("phala"));
760 assert_eq!(updated.attestation_tdx_debug, Some(false));
761 assert_eq!(
762 updated.attestation_nvidia_verified.as_deref(),
763 Some("ignored")
764 );
765 assert_eq!(updated.verified_at, Some(now(1)));
766 }
767}