use anyhow::Result;
use futures_core::Stream;
use futures_util::StreamExt;
use std::net::{Ipv4Addr, SocketAddr};
use std::path::Path;
#[cfg(unix)]
use std::path::PathBuf;
use transport::tcp::TcpConnection;
#[cfg(unix)]
use transport::uds::{ClientConfig, Connection, CrabtalkClient};
use wcore::protocol::{
api::Client,
message::{
AgentEventMsg, AskQuestion, ClientMessage, HubEvent, InstallPackageMsg, KillMsg,
ReplyToAsk, ServerMessage, SessionInfo, StreamMsg, SubscribeEvents, UninstallPackageMsg,
client_message, hub_event, server_message, stream_event,
},
};
pub enum OutputChunk {
Text(String),
Thinking(String),
ToolStart(Vec<(String, String)>),
ToolResult(String, String),
ToolDone(bool),
AskUser {
questions: Vec<AskQuestion>,
session: u64,
},
}
pub enum Transport {
#[cfg(unix)]
Uds(Connection),
Tcp(TcpConnection),
}
macro_rules! dispatch {
($self:expr, |$c:ident| $body:expr) => {
match $self {
#[cfg(unix)]
Transport::Uds($c) => $body,
Transport::Tcp($c) => $body,
}
};
}
impl Client for Transport {
async fn request(&mut self, msg: ClientMessage) -> Result<ServerMessage> {
dispatch!(self, |c| c.request(msg).await)
}
fn request_stream(
&mut self,
msg: ClientMessage,
) -> impl Stream<Item = Result<ServerMessage>> + Send + '_ {
async_stream::try_stream! {
dispatch!(self, |c| {
let s = c.request_stream(msg);
tokio::pin!(s);
while let Some(item) = s.next().await {
yield item?;
}
});
}
}
}
#[derive(Clone)]
pub enum ConnectionInfo {
#[cfg(unix)]
Uds(PathBuf),
Tcp(u16),
}
pub struct Runner {
transport: Transport,
pub conn_info: ConnectionInfo,
}
impl std::ops::Deref for Runner {
type Target = Transport;
fn deref(&self) -> &Self::Target {
&self.transport
}
}
impl std::ops::DerefMut for Runner {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.transport
}
}
impl Runner {
#[cfg(unix)]
pub async fn connect(socket_path: &Path) -> Result<Self> {
let config = ClientConfig {
socket_path: socket_path.to_path_buf(),
};
let client = CrabtalkClient::new(config);
let connection = client.connect().await?;
Ok(Self {
transport: Transport::Uds(connection),
conn_info: ConnectionInfo::Uds(socket_path.to_path_buf()),
})
}
pub async fn connect_tcp(port: u16) -> Result<Self> {
let addr = SocketAddr::from((Ipv4Addr::LOCALHOST, port));
let connection = TcpConnection::connect(addr).await?;
Ok(Self {
transport: Transport::Tcp(connection),
conn_info: ConnectionInfo::Tcp(port),
})
}
pub async fn connect_from(info: &ConnectionInfo) -> Result<Self> {
match info {
#[cfg(unix)]
ConnectionInfo::Uds(path) => Self::connect(path).await,
ConnectionInfo::Tcp(port) => Self::connect_tcp(*port).await,
}
}
pub fn stream<'a>(
&'a mut self,
agent: &'a str,
content: &'a str,
cwd: Option<&'a Path>,
new_chat: bool,
resume_file: Option<String>,
sender: Option<String>,
) -> impl Stream<Item = Result<OutputChunk>> + Send + 'a {
let cwd = cwd.map(|p| p.to_string_lossy().into_owned()).or_else(|| {
std::env::current_dir()
.ok()
.map(|p| p.to_string_lossy().into_owned())
});
self.transport
.request_stream(ClientMessage::from(StreamMsg {
agent: agent.to_string(),
content: content.to_string(),
session: None,
sender,
cwd,
new_chat,
resume_file,
}))
.take_while(|r| {
std::future::ready(!matches!(
r,
Ok(ServerMessage {
msg: Some(server_message::Msg::Stream(e))
}) if matches!(&e.event, Some(stream_event::Event::End(end)) if end.error.is_empty())
))
})
.scan(0u64, |session_id, result| {
let chunk = match result {
Ok(ServerMessage {
msg: Some(server_message::Msg::Stream(e)),
}) => match &e.event {
Some(stream_event::Event::Start(s)) => {
*session_id = s.session;
None
}
Some(stream_event::Event::Chunk(c)) => {
Some(Ok(OutputChunk::Text(c.content.clone())))
}
Some(stream_event::Event::Thinking(t)) => {
Some(Ok(OutputChunk::Thinking(t.content.clone())))
}
Some(stream_event::Event::ToolStart(ts)) => {
let calls: Vec<_> = ts
.calls
.iter()
.map(|c| (c.name.clone(), c.arguments.clone()))
.collect();
Some(Ok(OutputChunk::ToolStart(calls)))
}
Some(stream_event::Event::ToolResult(tr)) => Some(Ok(
OutputChunk::ToolResult(tr.call_id.clone(), tr.output.clone()),
)),
Some(stream_event::Event::ToolsComplete(_)) => {
Some(Ok(OutputChunk::ToolDone(true)))
}
Some(stream_event::Event::AskUser(ask)) => Some(Ok(OutputChunk::AskUser {
questions: ask.questions.clone(),
session: *session_id,
})),
Some(stream_event::Event::End(end)) if !end.error.is_empty() => {
Some(Err(anyhow::anyhow!("{}", end.error)))
}
Some(stream_event::Event::End(_)) => None,
None => None,
},
Ok(ServerMessage {
msg: Some(server_message::Msg::Error(e)),
}) => Some(Err(anyhow::anyhow!(
"server error ({}): {}",
e.code,
e.message
))),
Ok(_) => None,
Err(e) => Some(Err(e)),
};
std::future::ready(Some(chunk))
})
.filter_map(std::future::ready)
}
pub async fn list_sessions(&mut self) -> Result<Vec<SessionInfo>> {
let msg = ClientMessage {
msg: Some(client_message::Msg::Sessions(Default::default())),
};
match self.transport.request(msg).await? {
ServerMessage {
msg: Some(server_message::Msg::Sessions(sl)),
} => Ok(sl.sessions),
ServerMessage {
msg: Some(server_message::Msg::Error(e)),
} => {
anyhow::bail!("server error ({}): {}", e.code, e.message)
}
other => anyhow::bail!("unexpected response: {other:?}"),
}
}
pub async fn kill_session(&mut self, session: u64) -> Result<bool> {
let msg = ClientMessage {
msg: Some(client_message::Msg::Kill(KillMsg { session })),
};
match self.transport.request(msg).await? {
ServerMessage {
msg: Some(server_message::Msg::Pong(_)),
} => Ok(true),
ServerMessage {
msg: Some(server_message::Msg::Error(e)),
} if e.code == 404 => Ok(false),
ServerMessage {
msg: Some(server_message::Msg::Error(e)),
} => {
anyhow::bail!("server error ({}): {}", e.code, e.message)
}
other => anyhow::bail!("unexpected response: {other:?}"),
}
}
pub async fn reload(&mut self) -> Result<()> {
let msg = ClientMessage {
msg: Some(client_message::Msg::Reload(Default::default())),
};
match self.transport.request(msg).await? {
ServerMessage {
msg: Some(server_message::Msg::Pong(_)),
} => Ok(()),
ServerMessage {
msg: Some(server_message::Msg::Error(e)),
} => {
anyhow::bail!("server error ({}): {}", e.code, e.message)
}
other => anyhow::bail!("unexpected response: {other:?}"),
}
}
pub fn subscribe_events(&mut self) -> impl Stream<Item = Result<AgentEventMsg>> + Send + '_ {
self.transport
.request_stream(ClientMessage {
msg: Some(client_message::Msg::SubscribeEvents(SubscribeEvents {})),
})
.filter_map(|r| async {
match r {
Ok(ServerMessage {
msg: Some(server_message::Msg::AgentEvent(e)),
}) => Some(Ok(e)),
Ok(ServerMessage {
msg: Some(server_message::Msg::Error(e)),
}) => Some(Err(anyhow::anyhow!(
"server error ({}): {}",
e.code,
e.message
))),
Ok(_) => None,
Err(e) => Some(Err(e)),
}
})
}
pub fn install_package<'a>(
&'a mut self,
package: &str,
branch: &str,
path: &str,
force: bool,
) -> impl Stream<Item = Result<hub_event::Event>> + Send + 'a {
self.transport
.request_stream(ClientMessage {
msg: Some(client_message::Msg::InstallPackage(InstallPackageMsg {
package: package.to_string(),
branch: branch.to_string(),
path: path.to_string(),
force,
})),
})
.take_while(|r| {
std::future::ready(!matches!(
r,
Ok(ServerMessage {
msg: Some(server_message::Msg::HubEvent(HubEvent {
event: Some(hub_event::Event::Done(d))
}))
}) if d.error.is_empty()
))
})
.filter_map(|r| {
std::future::ready(match r {
Ok(ServerMessage {
msg: Some(server_message::Msg::HubEvent(e)),
}) => e.event.map(Ok),
Ok(ServerMessage {
msg: Some(server_message::Msg::Error(e)),
}) => Some(Err(anyhow::anyhow!(
"server error ({}): {}",
e.code,
e.message
))),
Ok(_) => None,
Err(e) => Some(Err(e)),
})
})
}
pub fn uninstall_package<'a>(
&'a mut self,
package: &str,
) -> impl Stream<Item = Result<hub_event::Event>> + Send + 'a {
self.transport
.request_stream(ClientMessage {
msg: Some(client_message::Msg::UninstallPackage(UninstallPackageMsg {
package: package.to_string(),
})),
})
.take_while(|r| {
std::future::ready(!matches!(
r,
Ok(ServerMessage {
msg: Some(server_message::Msg::HubEvent(HubEvent {
event: Some(hub_event::Event::Done(d))
}))
}) if d.error.is_empty()
))
})
.filter_map(|r| {
std::future::ready(match r {
Ok(ServerMessage {
msg: Some(server_message::Msg::HubEvent(e)),
}) => e.event.map(Ok),
Ok(ServerMessage {
msg: Some(server_message::Msg::Error(e)),
}) => Some(Err(anyhow::anyhow!(
"server error ({}): {}",
e.code,
e.message
))),
Ok(_) => None,
Err(e) => Some(Err(e)),
})
})
}
}
pub async fn send_reply(conn_info: &ConnectionInfo, session: u64, content: String) -> Result<()> {
let msg = ClientMessage::from(ReplyToAsk { session, content });
match conn_info {
#[cfg(unix)]
ConnectionInfo::Uds(path) => {
let client = CrabtalkClient::new(ClientConfig {
socket_path: path.clone(),
});
let mut conn = client.connect().await?;
conn.request(msg).await?;
}
ConnectionInfo::Tcp(port) => {
let addr = SocketAddr::from((Ipv4Addr::LOCALHOST, *port));
let mut conn = TcpConnection::connect(addr).await?;
conn.request(msg).await?;
}
}
Ok(())
}