use std::collections::BTreeMap;
use std::sync::Arc;
use rmcp::ErrorData;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use crate::errors::{McpServerError, map_error};
use crate::state::SessionState;
use crate::tools::common::current_tab;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum StorageKind {
Local,
Session,
}
fn pick_storage(tab: &zendriver::Tab, kind: StorageKind) -> zendriver::Storage {
match kind {
StorageKind::Local => tab.local_storage(),
StorageKind::Session => tab.session_storage(),
}
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct StorageGetInput {
pub kind: StorageKind,
#[serde(default)]
pub key: Option<String>,
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct StorageGetOutput {
pub values: BTreeMap<String, String>,
}
pub async fn storage_get(
state: Arc<Mutex<SessionState>>,
input: StorageGetInput,
) -> Result<StorageGetOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
let storage = pick_storage(&tab, input.kind);
let values = match input.key {
Some(k) => {
let v = storage
.get(&k)
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
match v {
Some(value) => BTreeMap::from([(k, value)]),
None => BTreeMap::new(),
}
}
None => storage
.get_all()
.await
.map_err(|e| map_error(McpServerError::from(e)))?
.into_iter()
.collect(),
};
Ok(StorageGetOutput { values })
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct StorageSetInput {
pub kind: StorageKind,
pub key: String,
pub value: String,
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct StorageSetOutput {
pub ok: bool,
}
pub async fn storage_set(
state: Arc<Mutex<SessionState>>,
input: StorageSetInput,
) -> Result<StorageSetOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
let storage = pick_storage(&tab, input.kind);
storage
.set(&input.key, &input.value)
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
Ok(StorageSetOutput { ok: true })
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct StorageDeleteInput {
pub kind: StorageKind,
pub key: String,
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct StorageDeleteOutput {
pub deleted: bool,
}
pub async fn storage_delete(
state: Arc<Mutex<SessionState>>,
input: StorageDeleteInput,
) -> Result<StorageDeleteOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
let storage = pick_storage(&tab, input.kind);
storage
.remove(&input.key)
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
Ok(StorageDeleteOutput { deleted: true })
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct StorageClearInput {
pub kind: StorageKind,
}
#[derive(Debug, Serialize, JsonSchema)]
pub struct StorageClearOutput {
pub ok: bool,
}
pub async fn storage_clear(
state: Arc<Mutex<SessionState>>,
input: StorageClearInput,
) -> Result<StorageClearOutput, ErrorData> {
let s = state.lock().await;
let tab = current_tab(&s).await?;
let storage = pick_storage(&tab, input.kind);
storage
.clear()
.await
.map_err(|e| map_error(McpServerError::from(e)))?;
Ok(StorageClearOutput { ok: true })
}
#[cfg(test)]
mod tests {
use super::*;
fn fresh() -> Arc<Mutex<SessionState>> {
Arc::new(Mutex::new(SessionState::new()))
}
#[tokio::test]
async fn storage_get_with_no_browser_suggests_browser_open() {
let err = storage_get(
fresh(),
StorageGetInput {
kind: StorageKind::Local,
key: None,
},
)
.await
.expect_err("must error without an open browser");
assert!(err.message.contains("browser_open"), "msg: {}", err.message);
let data = err.data.as_ref().expect("data populated");
assert_eq!(data["suggested_next"], "browser_open");
}
#[tokio::test]
async fn storage_set_with_no_browser_suggests_browser_open() {
let err = storage_set(
fresh(),
StorageSetInput {
kind: StorageKind::Local,
key: "k".into(),
value: "v".into(),
},
)
.await
.expect_err("must error without an open browser");
assert!(err.message.contains("browser_open"));
let data = err.data.as_ref().expect("data populated");
assert_eq!(data["suggested_next"], "browser_open");
}
#[tokio::test]
async fn storage_delete_with_no_browser_suggests_browser_open() {
let err = storage_delete(
fresh(),
StorageDeleteInput {
kind: StorageKind::Session,
key: "k".into(),
},
)
.await
.expect_err("must error without an open browser");
assert!(err.message.contains("browser_open"));
let data = err.data.as_ref().expect("data populated");
assert_eq!(data["suggested_next"], "browser_open");
}
#[tokio::test]
async fn storage_clear_with_no_browser_suggests_browser_open() {
let err = storage_clear(
fresh(),
StorageClearInput {
kind: StorageKind::Local,
},
)
.await
.expect_err("must error without an open browser");
assert!(err.message.contains("browser_open"));
let data = err.data.as_ref().expect("data populated");
assert_eq!(data["suggested_next"], "browser_open");
}
#[test]
fn storage_kind_round_trips_serde_snake_case() {
let local: StorageKind =
serde_json::from_value(serde_json::json!("local")).expect("parse local");
assert_eq!(local, StorageKind::Local);
let session: StorageKind =
serde_json::from_value(serde_json::json!("session")).expect("parse session");
assert_eq!(session, StorageKind::Session);
assert_eq!(
serde_json::to_value(StorageKind::Local).unwrap(),
serde_json::json!("local")
);
assert_eq!(
serde_json::to_value(StorageKind::Session).unwrap(),
serde_json::json!("session")
);
}
}