use serde::{Deserialize, Serialize};
#[cfg(target_arch = "wasm32")]
use turbomcp_transport_streamable::SessionId;
#[cfg(target_arch = "wasm32")]
use turbomcp_transport_streamable::SessionStore;
use turbomcp_transport_streamable::{Session, StoredEvent};
use worker::Env;
#[cfg(target_arch = "wasm32")]
use worker::Stub;
#[derive(Clone)]
#[allow(dead_code)] pub struct DurableObjectSessionStore {
namespace: String,
env: Option<Env>,
}
impl DurableObjectSessionStore {
pub fn new(namespace: impl Into<String>) -> Self {
Self {
namespace: namespace.into(),
env: None,
}
}
pub fn from_env(env: &Env, binding: &str) -> worker::Result<Self> {
let _ = env.durable_object(binding)?;
Ok(Self {
namespace: binding.to_string(),
env: Some(env.clone()),
})
}
pub fn with_env(mut self, env: Env) -> Self {
self.env = Some(env);
self
}
#[cfg(target_arch = "wasm32")]
fn get_stub(&self, session_id: &str) -> worker::Result<Stub> {
let env = self
.env
.as_ref()
.ok_or_else(|| worker::Error::RustError("No environment set".into()))?;
let ns = env.durable_object(&self.namespace)?;
let id = ns.id_from_name(session_id)?;
id.get_stub()
}
#[cfg(target_arch = "wasm32")]
async fn do_request<T: for<'de> Deserialize<'de>>(
&self,
session_id: &str,
path: &str,
body: Option<&impl Serialize>,
) -> Result<T, DoSessionError> {
let stub = self.get_stub(session_id).map_err(DoSessionError::Worker)?;
let mut init = worker::RequestInit::new();
init.with_method(worker::Method::Post);
if let Some(body) = body {
let json = serde_json::to_string(body).map_err(DoSessionError::Serialization)?;
init.with_body(Some(json.into()));
}
let url = format!("https://do-internal{path}");
let request = worker::Request::new_with_init(&url, &init)?;
let mut response = stub.fetch_with_request(request).await?;
let text = response.text().await?;
serde_json::from_str(&text).map_err(DoSessionError::Deserialization)
}
}
#[derive(Debug)]
#[allow(dead_code)] pub enum DoSessionError {
Worker(worker::Error),
Serialization(serde_json::Error),
Deserialization(serde_json::Error),
}
impl std::fmt::Display for DoSessionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Worker(e) => write!(f, "Worker error: {e:?}"),
Self::Serialization(e) => write!(f, "Serialization error: {e}"),
Self::Deserialization(e) => write!(f, "Deserialization error: {e}"),
}
}
}
impl std::error::Error for DoSessionError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Worker(e) => Some(e),
Self::Serialization(e) => Some(e),
Self::Deserialization(e) => Some(e),
}
}
}
impl From<worker::Error> for DoSessionError {
fn from(e: worker::Error) -> Self {
Self::Worker(e)
}
}
#[cfg(target_arch = "wasm32")]
impl SessionStore for DurableObjectSessionStore {
type Error = DoSessionError;
async fn create(&self) -> Result<SessionId, Self::Error> {
let id = SessionId::new();
let session = Session::new(id.clone());
self.do_request::<()>(id.as_str(), "/session/create", Some(&session))
.await?;
Ok(id)
}
async fn get(&self, id: &SessionId) -> Result<Option<Session>, Self::Error> {
#[derive(Deserialize)]
struct GetResponse {
session: Option<Session>,
}
let response: GetResponse = self
.do_request(id.as_str(), "/session/get", None::<&()>)
.await?;
Ok(response.session)
}
async fn update(&self, session: &Session) -> Result<(), Self::Error> {
self.do_request::<()>(session.id.as_str(), "/session/update", Some(session))
.await
}
async fn store_event(&self, id: &SessionId, event: StoredEvent) -> Result<(), Self::Error> {
self.do_request::<()>(id.as_str(), "/event/store", Some(&event))
.await
}
async fn replay_from(
&self,
id: &SessionId,
last_event_id: &str,
) -> Result<Vec<StoredEvent>, Self::Error> {
#[derive(Serialize)]
struct ReplayRequest<'a> {
last_event_id: &'a str,
}
#[derive(Deserialize)]
struct ReplayResponse {
events: Vec<StoredEvent>,
}
let request = ReplayRequest { last_event_id };
let response: ReplayResponse = self
.do_request(id.as_str(), "/event/replay", Some(&request))
.await?;
Ok(response.events)
}
async fn destroy(&self, id: &SessionId) -> Result<(), Self::Error> {
self.do_request::<()>(id.as_str(), "/session/destroy", None::<&()>)
.await
}
async fn cleanup_expired(&self, timeout_ms: u64) -> Result<u64, Self::Error> {
let _ = timeout_ms;
Ok(0)
}
}
#[allow(dead_code)]
pub mod protocol {
use super::*;
pub type CreateRequest = Session;
#[derive(Debug, Serialize, Deserialize)]
pub struct GetRequest;
#[derive(Debug, Serialize, Deserialize)]
pub struct GetResponse {
pub session: Option<Session>,
}
pub type UpdateRequest = Session;
pub type StoreEventRequest = StoredEvent;
#[derive(Debug, Serialize, Deserialize)]
pub struct ReplayRequest {
pub last_event_id: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ReplayResponse {
pub events: Vec<StoredEvent>,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_store_creation() {
let store = DurableObjectSessionStore::new("MCP_SESSIONS");
assert_eq!(store.namespace, "MCP_SESSIONS");
assert!(store.env.is_none());
}
#[test]
fn test_do_session_error_display() {
let err = DoSessionError::Serialization(serde_json::from_str::<()>("invalid").unwrap_err());
assert!(err.to_string().contains("Serialization error"));
}
}