use std::path::{Path, PathBuf};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::UnixStream;
use tokio_util::bytes::{Bytes, BytesMut};
use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
use super::cursor::Cursor;
use super::ids::{AgentId, RoomId};
use super::model::{AgentCard, AgentRecord, Room, RoomScope};
use super::protocol::{
CommsNotification, CommsOut, CommsRequest, CommsResponse, PROTO_VER, SeqMeta, StatusReport,
};
use super::singleton::{self, CommsPaths};
use super::transport::MAX_FRAME_BYTES;
const READ_CHUNK: usize = 8 * 1024;
type SpawnFn = Box<dyn Fn(&CommsPaths) -> std::io::Result<()> + Send + Sync>;
#[derive(Debug, thiserror::Error)]
pub enum CommsClientError {
#[error("comms transport error: {0}")]
Io(#[from] std::io::Error),
#[error("encode error: {0}")]
Encode(#[from] rmp_serde::encode::Error),
#[error("decode error: {0}")]
Decode(#[from] rmp_serde::decode::Error),
#[error(transparent)]
Singleton(#[from] super::singleton::SingletonError),
#[error("connection closed before a response was received")]
Closed,
#[error("broker error [{code}]: {message}")]
Broker {
code: String,
message: String,
},
#[error("unexpected response shape for {request}")]
Unexpected {
request: &'static str,
},
#[error("protocol skew: daemon speaks {daemon}, client speaks {client}")]
ProtoSkew {
daemon: u32,
client: u32,
},
}
pub struct CommsClient {
stream: UnixStream,
codec: LengthDelimitedCodec,
read_buf: BytesMut,
agent: AgentId,
pending_notifications: std::collections::VecDeque<CommsNotification>,
paths: CommsPaths,
remote: Option<String>,
cwd: Option<PathBuf>,
spawn: SpawnFn,
}
impl CommsClient {
pub async fn connect(
paths: &CommsPaths,
agent: AgentId,
remote: Option<String>,
cwd: Option<PathBuf>,
) -> Result<Self, CommsClientError> {
Self::connect_with_respawn(paths, agent, remote, cwd, |paths| {
singleton::spawn_detached_daemon(paths)
})
.await
}
pub async fn connect_with_respawn(
paths: &CommsPaths,
agent: AgentId,
remote: Option<String>,
cwd: Option<PathBuf>,
spawn: impl Fn(&CommsPaths) -> std::io::Result<()> + Send + Sync + 'static,
) -> Result<Self, CommsClientError> {
let (stream, codec) = Self::dial(paths).await?;
let mut client = Self {
stream,
codec,
read_buf: BytesMut::with_capacity(READ_CHUNK),
agent,
pending_notifications: std::collections::VecDeque::new(),
paths: paths.clone(),
remote,
cwd,
spawn: Box::new(spawn),
};
client.handshake().await?;
Ok(client)
}
pub async fn ensure_and_connect(
agent: AgentId,
remote: Option<String>,
cwd: Option<PathBuf>,
) -> Result<Self, CommsClientError> {
let paths = singleton::resolve_paths()?;
singleton::ensure_daemon(&paths).await?;
Self::connect(&paths, agent, remote, cwd).await
}
async fn dial(
paths: &CommsPaths,
) -> Result<(UnixStream, LengthDelimitedCodec), CommsClientError> {
let stream = UnixStream::connect(&paths.socket_path).await?;
let mut codec = LengthDelimitedCodec::new();
codec.set_max_frame_length(MAX_FRAME_BYTES);
Ok((stream, codec))
}
async fn handshake(&mut self) -> Result<(), CommsClientError> {
let resp = self
.send_and_await(CommsRequest::Hello {
agent: self.agent.clone(),
proto_ver: PROTO_VER,
remote: self.remote.clone(),
cwd: self.cwd.clone(),
})
.await?;
match resp {
CommsResponse::Welcome { proto_ver, .. } if proto_ver == PROTO_VER => Ok(()),
CommsResponse::Welcome { proto_ver, .. } => Err(CommsClientError::ProtoSkew {
daemon: proto_ver,
client: PROTO_VER,
}),
CommsResponse::Error { code, message } => {
Err(CommsClientError::Broker { code, message })
}
_ => Err(CommsClientError::Unexpected { request: "hello" }),
}
}
async fn reconnect(&mut self) -> Result<(), CommsClientError> {
let spawn = &self.spawn;
singleton::ensure_daemon_with(&self.paths, singleton::probe_alive, |paths| spawn(paths))
.await?;
let (stream, codec) = Self::dial(&self.paths).await?;
self.stream = stream;
self.codec = codec;
self.read_buf.clear();
self.pending_notifications.clear();
self.handshake().await
}
pub fn agent(&self) -> &AgentId {
&self.agent
}
pub async fn register_agent(&mut self, card: AgentCard) -> Result<(), CommsClientError> {
self.expect_ok(CommsRequest::Register { card }, "register")
.await
}
pub async fn list_agents(
&mut self,
room: Option<RoomId>,
) -> Result<Vec<AgentRecord>, CommsClientError> {
match self.request(CommsRequest::ListAgents { room }).await? {
CommsResponse::Agents(a) => Ok(a),
other => Err(self.shape_err(other, "list_agents")),
}
}
pub async fn create_room(
&mut self,
room: RoomId,
scope: RoomScope,
title: Option<String>,
) -> Result<Room, CommsClientError> {
match self
.request(CommsRequest::CreateRoom { room, scope, title })
.await?
{
CommsResponse::Room(r) => Ok(r),
other => Err(self.shape_err(other, "create_room")),
}
}
pub async fn list_rooms(
&mut self,
remote: Option<String>,
cwd: Option<PathBuf>,
) -> Result<Vec<Room>, CommsClientError> {
match self
.request(CommsRequest::ListRooms { remote, cwd })
.await?
{
CommsResponse::Rooms(r) => Ok(r),
other => Err(self.shape_err(other, "list_rooms")),
}
}
pub async fn join_room(&mut self, room: RoomId) -> Result<(), CommsClientError> {
self.expect_ok(CommsRequest::Join { room }, "join_room")
.await
}
pub async fn leave_room(&mut self, room: RoomId) -> Result<(), CommsClientError> {
self.expect_ok(CommsRequest::Leave { room }, "leave_room")
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn post_message(
&mut self,
room: RoomId,
subject: String,
body: Vec<u8>,
tags: Vec<String>,
reply_to: Option<String>,
scope: Vec<String>,
) -> Result<String, CommsClientError> {
match self
.request(CommsRequest::Post {
room,
subject,
tags,
reply_to,
scope,
body,
})
.await?
{
CommsResponse::Posted { message_id } => Ok(message_id),
other => Err(self.shape_err(other, "post_message")),
}
}
pub async fn ack_inbox(
&mut self,
message_ids: Vec<String>,
room: Option<RoomId>,
to_seq: Option<u64>,
) -> Result<(u32, Vec<(String, u64)>), CommsClientError> {
match self
.request(CommsRequest::AckInbox {
message_ids,
room,
to_seq,
})
.await?
{
CommsResponse::Acked {
acked,
cursors_advanced,
} => Ok((acked, cursors_advanced)),
other => Err(self.shape_err(other, "ack_inbox")),
}
}
pub async fn read_history(
&mut self,
room: RoomId,
cursor: Option<Cursor>,
limit: u32,
) -> Result<(Vec<SeqMeta>, Option<Cursor>), CommsClientError> {
match self
.request(CommsRequest::History {
room,
cursor,
limit: Some(limit),
})
.await?
{
CommsResponse::History {
messages,
next_cursor,
} => Ok((messages, next_cursor)),
other => Err(self.shape_err(other, "read_history")),
}
}
pub async fn get_body(
&mut self,
message_id: String,
) -> Result<Option<Vec<u8>>, CommsClientError> {
match self.request(CommsRequest::GetBody { message_id }).await? {
CommsResponse::Body { body } => Ok(body),
other => Err(self.shape_err(other, "get_body")),
}
}
#[allow(clippy::type_complexity)]
pub async fn read_inbox(
&mut self,
remote: Option<String>,
cwd: Option<PathBuf>,
cursor: Option<Cursor>,
limit: u32,
mark_read: bool,
) -> Result<(Vec<SeqMeta>, u32, Option<Cursor>), CommsClientError> {
match self
.request(CommsRequest::Inbox {
remote,
cwd,
cursor,
limit: Some(limit),
mark_read,
})
.await?
{
CommsResponse::Inbox {
messages,
unread,
next_cursor,
} => Ok((messages, unread, next_cursor)),
other => Err(self.shape_err(other, "read_inbox")),
}
}
pub async fn subscribe(&mut self, room: RoomId) -> Result<u64, CommsClientError> {
match self.request(CommsRequest::Subscribe { room }).await? {
CommsResponse::Subscribed { sub } => Ok(sub),
other => Err(self.shape_err(other, "subscribe")),
}
}
pub async fn unsubscribe(&mut self, sub: u64) -> Result<(), CommsClientError> {
self.expect_ok(CommsRequest::Unsubscribe { sub }, "unsubscribe")
.await
}
pub async fn status(&mut self) -> Result<StatusReport, CommsClientError> {
match self.request(CommsRequest::Status).await? {
CommsResponse::Status(s) => Ok(s),
other => Err(self.shape_err(other, "status")),
}
}
pub async fn stop(&mut self) -> Result<(), CommsClientError> {
self.expect_ok(CommsRequest::Stop, "stop").await
}
pub fn next_notification(&mut self) -> Option<CommsNotification> {
self.pending_notifications.pop_front()
}
pub async fn poll_notification(
&mut self,
) -> Result<Option<CommsNotification>, CommsClientError> {
if let Some(n) = self.pending_notifications.pop_front() {
return Ok(Some(n));
}
loop {
match self.read_frame().await? {
Some(CommsOut::Notification(n)) => return Ok(Some(n)),
Some(CommsOut::Response(_)) => continue, None => return Ok(None),
}
}
}
async fn expect_ok(
&mut self,
req: CommsRequest,
label: &'static str,
) -> Result<(), CommsClientError> {
match self.request(req).await? {
CommsResponse::Ok => Ok(()),
other => Err(self.shape_err(other, label)),
}
}
fn shape_err(&self, resp: CommsResponse, request: &'static str) -> CommsClientError {
match resp {
CommsResponse::Error { code, message } => CommsClientError::Broker { code, message },
_ => CommsClientError::Unexpected { request },
}
}
async fn request(&mut self, req: CommsRequest) -> Result<CommsResponse, CommsClientError> {
match self.send_and_await(req.clone()).await {
Ok(resp) => Ok(resp),
Err(err) if is_connection_lost(&err) => {
self.reconnect().await?;
self.send_and_await(req).await
}
Err(err) => Err(err),
}
}
async fn send_and_await(
&mut self,
req: CommsRequest,
) -> Result<CommsResponse, CommsClientError> {
self.write_request(&req).await?;
loop {
match self.read_frame().await? {
Some(CommsOut::Response(resp)) => return Ok(resp),
Some(CommsOut::Notification(n)) => self.pending_notifications.push_back(n),
None => return Err(CommsClientError::Closed),
}
}
}
async fn write_request(&mut self, req: &CommsRequest) -> Result<(), CommsClientError> {
let body = rmp_serde::to_vec_named(req)?;
let mut framed = BytesMut::new();
self.codec.encode(Bytes::from(body), &mut framed)?;
self.stream.write_all(&framed).await?;
self.stream.flush().await?;
Ok(())
}
async fn read_frame(&mut self) -> Result<Option<CommsOut>, CommsClientError> {
loop {
if let Some(frame) = self.codec.decode(&mut self.read_buf)? {
let out: CommsOut = rmp_serde::from_slice(&frame)?;
return Ok(Some(out));
}
let n = self.stream.read_buf(&mut self.read_buf).await?;
if n == 0 {
if self.read_buf.is_empty() {
return Ok(None);
}
return Err(CommsClientError::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"broker closed mid-frame",
)));
}
}
}
}
fn is_connection_lost(err: &CommsClientError) -> bool {
match err {
CommsClientError::Closed => true,
CommsClientError::Io(io) => matches!(
io.kind(),
std::io::ErrorKind::BrokenPipe
| std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::NotConnected
| std::io::ErrorKind::UnexpectedEof
),
_ => false,
}
}
pub fn scope_context_for(cwd: &Path) -> (Option<String>, Option<PathBuf>) {
let repo = crate::git::Repo::discover(cwd).ok();
let remote = repo.as_ref().and_then(|r| {
let key = crate::git::scope_key(r);
if key.starts_with("path:") {
None
} else {
Some(key)
}
});
(remote, Some(cwd.to_path_buf()))
}