use std::collections::HashMap;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::sync::mpsc;
use crate::jsonrpc::{JsonRpcError, JsonRpcRequest, JsonRpcResponse, RequestId};
use crate::error::ChannelSdkError;
use crate::trait_def::Channel;
use anyclaw_sdk_types::{
ChannelInitializeParams, ChannelInitializeResult, ChannelRequestPermission, ChannelSendMessage,
DeliverMessage, PermissionResponse, SessionCreated,
};
pub struct ChannelHarness<C: Channel> {
channel: C,
}
impl<C: Channel> ChannelHarness<C> {
pub fn new(channel: C) -> Self {
Self { channel }
}
pub async fn run_stdio(self) -> Result<(), ChannelSdkError> {
self.run(tokio::io::stdin(), tokio::io::stdout()).await
}
pub async fn run<R, W>(mut self, reader: R, mut writer: W) -> Result<(), ChannelSdkError>
where
R: tokio::io::AsyncRead + Unpin + Send,
W: tokio::io::AsyncWrite + Unpin + Send,
{
let mut lines = BufReader::new(reader).lines();
let (outbound_tx, mut outbound_rx) = mpsc::channel::<ChannelSendMessage>(64);
let (permission_tx, mut permission_rx) = mpsc::channel::<PermissionResponse>(16);
let mut pending_permissions: HashMap<String, RequestId> = HashMap::new();
loop {
tokio::select! {
line_result = lines.next_line() => {
match line_result {
Ok(Some(line)) => {
if line.trim().is_empty() {
continue;
}
let req: JsonRpcRequest = match serde_json::from_str(&line) {
Ok(v) => v,
Err(_) => continue,
};
if let Some(response) = self.dispatch(req, &outbound_tx, &permission_tx, &mut pending_permissions).await? {
Self::write_response(&mut writer, &response).await?;
}
}
Ok(None) | Err(_) => break,
}
}
Some(send_msg) = outbound_rx.recv() => {
Self::write_outbound(&mut writer, &send_msg).await?;
}
Some(perm_resp) = permission_rx.recv() => {
Self::flush_permission(&mut writer, perm_resp, &mut pending_permissions).await?;
}
}
}
while !pending_permissions.is_empty() {
tokio::select! {
Some(perm_resp) = permission_rx.recv() => {
Self::flush_permission(&mut writer, perm_resp, &mut pending_permissions).await?;
}
Some(send_msg) = outbound_rx.recv() => {
Self::write_outbound(&mut writer, &send_msg).await?;
}
else => break,
}
}
Ok(())
}
async fn write_response<W: tokio::io::AsyncWrite + Unpin>(
writer: &mut W,
msg: &JsonRpcResponse,
) -> Result<(), ChannelSdkError> {
let mut line = serde_json::to_vec(msg)?;
line.push(b'\n');
writer.write_all(&line).await?;
writer.flush().await?;
Ok(())
}
async fn write_outbound<W: tokio::io::AsyncWrite + Unpin>(
writer: &mut W,
send_msg: &ChannelSendMessage,
) -> Result<(), ChannelSdkError> {
let params = serde_json::to_value(send_msg)?;
let notification = JsonRpcRequest::new("channel/sendMessage", None, Some(params));
let mut line = serde_json::to_vec(¬ification)?;
line.push(b'\n');
writer.write_all(&line).await?;
writer.flush().await?;
Ok(())
}
async fn flush_permission<W: tokio::io::AsyncWrite + Unpin>(
writer: &mut W,
perm_resp: PermissionResponse,
pending_permissions: &mut HashMap<String, RequestId>,
) -> Result<(), ChannelSdkError> {
if let Some(jsonrpc_id) = pending_permissions.remove(&perm_resp.request_id) {
let result = serde_json::to_value(&perm_resp)?;
let response = JsonRpcResponse::success(Some(jsonrpc_id), result);
Self::write_response(writer, &response).await?;
} else {
tracing::warn!(
request_id = %perm_resp.request_id,
"received permission response for unknown request"
);
}
Ok(())
}
async fn dispatch(
&mut self,
req: JsonRpcRequest,
outbound_tx: &mpsc::Sender<ChannelSendMessage>,
permission_tx: &mpsc::Sender<PermissionResponse>,
pending_permissions: &mut HashMap<String, RequestId>,
) -> Result<Option<JsonRpcResponse>, ChannelSdkError> {
let id = req.id.clone();
let params = req.params.unwrap_or(serde_json::Value::Null);
match req.method.as_str() {
"initialize" => {
let caps = self.channel.capabilities();
let result = ChannelInitializeResult {
protocol_version: 1,
capabilities: caps,
defaults: self.channel.defaults(),
};
if let Ok(init_params) = serde_json::from_value::<ChannelInitializeParams>(params) {
if init_params.protocol_version != 1 {
tracing::warn!(
protocol_version = init_params.protocol_version,
"channel received unexpected protocol version; expected 1"
);
}
self.channel.on_initialize(init_params).await?;
}
self.channel
.on_ready(outbound_tx.clone(), permission_tx.clone())
.await?;
if let Some(req_id) = id {
return Ok(Some(JsonRpcResponse::success(
Some(req_id),
serde_json::to_value(&result)?,
)));
}
}
"channel/deliverMessage" => {
if let Ok(deliver) = serde_json::from_value::<DeliverMessage>(params) {
self.channel.deliver_message(deliver).await?;
}
}
"channel/sessionCreated" => {
if let Ok(msg) = serde_json::from_value::<SessionCreated>(params) {
self.channel.on_session_created(msg).await?;
}
}
"channel/requestPermission" => {
if let Ok(req) = serde_json::from_value::<ChannelRequestPermission>(params) {
tracing::debug!(request_id = %req.request_id, description = %req.description, "permission request dispatched to channel");
if let Some(req_id) = id {
pending_permissions.insert(req.request_id.clone(), req_id);
}
self.channel.show_permission_prompt(req).await?;
}
}
method => {
let result = self.channel.handle_unknown(method, params).await;
if let Some(req_id) = id {
return Ok(Some(match result {
Ok(val) => JsonRpcResponse::success(Some(req_id), val),
Err(e) => JsonRpcResponse::error(
Some(req_id),
JsonRpcError {
code: -32601,
message: e.to_string(),
data: None,
},
),
}));
}
}
}
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyclaw_sdk_types::{ChannelCapabilities, PermissionResponse};
use rstest::rstest;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct TestChannel {
on_ready_called: Arc<Mutex<bool>>,
delivered: Arc<Mutex<Vec<DeliverMessage>>>,
permission_tx: Arc<Mutex<Option<mpsc::Sender<PermissionResponse>>>>,
default_option_id: String,
}
impl TestChannel {
fn new() -> Self {
Self {
on_ready_called: Arc::new(Mutex::new(false)),
delivered: Arc::new(Mutex::new(Vec::new())),
permission_tx: Arc::new(Mutex::new(None)),
default_option_id: "allow".into(),
}
}
}
impl Channel for TestChannel {
fn capabilities(&self) -> ChannelCapabilities {
ChannelCapabilities {
streaming: true,
rich_text: false,
}
}
async fn on_ready(
&mut self,
_outbound: mpsc::Sender<ChannelSendMessage>,
permission_tx: mpsc::Sender<PermissionResponse>,
) -> Result<(), ChannelSdkError> {
*self.on_ready_called.lock().unwrap() = true;
*self.permission_tx.lock().unwrap() = Some(permission_tx);
Ok(())
}
async fn deliver_message(&mut self, msg: DeliverMessage) -> Result<(), ChannelSdkError> {
self.delivered.lock().unwrap().push(msg);
Ok(())
}
async fn show_permission_prompt(
&mut self,
req: ChannelRequestPermission,
) -> Result<(), ChannelSdkError> {
let tx = self.permission_tx.lock().unwrap().clone();
if let Some(tx) = tx {
let _ = tx
.send(PermissionResponse {
request_id: req.request_id,
option_id: self.default_option_id.clone(),
})
.await;
}
Ok(())
}
}
fn make_jsonrpc_request(id: u64, method: &str, params: serde_json::Value) -> String {
let msg = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
});
format!("{}\n", serde_json::to_string(&msg).unwrap())
}
fn make_jsonrpc_notification(method: &str, params: serde_json::Value) -> String {
let msg = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": params,
});
format!("{}\n", serde_json::to_string(&msg).unwrap())
}
fn parse_responses(output: &[u8]) -> Vec<serde_json::Value> {
let text = String::from_utf8_lossy(output);
text.lines()
.filter(|l| !l.trim().is_empty())
.filter_map(|l| serde_json::from_str(l).ok())
.collect()
}
#[tokio::test]
async fn when_channel_harness_created_then_constructs_successfully() {
let ch = TestChannel::new();
let _harness = ChannelHarness::new(ch);
}
#[tokio::test]
async fn when_initialize_request_received_then_harness_responds_with_capabilities_and_calls_on_ready()
{
let ch = TestChannel::new();
let on_ready_called = ch.on_ready_called.clone();
let input =
make_jsonrpc_request(1, "initialize", serde_json::json!({"protocolVersion": 1}));
let reader = std::io::Cursor::new(input.into_bytes());
let mut output = Vec::new();
let harness = ChannelHarness::new(ch);
harness.run(reader, &mut output).await.unwrap();
assert!(*on_ready_called.lock().unwrap());
let responses = parse_responses(&output);
assert_eq!(responses.len(), 1);
assert_eq!(responses[0]["id"], 1);
assert_eq!(responses[0]["result"]["capabilities"]["streaming"], true);
assert_eq!(responses[0]["result"]["capabilities"]["richText"], false);
assert_eq!(responses[0]["result"]["protocolVersion"], 1);
}
#[tokio::test]
async fn when_deliver_message_notification_received_then_harness_calls_channel_deliver_message()
{
let ch = TestChannel::new();
let delivered = ch.delivered.clone();
let mut input =
make_jsonrpc_request(1, "initialize", serde_json::json!({"protocolVersion": 1}));
input.push_str(&make_jsonrpc_notification(
"channel/deliverMessage",
serde_json::json!({"sessionId": "s1", "content": "hello"}),
));
let reader = std::io::Cursor::new(input.into_bytes());
let mut output = Vec::new();
let harness = ChannelHarness::new(ch);
harness.run(reader, &mut output).await.unwrap();
let msgs = delivered.lock().unwrap();
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0].session_id, "s1");
assert_eq!(msgs[0].content, "hello");
}
#[tokio::test]
async fn when_request_permission_received_then_harness_calls_channel_and_sends_response() {
let ch = TestChannel::new();
let mut input =
make_jsonrpc_request(1, "initialize", serde_json::json!({"protocolVersion": 1}));
input.push_str(&make_jsonrpc_request(
2,
"channel/requestPermission",
serde_json::json!({
"requestId": "perm-1",
"sessionId": "s1",
"description": "Allow?",
"options": [{"optionId": "allow", "label": "Allow"}]
}),
));
let reader = std::io::Cursor::new(input.into_bytes());
let mut output = Vec::new();
let harness = ChannelHarness::new(ch);
harness.run(reader, &mut output).await.unwrap();
let responses = parse_responses(&output);
assert_eq!(responses.len(), 2);
assert_eq!(responses[1]["id"], 2);
assert_eq!(responses[1]["result"]["requestId"], "perm-1");
assert_eq!(responses[1]["result"]["optionId"], "allow");
}
#[tokio::test]
async fn when_unknown_method_received_then_harness_calls_handle_unknown_and_returns_error() {
let ch = TestChannel::new();
let mut input =
make_jsonrpc_request(1, "initialize", serde_json::json!({"protocolVersion": 1}));
input.push_str(&make_jsonrpc_request(
2,
"custom/method",
serde_json::json!({}),
));
let reader = std::io::Cursor::new(input.into_bytes());
let mut output = Vec::new();
let harness = ChannelHarness::new(ch);
harness.run(reader, &mut output).await.unwrap();
let responses = parse_responses(&output);
assert_eq!(responses.len(), 2);
assert_eq!(responses[1]["id"], 2);
assert!(
responses[1]["error"]["message"]
.as_str()
.unwrap()
.contains("custom/method")
);
}
#[tokio::test]
async fn when_reader_reaches_eof_then_harness_exits_cleanly() {
let ch = TestChannel::new();
let reader = std::io::Cursor::new(Vec::<u8>::new());
let mut output = Vec::new();
let harness = ChannelHarness::new(ch);
let result = harness.run(reader, &mut output).await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn when_channel_tester_initialized_then_on_ready_called() {
use crate::testing::ChannelTester;
let ch = TestChannel::new();
let on_ready_called = ch.on_ready_called.clone();
let mut tester = ChannelTester::new(ch);
tester.initialize(None).await.unwrap();
assert!(*on_ready_called.lock().unwrap());
}
#[rstest]
#[tokio::test]
async fn when_channel_tester_delivers_message_then_channel_receives_it() {
use crate::testing::ChannelTester;
let ch = TestChannel::new();
let delivered = ch.delivered.clone();
let mut tester = ChannelTester::new(ch);
tester.initialize(None).await.unwrap();
tester
.deliver(DeliverMessage {
session_id: "s1".into(),
content: serde_json::json!("test-msg"),
})
.await
.unwrap();
let msgs = delivered.lock().unwrap();
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0].session_id, "s1");
}
#[tokio::test]
async fn when_channel_returns_defaults_then_init_response_includes_defaults() {
struct ChannelWithDefaults;
impl Channel for ChannelWithDefaults {
fn capabilities(&self) -> ChannelCapabilities {
ChannelCapabilities {
streaming: false,
rich_text: false,
}
}
fn defaults(&self) -> Option<std::collections::HashMap<String, serde_json::Value>> {
let mut map = std::collections::HashMap::new();
map.insert("timeout".into(), serde_json::json!(60));
Some(map)
}
async fn on_ready(
&mut self,
_outbound: mpsc::Sender<ChannelSendMessage>,
_permission_tx: mpsc::Sender<PermissionResponse>,
) -> Result<(), ChannelSdkError> {
Ok(())
}
async fn deliver_message(
&mut self,
_msg: DeliverMessage,
) -> Result<(), ChannelSdkError> {
Ok(())
}
async fn show_permission_prompt(
&mut self,
_req: anyclaw_sdk_types::ChannelRequestPermission,
) -> Result<(), ChannelSdkError> {
Ok(())
}
}
let input =
make_jsonrpc_request(1, "initialize", serde_json::json!({"protocolVersion": 1}));
let reader = std::io::Cursor::new(input.into_bytes());
let mut output = Vec::new();
let harness = ChannelHarness::new(ChannelWithDefaults);
harness.run(reader, &mut output).await.unwrap();
let responses = parse_responses(&output);
assert_eq!(responses.len(), 1);
assert_eq!(responses[0]["result"]["defaults"]["timeout"], 60);
}
#[tokio::test]
async fn when_channel_returns_no_defaults_then_init_response_omits_defaults() {
let ch = TestChannel::new();
let input =
make_jsonrpc_request(1, "initialize", serde_json::json!({"protocolVersion": 1}));
let reader = std::io::Cursor::new(input.into_bytes());
let mut output = Vec::new();
let harness = ChannelHarness::new(ch);
harness.run(reader, &mut output).await.unwrap();
let responses = parse_responses(&output);
assert_eq!(responses.len(), 1);
assert!(responses[0]["result"].get("defaults").is_none());
}
}