use std::collections::HashMap;
use meerkat_core::lifecycle::InputId;
use meerkat_core::types::RunResult;
use crate::tokio::sync::oneshot;
#[derive(Debug)]
pub enum CompletionOutcome {
Completed(RunResult),
CompletedWithoutResult,
Abandoned(String),
RuntimeTerminated(String),
}
pub struct CompletionHandle {
rx: oneshot::Receiver<CompletionOutcome>,
}
impl CompletionHandle {
pub async fn wait(self) -> CompletionOutcome {
match self.rx.await {
Ok(outcome) => outcome,
Err(_) => CompletionOutcome::RuntimeTerminated(
"completion channel closed without result".into(),
),
}
}
pub fn already_resolved(outcome: CompletionOutcome) -> Self {
let (tx, rx) = oneshot::channel();
let _ = tx.send(outcome);
Self { rx }
}
}
#[derive(Default)]
pub struct CompletionRegistry {
waiters: HashMap<InputId, Vec<oneshot::Sender<CompletionOutcome>>>,
}
impl CompletionRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, input_id: InputId) -> CompletionHandle {
let (tx, rx) = oneshot::channel();
self.waiters.entry(input_id).or_default().push(tx);
CompletionHandle { rx }
}
pub fn resolve_completed(&mut self, input_id: &InputId, result: RunResult) -> bool {
if let Some(senders) = self.waiters.remove(input_id) {
for tx in senders {
let _ = tx.send(CompletionOutcome::Completed(result.clone()));
}
true
} else {
false
}
}
pub fn resolve_without_result(&mut self, input_id: &InputId) -> bool {
if let Some(senders) = self.waiters.remove(input_id) {
for tx in senders {
let _ = tx.send(CompletionOutcome::CompletedWithoutResult);
}
true
} else {
false
}
}
pub fn resolve_abandoned(&mut self, input_id: &InputId, reason: String) -> bool {
if let Some(senders) = self.waiters.remove(input_id) {
for tx in senders {
let _ = tx.send(CompletionOutcome::Abandoned(reason.clone()));
}
true
} else {
false
}
}
pub fn resolve_all_terminated(&mut self, reason: &str) {
for (_, senders) in self.waiters.drain() {
for tx in senders {
let _ = tx.send(CompletionOutcome::RuntimeTerminated(reason.into()));
}
}
}
pub fn has_pending(&self) -> bool {
!self.waiters.is_empty()
}
pub fn pending_count(&self) -> usize {
self.waiters.values().map(Vec::len).sum()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use meerkat_core::types::{SessionId, Usage};
fn make_run_result() -> RunResult {
RunResult {
text: "hello".into(),
session_id: SessionId::new(),
usage: Usage::default(),
turns: 1,
tool_calls: 0,
structured_output: None,
schema_warnings: None,
skill_diagnostics: None,
}
}
#[tokio::test]
async fn register_and_complete() {
let mut registry = CompletionRegistry::new();
let input_id = InputId::new();
let handle = registry.register(input_id.clone());
assert!(registry.has_pending());
assert_eq!(registry.pending_count(), 1);
let result = make_run_result();
assert!(registry.resolve_completed(&input_id, result));
match handle.wait().await {
CompletionOutcome::Completed(r) => assert_eq!(r.text, "hello"),
other => panic!("Expected Completed, got {other:?}"),
}
}
#[tokio::test]
async fn register_and_abandon() {
let mut registry = CompletionRegistry::new();
let input_id = InputId::new();
let handle = registry.register(input_id.clone());
assert!(registry.resolve_abandoned(&input_id, "retired".into()));
match handle.wait().await {
CompletionOutcome::Abandoned(reason) => assert_eq!(reason, "retired"),
other => panic!("Expected Abandoned, got {other:?}"),
}
}
#[tokio::test]
async fn resolve_all_terminated() {
let mut registry = CompletionRegistry::new();
let h1 = registry.register(InputId::new());
let h2 = registry.register(InputId::new());
registry.resolve_all_terminated("runtime stopped");
assert!(!registry.has_pending());
match h1.wait().await {
CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime stopped"),
other => panic!("Expected RuntimeTerminated, got {other:?}"),
}
match h2.wait().await {
CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime stopped"),
other => panic!("Expected RuntimeTerminated, got {other:?}"),
}
}
#[tokio::test]
async fn resolve_nonexistent_returns_false() {
let mut registry = CompletionRegistry::new();
assert!(!registry.resolve_completed(&InputId::new(), make_run_result()));
assert!(!registry.resolve_abandoned(&InputId::new(), "gone".into()));
}
#[tokio::test]
async fn dropped_sender_gives_terminated() {
let mut registry = CompletionRegistry::new();
let input_id = InputId::new();
let handle = registry.register(input_id);
drop(registry);
match handle.wait().await {
CompletionOutcome::RuntimeTerminated(_) => {}
other => panic!("Expected RuntimeTerminated, got {other:?}"),
}
}
#[tokio::test]
async fn multi_waiter_all_receive_result() {
let mut registry = CompletionRegistry::new();
let input_id = InputId::new();
let h1 = registry.register(input_id.clone());
let h2 = registry.register(input_id.clone());
let h3 = registry.register(input_id.clone());
assert_eq!(registry.pending_count(), 3);
let result = make_run_result();
assert!(registry.resolve_completed(&input_id, result));
assert!(!registry.has_pending());
for handle in [h1, h2, h3] {
match handle.wait().await {
CompletionOutcome::Completed(r) => assert_eq!(r.text, "hello"),
other => panic!("Expected Completed, got {other:?}"),
}
}
}
#[tokio::test]
async fn resolve_without_result_sends_variant() {
let mut registry = CompletionRegistry::new();
let input_id = InputId::new();
let handle = registry.register(input_id.clone());
assert!(registry.resolve_without_result(&input_id));
match handle.wait().await {
CompletionOutcome::CompletedWithoutResult => {}
other => panic!("Expected CompletedWithoutResult, got {other:?}"),
}
}
#[tokio::test]
async fn resolve_without_result_multi_waiter() {
let mut registry = CompletionRegistry::new();
let input_id = InputId::new();
let h1 = registry.register(input_id.clone());
let h2 = registry.register(input_id.clone());
assert!(registry.resolve_without_result(&input_id));
for handle in [h1, h2] {
match handle.wait().await {
CompletionOutcome::CompletedWithoutResult => {}
other => panic!("Expected CompletedWithoutResult, got {other:?}"),
}
}
}
#[tokio::test]
async fn already_resolved_handle() {
let handle = CompletionHandle::already_resolved(CompletionOutcome::CompletedWithoutResult);
match handle.wait().await {
CompletionOutcome::CompletedWithoutResult => {}
other => panic!("Expected CompletedWithoutResult, got {other:?}"),
}
}
#[tokio::test]
async fn multi_waiter_terminated_on_reset() {
let mut registry = CompletionRegistry::new();
let input_id = InputId::new();
let h1 = registry.register(input_id.clone());
let h2 = registry.register(input_id);
registry.resolve_all_terminated("runtime reset");
for handle in [h1, h2] {
match handle.wait().await {
CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime reset"),
other => panic!("Expected RuntimeTerminated, got {other:?}"),
}
}
}
#[tokio::test]
async fn resolve_without_result_nonexistent_returns_false() {
let mut registry = CompletionRegistry::new();
assert!(!registry.resolve_without_result(&InputId::new()));
}
}