use std::path::{Path, PathBuf};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[cfg(unix)]
use tokio::net::UnixStream as PlatformStream;
#[cfg(windows)]
use tokio::net::windows::named_pipe::NamedPipeClient as PlatformStream;
use tokio_util::bytes::{Bytes, BytesMut};
use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
use super::cursor::Cursor;
use super::ids::{AgentId, RoomId};
use super::model::{AgentCard, AgentRecord, Room, RoomScope};
use super::protocol::{
CommsNotification, CommsOut, CommsRequest, CommsResponse, PROTO_VER, SeqMeta, StatusReport,
};
use super::singleton::{self, CommsPaths};
use super::transport::MAX_FRAME_BYTES;
const READ_CHUNK: usize = 8 * 1024;
#[cfg(windows)]
const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
type SpawnFn = Box<dyn Fn(&CommsPaths) -> std::io::Result<()> + Send + Sync>;
#[derive(Debug, thiserror::Error)]
pub enum CommsClientError {
#[error("comms transport error: {0}")]
Io(#[from] std::io::Error),
#[error("encode error: {0}")]
Encode(#[from] rmp_serde::encode::Error),
#[error("decode error: {0}")]
Decode(#[from] rmp_serde::decode::Error),
#[error(transparent)]
Singleton(#[from] super::singleton::SingletonError),
#[error("connection closed before a response was received")]
Closed,
#[error("broker error [{code}]: {message}")]
Broker {
code: String,
message: String,
},
#[error("unexpected response shape for {request}")]
Unexpected {
request: &'static str,
},
#[error("protocol skew: daemon speaks {daemon}, client speaks {client}")]
ProtoSkew {
daemon: u32,
client: u32,
},
}
pub struct CommsClient {
stream: PlatformStream,
codec: LengthDelimitedCodec,
read_buf: BytesMut,
agent: AgentId,
pending_notifications: std::collections::VecDeque<CommsNotification>,
paths: CommsPaths,
remote: Option<String>,
cwd: Option<PathBuf>,
session: SessionContext,
spawn: SpawnFn,
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct SessionContext {
pub session_id: Option<String>,
pub parent_agent: Option<String>,
}
pub const SESSION_ID_ENV: &str = "BASEMIND_SESSION_ID";
pub const PARENT_AGENT_ENV: &str = "BASEMIND_PARENT_AGENT_ID";
impl SessionContext {
#[must_use]
pub fn from_env() -> Self {
Self {
session_id: non_empty_env(SESSION_ID_ENV),
parent_agent: non_empty_env(PARENT_AGENT_ENV),
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.session_id.is_none() && self.parent_agent.is_none()
}
}
fn non_empty_env(key: &str) -> Option<String> {
match std::env::var(key) {
Ok(value) if !value.is_empty() => Some(value),
_ => None,
}
}
impl CommsClient {
pub async fn connect(
paths: &CommsPaths,
agent: AgentId,
remote: Option<String>,
cwd: Option<PathBuf>,
) -> Result<Self, CommsClientError> {
Self::connect_with_respawn(paths, agent, remote, cwd, |paths| {
singleton::spawn_detached_daemon(paths)
})
.await
}
pub async fn connect_with_respawn(
paths: &CommsPaths,
agent: AgentId,
remote: Option<String>,
cwd: Option<PathBuf>,
spawn: impl Fn(&CommsPaths) -> std::io::Result<()> + Send + Sync + 'static,
) -> Result<Self, CommsClientError> {
let (stream, codec) = Self::dial(paths).await?;
let mut client = Self {
stream,
codec,
read_buf: BytesMut::with_capacity(READ_CHUNK),
agent,
pending_notifications: std::collections::VecDeque::new(),
paths: paths.clone(),
remote,
cwd,
session: SessionContext::default(),
spawn: Box::new(spawn),
};
client.handshake().await?;
Ok(client)
}
pub async fn connect_with_session(
paths: &CommsPaths,
agent: AgentId,
remote: Option<String>,
cwd: Option<PathBuf>,
session: SessionContext,
) -> Result<Self, CommsClientError> {
let (stream, codec) = Self::dial(paths).await?;
let mut client = Self {
stream,
codec,
read_buf: BytesMut::with_capacity(READ_CHUNK),
agent,
pending_notifications: std::collections::VecDeque::new(),
paths: paths.clone(),
remote,
cwd,
session,
spawn: Box::new(singleton::spawn_detached_daemon),
};
client.handshake().await?;
Ok(client)
}
pub async fn ensure_and_connect(
agent: AgentId,
remote: Option<String>,
cwd: Option<PathBuf>,
) -> Result<Self, CommsClientError> {
Self::ensure_and_connect_with_session(agent, remote, cwd, SessionContext::from_env()).await
}
pub async fn ensure_and_connect_with_session(
agent: AgentId,
remote: Option<String>,
cwd: Option<PathBuf>,
session: SessionContext,
) -> Result<Self, CommsClientError> {
let paths = singleton::resolve_paths()?;
singleton::ensure_daemon(&paths).await?;
Self::connect_with_session(&paths, agent, remote, cwd, session).await
}
async fn dial(
paths: &CommsPaths,
) -> Result<(PlatformStream, LengthDelimitedCodec), CommsClientError> {
let stream = Self::connect_stream(&paths.socket_path).await?;
let mut codec = LengthDelimitedCodec::new();
codec.set_max_frame_length(MAX_FRAME_BYTES);
Ok((stream, codec))
}
#[cfg(unix)]
async fn connect_stream(socket_path: &Path) -> Result<PlatformStream, CommsClientError> {
PlatformStream::connect(socket_path)
.await
.map_err(|source| daemon_unreachable_error(socket_path, source))
}
#[cfg(windows)]
async fn connect_stream(socket_path: &Path) -> Result<PlatformStream, CommsClientError> {
use tokio::net::windows::named_pipe::ClientOptions;
const ERROR_PIPE_BUSY: i32 = 231;
const RETRY_INTERVAL: std::time::Duration = std::time::Duration::from_millis(50);
let deadline = std::time::Instant::now() + CONNECT_TIMEOUT;
loop {
match ClientOptions::new().open(socket_path) {
Ok(client) => return Ok(client),
Err(e) if e.raw_os_error() == Some(ERROR_PIPE_BUSY) => {
if std::time::Instant::now() >= deadline {
return Err(daemon_unreachable_error(socket_path, e));
}
tokio::time::sleep(RETRY_INTERVAL).await;
}
Err(source) => return Err(daemon_unreachable_error(socket_path, source)),
}
}
}
async fn handshake(&mut self) -> Result<(), CommsClientError> {
let resp = self
.send_and_await(CommsRequest::Hello {
agent: self.agent.clone(),
proto_ver: PROTO_VER,
remote: self.remote.clone(),
cwd: self.cwd.clone(),
session_id: self.session.session_id.clone(),
parent_agent: self.session.parent_agent.clone(),
})
.await?;
match resp {
CommsResponse::Welcome { proto_ver, .. } if proto_ver == PROTO_VER => Ok(()),
CommsResponse::Welcome { proto_ver, .. } => Err(CommsClientError::ProtoSkew {
daemon: proto_ver,
client: PROTO_VER,
}),
CommsResponse::Error { code, message } => {
Err(CommsClientError::Broker { code, message })
}
_ => Err(CommsClientError::Unexpected { request: "hello" }),
}
}
async fn reconnect(&mut self) -> Result<(), CommsClientError> {
let spawn = &self.spawn;
singleton::ensure_daemon_with(&self.paths, singleton::probe_alive, |paths| spawn(paths))
.await?;
let (stream, codec) = Self::dial(&self.paths).await?;
self.stream = stream;
self.codec = codec;
self.read_buf.clear();
self.pending_notifications.clear();
self.handshake().await
}
pub fn agent(&self) -> &AgentId {
&self.agent
}
pub async fn register_agent(&mut self, card: AgentCard) -> Result<(), CommsClientError> {
self.expect_ok(CommsRequest::Register { card }, "register")
.await
}
pub async fn list_agents(
&mut self,
room: Option<RoomId>,
) -> Result<Vec<AgentRecord>, CommsClientError> {
match self.request(CommsRequest::ListAgents { room }).await? {
CommsResponse::Agents(a) => Ok(a),
other => Err(self.shape_err(other, "list_agents")),
}
}
pub async fn create_room(
&mut self,
room: RoomId,
scope: RoomScope,
title: Option<String>,
) -> Result<Room, CommsClientError> {
match self
.request(CommsRequest::CreateRoom { room, scope, title })
.await?
{
CommsResponse::Room(r) => Ok(r),
other => Err(self.shape_err(other, "create_room")),
}
}
pub async fn list_rooms(
&mut self,
remote: Option<String>,
cwd: Option<PathBuf>,
) -> Result<Vec<Room>, CommsClientError> {
match self
.request(CommsRequest::ListRooms { remote, cwd })
.await?
{
CommsResponse::Rooms(r) => Ok(r),
other => Err(self.shape_err(other, "list_rooms")),
}
}
pub async fn join_room(&mut self, room: RoomId) -> Result<(), CommsClientError> {
self.expect_ok(CommsRequest::Join { room }, "join_room")
.await
}
pub async fn leave_room(&mut self, room: RoomId) -> Result<(), CommsClientError> {
self.expect_ok(CommsRequest::Leave { room }, "leave_room")
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn post_message(
&mut self,
room: RoomId,
subject: String,
body: Vec<u8>,
tags: Vec<String>,
reply_to: Option<String>,
scope: Vec<String>,
) -> Result<String, CommsClientError> {
match self
.request(CommsRequest::Post {
room,
subject,
tags,
reply_to,
scope,
body,
})
.await?
{
CommsResponse::Posted { message_id } => Ok(message_id),
other => Err(self.shape_err(other, "post_message")),
}
}
pub async fn ack_inbox(
&mut self,
message_ids: Vec<String>,
room: Option<RoomId>,
to_seq: Option<u64>,
) -> Result<(u32, Vec<(String, u64)>), CommsClientError> {
match self
.request(CommsRequest::AckInbox {
message_ids,
room,
to_seq,
})
.await?
{
CommsResponse::Acked {
acked,
cursors_advanced,
} => Ok((acked, cursors_advanced)),
other => Err(self.shape_err(other, "ack_inbox")),
}
}
pub async fn read_history(
&mut self,
room: RoomId,
cursor: Option<Cursor>,
limit: u32,
since_micros: Option<i64>,
) -> Result<(Vec<SeqMeta>, Option<Cursor>), CommsClientError> {
match self
.request(CommsRequest::History {
room,
cursor,
limit: Some(limit),
since_micros,
})
.await?
{
CommsResponse::History {
messages,
next_cursor,
} => Ok((messages, next_cursor)),
other => Err(self.shape_err(other, "read_history")),
}
}
pub async fn get_body(
&mut self,
message_id: String,
) -> Result<Option<Vec<u8>>, CommsClientError> {
match self.request(CommsRequest::GetBody { message_id }).await? {
CommsResponse::Body { body } => Ok(body),
other => Err(self.shape_err(other, "get_body")),
}
}
#[allow(clippy::type_complexity, clippy::too_many_arguments)]
pub async fn read_inbox(
&mut self,
remote: Option<String>,
cwd: Option<PathBuf>,
cursor: Option<Cursor>,
limit: u32,
mark_read: bool,
since_micros: Option<i64>,
) -> Result<(Vec<SeqMeta>, u32, Option<Cursor>), CommsClientError> {
match self
.request(CommsRequest::Inbox {
remote,
cwd,
cursor,
limit: Some(limit),
mark_read,
since_micros,
})
.await?
{
CommsResponse::Inbox {
messages,
unread,
next_cursor,
} => Ok((messages, unread, next_cursor)),
other => Err(self.shape_err(other, "read_inbox")),
}
}
pub async fn subscribe(&mut self, room: RoomId) -> Result<u64, CommsClientError> {
match self.request(CommsRequest::Subscribe { room }).await? {
CommsResponse::Subscribed { sub } => Ok(sub),
other => Err(self.shape_err(other, "subscribe")),
}
}
pub async fn unsubscribe(&mut self, sub: u64) -> Result<(), CommsClientError> {
self.expect_ok(CommsRequest::Unsubscribe { sub }, "unsubscribe")
.await
}
pub async fn status(&mut self) -> Result<StatusReport, CommsClientError> {
match self.request(CommsRequest::Status).await? {
CommsResponse::Status(s) => Ok(s),
other => Err(self.shape_err(other, "status")),
}
}
pub async fn list_sessions(
&mut self,
) -> Result<Vec<crate::comms::model::SessionLineage>, CommsClientError> {
match self.request(CommsRequest::ListSessions {}).await? {
CommsResponse::Sessions { sessions } => Ok(sessions),
other => Err(self.shape_err(other, "list_sessions")),
}
}
pub async fn delete_session(&mut self, session_id: &str) -> Result<(), CommsClientError> {
self.expect_ok(
CommsRequest::DeleteSession {
session_id: session_id.to_string(),
},
"delete_session",
)
.await
}
pub async fn stop(&mut self) -> Result<(), CommsClientError> {
self.expect_ok(CommsRequest::Stop, "stop").await
}
pub fn next_notification(&mut self) -> Option<CommsNotification> {
self.pending_notifications.pop_front()
}
pub async fn poll_notification(
&mut self,
) -> Result<Option<CommsNotification>, CommsClientError> {
if let Some(n) = self.pending_notifications.pop_front() {
return Ok(Some(n));
}
loop {
match self.read_frame().await? {
Some(CommsOut::Notification(n)) => return Ok(Some(n)),
Some(CommsOut::Response(_)) => continue, None => return Ok(None),
}
}
}
async fn expect_ok(
&mut self,
req: CommsRequest,
label: &'static str,
) -> Result<(), CommsClientError> {
match self.request(req).await? {
CommsResponse::Ok => Ok(()),
other => Err(self.shape_err(other, label)),
}
}
fn shape_err(&self, resp: CommsResponse, request: &'static str) -> CommsClientError {
match resp {
CommsResponse::Error { code, message } => CommsClientError::Broker { code, message },
_ => CommsClientError::Unexpected { request },
}
}
async fn request(&mut self, req: CommsRequest) -> Result<CommsResponse, CommsClientError> {
match self.send_and_await(req.clone()).await {
Ok(resp) => Ok(resp),
Err(err) if is_connection_lost(&err) => {
self.reconnect().await?;
self.send_and_await(req).await
}
Err(err) => Err(err),
}
}
async fn send_and_await(
&mut self,
req: CommsRequest,
) -> Result<CommsResponse, CommsClientError> {
self.write_request(&req).await?;
loop {
match self.read_frame().await? {
Some(CommsOut::Response(resp)) => return Ok(resp),
Some(CommsOut::Notification(n)) => self.pending_notifications.push_back(n),
None => return Err(CommsClientError::Closed),
}
}
}
async fn write_request(&mut self, req: &CommsRequest) -> Result<(), CommsClientError> {
let body = rmp_serde::to_vec_named(req)?;
let mut framed = BytesMut::new();
self.codec.encode(Bytes::from(body), &mut framed)?;
self.stream.write_all(&framed).await?;
self.stream.flush().await?;
Ok(())
}
async fn read_frame(&mut self) -> Result<Option<CommsOut>, CommsClientError> {
loop {
if let Some(frame) = self.codec.decode(&mut self.read_buf)? {
let out: CommsOut = rmp_serde::from_slice(&frame)?;
return Ok(Some(out));
}
let n = self.stream.read_buf(&mut self.read_buf).await?;
if n == 0 {
if self.read_buf.is_empty() {
return Ok(None);
}
return Err(CommsClientError::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"broker closed mid-frame",
)));
}
}
}
}
fn is_connection_lost(err: &CommsClientError) -> bool {
match err {
CommsClientError::Closed => true,
CommsClientError::Io(io) => matches!(
io.kind(),
std::io::ErrorKind::BrokenPipe
| std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::NotConnected
| std::io::ErrorKind::UnexpectedEof
),
_ => false,
}
}
#[cfg(all(test, feature = "comms", unix))]
mod tests {
use super::*;
#[tokio::test]
async fn dial_missing_socket_reports_daemon_not_running_with_start_hint() {
let dir = std::env::temp_dir().join(format!("basemind-comms-test-{}", std::process::id()));
let paths = CommsPaths {
comms_dir: dir.clone(),
socket_path: dir.join("definitely-absent.sock"),
};
let err = match CommsClient::dial(&paths).await {
Ok(_) => panic!("dialing an absent socket must fail"),
Err(err) => err,
};
let msg = err.to_string();
assert!(
msg.contains("comms daemon is not running"),
"error should name that the daemon is not running, got: {msg}"
);
assert!(
msg.contains("basemind comms start"),
"error should name the start command, got: {msg}"
);
assert!(
!msg.starts_with("comms transport error: No such file or directory"),
"error must not be the bare OS string, got: {msg}"
);
}
#[test]
fn session_context_explicit_seam_carries_lineage() {
let top = SessionContext::default();
assert!(top.is_empty(), "a default context is a top-level agent");
let child = SessionContext {
session_id: Some("bmsh-1-0".to_string()),
parent_agent: Some("parent".to_string()),
};
assert!(!child.is_empty(), "a session child is not empty");
assert_eq!(child.session_id.as_deref(), Some("bmsh-1-0"));
assert_eq!(child.parent_agent.as_deref(), Some("parent"));
}
#[test]
fn session_context_from_env_reads_the_boundary_variables() {
static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
let _guard = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
let prior_session = std::env::var(SESSION_ID_ENV).ok();
let prior_parent = std::env::var(PARENT_AGENT_ENV).ok();
unsafe {
std::env::set_var(SESSION_ID_ENV, "bmsh-9-3");
std::env::set_var(PARENT_AGENT_ENV, "lead-agent");
}
let present = SessionContext::from_env();
assert_eq!(present.session_id.as_deref(), Some("bmsh-9-3"));
assert_eq!(present.parent_agent.as_deref(), Some("lead-agent"));
unsafe {
std::env::set_var(SESSION_ID_ENV, "");
std::env::remove_var(PARENT_AGENT_ENV);
}
assert!(
SessionContext::from_env().is_empty(),
"empty / unset env maps to a top-level (empty) context"
);
unsafe {
match prior_session {
Some(v) => std::env::set_var(SESSION_ID_ENV, v),
None => std::env::remove_var(SESSION_ID_ENV),
}
match prior_parent {
Some(v) => std::env::set_var(PARENT_AGENT_ENV, v),
None => std::env::remove_var(PARENT_AGENT_ENV),
}
}
}
}
fn daemon_unreachable_error(socket_path: &Path, source: std::io::Error) -> CommsClientError {
match source.kind() {
std::io::ErrorKind::NotFound | std::io::ErrorKind::ConnectionRefused => {
CommsClientError::Io(std::io::Error::new(
source.kind(),
format!(
"comms daemon is not running (no socket at {}); start it with \
`basemind comms start`",
socket_path.display()
),
))
}
_ => CommsClientError::Io(source),
}
}
pub fn scope_context_for(cwd: &Path) -> (Option<String>, Option<PathBuf>) {
let repo = crate::git::Repo::discover(cwd).ok();
let remote = repo.as_ref().and_then(|r| {
let key = crate::git::scope_key(r);
if key.starts_with("path:") {
None
} else {
Some(key)
}
});
(remote, Some(cwd.to_path_buf()))
}