use serde::{Deserialize, Serialize};
use worker::Env;
#[derive(Clone)]
pub struct DurableObjectTokenStore {
namespace: String,
env: Option<Env>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct OAuthTokenData {
pub token_type: String,
pub client_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub user_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub code_challenge: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub code_challenge_method: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub redirect_uri: Option<String>,
pub created_at: u64,
pub expires_at: u64,
pub used: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
impl DurableObjectTokenStore {
pub fn new(namespace: impl Into<String>) -> Self {
Self {
namespace: namespace.into(),
env: None,
}
}
pub fn from_env(env: &Env, binding: &str) -> worker::Result<Self> {
let _ = env.durable_object(binding)?;
Ok(Self {
namespace: binding.to_string(),
env: Some(env.clone()),
})
}
pub fn with_env(mut self, env: Env) -> Self {
self.env = Some(env);
self
}
pub async fn store_code(
&self,
code: &str,
data: &OAuthTokenData,
expires_in_ms: u64,
) -> Result<(), TokenStoreError> {
self.store_token_internal("code", code, data, expires_in_ms)
.await
}
pub async fn get_and_delete_code(
&self,
code: &str,
) -> Result<Option<OAuthTokenData>, TokenStoreError> {
self.get_and_delete_internal("code", code).await
}
pub async fn store_refresh_token(
&self,
token: &str,
data: &OAuthTokenData,
expires_in_ms: u64,
) -> Result<(), TokenStoreError> {
self.store_token_internal("refresh", token, data, expires_in_ms)
.await
}
pub async fn get_refresh_token(
&self,
token: &str,
) -> Result<Option<OAuthTokenData>, TokenStoreError> {
self.get_token_internal("refresh", token).await
}
pub async fn revoke_refresh_token(&self, token: &str) -> Result<bool, TokenStoreError> {
self.delete_token_internal("refresh", token).await
}
pub async fn revoke_all_for_user(&self, user_id: &str) -> Result<u64, TokenStoreError> {
#[derive(Serialize)]
struct RevokeRequest<'a> {
user_id: &'a str,
}
#[derive(Deserialize)]
struct RevokeResponse {
revoked: u64,
}
let request = RevokeRequest { user_id };
let response: RevokeResponse = self
.do_request(user_id, "/tokens/revoke-all-user", Some(&request))
.await?;
Ok(response.revoked)
}
pub async fn revoke_all_for_client(&self, client_id: &str) -> Result<u64, TokenStoreError> {
#[derive(Serialize)]
struct RevokeRequest<'a> {
client_id: &'a str,
}
#[derive(Deserialize)]
struct RevokeResponse {
revoked: u64,
}
let request = RevokeRequest { client_id };
let response: RevokeResponse = self
.do_request(client_id, "/tokens/revoke-all-client", Some(&request))
.await?;
Ok(response.revoked)
}
async fn store_token_internal(
&self,
token_type: &str,
token: &str,
data: &OAuthTokenData,
expires_in_ms: u64,
) -> Result<(), TokenStoreError> {
#[derive(Serialize)]
struct StoreRequest<'a> {
token_type: &'a str,
token_hash: String,
data: &'a OAuthTokenData,
expires_in_ms: u64,
}
let token_hash = hash_token(token);
let request = StoreRequest {
token_type,
token_hash,
data,
expires_in_ms,
};
self.do_request::<()>(&data.client_id, "/tokens/store", Some(&request))
.await
}
async fn get_token_internal(
&self,
token_type: &str,
token: &str,
) -> Result<Option<OAuthTokenData>, TokenStoreError> {
#[derive(Serialize)]
struct GetRequest<'a> {
token_type: &'a str,
token_hash: String,
}
#[derive(Deserialize)]
struct GetResponse {
data: Option<OAuthTokenData>,
}
let token_hash = hash_token(token);
let request = GetRequest {
token_type,
token_hash: token_hash.clone(),
};
let response: GetResponse = self
.do_request(&token_hash, "/tokens/get", Some(&request))
.await?;
Ok(response.data)
}
async fn get_and_delete_internal(
&self,
token_type: &str,
token: &str,
) -> Result<Option<OAuthTokenData>, TokenStoreError> {
#[derive(Serialize)]
struct GetDeleteRequest<'a> {
token_type: &'a str,
token_hash: String,
}
#[derive(Deserialize)]
struct GetDeleteResponse {
data: Option<OAuthTokenData>,
}
let token_hash = hash_token(token);
let request = GetDeleteRequest {
token_type,
token_hash: token_hash.clone(),
};
let response: GetDeleteResponse = self
.do_request(&token_hash, "/tokens/get-and-delete", Some(&request))
.await?;
Ok(response.data)
}
async fn delete_token_internal(
&self,
token_type: &str,
token: &str,
) -> Result<bool, TokenStoreError> {
#[derive(Serialize)]
struct DeleteRequest<'a> {
token_type: &'a str,
token_hash: String,
}
#[derive(Deserialize)]
struct DeleteResponse {
deleted: bool,
}
let token_hash = hash_token(token);
let request = DeleteRequest {
token_type,
token_hash: token_hash.clone(),
};
let response: DeleteResponse = self
.do_request(&token_hash, "/tokens/delete", Some(&request))
.await?;
Ok(response.deleted)
}
async fn do_request<T: for<'de> Deserialize<'de>>(
&self,
key: &str,
path: &str,
body: Option<&impl Serialize>,
) -> Result<T, TokenStoreError> {
let env = self.env.as_ref().ok_or(TokenStoreError::NoEnvironment)?;
let ns = env
.durable_object(&self.namespace)
.map_err(TokenStoreError::Worker)?;
let id = ns.id_from_name(key).map_err(TokenStoreError::Worker)?;
let stub = id.get_stub().map_err(TokenStoreError::Worker)?;
let mut init = worker::RequestInit::new();
init.with_method(worker::Method::Post);
if let Some(body) = body {
let json = serde_json::to_string(body).map_err(TokenStoreError::Serialization)?;
init.with_body(Some(json.into()));
}
let url = format!("https://do-internal{path}");
let request =
worker::Request::new_with_init(&url, &init).map_err(TokenStoreError::Worker)?;
let mut response = stub
.fetch_with_request(request)
.await
.map_err(TokenStoreError::Worker)?;
let text = response.text().await.map_err(TokenStoreError::Worker)?;
serde_json::from_str(&text).map_err(TokenStoreError::Deserialization)
}
}
fn hash_token(token: &str) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(token.as_bytes());
let result = hasher.finalize();
format!("tok_{:x}", result)
}
#[derive(Debug)]
pub enum TokenStoreError {
NoEnvironment,
Worker(worker::Error),
Serialization(serde_json::Error),
Deserialization(serde_json::Error),
}
impl std::fmt::Display for TokenStoreError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoEnvironment => write!(f, "No environment set"),
Self::Worker(e) => write!(f, "Worker error: {e:?}"),
Self::Serialization(e) => write!(f, "Serialization error: {e}"),
Self::Deserialization(e) => write!(f, "Deserialization error: {e}"),
}
}
}
impl std::error::Error for TokenStoreError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Worker(e) => Some(e),
Self::Serialization(e) => Some(e),
Self::Deserialization(e) => Some(e),
Self::NoEnvironment => None,
}
}
}
impl From<worker::Error> for TokenStoreError {
fn from(e: worker::Error) -> Self {
Self::Worker(e)
}
}
#[allow(dead_code)]
pub mod protocol {
use super::*;
#[derive(Debug, Serialize, Deserialize)]
pub struct StoreRequest {
pub token_type: String,
pub token_hash: String,
pub data: OAuthTokenData,
pub expires_in_ms: u64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GetRequest {
pub token_type: String,
pub token_hash: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GetResponse {
pub data: Option<OAuthTokenData>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DeleteResponse {
pub deleted: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RevokeAllUserRequest {
pub user_id: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RevokeAllClientRequest {
pub client_id: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RevokeAllResponse {
pub revoked: u64,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_store_creation() {
let store = DurableObjectTokenStore::new("MCP_OAUTH_TOKENS");
assert_eq!(store.namespace, "MCP_OAUTH_TOKENS");
assert!(store.env.is_none());
}
#[test]
fn test_token_hashing() {
let hash1 = hash_token("secret-token-123");
let hash2 = hash_token("secret-token-123");
let hash3 = hash_token("different-token");
assert_eq!(hash1, hash2);
assert_ne!(hash1, hash3);
assert!(hash1.starts_with("tok_"));
}
#[test]
fn test_oauth_token_data_default() {
let data = OAuthTokenData::default();
assert!(data.token_type.is_empty());
assert!(data.client_id.is_empty());
assert!(data.user_id.is_none());
assert!(!data.used);
}
#[test]
fn test_token_store_error_display() {
let err = TokenStoreError::NoEnvironment;
assert_eq!(err.to_string(), "No environment set");
}
}