use std::path::Path;
use crate::client::{discover, Client, DiscoveredInstance};
use crate::types::{Query, QueryError, QueryResult};
pub async fn forward<F>(app_name: &str, q: &Query, fallback: F) -> QueryResult
where
F: FnOnce() -> QueryResult,
{
let mut all = discover(Some(app_name));
all.sort_by_key(|i| std::cmp::Reverse(i.pid));
for target in all {
match connect_and_query(&target.socket_path, q).await {
Ok(result) => return result,
Err(e) => {
tracing::debug!(
target = %target.app_name,
pid = target.pid,
error = ?e,
"kanshou socket unreachable (stale?); trying next"
);
}
}
}
fallback()
}
pub async fn forward_to_pid<F>(app_name: &str, pid: u32, q: &Query, fallback: F) -> QueryResult
where
F: FnOnce() -> QueryResult,
{
let target = discover(Some(app_name))
.into_iter()
.find(|inst| inst.pid == pid);
match target {
Some(t) => match connect_and_query(&t.socket_path, q).await {
Ok(result) => result,
Err(_) => fallback(),
},
None => fallback(),
}
}
async fn connect_and_query(socket: &Path, q: &Query) -> std::io::Result<QueryResult> {
let mut client = Client::connect(socket).await?;
client.query(q).await
}
pub enum ForwardOutcome {
Live { pid: u32, value: serde_json::Value },
Fallback { value: serde_json::Value },
LiveError { pid: u32, error: QueryError },
}
pub async fn forward_status<F>(
app_name: &str,
q: &Query,
fallback: F,
) -> ForwardOutcome
where
F: FnOnce() -> QueryResult,
{
let mut all = discover(Some(app_name));
all.sort_by_key(|i| std::cmp::Reverse(i.pid));
for target in all {
match connect_and_query(&target.socket_path, q).await {
Ok(Ok(value)) => {
return ForwardOutcome::Live {
pid: target.pid,
value,
}
}
Ok(Err(error)) => {
return ForwardOutcome::LiveError {
pid: target.pid,
error,
}
}
Err(e) => {
tracing::debug!(
pid = target.pid,
error = ?e,
"kanshou socket unreachable (stale?); trying next"
);
}
}
}
match fallback() {
Ok(value) => ForwardOutcome::Fallback { value },
Err(error) => ForwardOutcome::LiveError {
pid: 0,
error,
},
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
struct State;
impl crate::Introspect for State {
fn query(&self, q: &crate::Query) -> crate::QueryResult {
match q.path.as_slice() {
[s] if s == "ok" => Ok(serde_json::json!("live")),
_ => Err(QueryError::unknown_field(q.path.join("."))),
}
}
}
static APP: AtomicU64 = AtomicU64::new(0);
fn fresh_app_name() -> String {
format!(
"kanshou-mcp-test-{}-{}",
std::process::id(),
APP.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
)
}
#[tokio::test]
async fn forward_hits_live_consumer() {
let app = fresh_app_name();
let server = crate::Server::new(&app, Arc::new(State)).unwrap();
let server_task = tokio::spawn(async move {
let _ = server.serve().await;
});
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
let result = forward(&app, &Query::field(["ok"]), || {
Ok(serde_json::json!("fallback"))
})
.await;
assert_eq!(result.unwrap(), serde_json::json!("live"));
server_task.abort();
}
#[tokio::test]
async fn forward_uses_fallback_when_no_consumer() {
let app = fresh_app_name();
let result = forward(&app, &Query::field(["ok"]), || {
Ok(serde_json::json!("fallback"))
})
.await;
assert_eq!(result.unwrap(), serde_json::json!("fallback"));
}
}