use serde::{Deserialize, Serialize, de::DeserializeOwned};
use worker::Env;
#[derive(Clone)]
pub struct DurableObjectStateStore {
namespace: String,
env: Option<Env>,
}
impl DurableObjectStateStore {
pub fn new(namespace: impl Into<String>) -> Self {
Self {
namespace: namespace.into(),
env: None,
}
}
pub fn from_env(env: &Env, binding: &str) -> worker::Result<Self> {
let _ = env.durable_object(binding)?;
Ok(Self {
namespace: binding.to_string(),
env: Some(env.clone()),
})
}
pub fn with_env(mut self, env: Env) -> Self {
self.env = Some(env);
self
}
pub async fn get<T: DeserializeOwned>(
&self,
namespace_id: &str,
key: &str,
) -> Result<Option<T>, StateStoreError> {
#[derive(Serialize)]
struct GetRequest<'a> {
key: &'a str,
}
#[derive(Deserialize)]
struct GetResponse<T> {
value: Option<T>,
}
let request = GetRequest { key };
let response: GetResponse<T> = self
.do_request(namespace_id, "/state/get", Some(&request))
.await?;
Ok(response.value)
}
pub async fn set<T: Serialize>(
&self,
namespace_id: &str,
key: &str,
value: &T,
) -> Result<(), StateStoreError> {
#[derive(Serialize)]
struct SetRequest<'a, T> {
key: &'a str,
value: &'a T,
}
let request = SetRequest { key, value };
self.do_request::<()>(namespace_id, "/state/set", Some(&request))
.await
}
pub async fn delete(&self, namespace_id: &str, key: &str) -> Result<bool, StateStoreError> {
#[derive(Serialize)]
struct DeleteRequest<'a> {
key: &'a str,
}
#[derive(Deserialize)]
struct DeleteResponse {
deleted: bool,
}
let request = DeleteRequest { key };
let response: DeleteResponse = self
.do_request(namespace_id, "/state/delete", Some(&request))
.await?;
Ok(response.deleted)
}
pub async fn list(
&self,
namespace_id: &str,
prefix: Option<&str>,
) -> Result<Vec<String>, StateStoreError> {
#[derive(Serialize)]
struct ListRequest<'a> {
prefix: Option<&'a str>,
}
#[derive(Deserialize)]
struct ListResponse {
keys: Vec<String>,
}
let request = ListRequest { prefix };
let response: ListResponse = self
.do_request(namespace_id, "/state/list", Some(&request))
.await?;
Ok(response.keys)
}
pub async fn clear(&self, namespace_id: &str) -> Result<(), StateStoreError> {
self.do_request::<()>(namespace_id, "/state/clear", None::<&()>)
.await
}
async fn do_request<T: for<'de> Deserialize<'de>>(
&self,
namespace_id: &str,
path: &str,
body: Option<&impl Serialize>,
) -> Result<T, StateStoreError> {
let env = self.env.as_ref().ok_or(StateStoreError::NoEnvironment)?;
let ns = env
.durable_object(&self.namespace)
.map_err(StateStoreError::Worker)?;
let id = ns
.id_from_name(namespace_id)
.map_err(StateStoreError::Worker)?;
let stub = id.get_stub().map_err(StateStoreError::Worker)?;
let mut init = worker::RequestInit::new();
init.with_method(worker::Method::Post);
if let Some(body) = body {
let json = serde_json::to_string(body).map_err(StateStoreError::Serialization)?;
init.with_body(Some(json.into()));
}
let url = format!("https://do-internal{path}");
let request =
worker::Request::new_with_init(&url, &init).map_err(StateStoreError::Worker)?;
let mut response = stub
.fetch_with_request(request)
.await
.map_err(StateStoreError::Worker)?;
let text = response.text().await.map_err(StateStoreError::Worker)?;
serde_json::from_str(&text).map_err(StateStoreError::Deserialization)
}
}
#[derive(Debug)]
pub enum StateStoreError {
NoEnvironment,
Worker(worker::Error),
Serialization(serde_json::Error),
Deserialization(serde_json::Error),
}
impl std::fmt::Display for StateStoreError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoEnvironment => write!(f, "No environment set"),
Self::Worker(e) => write!(f, "Worker error: {e:?}"),
Self::Serialization(e) => write!(f, "Serialization error: {e}"),
Self::Deserialization(e) => write!(f, "Deserialization error: {e}"),
}
}
}
impl std::error::Error for StateStoreError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Worker(e) => Some(e),
Self::Serialization(e) => Some(e),
Self::Deserialization(e) => Some(e),
Self::NoEnvironment => None,
}
}
}
impl From<worker::Error> for StateStoreError {
fn from(e: worker::Error) -> Self {
Self::Worker(e)
}
}
#[allow(dead_code)]
pub mod protocol {
use super::*;
#[derive(Debug, Serialize, Deserialize)]
pub struct GetRequest {
pub key: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GetResponse<T> {
pub value: Option<T>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SetRequest<T> {
pub key: String,
pub value: T,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DeleteRequest {
pub key: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DeleteResponse {
pub deleted: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ListRequest {
pub prefix: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ListResponse {
pub keys: Vec<String>,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_state_store_creation() {
let store = DurableObjectStateStore::new("MCP_STATE");
assert_eq!(store.namespace, "MCP_STATE");
assert!(store.env.is_none());
}
#[test]
fn test_state_store_error_display() {
let err = StateStoreError::NoEnvironment;
assert_eq!(err.to_string(), "No environment set");
let err =
StateStoreError::Serialization(serde_json::from_str::<()>("invalid").unwrap_err());
assert!(err.to_string().contains("Serialization error"));
}
}