use crate::transport::http::ServiceUrl;
use once_cell::sync::OnceCell;
use std::sync::RwLock;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
pub(super) struct McpSession {
initialized: Notify,
sse_ready: Notify,
url: ServiceUrl,
session_id: OnceCell<uuid::Uuid>,
last_event_id: RwLock<Option<String>>,
cancellation_token: CancellationToken,
}
impl McpSession {
pub(super) fn new(url: ServiceUrl, token: CancellationToken) -> Self {
Self {
initialized: Notify::new(),
sse_ready: Notify::new(),
session_id: OnceCell::new(),
last_event_id: RwLock::new(None),
cancellation_token: token,
url,
}
}
pub(super) fn url(&self) -> &ServiceUrl {
&self.url
}
pub(super) fn cancellation_token(&self) -> CancellationToken {
self.cancellation_token.clone()
}
pub(super) fn has_session_id(&self) -> bool {
self.session_id.get().is_some()
}
pub(super) fn session_id(&self) -> Option<&uuid::Uuid> {
self.session_id.get()
}
pub(super) fn set_session_id(&self, id: uuid::Uuid) {
if let Err(_err) = self.session_id.set(id) {
#[cfg(feature = "tracing")]
tracing::info!("MCP Session Id already set");
}
}
pub(super) fn last_event_id(&self) -> Option<String> {
self.last_event_id.read().ok().and_then(|g| g.clone())
}
pub(super) fn set_last_event_id(&self, id: String) {
if let Ok(mut guard) = self.last_event_id.write() {
*guard = Some(id);
}
}
#[inline]
pub(super) fn notify_session_initialized(&self) {
self.initialized.notify_one();
}
#[inline]
pub(super) fn notify_sse_initialized(&self) {
self.sse_ready.notify_one();
}
#[inline]
pub(super) async fn initialized(&self) {
self.initialized.notified().await;
}
#[inline]
pub(super) async fn sse_ready(&self) {
self.sse_ready.notified().await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::http::HttpProto;
use std::sync::Arc;
use tokio::time::{Duration, timeout};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
fn create_session() -> McpSession {
let url = ServiceUrl {
proto: HttpProto::Http,
addr: "localhost",
endpoint: "init",
};
let token = CancellationToken::new();
McpSession::new(url, token)
}
#[tokio::test]
async fn it_has_url() {
let session = create_session();
assert_eq!(session.url().addr, "localhost");
assert_eq!(session.url().endpoint, "init");
}
#[tokio::test]
async fn it_has_cancellable_and_synced_cancellation_token() {
let session = create_session();
let token = session.cancellation_token();
token.cancel();
assert!(token.is_cancelled());
}
#[tokio::test]
async fn it_sets_and_gets_session_id() {
let session = create_session();
let id = Uuid::new_v4();
assert!(!session.has_session_id());
assert!(session.session_id().is_none());
session.set_session_id(id);
assert!(session.has_session_id());
assert_eq!(session.session_id(), Some(&id));
}
#[test]
fn it_returns_none_last_event_id_by_default() {
let session = create_session();
assert!(session.last_event_id().is_none());
}
#[test]
fn it_sets_and_gets_last_event_id() {
let session = create_session();
session.set_last_event_id("abc-123".to_string());
assert_eq!(session.last_event_id(), Some("abc-123".to_string()));
}
#[test]
fn it_overwrites_last_event_id_on_each_set() {
let session = create_session();
session.set_last_event_id("first".to_string());
session.set_last_event_id("second".to_string());
assert_eq!(session.last_event_id(), Some("second".to_string()));
}
#[tokio::test]
async fn it_guarantees_session_id_cannot_be_overwritten() {
let session = create_session();
let id1 = Uuid::new_v4();
let id2 = Uuid::new_v4();
session.set_session_id(id1);
session.set_session_id(id2);
assert_eq!(session.session_id(), Some(&id1));
assert_ne!(session.session_id(), Some(&id2));
}
#[tokio::test]
async fn it_notifies_and_initialized() {
let session = Arc::new(create_session());
let handle = tokio::spawn({
let session = session.clone();
async move {
session.initialized().await;
}
});
tokio::time::sleep(Duration::from_millis(10)).await;
session.notify_session_initialized();
assert!(timeout(Duration::from_secs(1), handle).await.is_ok());
}
}