use std::collections::HashMap;
use std::sync::Arc;
use anyhow::{Context, Result, bail};
use kanade_shared::ipc::envelope::{
JSONRPC_VERSION, RpcMessage, RpcNotification, RpcRequest, RpcResponse, RpcResponsePayload,
};
use kanade_shared::ipc::handshake::{HandshakeParams, HandshakeResult, PROTOCOL_V1};
use kanade_shared::ipc::jobs::{
JobsExecuteParams, JobsExecuteResult, JobsKillParams, JobsKillResult, JobsListParams,
JobsListResult,
};
use kanade_shared::ipc::method;
use kanade_shared::ipc::state::{StateSnapshot, StateSnapshotParams};
use kanade_shared::ipc::system::{PingParams, PingResult};
use serde::{Serialize, de::DeserializeOwned};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, WriteHalf, split};
use tokio::net::windows::named_pipe::{ClientOptions, NamedPipeClient};
use tokio::sync::{Mutex, broadcast, oneshot};
use tracing::{debug, info, warn};
const PIPE_NAME: &str = r"\\.\pipe\kanade-agent";
const MAX_FRAME_BYTES: usize = 1024 * 1024;
const NOTIFICATION_CAPACITY: usize = 256;
const CLIENT_NAME: &str = "kanade-client";
const CLIENT_VERSION: &str = env!("CARGO_PKG_VERSION");
type Pending = Arc<std::sync::Mutex<HashMap<String, oneshot::Sender<RpcResponse>>>>;
struct PendingGuard {
pending: Pending,
id: String,
}
impl Drop for PendingGuard {
fn drop(&mut self) {
self.pending.lock().unwrap().remove(&self.id);
}
}
#[derive(Clone)]
pub struct KlpClient {
write: Arc<Mutex<WriteHalf<NamedPipeClient>>>,
pending: Pending,
handshake: Arc<HandshakeResult>,
notifications: broadcast::Sender<RpcNotification>,
}
impl KlpClient {
pub async fn connect() -> Result<Self> {
let mut pipe = ClientOptions::new()
.open(PIPE_NAME)
.with_context(|| format!("open Named Pipe {PIPE_NAME}"))?;
info!(pipe = PIPE_NAME, "KLP client: pipe connected");
let handshake = handshake(&mut pipe).await.context("KLP handshake")?;
info!(
agent_version = %handshake.agent_version,
user = %handshake.session.user,
session_id = handshake.session.session_id,
pc_id = %handshake.session.pc_id,
"KLP client: handshake complete",
);
let (read, write) = split(pipe);
let pending: Pending = Arc::new(std::sync::Mutex::new(HashMap::new()));
let (notifications, _) = broadcast::channel(NOTIFICATION_CAPACITY);
tokio::spawn(reader_loop(read, pending.clone(), notifications.clone()));
Ok(Self {
write: Arc::new(Mutex::new(write)),
pending,
handshake: Arc::new(handshake),
notifications,
})
}
pub fn handshake(&self) -> Arc<HandshakeResult> {
self.handshake.clone()
}
pub fn subscribe(&self) -> broadcast::Receiver<RpcNotification> {
self.notifications.subscribe()
}
pub async fn request<P: Serialize, R: DeserializeOwned>(
&self,
method: &str,
params: &P,
) -> Result<R> {
let id = uuid::Uuid::new_v4().to_string();
let req = RpcRequest::new(&id, method, params).context("encode KLP request")?;
let body = serde_json::to_vec(&req).context("serialise KLP request")?;
let rx = {
let (tx, rx) = oneshot::channel();
self.pending.lock().unwrap().insert(id.clone(), tx);
rx
};
let _guard = PendingGuard {
pending: self.pending.clone(),
id: id.clone(),
};
let write_result = {
let mut writer = self.write.lock().await;
write_frame(&mut *writer, &body).await
};
write_result.context("write frame")?;
let resp = rx
.await
.map_err(|_| anyhow::anyhow!("KLP connection closed before response (id {id})"))?;
decode_response::<R>(resp)
}
pub async fn ping(&self) -> Result<PingResult> {
self.request::<PingParams, PingResult>(method::SYSTEM_PING, &PingParams::default())
.await
}
pub async fn snapshot(&self) -> Result<StateSnapshot> {
self.request::<StateSnapshotParams, StateSnapshot>(
method::STATE_SNAPSHOT,
&StateSnapshotParams::default(),
)
.await
}
pub async fn jobs_list(&self, params: &JobsListParams) -> Result<JobsListResult> {
self.request::<JobsListParams, JobsListResult>(method::JOBS_LIST, params)
.await
}
pub async fn jobs_execute(&self, id: &str) -> Result<JobsExecuteResult> {
self.request::<JobsExecuteParams, JobsExecuteResult>(
method::JOBS_EXECUTE,
&JobsExecuteParams { id: id.to_string() },
)
.await
}
pub async fn jobs_kill(&self, run_id: &str) -> Result<JobsKillResult> {
self.request::<JobsKillParams, JobsKillResult>(
method::JOBS_KILL,
&JobsKillParams {
run_id: run_id.to_string(),
},
)
.await
}
}
async fn reader_loop<R: AsyncRead + Unpin>(
mut read: R,
pending: Pending,
notifications: broadcast::Sender<RpcNotification>,
) {
loop {
let bytes = match read_frame(&mut read).await {
Ok(b) => b,
Err(e) => {
debug!(error = %e, "klp reader: pipe closed, exiting");
break;
}
};
let msg: RpcMessage = match serde_json::from_slice(&bytes) {
Ok(m) => m,
Err(e) => {
warn!(error = %e, "klp reader: undecodable frame, skipping");
continue;
}
};
match msg {
RpcMessage::Response(resp) => match resp.id.as_deref() {
Some(id) => {
let waiter = pending.lock().unwrap().remove(id);
match waiter {
Some(tx) => {
let _ = tx.send(resp);
}
None => {
debug!(id, "klp reader: response for unknown/expired request")
}
}
}
None => debug!("klp reader: response without id, ignoring"),
},
RpcMessage::Notification(notif) => {
let _ = notifications.send(notif);
}
RpcMessage::Request(_) => {
debug!("klp reader: agent sent a Request (unexpected), ignoring");
}
}
}
pending.lock().unwrap().clear();
}
fn decode_response<R: DeserializeOwned>(resp: RpcResponse) -> Result<R> {
match resp.payload {
RpcResponsePayload::Ok { result } => {
serde_json::from_value(result).context("decode typed result")
}
RpcResponsePayload::Err { error } => {
let detail = error
.data
.as_ref()
.map(|d| d.detail.clone())
.unwrap_or_default();
bail!("KLP error {} ({}): {detail}", error.code, error.message);
}
}
}
async fn handshake(pipe: &mut NamedPipeClient) -> Result<HandshakeResult> {
let id = uuid::Uuid::new_v4().to_string();
let req = RpcRequest::new(
&id,
method::SYSTEM_HANDSHAKE,
&HandshakeParams {
client: CLIENT_NAME.to_string(),
client_version: CLIENT_VERSION.to_string(),
protocol: vec![PROTOCOL_V1],
features: vec![],
},
)
.context("encode handshake request")?;
let body = serde_json::to_vec(&req).context("serialise handshake request")?;
write_frame(pipe, &body).await.context("write handshake")?;
let resp_bytes = read_frame(pipe).await.context("read handshake response")?;
let msg: RpcMessage = serde_json::from_slice(&resp_bytes).context("decode envelope")?;
let RpcMessage::Response(resp) = msg else {
bail!("expected handshake Response, got {msg:?}");
};
if resp.id.as_deref() != Some(id.as_str()) {
bail!(
"handshake response id mismatch: expected {id:?}, got {:?}",
resp.id
);
}
if resp.jsonrpc != JSONRPC_VERSION {
debug!(jsonrpc = %resp.jsonrpc, "unexpected jsonrpc field (proceeding)");
}
decode_response::<HandshakeResult>(resp)
}
async fn read_frame<R: AsyncRead + Unpin>(reader: &mut R) -> std::io::Result<Vec<u8>> {
let mut len_bytes = [0u8; 4];
reader.read_exact(&mut len_bytes).await?;
let len = u32::from_le_bytes(len_bytes) as usize;
if len > MAX_FRAME_BYTES {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("KLP frame {len} bytes exceeds 1 MiB cap"),
));
}
let mut body = vec![0u8; len];
reader.read_exact(&mut body).await?;
Ok(body)
}
async fn write_frame<W: AsyncWrite + Unpin>(writer: &mut W, body: &[u8]) -> std::io::Result<()> {
if body.len() > MAX_FRAME_BYTES {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("KLP frame {} bytes exceeds 1 MiB cap", body.len()),
));
}
let len = (body.len() as u32).to_le_bytes();
writer.write_all(&len).await?;
writer.write_all(body).await?;
writer.flush().await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use kanade_shared::ipc::jobs::{JobProgress, RunStatus};
async fn push_frame<W: AsyncWrite + Unpin>(w: &mut W, value: &serde_json::Value) {
let body = serde_json::to_vec(value).unwrap();
write_frame(w, &body).await.unwrap();
}
#[tokio::test]
async fn reader_routes_response_to_pending_request() {
let (client_side, mut agent_side) = tokio::io::duplex(64 * 1024);
let pending: Pending = Arc::new(std::sync::Mutex::new(HashMap::new()));
let (notifications, _rx) = broadcast::channel(16);
let (tx, rx) = oneshot::channel();
pending.lock().unwrap().insert("req-1".into(), tx);
tokio::spawn(reader_loop(client_side, pending.clone(), notifications));
push_frame(
&mut agent_side,
&serde_json::json!({
"jsonrpc": "2.0",
"id": "req-1",
"result": { "run_id": "run-xyz" }
}),
)
.await;
let resp = rx.await.expect("reader should deliver the response");
assert_eq!(resp.id.as_deref(), Some("req-1"));
assert!(pending.lock().unwrap().is_empty(), "pending entry consumed");
}
#[tokio::test]
async fn reader_forwards_notification_to_subscribers() {
let (client_side, mut agent_side) = tokio::io::duplex(64 * 1024);
let pending: Pending = Arc::new(std::sync::Mutex::new(HashMap::new()));
let (notifications, mut sub) = broadcast::channel(16);
tokio::spawn(reader_loop(client_side, pending, notifications));
let progress = JobProgress {
run_id: "run-1".into(),
status: RunStatus::Running,
stdout_chunk: None,
stderr_chunk: None,
exit_code: None,
};
let notif = RpcNotification::new(method::JOBS_PROGRESS, &progress).unwrap();
push_frame(&mut agent_side, &serde_json::to_value(¬if).unwrap()).await;
let got = sub.recv().await.expect("notification forwarded");
assert_eq!(got.method, method::JOBS_PROGRESS);
assert_eq!(got.params["run_id"], "run-1");
assert_eq!(got.params["status"], "running");
}
#[tokio::test]
async fn reader_exit_fails_pending_requests() {
let (client_side, agent_side) = tokio::io::duplex(64 * 1024);
let pending: Pending = Arc::new(std::sync::Mutex::new(HashMap::new()));
let (notifications, _rx) = broadcast::channel(16);
let (tx, rx) = oneshot::channel::<RpcResponse>();
pending.lock().unwrap().insert("req-orphan".into(), tx);
let handle = tokio::spawn(reader_loop(client_side, pending, notifications));
drop(agent_side);
handle.await.unwrap();
assert!(
rx.await.is_err(),
"pending request should be failed, not hung"
);
}
}