use std::collections::HashMap;
use kimi_wire::{
protocol::*,
transport::{ChannelTransport, Transport, TransportWireClient},
WireClient,
};
async fn agent_send_request(
transport: &mut ChannelTransport,
id: &str,
request: Request,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let req = JsonRpcRequest {
jsonrpc: JsonRpcVersion,
method: "request".to_string(),
id: id.to_string(),
params: request,
};
let line = serde_json::to_string(&req)?;
transport.write_line(&line).await?;
Ok(())
}
async fn agent_read_response<R: serde::de::DeserializeOwned>(
transport: &mut ChannelTransport,
) -> Result<JsonRpcSuccessResponse<R>, Box<dyn std::error::Error + Send + Sync>> {
let line = transport.read_line().await?.ok_or("stream closed")?;
let resp: JsonRpcSuccessResponse<R> = serde_json::from_str(&line)?;
Ok(resp)
}
async fn agent_read_error(
transport: &mut ChannelTransport,
) -> Result<JsonRpcErrorResponse, Box<dyn std::error::Error + Send + Sync>> {
let line = transport.read_line().await?.ok_or("stream closed")?;
let resp: JsonRpcErrorResponse = serde_json::from_str(&line)?;
Ok(resp)
}
async fn read_request(client: &mut TransportWireClient<ChannelTransport>) -> Request {
let raw = client.read_raw_message().await.unwrap();
assert_eq!(raw.method.as_deref(), Some("request"));
serde_json::from_value(raw.params.expect("params present")).unwrap()
}
async fn read_request_with_rpc_id(
client: &mut TransportWireClient<ChannelTransport>,
) -> (Request, String) {
let raw = client.read_raw_message().await.unwrap();
assert_eq!(raw.method.as_deref(), Some("request"));
let request: Request = serde_json::from_value(raw.params.expect("params present")).unwrap();
(request, raw.id.expect("id present"))
}
fn approval_request(id: &str, tool_call_id: &str, description: &str) -> Request {
Request::ApprovalRequest(ApprovalRequest {
id: id.to_string(),
tool_call_id: tool_call_id.to_string(),
sender: "Shell".to_string(),
action: "test action".to_string(),
description: description.to_string(),
display: None,
source_kind: None,
source_id: None,
agent_id: None,
subagent_type: None,
source_description: None,
})
}
#[tokio::test]
async fn test_bidirectional_approval_request_response_flow() {
let (client_transport, mut agent_transport) = ChannelTransport::pair();
let mut client = TransportWireClient::new(client_transport);
let agent = tokio::spawn(async move {
let request = Request::ApprovalRequest(ApprovalRequest {
id: "approval-req-1".to_string(),
tool_call_id: "tc-42".to_string(),
sender: "Shell".to_string(),
action: "run shell command".to_string(),
description: "Run `ls -la`".to_string(),
display: Some(vec![DisplayBlock::brief("listing directory contents")]),
source_kind: Some(SourceKind::ForegroundTurn),
source_id: None,
agent_id: None,
subagent_type: None,
source_description: None,
});
agent_send_request(&mut agent_transport, "req-1", request).await?;
let resp: JsonRpcSuccessResponse<ApprovalResponse> =
agent_read_response(&mut agent_transport).await?;
assert_eq!(resp.id, "req-1");
assert_eq!(resp.result.request_id, "approval-req-1");
assert_eq!(resp.result.response, ApprovalResponseKind::Approve);
Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
});
let request = read_request(&mut client).await;
assert!(matches!(request, Request::ApprovalRequest(ref ar) if ar.id == "approval-req-1"));
let response = ApprovalResponse {
request_id: "approval-req-1".to_string(),
response: ApprovalResponseKind::Approve,
feedback: None,
};
client.send_response("req-1", &response).await.unwrap();
agent.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_bidirectional_tool_call_request_response_flow() {
let (client_transport, mut agent_transport) = ChannelTransport::pair();
let mut client = TransportWireClient::new(client_transport);
let agent = tokio::spawn(async move {
let request = Request::ToolCallRequest(ToolCallRequest {
id: "tool-req-1".to_string(),
name: "write_file".to_string(),
arguments: Some(r#"{"path": "/tmp/hello.txt", "content": "hi"}"#.to_string()),
});
agent_send_request(&mut agent_transport, "req-2", request).await?;
let resp: JsonRpcSuccessResponse<ToolCallResponse> =
agent_read_response(&mut agent_transport).await?;
assert_eq!(resp.id, "req-2");
assert_eq!(resp.result.tool_call_id, "tool-req-1");
assert_eq!(
resp.result.return_value,
ToolReturnValue::new("done").with_output("output text")
);
Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
});
let request = read_request(&mut client).await;
assert!(matches!(request, Request::ToolCallRequest(ref tcr) if tcr.id == "tool-req-1"));
let response = ToolCallResponse {
tool_call_id: "tool-req-1".to_string(),
return_value: ToolReturnValue::new("done").with_output("output text"),
};
client.send_response("req-2", &response).await.unwrap();
agent.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_bidirectional_question_request_response_flow() {
let (client_transport, mut agent_transport) = ChannelTransport::pair();
let mut client = TransportWireClient::new(client_transport);
let agent = tokio::spawn(async move {
let request = Request::QuestionRequest(QuestionRequest {
id: "question-req-1".to_string(),
tool_call_id: "tc-q-1".to_string(),
questions: vec![QuestionItem {
question: "Which region?".to_string(),
header: Some("region".to_string()),
options: vec![
QuestionOption {
label: "us-east-1".to_string(),
description: Some("N. Virginia".to_string()),
},
QuestionOption {
label: "eu-west-1".to_string(),
description: Some("Ireland".to_string()),
},
],
multi_select: Some(false),
}],
});
agent_send_request(&mut agent_transport, "req-3", request).await?;
let resp: JsonRpcSuccessResponse<QuestionResponse> =
agent_read_response(&mut agent_transport).await?;
assert_eq!(resp.id, "req-3");
assert_eq!(resp.result.request_id, "question-req-1");
assert_eq!(
resp.result.answers,
HashMap::from([("Which region?".to_string(), "us-east-1".to_string())])
);
Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
});
let request = read_request(&mut client).await;
assert!(matches!(request, Request::QuestionRequest(ref qr) if qr.id == "question-req-1"));
let response = QuestionResponse {
request_id: "question-req-1".to_string(),
answers: HashMap::from([("Which region?".to_string(), "us-east-1".to_string())]),
};
client.send_response("req-3", &response).await.unwrap();
agent.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_bidirectional_hook_request_response_flow() {
let (client_transport, mut agent_transport) = ChannelTransport::pair();
let mut client = TransportWireClient::new(client_transport);
let agent = tokio::spawn(async move {
let request = Request::HookRequest(HookRequest {
id: "hook-req-1".to_string(),
subscription_id: "sub-1".to_string(),
event: "before_tool_call".to_string(),
target: "write_file".to_string(),
input_data: serde_json::json!({"path": "/tmp/test.txt"}),
});
agent_send_request(&mut agent_transport, "req-4", request).await?;
let resp: JsonRpcSuccessResponse<HookResponse> =
agent_read_response(&mut agent_transport).await?;
assert_eq!(resp.id, "req-4");
assert_eq!(resp.result.request_id, "hook-req-1");
assert_eq!(resp.result.action, HookAction::Allow);
assert_eq!(resp.result.reason, "test allow");
Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
});
let request = read_request(&mut client).await;
assert!(matches!(request, Request::HookRequest(ref hr) if hr.id == "hook-req-1"));
let response = HookResponse {
request_id: "hook-req-1".to_string(),
action: HookAction::Allow,
reason: "test allow".to_string(),
};
client.send_response("req-4", &response).await.unwrap();
agent.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_bidirectional_request_with_send_error() {
let (client_transport, mut agent_transport) = ChannelTransport::pair();
let mut client = TransportWireClient::new(client_transport);
let agent = tokio::spawn(async move {
let request = approval_request("approval-req-error", "tc-e-1", "Dangerous command");
agent_send_request(&mut agent_transport, "req-err", request).await?;
let resp = agent_read_error(&mut agent_transport).await?;
assert_eq!(resp.id, "req-err");
assert_eq!(resp.error.code, -32000);
assert_eq!(resp.error.message, "user rejected approval");
Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
});
let request = read_request(&mut client).await;
assert!(matches!(request, Request::ApprovalRequest(ref ar) if ar.id == "approval-req-error"));
client
.send_error("req-err", -32000, "user rejected approval")
.await
.unwrap();
agent.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_bidirectional_multiple_in_flight_requests_handled_in_order() {
let (client_transport, mut agent_transport) = ChannelTransport::pair();
let mut client = TransportWireClient::new(client_transport);
let agent = tokio::spawn(async move {
for id in &["r1", "r2", "r3"] {
let request = approval_request(
&format!("approval-{id}"),
&format!("tc-{id}"),
&format!("request {id}"),
);
agent_send_request(&mut agent_transport, id, request).await?;
}
let mut ids = Vec::new();
for _ in 0..3 {
let resp: JsonRpcSuccessResponse<ApprovalResponse> =
agent_read_response(&mut agent_transport).await?;
ids.push(resp.id);
}
assert_eq!(ids, vec!["r3", "r2", "r1"]);
Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
});
let mut requests = Vec::new();
for _ in 0..3 {
let (request, rpc_id) = read_request_with_rpc_id(&mut client).await;
let req_id = match &request {
Request::ApprovalRequest(ar) => ar.id.clone(),
other => unreachable!("unexpected variant: {:?}", other),
};
requests.push((rpc_id, req_id));
}
for (rpc_id, req_id) in requests.into_iter().rev() {
let response = ApprovalResponse {
request_id: req_id,
response: ApprovalResponseKind::Approve,
feedback: None,
};
client.send_response(&rpc_id, &response).await.unwrap();
}
agent.await.unwrap().unwrap();
}