1use std::collections::HashMap;
13
14use meerkat_core::lifecycle::InputId;
15use meerkat_core::types::RunResult;
16
17use crate::tokio::sync::oneshot;
18
19#[derive(Debug)]
21pub enum CompletionOutcome {
22 Completed(RunResult),
24 CompletedWithoutResult,
26 Abandoned(String),
28 RuntimeTerminated(String),
30}
31
32pub struct CompletionHandle {
34 rx: oneshot::Receiver<CompletionOutcome>,
35}
36
37impl CompletionHandle {
38 pub async fn wait(self) -> CompletionOutcome {
40 match self.rx.await {
41 Ok(outcome) => outcome,
42 Err(_) => CompletionOutcome::RuntimeTerminated(
44 "completion channel closed without result".into(),
45 ),
46 }
47 }
48
49 pub fn already_resolved(outcome: CompletionOutcome) -> Self {
54 let (tx, rx) = oneshot::channel();
55 let _ = tx.send(outcome);
56 Self { rx }
57 }
58}
59
60#[derive(Default)]
65pub(crate) struct CompletionRegistry {
66 waiters: HashMap<InputId, Vec<oneshot::Sender<CompletionOutcome>>>,
67}
68
69impl CompletionRegistry {
70 pub(crate) fn new() -> Self {
71 Self::default()
72 }
73
74 fn take_waiters(
75 &mut self,
76 input_id: &InputId,
77 ) -> Option<Vec<oneshot::Sender<CompletionOutcome>>> {
78 self.waiters.remove(input_id)
79 }
80
81 pub(crate) fn register(&mut self, input_id: InputId) -> CompletionHandle {
86 let (tx, rx) = oneshot::channel();
87 self.waiters.entry(input_id).or_default().push(tx);
88 CompletionHandle { rx }
89 }
90
91 pub(crate) fn resolve_completed(&mut self, input_id: &InputId, result: RunResult) {
93 if let Some(senders) = self.take_waiters(input_id) {
94 for tx in senders {
95 let _ = tx.send(CompletionOutcome::Completed(result.clone()));
96 }
97 }
98 }
99
100 pub(crate) fn resolve_without_result(&mut self, input_id: &InputId) {
102 if let Some(senders) = self.take_waiters(input_id) {
103 for tx in senders {
104 let _ = tx.send(CompletionOutcome::CompletedWithoutResult);
105 }
106 }
107 }
108
109 pub(crate) fn resolve_abandoned(&mut self, input_id: &InputId, reason: String) {
111 if let Some(senders) = self.take_waiters(input_id) {
112 for tx in senders {
113 let _ = tx.send(CompletionOutcome::Abandoned(reason.clone()));
114 }
115 }
116 }
117
118 pub(crate) fn resolve_all_terminated(&mut self, reason: &str) {
122 for (_, senders) in self.waiters.drain() {
123 for tx in senders {
124 let _ = tx.send(CompletionOutcome::RuntimeTerminated(reason.into()));
125 }
126 }
127 }
128
129 pub(crate) fn resolve_not_pending<F>(&mut self, mut is_still_pending: F, reason: &str)
132 where
133 F: FnMut(&InputId) -> bool,
134 {
135 self.waiters.retain(|input_id, senders| {
136 if is_still_pending(input_id) {
137 return true;
138 }
139
140 for tx in senders.drain(..) {
141 let _ = tx.send(CompletionOutcome::RuntimeTerminated(reason.into()));
142 }
143 false
144 });
145 }
146
147 #[cfg(test)]
152 pub fn debug_has_waiters(&self) -> bool {
153 !self.waiters.is_empty()
154 }
155
156 #[cfg(test)]
161 pub fn debug_waiter_count(&self) -> usize {
162 self.waiters.values().map(Vec::len).sum()
163 }
164}
165
166#[cfg(test)]
167#[allow(clippy::unwrap_used, clippy::panic)]
168mod tests {
169 use super::*;
170 use meerkat_core::types::{SessionId, Usage};
171
172 fn make_run_result() -> RunResult {
173 RunResult {
174 text: "hello".into(),
175 session_id: SessionId::new(),
176 usage: Usage::default(),
177 turns: 1,
178 tool_calls: 0,
179 structured_output: None,
180 schema_warnings: None,
181 skill_diagnostics: None,
182 }
183 }
184
185 #[tokio::test]
186 async fn register_and_complete() {
187 let mut registry = CompletionRegistry::new();
188 let input_id = InputId::new();
189 let handle = registry.register(input_id.clone());
190
191 assert!(registry.debug_has_waiters());
192 assert_eq!(registry.debug_waiter_count(), 1);
193
194 let result = make_run_result();
195 registry.resolve_completed(&input_id, result);
196
197 match handle.wait().await {
198 CompletionOutcome::Completed(r) => assert_eq!(r.text, "hello"),
199 other => panic!("Expected Completed, got {other:?}"),
200 }
201 }
202
203 #[tokio::test]
204 async fn register_and_abandon() {
205 let mut registry = CompletionRegistry::new();
206 let input_id = InputId::new();
207 let handle = registry.register(input_id.clone());
208
209 registry.resolve_abandoned(&input_id, "retired".into());
210
211 match handle.wait().await {
212 CompletionOutcome::Abandoned(reason) => assert_eq!(reason, "retired"),
213 other => panic!("Expected Abandoned, got {other:?}"),
214 }
215 }
216
217 #[tokio::test]
218 async fn resolve_all_terminated() {
219 let mut registry = CompletionRegistry::new();
220 let h1 = registry.register(InputId::new());
221 let h2 = registry.register(InputId::new());
222
223 registry.resolve_all_terminated("runtime stopped");
224
225 assert!(!registry.debug_has_waiters());
226
227 match h1.wait().await {
228 CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime stopped"),
229 other => panic!("Expected RuntimeTerminated, got {other:?}"),
230 }
231 match h2.wait().await {
232 CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime stopped"),
233 other => panic!("Expected RuntimeTerminated, got {other:?}"),
234 }
235 }
236
237 #[tokio::test]
238 async fn resolve_nonexistent_is_a_noop() {
239 let mut registry = CompletionRegistry::new();
240 registry.resolve_completed(&InputId::new(), make_run_result());
241 registry.resolve_abandoned(&InputId::new(), "gone".into());
242 assert!(!registry.debug_has_waiters());
243 }
244
245 #[tokio::test]
246 async fn dropped_sender_gives_terminated() {
247 let mut registry = CompletionRegistry::new();
248 let input_id = InputId::new();
249 let handle = registry.register(input_id);
250
251 drop(registry);
253
254 match handle.wait().await {
255 CompletionOutcome::RuntimeTerminated(_) => {}
256 other => panic!("Expected RuntimeTerminated, got {other:?}"),
257 }
258 }
259
260 #[tokio::test]
261 async fn multi_waiter_all_receive_result() {
262 let mut registry = CompletionRegistry::new();
263 let input_id = InputId::new();
264
265 let h1 = registry.register(input_id.clone());
266 let h2 = registry.register(input_id.clone());
267 let h3 = registry.register(input_id.clone());
268
269 assert_eq!(registry.debug_waiter_count(), 3);
270
271 let result = make_run_result();
272 registry.resolve_completed(&input_id, result);
273
274 assert!(!registry.debug_has_waiters());
275
276 for handle in [h1, h2, h3] {
277 match handle.wait().await {
278 CompletionOutcome::Completed(r) => assert_eq!(r.text, "hello"),
279 other => panic!("Expected Completed, got {other:?}"),
280 }
281 }
282 }
283
284 #[tokio::test]
285 async fn resolve_without_result_sends_variant() {
286 let mut registry = CompletionRegistry::new();
287 let input_id = InputId::new();
288 let handle = registry.register(input_id.clone());
289
290 registry.resolve_without_result(&input_id);
291
292 match handle.wait().await {
293 CompletionOutcome::CompletedWithoutResult => {}
294 other => panic!("Expected CompletedWithoutResult, got {other:?}"),
295 }
296 }
297
298 #[tokio::test]
299 async fn resolve_without_result_multi_waiter() {
300 let mut registry = CompletionRegistry::new();
301 let input_id = InputId::new();
302 let h1 = registry.register(input_id.clone());
303 let h2 = registry.register(input_id.clone());
304
305 registry.resolve_without_result(&input_id);
306
307 for handle in [h1, h2] {
308 match handle.wait().await {
309 CompletionOutcome::CompletedWithoutResult => {}
310 other => panic!("Expected CompletedWithoutResult, got {other:?}"),
311 }
312 }
313 }
314
315 #[tokio::test]
316 async fn already_resolved_handle() {
317 let handle = CompletionHandle::already_resolved(CompletionOutcome::CompletedWithoutResult);
318 match handle.wait().await {
319 CompletionOutcome::CompletedWithoutResult => {}
320 other => panic!("Expected CompletedWithoutResult, got {other:?}"),
321 }
322 }
323
324 #[tokio::test]
325 async fn multi_waiter_terminated_on_reset() {
326 let mut registry = CompletionRegistry::new();
327 let input_id = InputId::new();
328 let h1 = registry.register(input_id.clone());
329 let h2 = registry.register(input_id);
330
331 registry.resolve_all_terminated("runtime reset");
332
333 for handle in [h1, h2] {
334 match handle.wait().await {
335 CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime reset"),
336 other => panic!("Expected RuntimeTerminated, got {other:?}"),
337 }
338 }
339 }
340
341 #[tokio::test]
342 async fn resolve_not_pending_keeps_pending_waiters() {
343 let mut registry = CompletionRegistry::new();
344 let keep_id = InputId::new();
345 let drop_id = InputId::new();
346
347 let keep_handle = registry.register(keep_id.clone());
348 let drop_handle = registry.register(drop_id.clone());
349 registry.resolve_not_pending(|input_id| input_id == &keep_id, "runtime recycled");
350 assert_eq!(registry.debug_waiter_count(), 1);
351
352 match drop_handle.wait().await {
353 CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime recycled"),
354 other => panic!("Expected RuntimeTerminated, got {other:?}"),
355 }
356
357 registry.resolve_without_result(&keep_id);
358 match keep_handle.wait().await {
359 CompletionOutcome::CompletedWithoutResult => {}
360 other => panic!("Expected CompletedWithoutResult, got {other:?}"),
361 }
362 }
363
364 #[tokio::test]
365 async fn resolve_without_result_nonexistent_is_a_noop() {
366 let mut registry = CompletionRegistry::new();
367 registry.resolve_without_result(&InputId::new());
368 assert!(!registry.debug_has_waiters());
369 }
370}