use std::collections::HashSet;
use std::io::ErrorKind;
use std::path::PathBuf;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::error::{atomic_write, Result, StoreError};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct MsgId(pub String);
impl MsgId {
pub fn new() -> Self {
MsgId(uuid::Uuid::new_v4().to_string())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl Default for MsgId {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AgentRef {
pub session_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum InboxKind {
Task,
Ask,
Handoff,
Reply,
McpRequest,
McpReply,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AskMode {
#[default]
Query,
Steer,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AskBody {
pub question: String,
#[serde(default)]
pub mode: AskMode,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ReplyBody {
pub answer: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct InboxMessage {
pub id: MsgId,
pub from: AgentRef,
pub kind: InboxKind,
pub body: serde_json::Value,
pub created_at: DateTime<Utc>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub correlation_id: Option<MsgId>,
}
#[derive(Debug, Clone)]
pub struct Delivered {
pub msg: InboxMessage,
pub cur_path: PathBuf,
}
pub struct Mailbox {
dir: PathBuf,
}
impl Mailbox {
pub fn at(dir: impl Into<PathBuf>) -> Self {
Self { dir: dir.into() }
}
fn new_dir(&self) -> PathBuf {
self.dir.join("new")
}
fn cur_dir(&self) -> PathBuf {
self.dir.join("cur")
}
fn corrupt_dir(&self) -> PathBuf {
self.dir.join("corrupt")
}
pub async fn ensure_dirs(&self) -> Result<()> {
for d in [self.new_dir(), self.cur_dir(), self.corrupt_dir()] {
tokio::fs::create_dir_all(&d)
.await
.map_err(|e| StoreError::io(&d, e))?;
}
Ok(())
}
pub async fn deliver(&self, msg: &InboxMessage) -> Result<MsgId> {
let bytes = serde_json::to_vec_pretty(msg).map_err(|e| StoreError::decode(&self.dir, e))?;
let nanos = msg.created_at.timestamp_nanos_opt().unwrap_or(0).max(0);
let name = format!("{nanos:020}-{}.json", msg.id.0);
atomic_write(&self.new_dir().join(&name), &bytes).await?;
Ok(msg.id.clone())
}
pub async fn drain(&self) -> Result<Vec<Delivered>> {
self.ensure_dirs().await?;
let names = self.sorted_json_names(&self.new_dir()).await?;
let mut out = Vec::new();
for name in names {
let src = self.new_dir().join(&name);
let dst = self.cur_dir().join(&name);
if tokio::fs::rename(&src, &dst).await.is_err() {
continue;
}
match read_msg(&dst).await {
Ok(msg) => out.push(Delivered { msg, cur_path: dst }),
Err(_) => {
let _ = tokio::fs::rename(&dst, &self.corrupt_dir().join(&name)).await;
}
}
}
Ok(out)
}
pub async fn ack_delivered(&self, delivered: &Delivered) -> Result<()> {
match tokio::fs::remove_file(&delivered.cur_path).await {
Ok(()) => Ok(()),
Err(e) if e.kind() == ErrorKind::NotFound => Ok(()),
Err(e) => Err(StoreError::io(&delivered.cur_path, e)),
}
}
pub async fn ack(&self, id: &MsgId) -> Result<()> {
let needle = format!("-{}.json", id.0);
let cur = self.cur_dir();
let mut rd = match tokio::fs::read_dir(&cur).await {
Ok(rd) => rd,
Err(e) if e.kind() == ErrorKind::NotFound => return Ok(()),
Err(e) => return Err(StoreError::io(&cur, e)),
};
while let Some(ent) = rd.next_entry().await.map_err(|e| StoreError::io(&cur, e))? {
let fname = ent.file_name().to_string_lossy().into_owned();
if fname.ends_with(&needle) {
tokio::fs::remove_file(ent.path())
.await
.map_err(|e| StoreError::io(ent.path(), e))?;
return Ok(());
}
}
Ok(())
}
pub async fn recover(&self) -> Result<Vec<Delivered>> {
self.ensure_dirs().await?;
let names = self.sorted_json_names(&self.cur_dir()).await?;
let mut out = Vec::new();
for name in names {
let path = self.cur_dir().join(&name);
match read_msg(&path).await {
Ok(msg) => out.push(Delivered {
msg,
cur_path: path,
}),
Err(_) => {
let _ = tokio::fs::rename(&path, &self.corrupt_dir().join(&name)).await;
}
}
}
Ok(out)
}
pub async fn is_empty(&self) -> Result<bool> {
Ok(self.sorted_json_names(&self.new_dir()).await?.is_empty())
}
async fn sorted_json_names(&self, dir: &std::path::Path) -> Result<Vec<String>> {
let mut rd = match tokio::fs::read_dir(dir).await {
Ok(rd) => rd,
Err(e) if e.kind() == ErrorKind::NotFound => return Ok(Vec::new()),
Err(e) => return Err(StoreError::io(dir, e)),
};
let mut names = Vec::new();
while let Some(ent) = rd.next_entry().await.map_err(|e| StoreError::io(dir, e))? {
let fname = ent.file_name().to_string_lossy().into_owned();
if fname.starts_with('.') || !fname.ends_with(".json") {
continue; }
names.push(fname);
}
names.sort();
Ok(names)
}
}
async fn read_msg(path: &std::path::Path) -> Result<InboxMessage> {
let bytes = tokio::fs::read(path)
.await
.map_err(|e| StoreError::io(path, e))?;
serde_json::from_slice(&bytes).map_err(|e| StoreError::decode(path, e))
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
#[serde(from = "Vec<MsgId>", into = "Vec<MsgId>")]
pub struct AdmittedSet {
order: std::collections::VecDeque<MsgId>,
index: HashSet<MsgId>,
}
pub const ADMITTED_SET_CAPACITY: usize = 4096;
impl AdmittedSet {
pub fn contains(&self, id: &MsgId) -> bool {
self.index.contains(id)
}
pub fn insert(&mut self, id: MsgId) -> bool {
if !self.index.insert(id.clone()) {
return false;
}
self.order.push_back(id);
while self.order.len() > ADMITTED_SET_CAPACITY {
if let Some(evicted) = self.order.pop_front() {
self.index.remove(&evicted);
}
}
true
}
pub fn len(&self) -> usize {
self.order.len()
}
pub fn is_empty(&self) -> bool {
self.order.is_empty()
}
}
impl From<Vec<MsgId>> for AdmittedSet {
fn from(ids: Vec<MsgId>) -> Self {
let mut set = AdmittedSet::default();
for id in ids {
set.insert(id);
}
set
}
}
impl From<AdmittedSet> for Vec<MsgId> {
fn from(set: AdmittedSet) -> Self {
set.order.into_iter().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::TimeZone;
use serde_json::json;
use tempfile::TempDir;
fn mailbox() -> (TempDir, Mailbox) {
let dir = TempDir::new().unwrap();
let mb = Mailbox::at(dir.path().join("mailbox"));
(dir, mb)
}
fn msg(seq: u32) -> InboxMessage {
InboxMessage {
id: MsgId::new(),
from: AgentRef {
session_id: "parent".into(),
role: None,
},
kind: InboxKind::Task,
body: json!({ "seq": seq }),
created_at: Utc::now(),
correlation_id: None,
}
}
#[test]
fn ask_reply_bodies_and_correlation_round_trip() {
let ask = AskBody {
question: "what did you find?".into(),
mode: AskMode::Query,
};
let ask_json = serde_json::to_value(&ask).unwrap();
assert_eq!(ask_json["mode"], "query");
assert_eq!(serde_json::from_value::<AskBody>(ask_json).unwrap(), ask);
let defaulted: AskBody = serde_json::from_value(json!({ "question": "q" })).unwrap();
assert_eq!(defaulted.mode, AskMode::Query);
assert_eq!(
serde_json::from_value::<AskMode>(json!("steer")).unwrap(),
AskMode::Steer
);
let ask_id = MsgId::new();
let reply = InboxMessage {
id: MsgId::new(),
from: AgentRef {
session_id: "child".into(),
role: None,
},
kind: InboxKind::Reply,
body: serde_json::to_value(ReplyBody {
answer: "found X".into(),
})
.unwrap(),
created_at: Utc::now(),
correlation_id: Some(ask_id.clone()),
};
let round: InboxMessage =
serde_json::from_value(serde_json::to_value(&reply).unwrap()).unwrap();
assert_eq!(round.correlation_id, Some(ask_id));
assert_eq!(round.kind, InboxKind::Reply);
let legacy: InboxMessage = serde_json::from_value(json!({
"id": MsgId::new(),
"from": { "session_id": "p" },
"kind": "task",
"body": {},
"created_at": Utc::now().to_rfc3339(),
}))
.unwrap();
assert_eq!(legacy.correlation_id, None);
}
#[tokio::test]
async fn deliver_then_drain_then_ack() {
let (_d, mb) = mailbox();
let m = msg(1);
mb.deliver(&m).await.unwrap();
assert!(!mb.is_empty().await.unwrap());
let batch = mb.drain().await.unwrap();
assert_eq!(batch.len(), 1);
assert_eq!(batch[0].msg.id, m.id);
assert!(mb.is_empty().await.unwrap());
mb.ack(&m.id).await.unwrap();
assert!(mb.recover().await.unwrap().is_empty());
}
#[tokio::test]
async fn multi_writer_no_loss() {
let (_d, mb) = mailbox();
mb.ensure_dirs().await.unwrap();
let dir = mb.dir.clone();
let mut handles = Vec::new();
for i in 0..50u32 {
let d = dir.clone();
handles.push(tokio::spawn(async move {
let mb = Mailbox::at(d);
mb.deliver(&msg(i)).await.unwrap();
}));
}
for h in handles {
h.await.unwrap();
}
let batch = mb.drain().await.unwrap();
assert_eq!(batch.len(), 50);
let ids: HashSet<_> = batch.iter().map(|d| d.msg.id.clone()).collect();
assert_eq!(ids.len(), 50); }
#[tokio::test]
async fn drain_is_time_ordered() {
let (_d, mb) = mailbox();
let base = Utc.timestamp_opt(1_700_000_000, 0).unwrap();
for i in 0..5u32 {
let mut m = msg(i);
m.created_at = base + chrono::Duration::seconds(i as i64);
mb.deliver(&m).await.unwrap();
}
let batch = mb.drain().await.unwrap();
let seqs: Vec<u32> = batch
.iter()
.map(|d| d.msg.body["seq"].as_u64().unwrap() as u32)
.collect();
assert_eq!(seqs, vec![0, 1, 2, 3, 4]);
}
#[tokio::test]
async fn recover_returns_unacked_leftovers() {
let (_d, mb) = mailbox();
let m = msg(1);
mb.deliver(&m).await.unwrap();
let batch = mb.drain().await.unwrap(); assert_eq!(batch.len(), 1);
let mb2 = Mailbox::at(mb.dir.clone());
let recovered = mb2.recover().await.unwrap();
assert_eq!(recovered.len(), 1);
assert_eq!(recovered[0].msg.id, m.id);
}
#[tokio::test]
async fn corrupt_file_is_quarantined() {
let (_d, mb) = mailbox();
mb.ensure_dirs().await.unwrap();
mb.deliver(&msg(1)).await.unwrap();
tokio::fs::write(
mb.new_dir().join("00000000000000000001-bogus.json"),
b"not json",
)
.await
.unwrap();
let batch = mb.drain().await.unwrap();
assert_eq!(batch.len(), 1); let mut rd = tokio::fs::read_dir(mb.corrupt_dir()).await.unwrap();
let mut corrupt = 0;
while rd.next_entry().await.unwrap().is_some() {
corrupt += 1;
}
assert_eq!(corrupt, 1); }
#[tokio::test]
async fn admitted_set_dedupes() {
let mut seen = AdmittedSet::default();
let id = MsgId::new();
assert!(seen.insert(id.clone())); assert!(seen.contains(&id));
assert!(!seen.insert(id.clone())); assert_eq!(seen.len(), 1);
}
#[test]
fn admitted_set_is_bounded_and_serde_round_trips() {
let mut seen = AdmittedSet::default();
let first = MsgId::new();
seen.insert(first.clone());
for _ in 0..ADMITTED_SET_CAPACITY {
seen.insert(MsgId::new());
}
assert_eq!(seen.len(), ADMITTED_SET_CAPACITY);
assert!(!seen.contains(&first));
let json = serde_json::to_string(&seen).unwrap();
let restored: AdmittedSet = serde_json::from_str(&json).unwrap();
assert_eq!(restored.len(), seen.len());
let probe = Vec::<MsgId>::from(seen.clone())[0].clone();
assert!(restored.contains(&probe));
}
#[tokio::test]
async fn ack_delivered_removes_by_path() {
let (_d, mb) = mailbox();
mb.deliver(&msg(1)).await.unwrap();
let batch = mb.drain().await.unwrap();
mb.ack_delivered(&batch[0]).await.unwrap();
assert!(mb.recover().await.unwrap().is_empty()); mb.ack_delivered(&batch[0]).await.unwrap();
}
}