use std::collections::HashMap;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use hickory_proto::rr::RecordType;
use koi_common::mdns_protocol::{RegisterPayload, RegistrationResult};
use koi_common::types::{ServiceRecord, META_QUERY};
use koi_mcp::{KoiSource, ResourceChange, SourceError};
use koi_mdns::{LeasePolicy, MdnsEvent};
use serde_json::{json, Value};
use tokio::sync::broadcast;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use crate::DaemonCores;
const CHANGE_CHANNEL_CAPACITY: usize = 256;
const HEARTBEAT_LEASE: Duration = Duration::from_secs(90);
const HEARTBEAT_GRACE: Duration = Duration::from_secs(30);
pub struct CoreSource {
cores: DaemonCores,
started_at: Instant,
http_bind: String,
changes: broadcast::Sender<ResourceChange>,
}
impl CoreSource {
pub fn new(
cores: DaemonCores,
started_at: Instant,
http_bind: String,
cancel: CancellationToken,
) -> Self {
let (changes, _) = broadcast::channel(CHANGE_CHANNEL_CAPACITY);
spawn_change_pump(&cores, changes.clone(), cancel);
Self {
cores,
started_at,
http_bind,
changes,
}
}
}
fn disabled(capability: &str) -> SourceError {
SourceError(format!(
"the '{capability}' capability is disabled on this daemon"
))
}
#[async_trait]
impl KoiSource for CoreSource {
async fn is_available(&self) -> bool {
true
}
async fn browse(
&self,
service_type: Option<String>,
window: Duration,
) -> Result<Vec<ServiceRecord>, SourceError> {
let mdns = self.cores.mdns.as_ref().ok_or_else(|| disabled("mdns"))?;
let ty = service_type.as_deref().unwrap_or(META_QUERY);
let sub = mdns
.subscribe_type(ty)
.await
.map_err(|e| SourceError(e.to_string()))?;
let deadline = tokio::time::Instant::now() + window;
let mut seen: HashMap<String, ServiceRecord> = HashMap::new();
loop {
match tokio::time::timeout_at(deadline, sub.recv()).await {
Ok(Some(MdnsEvent::Found(record) | MdnsEvent::Resolved(record))) => {
seen.insert(record.name.clone(), record);
}
Ok(Some(MdnsEvent::Removed { .. })) => {}
Ok(None) => break, Err(_) => break, }
}
Ok(seen.into_values().collect())
}
async fn resolve(&self, instance: String) -> Result<ServiceRecord, SourceError> {
let mdns = self.cores.mdns.as_ref().ok_or_else(|| disabled("mdns"))?;
mdns.resolve(&instance)
.await
.map_err(|e| SourceError(e.to_string()))
}
async fn register(&self, payload: RegisterPayload) -> Result<RegistrationResult, SourceError> {
let mdns = self.cores.mdns.as_ref().ok_or_else(|| disabled("mdns"))?;
let policy = match payload.lease_secs {
None => LeasePolicy::Heartbeat {
lease: HEARTBEAT_LEASE,
grace: HEARTBEAT_GRACE,
},
Some(0) => return Err(SourceError("lease_secs must be greater than zero".into())),
Some(n) => LeasePolicy::Heartbeat {
lease: Duration::from_secs(n),
grace: HEARTBEAT_GRACE,
},
};
mdns.register_with_policy(payload, policy, None)
.map_err(|e| SourceError(e.to_string()))
}
async fn unregister(&self, id: String) -> Result<(), SourceError> {
let mdns = self.cores.mdns.as_ref().ok_or_else(|| disabled("mdns"))?;
mdns.unregister(&id).map_err(|e| SourceError(e.to_string()))
}
async fn heartbeat(&self, id: String) -> Result<(), SourceError> {
let mdns = self.cores.mdns.as_ref().ok_or_else(|| disabled("mdns"))?;
mdns.heartbeat(&id)
.map(|_| ())
.map_err(|e| SourceError(e.to_string()))
}
async fn unified_status(&self) -> Result<Value, SourceError> {
let capabilities: Vec<_> = koi_compose::status::assemble_capabilities(&self.cores)
.await
.into_iter()
.map(|c| c.status)
.collect();
Ok(json!({
"version": env!("CARGO_PKG_VERSION"),
"platform": std::env::consts::OS,
"uptime_secs": self.started_at.elapsed().as_secs(),
"daemon": true,
"http_bind": self.http_bind,
"capabilities": capabilities,
}))
}
async fn health_status(&self) -> Result<Value, SourceError> {
let health = self
.cores
.health
.as_ref()
.ok_or_else(|| disabled("health"))?;
let snapshot = health.core().snapshot().await;
serde_json::to_value(snapshot).map_err(|e| SourceError(e.to_string()))
}
async fn dns_list(&self) -> Result<Value, SourceError> {
let dns = self.cores.dns.as_ref().ok_or_else(|| disabled("dns"))?;
let names = dns.core().list_names();
Ok(json!({ "names": names }))
}
async fn dns_lookup(
&self,
name: String,
record_type: RecordType,
) -> Result<Value, SourceError> {
let dns = self.cores.dns.as_ref().ok_or_else(|| disabled("dns"))?;
match dns.core().lookup(&name, record_type).await {
Some(result) => {
let ips: Vec<String> = result.ips.into_iter().map(|ip| ip.to_string()).collect();
Ok(json!({ "name": result.name, "ips": ips, "source": result.source }))
}
None => Err(SourceError("record_not_found".into())),
}
}
async fn dns_add(
&self,
name: String,
ip: String,
ttl: Option<u32>,
) -> Result<Value, SourceError> {
let dns = self.cores.dns.as_ref().ok_or_else(|| disabled("dns"))?;
let core = dns.core();
let zone =
koi_dns::DnsZone::new(&core.config().zone).map_err(|e| SourceError(e.to_string()))?;
let normalized = zone
.normalize_name(&name)
.ok_or_else(|| SourceError(format!("name '{name}' is outside the zone")))?;
if ip.parse::<std::net::IpAddr>().is_err() {
return Err(SourceError(format!("invalid IP address: {ip}")));
}
let entry = koi_config::state::DnsEntry {
name: normalized,
ip,
ttl,
};
let entries = core
.add_entry(entry)
.map_err(|e| SourceError(e.to_string()))?;
Ok(json!({ "entries": entries }))
}
async fn dns_remove(&self, name: String) -> Result<Value, SourceError> {
let dns = self.cores.dns.as_ref().ok_or_else(|| disabled("dns"))?;
let core = dns.core();
let zone =
koi_dns::DnsZone::new(&core.config().zone).map_err(|e| SourceError(e.to_string()))?;
let normalized = zone
.normalize_name(&name)
.ok_or_else(|| SourceError(format!("name '{name}' is outside the zone")))?;
match core
.remove_entry(&normalized)
.map_err(|e| SourceError(e.to_string()))?
{
Some(entries) => Ok(json!({ "entries": entries })),
None => Err(SourceError("entry_not_found".into())),
}
}
async fn runtime_instances(&self) -> Result<Value, SourceError> {
let runtime = self
.cores
.runtime
.as_ref()
.ok_or_else(|| disabled("runtime"))?;
let instances = runtime
.list_instances()
.await
.map_err(|e| SourceError(e.to_string()))?;
serde_json::to_value(instances).map_err(|e| SourceError(e.to_string()))
}
async fn mdns_snapshot(&self) -> Result<Value, SourceError> {
let records = self
.cores
.mdns_snapshot
.as_ref()
.map(|s| s.cached_records())
.unwrap_or_default();
Ok(json!({ "services": records }))
}
fn change_stream(&self) -> Option<broadcast::Receiver<ResourceChange>> {
Some(self.changes.subscribe())
}
}
fn spawn_change_pump(
cores: &DaemonCores,
tx: broadcast::Sender<ResourceChange>,
cancel: CancellationToken,
) -> JoinHandle<()> {
let mut mdns_rx = cores.mdns.as_ref().map(|c| c.subscribe());
let mut health_rx = cores.health.as_ref().map(|r| r.core().subscribe());
let mut dns_rx = cores.dns.as_ref().map(|r| r.core().subscribe());
let mut runtime_rx = cores.runtime.as_ref().map(|r| r.subscribe());
tokio::spawn(async move {
loop {
let change: Option<ResourceChange> = tokio::select! {
_ = cancel.cancelled() => break,
Some(Ok(_)) = recv_opt(&mut mdns_rx) => Some(ResourceChange::Mdns),
Some(Ok(_)) = recv_opt(&mut health_rx) => Some(ResourceChange::Health),
Some(Ok(_)) = recv_opt(&mut dns_rx) => Some(ResourceChange::Dns),
Some(Ok(_)) = recv_opt(&mut runtime_rx) => Some(ResourceChange::Inventory),
};
if let Some(change) = change {
let _ = tx.send(change);
}
}
})
}
async fn recv_opt<T: Clone>(
rx: &mut Option<broadcast::Receiver<T>>,
) -> Option<Result<T, broadcast::error::RecvError>> {
match rx.as_mut() {
Some(rx) => Some(rx.recv().await),
None => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use tower::ServiceExt;
struct MockSource;
#[async_trait]
impl KoiSource for MockSource {
async fn is_available(&self) -> bool {
true
}
async fn browse(
&self,
_service_type: Option<String>,
_window: Duration,
) -> Result<Vec<ServiceRecord>, SourceError> {
Ok(Vec::new())
}
async fn resolve(&self, _instance: String) -> Result<ServiceRecord, SourceError> {
Err(SourceError("not found".into()))
}
async fn register(
&self,
_payload: RegisterPayload,
) -> Result<RegistrationResult, SourceError> {
Err(SourceError("mock".into()))
}
async fn unregister(&self, _id: String) -> Result<(), SourceError> {
Ok(())
}
async fn heartbeat(&self, _id: String) -> Result<(), SourceError> {
Ok(())
}
async fn unified_status(&self) -> Result<Value, SourceError> {
Ok(json!({}))
}
async fn health_status(&self) -> Result<Value, SourceError> {
Ok(json!({}))
}
async fn dns_list(&self) -> Result<Value, SourceError> {
Ok(json!({ "names": [] }))
}
async fn dns_lookup(
&self,
_name: String,
_record_type: RecordType,
) -> Result<Value, SourceError> {
Err(SourceError("not found".into()))
}
async fn dns_add(
&self,
_name: String,
_ip: String,
_ttl: Option<u32>,
) -> Result<Value, SourceError> {
Err(SourceError("mock".into()))
}
async fn dns_remove(&self, _name: String) -> Result<Value, SourceError> {
Err(SourceError("mock".into()))
}
async fn runtime_instances(&self) -> Result<Value, SourceError> {
Ok(json!([]))
}
async fn mdns_snapshot(&self) -> Result<Value, SourceError> {
Ok(json!({ "services": [] }))
}
}
async fn post(
app: &axum::Router,
session: Option<&str>,
body: &str,
) -> (StatusCode, Option<String>, String) {
let mut builder = Request::post("/v1/mcp")
.header("host", "localhost")
.header("content-type", "application/json")
.header("accept", "application/json, text/event-stream");
if let Some(sid) = session {
builder = builder.header("mcp-session-id", sid);
}
let req = builder.body(Body::from(body.to_string())).unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
let status = resp.status();
let sid = resp
.headers()
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
(status, sid, String::from_utf8_lossy(&bytes).into_owned())
}
#[tokio::test]
async fn streamable_http_session_lists_and_reads_resources() {
let service = koi_mcp::streamable_http_service(
std::sync::Arc::new(MockSource),
vec!["localhost".to_string()],
);
let app = axum::Router::new().nest_service("/v1/mcp", service);
let (status, sid, body) = post(
&app,
None,
r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"itest","version":"0.0.0"}}}"#,
)
.await;
assert_eq!(status, StatusCode::OK, "initialize should return 200");
let sid = sid.expect("stateful transport must assign an mcp-session-id");
assert!(body.contains("koi-mcp"), "serverInfo missing: {body}");
assert!(
body.contains("resources"),
"capabilities must advertise resources: {body}"
);
let (status, _, _) = post(
&app,
Some(&sid),
r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#,
)
.await;
assert!(status.is_success(), "initialized notif rejected: {status}");
let (status, _, body) = post(
&app,
Some(&sid),
r#"{"jsonrpc":"2.0","id":2,"method":"resources/list","params":{}}"#,
)
.await;
assert_eq!(status, StatusCode::OK);
assert!(
body.contains("koi://lan/inventory") && body.contains("koi://health"),
"resources/list missing expected URIs: {body}"
);
let (status, _, body) = post(
&app,
Some(&sid),
r#"{"jsonrpc":"2.0","id":3,"method":"resources/read","params":{"uri":"koi://health"}}"#,
)
.await;
assert_eq!(status, StatusCode::OK);
assert!(
body.contains("koi://health"),
"resources/read must return contents for the uri: {body}"
);
}
#[tokio::test]
async fn mcp_client_over_tcp_through_auth_layer() {
use axum::http::{HeaderName, HeaderValue};
use rmcp::transport::streamable_http_client::{
StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
};
use rmcp::ServiceExt as _;
let token = "itest-token";
let service = koi_mcp::streamable_http_service(std::sync::Arc::new(MockSource), Vec::new());
let expected = std::sync::Arc::new(token.to_string());
let app =
axum::Router::new()
.nest_service("/v1/mcp", service)
.layer(axum::middleware::from_fn(move |req, next| {
let expected = expected.clone();
crate::adapters::http::dat_auth_middleware(req, next, expected)
}));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let server = tokio::spawn(async move { axum::serve(listener, app).await });
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let url = format!("http://127.0.0.1:{port}/v1/mcp");
let mut headers = HashMap::new();
headers.insert(
HeaderName::from_static("x-koi-token"),
HeaderValue::from_str(token).unwrap(),
);
let config =
StreamableHttpClientTransportConfig::with_uri(url.clone()).custom_headers(headers);
let transport = StreamableHttpClientTransport::from_config(config);
let client = ().serve(transport).await.expect("authenticated client should initialize");
let tools = client.list_tools(None).await.expect("list_tools");
assert_eq!(
tools.tools.len(),
11,
"expected the 11 v1 tools over the wire"
);
let resources = client.list_resources(None).await.expect("list_resources");
assert!(
resources
.resources
.iter()
.any(|r| r.uri == "koi://lan/inventory"),
"resources/list over the wire must include the inventory resource"
);
let read = client
.read_resource(rmcp::model::ReadResourceRequestParams::new("koi://health"))
.await
.expect("read_resource");
assert!(!read.contents.is_empty(), "read must return contents");
let _ = client.cancel().await;
let config = StreamableHttpClientTransportConfig::with_uri(url);
let transport = StreamableHttpClientTransport::from_config(config);
match ().serve(transport).await {
Err(_) => {} Ok(client) => {
assert!(
client.list_tools(None).await.is_err(),
"a tokenless client must be rejected by the auth layer"
);
let _ = client.cancel().await;
}
}
server.abort();
}
}