use klieo_core::{ServerOutbound, ServerOutboundError};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
const ROOTS_LIST_TIMEOUT: Duration = Duration::from_secs(10);
#[non_exhaustive]
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct Root {
pub uri: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
impl Root {
pub fn new(uri: impl Into<String>, name: Option<String>) -> Self {
Self {
uri: uri.into(),
name,
}
}
}
pub(crate) struct RootsCache {
tx: watch::Sender<Vec<Root>>,
outbound: Arc<dyn ServerOutbound>,
}
impl RootsCache {
pub(crate) fn new(outbound: Arc<dyn ServerOutbound>) -> Self {
let (tx, _rx) = watch::channel(Vec::new());
Self { tx, outbound }
}
pub(crate) fn subscribe(&self) -> watch::Receiver<Vec<Root>> {
self.tx.subscribe()
}
pub(crate) fn snapshot(&self) -> Vec<Root> {
self.tx.borrow().clone()
}
pub(crate) async fn refresh(&self) -> Result<(), ServerOutboundError> {
let response = self
.outbound
.outbound_request("roots/list", serde_json::Value::Null, ROOTS_LIST_TIMEOUT)
.await?;
let roots = parse_roots_payload(&response);
self.tx.send_replace(roots);
Ok(())
}
}
fn parse_roots_payload(response: &serde_json::Value) -> Vec<Root> {
response
.get("roots")
.and_then(|v| serde_json::from_value(v.clone()).ok())
.unwrap_or_default()
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
struct MockOutbound {
response: Result<serde_json::Value, ServerOutboundError>,
}
#[async_trait]
impl ServerOutbound for MockOutbound {
async fn outbound_request(
&self,
_method: &str,
_params: serde_json::Value,
_timeout: Duration,
) -> Result<serde_json::Value, ServerOutboundError> {
match &self.response {
Ok(v) => Ok(v.clone()),
Err(_) => Err(ServerOutboundError::Timeout),
}
}
}
fn ok_outbound(payload: serde_json::Value) -> Arc<dyn ServerOutbound> {
Arc::new(MockOutbound {
response: Ok(payload),
})
}
fn err_outbound() -> Arc<dyn ServerOutbound> {
Arc::new(MockOutbound {
response: Err(ServerOutboundError::Timeout),
})
}
struct StubInvoker;
#[async_trait]
impl klieo_core::tool::ToolInvoker for StubInvoker {
fn catalogue(&self) -> Vec<klieo_core::llm::ToolDef> {
Vec::new()
}
async fn invoke(
&self,
name: &str,
_args: serde_json::Value,
_ctx: klieo_core::tool::ToolCtx,
) -> Result<serde_json::Value, klieo_core::error::ToolError> {
Err(klieo_core::error::ToolError::UnknownTool(name.into()))
}
}
#[tokio::test]
async fn client_roots_returns_empty_when_cache_missing() {
let server = crate::McpServer::expose_tools(Arc::new(StubInvoker));
assert!(
server
.stdio_session
.get()
.and_then(|s| s.roots_cache.get())
.is_none(),
"default-built server must not wire a roots cache"
);
assert_eq!(server.client_roots(), Vec::<Root>::new());
}
#[tokio::test]
async fn subscribe_root_changes_returns_none_when_cache_missing() {
let server = crate::McpServer::expose_tools(Arc::new(StubInvoker));
assert!(server.subscribe_root_changes().is_none());
}
#[tokio::test]
async fn cache_snapshot_returns_initial_empty_then_updated() {
let cache = RootsCache::new(ok_outbound(serde_json::json!({})));
assert_eq!(cache.snapshot(), Vec::<Root>::new());
let rx = cache.subscribe();
let next = vec![Root {
uri: "file:///workspace".into(),
name: None,
}];
cache.tx.send(next.clone()).expect("watch send");
assert_eq!(*rx.borrow(), next);
assert_eq!(cache.snapshot(), next);
}
#[tokio::test]
async fn cache_refresh_deserialises_roots_array() {
let payload = serde_json::json!({
"roots": [
{"uri": "file:///a"},
{"uri": "file:///b", "name": "home"}
]
});
let cache = RootsCache::new(ok_outbound(payload));
cache.refresh().await.expect("refresh ok");
let snapshot = cache.snapshot();
assert_eq!(snapshot.len(), 2);
assert_eq!(snapshot[0].uri, "file:///a");
assert_eq!(snapshot[0].name, None);
assert_eq!(snapshot[1].uri, "file:///b");
assert_eq!(snapshot[1].name.as_deref(), Some("home"));
}
#[tokio::test]
async fn cache_refresh_propagates_outbound_error() {
let cache = RootsCache::new(err_outbound());
let outcome = cache.refresh().await;
assert!(
matches!(outcome, Err(ServerOutboundError::Timeout)),
"transport-level failure must propagate verbatim; got {outcome:?}"
);
assert_eq!(
cache.snapshot(),
Vec::<Root>::new(),
"failed refresh must not mutate the cached snapshot"
);
}
#[test]
fn parse_roots_payload_handles_missing_field() {
assert_eq!(
parse_roots_payload(&serde_json::json!({})),
Vec::<Root>::new()
);
}
#[test]
fn parse_roots_payload_handles_malformed_entry() {
let response = serde_json::json!({"roots": [{"name": "no-uri"}]});
assert_eq!(parse_roots_payload(&response), Vec::<Root>::new());
}
#[test]
fn root_serialises_without_null_name() {
let root = Root {
uri: "file:///a".into(),
name: None,
};
let serialised = serde_json::to_value(&root).expect("encode");
assert_eq!(serialised, serde_json::json!({"uri": "file:///a"}));
}
}