use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
use bamboo_subagent::proto::ParentFrame;
use tokio::sync::mpsc;
fn map() -> &'static Mutex<HashMap<String, mpsc::UnboundedSender<ParentFrame>>> {
static MAP: OnceLock<Mutex<HashMap<String, mpsc::UnboundedSender<ParentFrame>>>> =
OnceLock::new();
MAP.get_or_init(|| Mutex::new(HashMap::new()))
}
pub struct LiveActorGuard {
child_id: String,
}
impl Drop for LiveActorGuard {
fn drop(&mut self) {
map().lock().unwrap().remove(&self.child_id);
}
}
pub fn register(child_id: &str, tx: mpsc::UnboundedSender<ParentFrame>) -> LiveActorGuard {
map().lock().unwrap().insert(child_id.to_string(), tx);
LiveActorGuard {
child_id: child_id.to_string(),
}
}
pub fn deliver_message(child_id: &str, text: &str) -> bool {
let guard = map().lock().unwrap();
match guard.get(child_id) {
Some(tx) => tx
.send(ParentFrame::Message {
text: text.to_string(),
})
.is_ok(),
None => false,
}
}
pub fn deliver_approval(child_id: &str, request_id: &str, approved: bool) -> bool {
let guard = map().lock().unwrap();
match guard.get(child_id) {
Some(tx) => tx
.send(ParentFrame::ApprovalReply {
id: request_id.to_string(),
approved,
})
.is_ok(),
None => false,
}
}
pub fn is_live(child_id: &str) -> bool {
map().lock().unwrap().contains_key(child_id)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn register_deliver_unregister() {
let (tx, mut rx) = mpsc::unbounded_channel();
let guard = register("c-live", tx);
assert!(is_live("c-live"));
assert!(deliver_message("c-live", "hi"));
match rx.try_recv() {
Ok(ParentFrame::Message { text }) => assert_eq!(text, "hi"),
other => panic!("expected message frame, got {other:?}"),
}
drop(guard);
assert!(!is_live("c-live"));
assert!(!deliver_message("c-live", "gone"));
}
#[test]
fn deliver_fails_when_receiver_dropped() {
let (tx, rx) = mpsc::unbounded_channel();
let _guard = register("c-dead", tx);
drop(rx);
assert!(!deliver_message("c-dead", "hi"));
}
#[test]
fn deliver_approval_routes_reply_frame() {
let (tx, mut rx) = mpsc::unbounded_channel();
let guard = register("c-appr", tx);
assert!(deliver_approval("c-appr", "req-7", true));
match rx.try_recv() {
Ok(ParentFrame::ApprovalReply { id, approved }) => {
assert_eq!(id, "req-7");
assert!(approved);
}
other => panic!("expected approval reply, got {other:?}"),
}
drop(guard);
assert!(!deliver_approval("c-appr", "req-8", false));
}
}