1use std::collections::HashMap;
13use std::future::Future;
14
15use meerkat_core::TurnErrorMetadata;
16use meerkat_core::lifecycle::InputId;
17use meerkat_core::types::RunResult;
18use serde_json::Value;
19
20use crate::tokio::sync::oneshot;
21
22#[derive(Debug)]
24pub enum CompletionOutcome {
25 Completed(Box<RunResult>),
27 CompletedWithoutResult,
29 CallbackPending { tool_name: String, args: Value },
32 Cancelled,
34 Abandoned(String),
36 AbandonedWithError {
38 reason: String,
39 error: TurnErrorMetadata,
40 },
41 CompletedWithFinalizationFailure {
44 result: Box<RunResult>,
45 error: TurnErrorMetadata,
46 },
47 RuntimeTerminated(String),
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
56pub struct CompletionWaiterEntrySnapshot {
57 pub input_id: InputId,
58 pub waiter_count: usize,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Default)]
66pub struct CompletionRegistrySnapshot {
67 pub input_count: usize,
68 pub waiter_count: usize,
69 pub waiting_inputs: Vec<CompletionWaiterEntrySnapshot>,
70}
71
72#[derive(Debug)]
74pub struct CompletionHandle {
75 rx: oneshot::Receiver<CompletionOutcome>,
76}
77
78impl CompletionHandle {
79 pub async fn wait(self) -> CompletionOutcome {
81 match self.rx.await {
82 Ok(outcome) => outcome,
83 Err(_) => CompletionOutcome::RuntimeTerminated(
85 "completion channel closed without result".into(),
86 ),
87 }
88 }
89
90 pub fn with_cleanup<F, Fut>(self, cleanup: F) -> Self
94 where
95 F: FnOnce() -> Fut + Send + 'static,
96 Fut: Future<Output = ()> + Send + 'static,
97 {
98 let (tx, rx) = oneshot::channel();
99 crate::tokio::spawn(async move {
100 let outcome = self.wait().await;
101 cleanup().await;
102 let _ = tx.send(outcome);
103 });
104 Self { rx }
105 }
106
107 pub fn with_outcome_cleanup<F, Fut>(self, cleanup: F) -> Self
112 where
113 F: FnOnce(CompletionOutcome) -> Fut + Send + 'static,
114 Fut: Future<Output = CompletionOutcome> + Send + 'static,
115 {
116 let (tx, rx) = oneshot::channel();
117 crate::tokio::spawn(async move {
118 let outcome = self.wait().await;
119 let outcome = cleanup(outcome).await;
120 let _ = tx.send(outcome);
121 });
122 Self { rx }
123 }
124
125 pub fn already_resolved(outcome: CompletionOutcome) -> Self {
130 let (tx, rx) = oneshot::channel();
131 let _ = tx.send(outcome);
132 Self { rx }
133 }
134}
135
136impl CompletionOutcome {
137 pub fn abandoned_reason(&self) -> Option<&str> {
138 match self {
139 Self::Abandoned(reason) | Self::AbandonedWithError { reason, .. } => Some(reason),
140 _ => None,
141 }
142 }
143
144 pub fn error_metadata(&self) -> Option<&TurnErrorMetadata> {
145 match self {
146 Self::AbandonedWithError { error, .. }
147 | Self::CompletedWithFinalizationFailure { error, .. } => Some(error),
148 _ => None,
149 }
150 }
151}
152
153#[derive(Default)]
158pub(crate) struct CompletionRegistry {
159 waiters: HashMap<InputId, Vec<oneshot::Sender<CompletionOutcome>>>,
160}
161
162impl CompletionRegistry {
163 pub(crate) fn new() -> Self {
164 Self::default()
165 }
166
167 fn take_waiters(
168 &mut self,
169 input_id: &InputId,
170 ) -> Option<Vec<oneshot::Sender<CompletionOutcome>>> {
171 self.waiters.remove(input_id)
172 }
173
174 pub(crate) fn register(&mut self, input_id: InputId) -> CompletionHandle {
179 let (tx, rx) = oneshot::channel();
180 self.waiters.entry(input_id).or_default().push(tx);
181 CompletionHandle { rx }
182 }
183
184 pub(crate) fn resolve_completed(&mut self, input_id: &InputId, result: RunResult) {
186 if let Some(senders) = self.take_waiters(input_id) {
187 for tx in senders {
188 let _ = tx.send(CompletionOutcome::Completed(Box::new(result.clone())));
189 }
190 }
191 }
192
193 pub(crate) fn resolve_without_result(&mut self, input_id: &InputId) {
195 if let Some(senders) = self.take_waiters(input_id) {
196 for tx in senders {
197 let _ = tx.send(CompletionOutcome::CompletedWithoutResult);
198 }
199 }
200 }
201
202 pub(crate) fn resolve_callback_pending(
204 &mut self,
205 input_id: &InputId,
206 tool_name: String,
207 args: Value,
208 ) {
209 if let Some(senders) = self.take_waiters(input_id) {
210 for tx in senders {
211 let _ = tx.send(CompletionOutcome::CallbackPending {
212 tool_name: tool_name.clone(),
213 args: args.clone(),
214 });
215 }
216 }
217 }
218
219 pub(crate) fn resolve_cancelled(&mut self, input_id: &InputId) {
221 if let Some(senders) = self.take_waiters(input_id) {
222 for tx in senders {
223 let _ = tx.send(CompletionOutcome::Cancelled);
224 }
225 }
226 }
227
228 pub(crate) fn resolve_abandoned(&mut self, input_id: &InputId, reason: String) {
230 if let Some(senders) = self.take_waiters(input_id) {
231 for tx in senders {
232 let _ = tx.send(CompletionOutcome::Abandoned(reason.clone()));
233 }
234 }
235 }
236
237 pub(crate) fn resolve_abandoned_with_error(
239 &mut self,
240 input_id: &InputId,
241 reason: String,
242 error: TurnErrorMetadata,
243 ) {
244 if let Some(senders) = self.take_waiters(input_id) {
245 for tx in senders {
246 let _ = tx.send(CompletionOutcome::AbandonedWithError {
247 reason: reason.clone(),
248 error: error.clone(),
249 });
250 }
251 }
252 }
253
254 pub(crate) fn resolve_completed_with_finalization_failure(
257 &mut self,
258 input_id: &InputId,
259 result: RunResult,
260 error: TurnErrorMetadata,
261 ) {
262 if let Some(senders) = self.take_waiters(input_id) {
263 for tx in senders {
264 let _ = tx.send(CompletionOutcome::CompletedWithFinalizationFailure {
265 result: Box::new(result.clone()),
266 error: error.clone(),
267 });
268 }
269 }
270 }
271
272 pub(crate) fn resolve_all_terminated(&mut self, reason: &str) {
276 for (_, senders) in self.waiters.drain() {
277 for tx in senders {
278 let _ = tx.send(CompletionOutcome::RuntimeTerminated(reason.into()));
279 }
280 }
281 }
282
283 pub(crate) fn resolve_not_pending<F>(&mut self, mut is_still_pending: F, reason: &str)
286 where
287 F: FnMut(&InputId) -> bool,
288 {
289 self.waiters.retain(|input_id, senders| {
290 if is_still_pending(input_id) {
291 return true;
292 }
293
294 for tx in senders.drain(..) {
295 let _ = tx.send(CompletionOutcome::RuntimeTerminated(reason.into()));
296 }
297 false
298 });
299 }
300
301 pub(crate) fn diagnostic_snapshot(&self) -> CompletionRegistrySnapshot {
303 let mut waiting_inputs: Vec<_> = self
304 .waiters
305 .iter()
306 .map(|(input_id, senders)| CompletionWaiterEntrySnapshot {
307 input_id: input_id.clone(),
308 waiter_count: senders.len(),
309 })
310 .collect();
311 waiting_inputs
312 .sort_by(|left, right| left.input_id.to_string().cmp(&right.input_id.to_string()));
313
314 CompletionRegistrySnapshot {
315 input_count: waiting_inputs.len(),
316 waiter_count: waiting_inputs.iter().map(|entry| entry.waiter_count).sum(),
317 waiting_inputs,
318 }
319 }
320
321 #[cfg(test)]
326 pub fn debug_has_waiters(&self) -> bool {
327 !self.waiters.is_empty()
328 }
329
330 #[cfg(test)]
335 pub fn debug_waiter_count(&self) -> usize {
336 self.waiters.values().map(Vec::len).sum()
337 }
338}
339
340#[cfg(test)]
341#[allow(clippy::unwrap_used, clippy::panic)]
342mod tests {
343 use super::*;
344 use meerkat_core::types::{SessionId, Usage};
345
346 fn make_run_result() -> RunResult {
347 RunResult {
348 text: "hello".into(),
349 session_id: SessionId::new(),
350 usage: Usage::default(),
351 turns: 1,
352 tool_calls: 0,
353 terminal_cause_kind: None,
354 structured_output: None,
355 extraction_error: None,
356 schema_warnings: None,
357 skill_diagnostics: None,
358 }
359 }
360
361 #[tokio::test]
362 async fn register_and_complete() {
363 let mut registry = CompletionRegistry::new();
364 let input_id = InputId::new();
365 let handle = registry.register(input_id.clone());
366
367 assert!(registry.debug_has_waiters());
368 assert_eq!(registry.debug_waiter_count(), 1);
369
370 let result = make_run_result();
371 registry.resolve_completed(&input_id, result);
372
373 match handle.wait().await {
374 CompletionOutcome::Completed(r) => assert_eq!(r.text, "hello"),
375 other => panic!("Expected Completed, got {other:?}"),
376 }
377 }
378
379 #[tokio::test]
380 async fn register_and_abandon() {
381 let mut registry = CompletionRegistry::new();
382 let input_id = InputId::new();
383 let handle = registry.register(input_id.clone());
384
385 registry.resolve_abandoned(&input_id, "retired".into());
386
387 match handle.wait().await {
388 CompletionOutcome::Abandoned(reason) => assert_eq!(reason, "retired"),
389 other => panic!("Expected Abandoned, got {other:?}"),
390 }
391 }
392
393 #[tokio::test]
394 async fn resolve_all_terminated() {
395 let mut registry = CompletionRegistry::new();
396 let h1 = registry.register(InputId::new());
397 let h2 = registry.register(InputId::new());
398
399 registry.resolve_all_terminated("runtime stopped");
400
401 assert!(!registry.debug_has_waiters());
402
403 match h1.wait().await {
404 CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime stopped"),
405 other => panic!("Expected RuntimeTerminated, got {other:?}"),
406 }
407 match h2.wait().await {
408 CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime stopped"),
409 other => panic!("Expected RuntimeTerminated, got {other:?}"),
410 }
411 }
412
413 #[tokio::test]
414 async fn resolve_nonexistent_is_a_noop() {
415 let mut registry = CompletionRegistry::new();
416 registry.resolve_completed(&InputId::new(), make_run_result());
417 registry.resolve_abandoned(&InputId::new(), "gone".into());
418 assert!(!registry.debug_has_waiters());
419 }
420
421 #[tokio::test]
422 async fn dropped_sender_gives_terminated() {
423 let mut registry = CompletionRegistry::new();
424 let input_id = InputId::new();
425 let handle = registry.register(input_id);
426
427 drop(registry);
429
430 match handle.wait().await {
431 CompletionOutcome::RuntimeTerminated(_) => {}
432 other => panic!("Expected RuntimeTerminated, got {other:?}"),
433 }
434 }
435
436 #[tokio::test]
437 async fn multi_waiter_all_receive_result() {
438 let mut registry = CompletionRegistry::new();
439 let input_id = InputId::new();
440
441 let h1 = registry.register(input_id.clone());
442 let h2 = registry.register(input_id.clone());
443 let h3 = registry.register(input_id.clone());
444
445 assert_eq!(registry.debug_waiter_count(), 3);
446
447 let result = make_run_result();
448 registry.resolve_completed(&input_id, result);
449
450 assert!(!registry.debug_has_waiters());
451
452 for handle in [h1, h2, h3] {
453 match handle.wait().await {
454 CompletionOutcome::Completed(r) => assert_eq!(r.text, "hello"),
455 other => panic!("Expected Completed, got {other:?}"),
456 }
457 }
458 }
459
460 #[tokio::test]
461 async fn resolve_without_result_sends_variant() {
462 let mut registry = CompletionRegistry::new();
463 let input_id = InputId::new();
464 let handle = registry.register(input_id.clone());
465
466 registry.resolve_without_result(&input_id);
467
468 match handle.wait().await {
469 CompletionOutcome::CompletedWithoutResult => {}
470 other => panic!("Expected CompletedWithoutResult, got {other:?}"),
471 }
472 }
473
474 #[tokio::test]
475 async fn resolve_without_result_multi_waiter() {
476 let mut registry = CompletionRegistry::new();
477 let input_id = InputId::new();
478 let h1 = registry.register(input_id.clone());
479 let h2 = registry.register(input_id.clone());
480
481 registry.resolve_without_result(&input_id);
482
483 for handle in [h1, h2] {
484 match handle.wait().await {
485 CompletionOutcome::CompletedWithoutResult => {}
486 other => panic!("Expected CompletedWithoutResult, got {other:?}"),
487 }
488 }
489 }
490
491 #[tokio::test]
492 async fn resolve_callback_pending_sends_variant() {
493 let mut registry = CompletionRegistry::new();
494 let input_id = InputId::new();
495 let handle = registry.register(input_id.clone());
496
497 registry.resolve_callback_pending(
498 &input_id,
499 "browser".to_string(),
500 serde_json::json!({ "url": "https://example.com" }),
501 );
502
503 match handle.wait().await {
504 CompletionOutcome::CallbackPending { tool_name, args } => {
505 assert_eq!(tool_name, "browser");
506 assert_eq!(args, serde_json::json!({ "url": "https://example.com" }));
507 }
508 other => panic!("Expected CallbackPending, got {other:?}"),
509 }
510 }
511
512 #[tokio::test]
513 async fn resolve_cancelled_sends_variant() {
514 let mut registry = CompletionRegistry::new();
515 let input_id = InputId::new();
516 let handle = registry.register(input_id.clone());
517
518 registry.resolve_cancelled(&input_id);
519
520 match handle.wait().await {
521 CompletionOutcome::Cancelled => {}
522 other => panic!("Expected Cancelled, got {other:?}"),
523 }
524 }
525
526 #[tokio::test]
527 async fn already_resolved_handle() {
528 let handle = CompletionHandle::already_resolved(CompletionOutcome::CompletedWithoutResult);
529 match handle.wait().await {
530 CompletionOutcome::CompletedWithoutResult => {}
531 other => panic!("Expected CompletedWithoutResult, got {other:?}"),
532 }
533 }
534
535 #[tokio::test]
536 async fn outcome_cleanup_observes_and_relays_result() {
537 use std::sync::Arc;
538 use std::sync::atomic::{AtomicBool, Ordering};
539
540 let observed = Arc::new(AtomicBool::new(false));
541 let cleanup_observed = Arc::clone(&observed);
542 let handle = CompletionHandle::already_resolved(CompletionOutcome::Abandoned(
543 "apply failed: test".to_string(),
544 ))
545 .with_outcome_cleanup(move |outcome| async move {
546 if matches!(&outcome, CompletionOutcome::Abandoned(reason) if reason == "apply failed: test")
547 {
548 cleanup_observed.store(true, Ordering::Release);
549 }
550 outcome
551 });
552
553 match handle.wait().await {
554 CompletionOutcome::Abandoned(reason) => {
555 assert_eq!(reason, "apply failed: test");
556 }
557 other => panic!("Expected Abandoned, got {other:?}"),
558 }
559 assert!(observed.load(Ordering::Acquire));
560 }
561
562 #[tokio::test]
563 async fn multi_waiter_terminated_on_reset() {
564 let mut registry = CompletionRegistry::new();
565 let input_id = InputId::new();
566 let h1 = registry.register(input_id.clone());
567 let h2 = registry.register(input_id);
568
569 registry.resolve_all_terminated("runtime reset");
570
571 for handle in [h1, h2] {
572 match handle.wait().await {
573 CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime reset"),
574 other => panic!("Expected RuntimeTerminated, got {other:?}"),
575 }
576 }
577 }
578
579 #[tokio::test]
580 async fn resolve_not_pending_keeps_pending_waiters() {
581 let mut registry = CompletionRegistry::new();
582 let keep_id = InputId::new();
583 let drop_id = InputId::new();
584
585 let keep_handle = registry.register(keep_id.clone());
586 let drop_handle = registry.register(drop_id.clone());
587 registry.resolve_not_pending(|input_id| input_id == &keep_id, "runtime recycled");
588 assert_eq!(registry.debug_waiter_count(), 1);
589
590 match drop_handle.wait().await {
591 CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime recycled"),
592 other => panic!("Expected RuntimeTerminated, got {other:?}"),
593 }
594
595 registry.resolve_without_result(&keep_id);
596 match keep_handle.wait().await {
597 CompletionOutcome::CompletedWithoutResult => {}
598 other => panic!("Expected CompletedWithoutResult, got {other:?}"),
599 }
600 }
601
602 #[tokio::test]
603 async fn resolve_without_result_nonexistent_is_a_noop() {
604 let mut registry = CompletionRegistry::new();
605 registry.resolve_without_result(&InputId::new());
606 assert!(!registry.debug_has_waiters());
607 }
608}