1use std::collections::HashMap;
13
14use meerkat_core::lifecycle::InputId;
15use meerkat_core::types::RunResult;
16use serde_json::Value;
17
18use crate::tokio::sync::oneshot;
19
20#[derive(Debug)]
22pub enum CompletionOutcome {
23 Completed(RunResult),
25 CompletedWithoutResult,
27 CallbackPending { tool_name: String, args: Value },
30 Abandoned(String),
32 RuntimeTerminated(String),
34}
35
36pub struct CompletionHandle {
38 rx: oneshot::Receiver<CompletionOutcome>,
39}
40
41impl CompletionHandle {
42 pub async fn wait(self) -> CompletionOutcome {
44 match self.rx.await {
45 Ok(outcome) => outcome,
46 Err(_) => CompletionOutcome::RuntimeTerminated(
48 "completion channel closed without result".into(),
49 ),
50 }
51 }
52
53 pub fn already_resolved(outcome: CompletionOutcome) -> Self {
58 let (tx, rx) = oneshot::channel();
59 let _ = tx.send(outcome);
60 Self { rx }
61 }
62}
63
64#[derive(Default)]
69pub(crate) struct CompletionRegistry {
70 waiters: HashMap<InputId, Vec<oneshot::Sender<CompletionOutcome>>>,
71}
72
73impl CompletionRegistry {
74 pub(crate) fn new() -> Self {
75 Self::default()
76 }
77
78 fn take_waiters(
79 &mut self,
80 input_id: &InputId,
81 ) -> Option<Vec<oneshot::Sender<CompletionOutcome>>> {
82 self.waiters.remove(input_id)
83 }
84
85 pub(crate) fn register(&mut self, input_id: InputId) -> CompletionHandle {
90 let (tx, rx) = oneshot::channel();
91 self.waiters.entry(input_id).or_default().push(tx);
92 CompletionHandle { rx }
93 }
94
95 pub(crate) fn resolve_completed(&mut self, input_id: &InputId, result: RunResult) {
97 if let Some(senders) = self.take_waiters(input_id) {
98 for tx in senders {
99 let _ = tx.send(CompletionOutcome::Completed(result.clone()));
100 }
101 }
102 }
103
104 pub(crate) fn resolve_without_result(&mut self, input_id: &InputId) {
106 if let Some(senders) = self.take_waiters(input_id) {
107 for tx in senders {
108 let _ = tx.send(CompletionOutcome::CompletedWithoutResult);
109 }
110 }
111 }
112
113 pub(crate) fn resolve_callback_pending(
115 &mut self,
116 input_id: &InputId,
117 tool_name: String,
118 args: Value,
119 ) {
120 if let Some(senders) = self.take_waiters(input_id) {
121 for tx in senders {
122 let _ = tx.send(CompletionOutcome::CallbackPending {
123 tool_name: tool_name.clone(),
124 args: args.clone(),
125 });
126 }
127 }
128 }
129
130 pub(crate) fn resolve_abandoned(&mut self, input_id: &InputId, reason: String) {
132 if let Some(senders) = self.take_waiters(input_id) {
133 for tx in senders {
134 let _ = tx.send(CompletionOutcome::Abandoned(reason.clone()));
135 }
136 }
137 }
138
139 pub(crate) fn resolve_all_terminated(&mut self, reason: &str) {
143 for (_, senders) in self.waiters.drain() {
144 for tx in senders {
145 let _ = tx.send(CompletionOutcome::RuntimeTerminated(reason.into()));
146 }
147 }
148 }
149
150 pub(crate) fn resolve_not_pending<F>(&mut self, mut is_still_pending: F, reason: &str)
153 where
154 F: FnMut(&InputId) -> bool,
155 {
156 self.waiters.retain(|input_id, senders| {
157 if is_still_pending(input_id) {
158 return true;
159 }
160
161 for tx in senders.drain(..) {
162 let _ = tx.send(CompletionOutcome::RuntimeTerminated(reason.into()));
163 }
164 false
165 });
166 }
167
168 #[cfg(test)]
173 pub fn debug_has_waiters(&self) -> bool {
174 !self.waiters.is_empty()
175 }
176
177 #[cfg(test)]
182 pub fn debug_waiter_count(&self) -> usize {
183 self.waiters.values().map(Vec::len).sum()
184 }
185}
186
187#[cfg(test)]
188#[allow(clippy::unwrap_used, clippy::panic)]
189mod tests {
190 use super::*;
191 use meerkat_core::types::{SessionId, Usage};
192
193 fn make_run_result() -> RunResult {
194 RunResult {
195 text: "hello".into(),
196 session_id: SessionId::new(),
197 usage: Usage::default(),
198 turns: 1,
199 tool_calls: 0,
200 structured_output: None,
201 schema_warnings: None,
202 skill_diagnostics: None,
203 }
204 }
205
206 #[tokio::test]
207 async fn register_and_complete() {
208 let mut registry = CompletionRegistry::new();
209 let input_id = InputId::new();
210 let handle = registry.register(input_id.clone());
211
212 assert!(registry.debug_has_waiters());
213 assert_eq!(registry.debug_waiter_count(), 1);
214
215 let result = make_run_result();
216 registry.resolve_completed(&input_id, result);
217
218 match handle.wait().await {
219 CompletionOutcome::Completed(r) => assert_eq!(r.text, "hello"),
220 other => panic!("Expected Completed, got {other:?}"),
221 }
222 }
223
224 #[tokio::test]
225 async fn register_and_abandon() {
226 let mut registry = CompletionRegistry::new();
227 let input_id = InputId::new();
228 let handle = registry.register(input_id.clone());
229
230 registry.resolve_abandoned(&input_id, "retired".into());
231
232 match handle.wait().await {
233 CompletionOutcome::Abandoned(reason) => assert_eq!(reason, "retired"),
234 other => panic!("Expected Abandoned, got {other:?}"),
235 }
236 }
237
238 #[tokio::test]
239 async fn resolve_all_terminated() {
240 let mut registry = CompletionRegistry::new();
241 let h1 = registry.register(InputId::new());
242 let h2 = registry.register(InputId::new());
243
244 registry.resolve_all_terminated("runtime stopped");
245
246 assert!(!registry.debug_has_waiters());
247
248 match h1.wait().await {
249 CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime stopped"),
250 other => panic!("Expected RuntimeTerminated, got {other:?}"),
251 }
252 match h2.wait().await {
253 CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime stopped"),
254 other => panic!("Expected RuntimeTerminated, got {other:?}"),
255 }
256 }
257
258 #[tokio::test]
259 async fn resolve_nonexistent_is_a_noop() {
260 let mut registry = CompletionRegistry::new();
261 registry.resolve_completed(&InputId::new(), make_run_result());
262 registry.resolve_abandoned(&InputId::new(), "gone".into());
263 assert!(!registry.debug_has_waiters());
264 }
265
266 #[tokio::test]
267 async fn dropped_sender_gives_terminated() {
268 let mut registry = CompletionRegistry::new();
269 let input_id = InputId::new();
270 let handle = registry.register(input_id);
271
272 drop(registry);
274
275 match handle.wait().await {
276 CompletionOutcome::RuntimeTerminated(_) => {}
277 other => panic!("Expected RuntimeTerminated, got {other:?}"),
278 }
279 }
280
281 #[tokio::test]
282 async fn multi_waiter_all_receive_result() {
283 let mut registry = CompletionRegistry::new();
284 let input_id = InputId::new();
285
286 let h1 = registry.register(input_id.clone());
287 let h2 = registry.register(input_id.clone());
288 let h3 = registry.register(input_id.clone());
289
290 assert_eq!(registry.debug_waiter_count(), 3);
291
292 let result = make_run_result();
293 registry.resolve_completed(&input_id, result);
294
295 assert!(!registry.debug_has_waiters());
296
297 for handle in [h1, h2, h3] {
298 match handle.wait().await {
299 CompletionOutcome::Completed(r) => assert_eq!(r.text, "hello"),
300 other => panic!("Expected Completed, got {other:?}"),
301 }
302 }
303 }
304
305 #[tokio::test]
306 async fn resolve_without_result_sends_variant() {
307 let mut registry = CompletionRegistry::new();
308 let input_id = InputId::new();
309 let handle = registry.register(input_id.clone());
310
311 registry.resolve_without_result(&input_id);
312
313 match handle.wait().await {
314 CompletionOutcome::CompletedWithoutResult => {}
315 other => panic!("Expected CompletedWithoutResult, got {other:?}"),
316 }
317 }
318
319 #[tokio::test]
320 async fn resolve_without_result_multi_waiter() {
321 let mut registry = CompletionRegistry::new();
322 let input_id = InputId::new();
323 let h1 = registry.register(input_id.clone());
324 let h2 = registry.register(input_id.clone());
325
326 registry.resolve_without_result(&input_id);
327
328 for handle in [h1, h2] {
329 match handle.wait().await {
330 CompletionOutcome::CompletedWithoutResult => {}
331 other => panic!("Expected CompletedWithoutResult, got {other:?}"),
332 }
333 }
334 }
335
336 #[tokio::test]
337 async fn resolve_callback_pending_sends_variant() {
338 let mut registry = CompletionRegistry::new();
339 let input_id = InputId::new();
340 let handle = registry.register(input_id.clone());
341
342 registry.resolve_callback_pending(
343 &input_id,
344 "browser".to_string(),
345 serde_json::json!({ "url": "https://example.com" }),
346 );
347
348 match handle.wait().await {
349 CompletionOutcome::CallbackPending { tool_name, args } => {
350 assert_eq!(tool_name, "browser");
351 assert_eq!(args, serde_json::json!({ "url": "https://example.com" }));
352 }
353 other => panic!("Expected CallbackPending, got {other:?}"),
354 }
355 }
356
357 #[tokio::test]
358 async fn already_resolved_handle() {
359 let handle = CompletionHandle::already_resolved(CompletionOutcome::CompletedWithoutResult);
360 match handle.wait().await {
361 CompletionOutcome::CompletedWithoutResult => {}
362 other => panic!("Expected CompletedWithoutResult, got {other:?}"),
363 }
364 }
365
366 #[tokio::test]
367 async fn multi_waiter_terminated_on_reset() {
368 let mut registry = CompletionRegistry::new();
369 let input_id = InputId::new();
370 let h1 = registry.register(input_id.clone());
371 let h2 = registry.register(input_id);
372
373 registry.resolve_all_terminated("runtime reset");
374
375 for handle in [h1, h2] {
376 match handle.wait().await {
377 CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime reset"),
378 other => panic!("Expected RuntimeTerminated, got {other:?}"),
379 }
380 }
381 }
382
383 #[tokio::test]
384 async fn resolve_not_pending_keeps_pending_waiters() {
385 let mut registry = CompletionRegistry::new();
386 let keep_id = InputId::new();
387 let drop_id = InputId::new();
388
389 let keep_handle = registry.register(keep_id.clone());
390 let drop_handle = registry.register(drop_id.clone());
391 registry.resolve_not_pending(|input_id| input_id == &keep_id, "runtime recycled");
392 assert_eq!(registry.debug_waiter_count(), 1);
393
394 match drop_handle.wait().await {
395 CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime recycled"),
396 other => panic!("Expected RuntimeTerminated, got {other:?}"),
397 }
398
399 registry.resolve_without_result(&keep_id);
400 match keep_handle.wait().await {
401 CompletionOutcome::CompletedWithoutResult => {}
402 other => panic!("Expected CompletedWithoutResult, got {other:?}"),
403 }
404 }
405
406 #[tokio::test]
407 async fn resolve_without_result_nonexistent_is_a_noop() {
408 let mut registry = CompletionRegistry::new();
409 registry.resolve_without_result(&InputId::new());
410 assert!(!registry.debug_has_waiters());
411 }
412}