use std::path::{Path, PathBuf};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::UnixStream;
use tokio_util::bytes::{Bytes, BytesMut};
use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
use super::cursor::Cursor;
use super::ids::{AgentId, RoomId};
use super::model::{AgentCard, AgentRecord, MessageMeta, Room, RoomScope};
use super::protocol::{
CommsNotification, CommsOut, CommsRequest, CommsResponse, PROTO_VER, StatusReport,
};
use super::singleton::{self, CommsPaths};
use super::transport::MAX_FRAME_BYTES;
const READ_CHUNK: usize = 8 * 1024;
#[derive(Debug, thiserror::Error)]
pub enum CommsClientError {
#[error("comms transport error: {0}")]
Io(#[from] std::io::Error),
#[error("encode error: {0}")]
Encode(#[from] rmp_serde::encode::Error),
#[error("decode error: {0}")]
Decode(#[from] rmp_serde::decode::Error),
#[error(transparent)]
Singleton(#[from] super::singleton::SingletonError),
#[error("connection closed before a response was received")]
Closed,
#[error("broker error [{code}]: {message}")]
Broker {
code: String,
message: String,
},
#[error("unexpected response shape for {request}")]
Unexpected {
request: &'static str,
},
#[error("protocol skew: daemon speaks {daemon}, client speaks {client}")]
ProtoSkew {
daemon: u32,
client: u32,
},
}
pub struct CommsClient {
stream: UnixStream,
codec: LengthDelimitedCodec,
read_buf: BytesMut,
agent: AgentId,
pending_notifications: std::collections::VecDeque<CommsNotification>,
}
impl CommsClient {
pub async fn connect(
paths: &CommsPaths,
agent: AgentId,
remote: Option<String>,
cwd: Option<PathBuf>,
) -> Result<Self, CommsClientError> {
let stream = UnixStream::connect(&paths.socket_path).await?;
let mut codec = LengthDelimitedCodec::new();
codec.set_max_frame_length(MAX_FRAME_BYTES);
let mut client = Self {
stream,
codec,
read_buf: BytesMut::with_capacity(READ_CHUNK),
agent: agent.clone(),
pending_notifications: std::collections::VecDeque::new(),
};
let resp = client
.request(CommsRequest::Hello {
agent,
proto_ver: PROTO_VER,
remote,
cwd,
})
.await?;
match resp {
CommsResponse::Welcome { proto_ver, .. } if proto_ver == PROTO_VER => Ok(client),
CommsResponse::Welcome { proto_ver, .. } => Err(CommsClientError::ProtoSkew {
daemon: proto_ver,
client: PROTO_VER,
}),
CommsResponse::Error { code, message } => {
Err(CommsClientError::Broker { code, message })
}
_ => Err(CommsClientError::Unexpected { request: "hello" }),
}
}
pub async fn ensure_and_connect(
agent: AgentId,
remote: Option<String>,
cwd: Option<PathBuf>,
) -> Result<Self, CommsClientError> {
let paths = singleton::resolve_paths()?;
singleton::ensure_daemon(&paths).await?;
Self::connect(&paths, agent, remote, cwd).await
}
pub fn agent(&self) -> &AgentId {
&self.agent
}
pub async fn register_agent(&mut self, card: AgentCard) -> Result<(), CommsClientError> {
self.expect_ok(CommsRequest::Register { card }, "register")
.await
}
pub async fn list_agents(
&mut self,
room: Option<RoomId>,
) -> Result<Vec<AgentRecord>, CommsClientError> {
match self.request(CommsRequest::ListAgents { room }).await? {
CommsResponse::Agents(a) => Ok(a),
other => Err(self.shape_err(other, "list_agents")),
}
}
pub async fn create_room(
&mut self,
room: RoomId,
scope: RoomScope,
title: Option<String>,
) -> Result<Room, CommsClientError> {
match self
.request(CommsRequest::CreateRoom { room, scope, title })
.await?
{
CommsResponse::Room(r) => Ok(r),
other => Err(self.shape_err(other, "create_room")),
}
}
pub async fn list_rooms(
&mut self,
remote: Option<String>,
cwd: Option<PathBuf>,
) -> Result<Vec<Room>, CommsClientError> {
match self
.request(CommsRequest::ListRooms { remote, cwd })
.await?
{
CommsResponse::Rooms(r) => Ok(r),
other => Err(self.shape_err(other, "list_rooms")),
}
}
pub async fn join_room(&mut self, room: RoomId) -> Result<(), CommsClientError> {
self.expect_ok(CommsRequest::Join { room }, "join_room")
.await
}
pub async fn leave_room(&mut self, room: RoomId) -> Result<(), CommsClientError> {
self.expect_ok(CommsRequest::Leave { room }, "leave_room")
.await
}
pub async fn post_message(
&mut self,
room: RoomId,
subject: String,
body: Vec<u8>,
tags: Vec<String>,
reply_to: Option<String>,
) -> Result<String, CommsClientError> {
match self
.request(CommsRequest::Post {
room,
subject,
tags,
reply_to,
body,
})
.await?
{
CommsResponse::Posted { message_id } => Ok(message_id),
other => Err(self.shape_err(other, "post_message")),
}
}
pub async fn read_history(
&mut self,
room: RoomId,
cursor: Option<Cursor>,
limit: u32,
) -> Result<(Vec<MessageMeta>, Option<Cursor>), CommsClientError> {
match self
.request(CommsRequest::History {
room,
cursor,
limit: Some(limit),
})
.await?
{
CommsResponse::History {
messages,
next_cursor,
} => Ok((messages, next_cursor)),
other => Err(self.shape_err(other, "read_history")),
}
}
pub async fn get_body(
&mut self,
message_id: String,
) -> Result<Option<Vec<u8>>, CommsClientError> {
match self.request(CommsRequest::GetBody { message_id }).await? {
CommsResponse::Body { body } => Ok(body),
other => Err(self.shape_err(other, "get_body")),
}
}
#[allow(clippy::type_complexity)]
pub async fn read_inbox(
&mut self,
remote: Option<String>,
cwd: Option<PathBuf>,
cursor: Option<Cursor>,
limit: u32,
mark_read: bool,
) -> Result<(Vec<MessageMeta>, u32, Option<Cursor>), CommsClientError> {
match self
.request(CommsRequest::Inbox {
remote,
cwd,
cursor,
limit: Some(limit),
mark_read,
})
.await?
{
CommsResponse::Inbox {
messages,
unread,
next_cursor,
} => Ok((messages, unread, next_cursor)),
other => Err(self.shape_err(other, "read_inbox")),
}
}
pub async fn subscribe(&mut self, room: RoomId) -> Result<u64, CommsClientError> {
match self.request(CommsRequest::Subscribe { room }).await? {
CommsResponse::Subscribed { sub } => Ok(sub),
other => Err(self.shape_err(other, "subscribe")),
}
}
pub async fn unsubscribe(&mut self, sub: u64) -> Result<(), CommsClientError> {
self.expect_ok(CommsRequest::Unsubscribe { sub }, "unsubscribe")
.await
}
pub async fn status(&mut self) -> Result<StatusReport, CommsClientError> {
match self.request(CommsRequest::Status).await? {
CommsResponse::Status(s) => Ok(s),
other => Err(self.shape_err(other, "status")),
}
}
pub async fn stop(&mut self) -> Result<(), CommsClientError> {
self.expect_ok(CommsRequest::Stop, "stop").await
}
pub fn next_notification(&mut self) -> Option<CommsNotification> {
self.pending_notifications.pop_front()
}
pub async fn poll_notification(
&mut self,
) -> Result<Option<CommsNotification>, CommsClientError> {
if let Some(n) = self.pending_notifications.pop_front() {
return Ok(Some(n));
}
loop {
match self.read_frame().await? {
Some(CommsOut::Notification(n)) => return Ok(Some(n)),
Some(CommsOut::Response(_)) => continue, None => return Ok(None),
}
}
}
async fn expect_ok(
&mut self,
req: CommsRequest,
label: &'static str,
) -> Result<(), CommsClientError> {
match self.request(req).await? {
CommsResponse::Ok => Ok(()),
other => Err(self.shape_err(other, label)),
}
}
fn shape_err(&self, resp: CommsResponse, request: &'static str) -> CommsClientError {
match resp {
CommsResponse::Error { code, message } => CommsClientError::Broker { code, message },
_ => CommsClientError::Unexpected { request },
}
}
async fn request(&mut self, req: CommsRequest) -> Result<CommsResponse, CommsClientError> {
self.write_request(&req).await?;
loop {
match self.read_frame().await? {
Some(CommsOut::Response(resp)) => return Ok(resp),
Some(CommsOut::Notification(n)) => self.pending_notifications.push_back(n),
None => return Err(CommsClientError::Closed),
}
}
}
async fn write_request(&mut self, req: &CommsRequest) -> Result<(), CommsClientError> {
let body = rmp_serde::to_vec_named(req)?;
let mut framed = BytesMut::new();
self.codec.encode(Bytes::from(body), &mut framed)?;
self.stream.write_all(&framed).await?;
self.stream.flush().await?;
Ok(())
}
async fn read_frame(&mut self) -> Result<Option<CommsOut>, CommsClientError> {
loop {
if let Some(frame) = self.codec.decode(&mut self.read_buf)? {
let out: CommsOut = rmp_serde::from_slice(&frame)?;
return Ok(Some(out));
}
let n = self.stream.read_buf(&mut self.read_buf).await?;
if n == 0 {
if self.read_buf.is_empty() {
return Ok(None);
}
return Err(CommsClientError::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"broker closed mid-frame",
)));
}
}
}
}
pub fn scope_context_for(cwd: &Path) -> (Option<String>, Option<PathBuf>) {
let repo = crate::git::Repo::discover(cwd).ok();
let remote = repo.as_ref().and_then(|r| {
let key = crate::git::scope_key(r);
if key.starts_with("path:") {
None
} else {
Some(key)
}
});
(remote, Some(cwd.to_path_buf()))
}