use std::collections::{HashMap, HashSet};
use std::sync::{Mutex, OnceLock};
use bamboo_subagent::proto::ParentFrame;
use tokio::sync::mpsc;
fn map() -> &'static Mutex<HashMap<String, mpsc::UnboundedSender<ParentFrame>>> {
static MAP: OnceLock<Mutex<HashMap<String, mpsc::UnboundedSender<ParentFrame>>>> =
OnceLock::new();
MAP.get_or_init(|| Mutex::new(HashMap::new()))
}
fn pending() -> &'static Mutex<HashMap<String, HashSet<String>>> {
static PENDING: OnceLock<Mutex<HashMap<String, HashSet<String>>>> = OnceLock::new();
PENDING.get_or_init(|| Mutex::new(HashMap::new()))
}
pub fn register_pending_approval(child_id: &str, request_id: &str) {
pending()
.lock()
.unwrap()
.entry(child_id.to_string())
.or_default()
.insert(request_id.to_string());
}
pub fn take_pending_approval(child_id: &str, request_id: &str) -> bool {
let mut guard = pending().lock().unwrap();
let Some(set) = guard.get_mut(child_id) else {
return false;
};
let took = set.remove(request_id);
if set.is_empty() {
guard.remove(child_id);
}
took
}
pub fn clear_pending_approvals_for(child_id: &str) {
pending().lock().unwrap().remove(child_id);
}
pub fn deliver_approval_checked(child_id: &str, request_id: &str, approved: bool) -> bool {
if take_pending_approval(child_id, request_id) {
deliver_approval(child_id, request_id, approved)
} else {
false
}
}
pub struct LiveActorGuard {
child_id: String,
}
impl Drop for LiveActorGuard {
fn drop(&mut self) {
map().lock().unwrap().remove(&self.child_id);
clear_pending_approvals_for(&self.child_id);
}
}
pub fn register(child_id: &str, tx: mpsc::UnboundedSender<ParentFrame>) -> LiveActorGuard {
map().lock().unwrap().insert(child_id.to_string(), tx);
LiveActorGuard {
child_id: child_id.to_string(),
}
}
pub fn deliver_message(child_id: &str, text: &str) -> bool {
let guard = map().lock().unwrap();
match guard.get(child_id) {
Some(tx) => tx
.send(ParentFrame::Message {
text: text.to_string(),
})
.is_ok(),
None => false,
}
}
pub fn deliver_approval(child_id: &str, request_id: &str, approved: bool) -> bool {
let guard = map().lock().unwrap();
match guard.get(child_id) {
Some(tx) => tx
.send(ParentFrame::ApprovalReply {
id: request_id.to_string(),
approved,
})
.is_ok(),
None => false,
}
}
pub fn is_live(child_id: &str) -> bool {
map().lock().unwrap().contains_key(child_id)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn register_deliver_unregister() {
let (tx, mut rx) = mpsc::unbounded_channel();
let guard = register("c-live", tx);
assert!(is_live("c-live"));
assert!(deliver_message("c-live", "hi"));
match rx.try_recv() {
Ok(ParentFrame::Message { text }) => assert_eq!(text, "hi"),
other => panic!("expected message frame, got {other:?}"),
}
drop(guard);
assert!(!is_live("c-live"));
assert!(!deliver_message("c-live", "gone"));
}
#[test]
fn deliver_fails_when_receiver_dropped() {
let (tx, rx) = mpsc::unbounded_channel();
let _guard = register("c-dead", tx);
drop(rx);
assert!(!deliver_message("c-dead", "hi"));
}
#[test]
fn deliver_approval_routes_reply_frame() {
let (tx, mut rx) = mpsc::unbounded_channel();
let guard = register("c-appr", tx);
assert!(deliver_approval("c-appr", "req-7", true));
match rx.try_recv() {
Ok(ParentFrame::ApprovalReply { id, approved }) => {
assert_eq!(id, "req-7");
assert!(approved);
}
other => panic!("expected approval reply, got {other:?}"),
}
drop(guard);
assert!(!deliver_approval("c-appr", "req-8", false));
}
#[test]
fn pending_approval_is_one_shot() {
register_pending_approval("c-pend", "req-1");
assert!(take_pending_approval("c-pend", "req-1"));
assert!(!take_pending_approval("c-pend", "req-1"));
}
#[test]
fn take_of_unregistered_pair_is_false() {
assert!(!take_pending_approval("c-unknown", "req-x"));
register_pending_approval("c-known", "req-real");
assert!(!take_pending_approval("c-known", "req-bogus"));
assert!(take_pending_approval("c-known", "req-real"));
}
#[test]
fn deliver_approval_checked_only_delivers_for_registered_pair() {
let (tx, mut rx) = mpsc::unbounded_channel();
let _guard = register("c-checked", tx);
assert!(!deliver_approval_checked("c-checked", "req-stray", true));
assert!(rx.try_recv().is_err());
register_pending_approval("c-checked", "req-ok");
assert!(deliver_approval_checked("c-checked", "req-ok", true));
match rx.try_recv() {
Ok(ParentFrame::ApprovalReply { id, approved }) => {
assert_eq!(id, "req-ok");
assert!(approved);
}
other => panic!("expected approval reply, got {other:?}"),
}
assert!(!deliver_approval_checked("c-checked", "req-ok", true));
assert!(rx.try_recv().is_err());
}
#[test]
fn clear_pending_approvals_for_drops_them() {
register_pending_approval("c-clear", "req-a");
register_pending_approval("c-clear", "req-b");
clear_pending_approvals_for("c-clear");
assert!(!take_pending_approval("c-clear", "req-a"));
assert!(!take_pending_approval("c-clear", "req-b"));
}
}