#![allow(async_fn_in_trait)]
use serde::{Deserialize, Serialize};
use crate::plugin_protocol::{HookResult, PluginRegistration, PluginToolResult};
use crate::protocol::{Request, Response};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionCtx {
pub cwd: String,
pub session_id: String,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub project_name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallReq {
pub tool_call_id: String,
pub name: String,
pub arguments: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub cwd: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub session_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub project_name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HookReq {
pub name: String,
pub data: serde_json::Value,
}
#[myelin::service(api_id = 0x0001)]
pub trait PluginService {
async fn init(&self, ctx: SessionCtx);
async fn hook(&self, req: HookReq) -> HookResult;
async fn tool_call(&self, call: ToolCallReq) -> PluginToolResult;
async fn cancel_tool_call(&self, tool_call_id: String);
async fn session_start(&self, ctx: SessionCtx);
async fn idle(&self);
}
#[myelin::service(api_id = 0x0002)]
pub trait PluginCallbackService {
async fn register(&self, reg: PluginRegistration);
async fn server_request(&self, req: Request) -> Response;
async fn output_delta(&self, tool_call_id: String, text: String);
}
pub const DUPLEX_SLOTS: usize = 32;
pub const DUPLEX_BUF: usize = 131_072;
pub type PluginDuplex<R, W> = myelin::stream::DuplexStreamTransport<
R,
W,
myelin::stream::LengthPrefixed,
myelin::stream::CborCodec,
DUPLEX_SLOTS,
DUPLEX_BUF,
>;
#[cfg(test)]
mod tests {
use super::*;
use myelin::io::futures_io::{FuturesIoReader, FuturesIoWriter};
use myelin::transport::ServerTransport;
use smol::Async;
use std::os::unix::net::UnixStream;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
fn make_pair() -> (
FuturesIoReader<Async<UnixStream>>,
FuturesIoWriter<Async<UnixStream>>,
FuturesIoReader<Async<UnixStream>>,
FuturesIoWriter<Async<UnixStream>>,
) {
let (sa, sb) = UnixStream::pair().expect("UnixStream::pair");
sa.set_nonblocking(true).expect("nonblocking sa");
sb.set_nonblocking(true).expect("nonblocking sb");
let sa_w = sa.try_clone().expect("clone sa");
sa_w.set_nonblocking(true).expect("nonblocking sa_w");
let sb_w = sb.try_clone().expect("clone sb");
sb_w.set_nonblocking(true).expect("nonblocking sb_w");
let sa_r = Async::new(sa).expect("Async sa");
let sa_w = Async::new(sa_w).expect("Async sa_w");
let sb_r = Async::new(sb).expect("Async sb");
let sb_w = Async::new(sb_w).expect("Async sb_w");
(
FuturesIoReader::new(sa_r),
FuturesIoWriter::new(sa_w),
FuturesIoReader::new(sb_r),
FuturesIoWriter::new(sb_w),
)
}
struct PluginSide {
init_calls: Arc<AtomicU32>,
idle_calls: Arc<AtomicU32>,
}
impl PluginService for PluginSide {
async fn init(&self, _ctx: SessionCtx) {
self.init_calls.fetch_add(1, Ordering::SeqCst);
}
async fn hook(&self, _req: HookReq) -> HookResult {
HookResult::default()
}
async fn tool_call(&self, call: ToolCallReq) -> PluginToolResult {
PluginToolResult {
tool_call_id: call.tool_call_id,
content: vec![],
is_error: false,
summary: None,
post_persist_actions: vec![],
}
}
async fn cancel_tool_call(&self, _id: String) {}
async fn session_start(&self, _ctx: SessionCtx) {}
async fn idle(&self) {
self.idle_calls.fetch_add(1, Ordering::SeqCst);
}
}
struct ServerSide {
register_calls: Arc<AtomicU32>,
}
impl PluginCallbackService for ServerSide {
async fn register(&self, _reg: PluginRegistration) {
self.register_calls.fetch_add(1, Ordering::SeqCst);
}
async fn server_request(&self, _req: Request) -> Response {
Response::Ok
}
async fn output_delta(&self, _id: String, _text: String) {}
}
#[test]
fn duplex_round_trips_both_directions() {
let (r_srv, w_srv, r_plg, w_plg) = make_pair();
let dx_srv: PluginDuplex<_, _> = PluginDuplex::new(r_srv, w_srv);
let dx_plg: PluginDuplex<_, _> = PluginDuplex::new(r_plg, w_plg);
let srv_server = dx_srv
.server_half::<PluginCallbackRequest, PluginCallbackResponse>(PLUGIN_CALLBACK_API_ID);
let srv_client = dx_srv.client_half::<PluginRequest, PluginResponse>(PLUGIN_API_ID);
let plg_server = dx_plg.server_half::<PluginRequest, PluginResponse>(PLUGIN_API_ID);
let plg_client = dx_plg
.client_half::<PluginCallbackRequest, PluginCallbackResponse>(PLUGIN_CALLBACK_API_ID);
let (pump_srv, _h_srv) = dx_srv.split();
let (pump_plg, _h_plg) = dx_plg.split();
let plugin_init_count = Arc::new(AtomicU32::new(0));
let plugin_idle_count = Arc::new(AtomicU32::new(0));
let server_register_count = Arc::new(AtomicU32::new(0));
let plg_impl = PluginSide {
init_calls: plugin_init_count.clone(),
idle_calls: plugin_idle_count.clone(),
};
let srv_impl = ServerSide {
register_calls: server_register_count.clone(),
};
smol::block_on(async {
let mut plg_server = plg_server;
let mut srv_server = srv_server;
let plugin_dispatch = async move {
for _ in 0..4 {
let (req, token) = plg_server.recv().await.expect("plugin recv");
let resp = plugin_dispatch(&plg_impl, req).await;
plg_server.reply(token, resp).await.expect("plugin reply");
}
};
let server_dispatch = async move {
let (req, token) = srv_server.recv().await.expect("server recv");
let resp = plugin_callback_dispatch(&srv_impl, req).await;
srv_server.reply(token, resp).await.expect("server reply");
};
let work = async {
let plg_client = PluginCallbackClient::new(plg_client);
let _ = plg_client
.register(PluginRegistration {
name: "smoke".into(),
tools: vec![],
hooks: vec![],
commands: vec![],
})
.await;
let srv_client = PluginClient::new(srv_client);
let i1 = srv_client.init(SessionCtx {
cwd: "/tmp".into(),
session_id: "s1".into(),
project_name: None,
});
let i2 = srv_client.init(SessionCtx {
cwd: "/tmp".into(),
session_id: "s2".into(),
project_name: None,
});
let id1 = srv_client.idle();
let id2 = srv_client.idle();
let ((_a, _b), (_c, _d)) = futures_lite::future::zip(
futures_lite::future::zip(i1, i2),
futures_lite::future::zip(id1, id2),
)
.await;
};
futures_lite::future::or(
async {
futures_lite::future::zip(
work,
futures_lite::future::zip(plugin_dispatch, server_dispatch),
)
.await;
},
async {
let _ = futures_lite::future::zip(pump_srv.run(), pump_plg.run()).await;
},
)
.await;
});
assert_eq!(plugin_init_count.load(Ordering::SeqCst), 2);
assert_eq!(plugin_idle_count.load(Ordering::SeqCst), 2);
assert_eq!(server_register_count.load(Ordering::SeqCst), 1);
}
}