bamboo_engine/external_agents/
live.rs1use std::collections::{HashMap, HashSet};
12use std::sync::{Mutex, OnceLock};
13
14use bamboo_subagent::proto::ParentFrame;
15use tokio::sync::mpsc;
16
17fn map() -> &'static Mutex<HashMap<String, mpsc::UnboundedSender<ParentFrame>>> {
18 static MAP: OnceLock<Mutex<HashMap<String, mpsc::UnboundedSender<ParentFrame>>>> =
19 OnceLock::new();
20 MAP.get_or_init(|| Mutex::new(HashMap::new()))
21}
22
23fn pending() -> &'static Mutex<HashMap<String, HashSet<String>>> {
30 static PENDING: OnceLock<Mutex<HashMap<String, HashSet<String>>>> = OnceLock::new();
31 PENDING.get_or_init(|| Mutex::new(HashMap::new()))
32}
33
34pub fn register_pending_approval(child_id: &str, request_id: &str) {
38 pending()
39 .lock()
40 .unwrap()
41 .entry(child_id.to_string())
42 .or_default()
43 .insert(request_id.to_string());
44}
45
46pub fn take_pending_approval(child_id: &str, request_id: &str) -> bool {
50 let mut guard = pending().lock().unwrap();
51 let Some(set) = guard.get_mut(child_id) else {
52 return false;
53 };
54 let took = set.remove(request_id);
55 if set.is_empty() {
56 guard.remove(child_id);
57 }
58 took
59}
60
61pub fn clear_pending_approvals_for(child_id: &str) {
63 pending().lock().unwrap().remove(child_id);
64}
65
66pub fn deliver_approval_checked(child_id: &str, request_id: &str, approved: bool) -> bool {
74 if take_pending_approval(child_id, request_id) {
75 deliver_approval(child_id, request_id, approved)
76 } else {
77 false
78 }
79}
80
81pub struct LiveActorGuard {
84 child_id: String,
85}
86
87impl Drop for LiveActorGuard {
88 fn drop(&mut self) {
89 map().lock().unwrap().remove(&self.child_id);
90 clear_pending_approvals_for(&self.child_id);
93 }
94}
95
96pub fn register(child_id: &str, tx: mpsc::UnboundedSender<ParentFrame>) -> LiveActorGuard {
98 map().lock().unwrap().insert(child_id.to_string(), tx);
99 LiveActorGuard {
100 child_id: child_id.to_string(),
101 }
102}
103
104pub fn deliver_message(child_id: &str, text: &str) -> bool {
107 let guard = map().lock().unwrap();
108 match guard.get(child_id) {
109 Some(tx) => tx
110 .send(ParentFrame::Message {
111 text: text.to_string(),
112 })
113 .is_ok(),
114 None => false,
115 }
116}
117
118pub fn deliver_approval(child_id: &str, request_id: &str, approved: bool) -> bool {
129 let guard = map().lock().unwrap();
130 match guard.get(child_id) {
131 Some(tx) => tx
132 .send(ParentFrame::ApprovalReply {
133 id: request_id.to_string(),
134 approved,
135 })
136 .is_ok(),
137 None => false,
138 }
139}
140
141pub fn is_live(child_id: &str) -> bool {
143 map().lock().unwrap().contains_key(child_id)
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn register_deliver_unregister() {
152 let (tx, mut rx) = mpsc::unbounded_channel();
153 let guard = register("c-live", tx);
154 assert!(is_live("c-live"));
155 assert!(deliver_message("c-live", "hi"));
156 match rx.try_recv() {
157 Ok(ParentFrame::Message { text }) => assert_eq!(text, "hi"),
158 other => panic!("expected message frame, got {other:?}"),
159 }
160
161 drop(guard);
162 assert!(!is_live("c-live"));
163 assert!(!deliver_message("c-live", "gone"));
164 }
165
166 #[test]
167 fn deliver_fails_when_receiver_dropped() {
168 let (tx, rx) = mpsc::unbounded_channel();
169 let _guard = register("c-dead", tx);
170 drop(rx);
171 assert!(!deliver_message("c-dead", "hi"));
172 }
173
174 #[test]
175 fn deliver_approval_routes_reply_frame() {
176 let (tx, mut rx) = mpsc::unbounded_channel();
177 let guard = register("c-appr", tx);
178 assert!(deliver_approval("c-appr", "req-7", true));
179 match rx.try_recv() {
180 Ok(ParentFrame::ApprovalReply { id, approved }) => {
181 assert_eq!(id, "req-7");
182 assert!(approved);
183 }
184 other => panic!("expected approval reply, got {other:?}"),
185 }
186 drop(guard);
187 assert!(!deliver_approval("c-appr", "req-8", false));
189 }
190
191 #[test]
192 fn pending_approval_is_one_shot() {
193 register_pending_approval("c-pend", "req-1");
194 assert!(take_pending_approval("c-pend", "req-1"));
196 assert!(!take_pending_approval("c-pend", "req-1"));
197 }
198
199 #[test]
200 fn take_of_unregistered_pair_is_false() {
201 assert!(!take_pending_approval("c-unknown", "req-x"));
203 register_pending_approval("c-known", "req-real");
205 assert!(!take_pending_approval("c-known", "req-bogus"));
206 assert!(take_pending_approval("c-known", "req-real"));
208 }
209
210 #[test]
211 fn deliver_approval_checked_only_delivers_for_registered_pair() {
212 let (tx, mut rx) = mpsc::unbounded_channel();
213 let _guard = register("c-checked", tx);
214
215 assert!(!deliver_approval_checked("c-checked", "req-stray", true));
217 assert!(rx.try_recv().is_err());
218
219 register_pending_approval("c-checked", "req-ok");
221 assert!(deliver_approval_checked("c-checked", "req-ok", true));
222 match rx.try_recv() {
223 Ok(ParentFrame::ApprovalReply { id, approved }) => {
224 assert_eq!(id, "req-ok");
225 assert!(approved);
226 }
227 other => panic!("expected approval reply, got {other:?}"),
228 }
229 assert!(!deliver_approval_checked("c-checked", "req-ok", true));
231 assert!(rx.try_recv().is_err());
232 }
233
234 #[test]
235 fn clear_pending_approvals_for_drops_them() {
236 register_pending_approval("c-clear", "req-a");
237 register_pending_approval("c-clear", "req-b");
238 clear_pending_approvals_for("c-clear");
239 assert!(!take_pending_approval("c-clear", "req-a"));
240 assert!(!take_pending_approval("c-clear", "req-b"));
241 }
242}