meerkat_runtime/
completion.rs1use std::collections::HashMap;
9
10use meerkat_core::lifecycle::InputId;
11use meerkat_core::types::RunResult;
12
13use crate::tokio::sync::oneshot;
14
15#[derive(Debug)]
17pub enum CompletionOutcome {
18 Completed(RunResult),
20 CompletedWithoutResult,
22 Abandoned(String),
24 RuntimeTerminated(String),
26}
27
28pub struct CompletionHandle {
30 rx: oneshot::Receiver<CompletionOutcome>,
31}
32
33impl CompletionHandle {
34 pub async fn wait(self) -> CompletionOutcome {
36 match self.rx.await {
37 Ok(outcome) => outcome,
38 Err(_) => CompletionOutcome::RuntimeTerminated(
40 "completion channel closed without result".into(),
41 ),
42 }
43 }
44
45 pub fn already_resolved(outcome: CompletionOutcome) -> Self {
50 let (tx, rx) = oneshot::channel();
51 let _ = tx.send(outcome);
52 Self { rx }
53 }
54}
55
56#[derive(Default)]
61pub struct CompletionRegistry {
62 waiters: HashMap<InputId, Vec<oneshot::Sender<CompletionOutcome>>>,
63}
64
65impl CompletionRegistry {
66 pub fn new() -> Self {
67 Self::default()
68 }
69
70 pub fn register(&mut self, input_id: InputId) -> CompletionHandle {
75 let (tx, rx) = oneshot::channel();
76 self.waiters.entry(input_id).or_default().push(tx);
77 CompletionHandle { rx }
78 }
79
80 pub fn resolve_completed(&mut self, input_id: &InputId, result: RunResult) -> bool {
84 if let Some(senders) = self.waiters.remove(input_id) {
85 for tx in senders {
86 let _ = tx.send(CompletionOutcome::Completed(result.clone()));
87 }
88 true
89 } else {
90 false
91 }
92 }
93
94 pub fn resolve_without_result(&mut self, input_id: &InputId) -> bool {
98 if let Some(senders) = self.waiters.remove(input_id) {
99 for tx in senders {
100 let _ = tx.send(CompletionOutcome::CompletedWithoutResult);
101 }
102 true
103 } else {
104 false
105 }
106 }
107
108 pub fn resolve_abandoned(&mut self, input_id: &InputId, reason: String) -> bool {
112 if let Some(senders) = self.waiters.remove(input_id) {
113 for tx in senders {
114 let _ = tx.send(CompletionOutcome::Abandoned(reason.clone()));
115 }
116 true
117 } else {
118 false
119 }
120 }
121
122 pub fn resolve_all_terminated(&mut self, reason: &str) {
126 for (_, senders) in self.waiters.drain() {
127 for tx in senders {
128 let _ = tx.send(CompletionOutcome::RuntimeTerminated(reason.into()));
129 }
130 }
131 }
132
133 pub fn has_pending(&self) -> bool {
135 !self.waiters.is_empty()
136 }
137
138 pub fn pending_count(&self) -> usize {
140 self.waiters.values().map(Vec::len).sum()
141 }
142}
143
144#[cfg(test)]
145#[allow(clippy::unwrap_used, clippy::panic)]
146mod tests {
147 use super::*;
148 use meerkat_core::types::{SessionId, Usage};
149
150 fn make_run_result() -> RunResult {
151 RunResult {
152 text: "hello".into(),
153 session_id: SessionId::new(),
154 usage: Usage::default(),
155 turns: 1,
156 tool_calls: 0,
157 structured_output: None,
158 schema_warnings: None,
159 skill_diagnostics: None,
160 }
161 }
162
163 #[tokio::test]
164 async fn register_and_complete() {
165 let mut registry = CompletionRegistry::new();
166 let input_id = InputId::new();
167 let handle = registry.register(input_id.clone());
168
169 assert!(registry.has_pending());
170 assert_eq!(registry.pending_count(), 1);
171
172 let result = make_run_result();
173 assert!(registry.resolve_completed(&input_id, result));
174
175 match handle.wait().await {
176 CompletionOutcome::Completed(r) => assert_eq!(r.text, "hello"),
177 other => panic!("Expected Completed, got {other:?}"),
178 }
179 }
180
181 #[tokio::test]
182 async fn register_and_abandon() {
183 let mut registry = CompletionRegistry::new();
184 let input_id = InputId::new();
185 let handle = registry.register(input_id.clone());
186
187 assert!(registry.resolve_abandoned(&input_id, "retired".into()));
188
189 match handle.wait().await {
190 CompletionOutcome::Abandoned(reason) => assert_eq!(reason, "retired"),
191 other => panic!("Expected Abandoned, got {other:?}"),
192 }
193 }
194
195 #[tokio::test]
196 async fn resolve_all_terminated() {
197 let mut registry = CompletionRegistry::new();
198 let h1 = registry.register(InputId::new());
199 let h2 = registry.register(InputId::new());
200
201 registry.resolve_all_terminated("runtime stopped");
202
203 assert!(!registry.has_pending());
204
205 match h1.wait().await {
206 CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime stopped"),
207 other => panic!("Expected RuntimeTerminated, got {other:?}"),
208 }
209 match h2.wait().await {
210 CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime stopped"),
211 other => panic!("Expected RuntimeTerminated, got {other:?}"),
212 }
213 }
214
215 #[tokio::test]
216 async fn resolve_nonexistent_returns_false() {
217 let mut registry = CompletionRegistry::new();
218 assert!(!registry.resolve_completed(&InputId::new(), make_run_result()));
219 assert!(!registry.resolve_abandoned(&InputId::new(), "gone".into()));
220 }
221
222 #[tokio::test]
223 async fn dropped_sender_gives_terminated() {
224 let mut registry = CompletionRegistry::new();
225 let input_id = InputId::new();
226 let handle = registry.register(input_id);
227
228 drop(registry);
230
231 match handle.wait().await {
232 CompletionOutcome::RuntimeTerminated(_) => {}
233 other => panic!("Expected RuntimeTerminated, got {other:?}"),
234 }
235 }
236
237 #[tokio::test]
238 async fn multi_waiter_all_receive_result() {
239 let mut registry = CompletionRegistry::new();
240 let input_id = InputId::new();
241
242 let h1 = registry.register(input_id.clone());
243 let h2 = registry.register(input_id.clone());
244 let h3 = registry.register(input_id.clone());
245
246 assert_eq!(registry.pending_count(), 3);
247
248 let result = make_run_result();
249 assert!(registry.resolve_completed(&input_id, result));
250
251 assert!(!registry.has_pending());
252
253 for handle in [h1, h2, h3] {
254 match handle.wait().await {
255 CompletionOutcome::Completed(r) => assert_eq!(r.text, "hello"),
256 other => panic!("Expected Completed, got {other:?}"),
257 }
258 }
259 }
260
261 #[tokio::test]
262 async fn resolve_without_result_sends_variant() {
263 let mut registry = CompletionRegistry::new();
264 let input_id = InputId::new();
265 let handle = registry.register(input_id.clone());
266
267 assert!(registry.resolve_without_result(&input_id));
268
269 match handle.wait().await {
270 CompletionOutcome::CompletedWithoutResult => {}
271 other => panic!("Expected CompletedWithoutResult, got {other:?}"),
272 }
273 }
274
275 #[tokio::test]
276 async fn resolve_without_result_multi_waiter() {
277 let mut registry = CompletionRegistry::new();
278 let input_id = InputId::new();
279 let h1 = registry.register(input_id.clone());
280 let h2 = registry.register(input_id.clone());
281
282 assert!(registry.resolve_without_result(&input_id));
283
284 for handle in [h1, h2] {
285 match handle.wait().await {
286 CompletionOutcome::CompletedWithoutResult => {}
287 other => panic!("Expected CompletedWithoutResult, got {other:?}"),
288 }
289 }
290 }
291
292 #[tokio::test]
293 async fn already_resolved_handle() {
294 let handle = CompletionHandle::already_resolved(CompletionOutcome::CompletedWithoutResult);
295 match handle.wait().await {
296 CompletionOutcome::CompletedWithoutResult => {}
297 other => panic!("Expected CompletedWithoutResult, got {other:?}"),
298 }
299 }
300
301 #[tokio::test]
302 async fn multi_waiter_terminated_on_reset() {
303 let mut registry = CompletionRegistry::new();
304 let input_id = InputId::new();
305 let h1 = registry.register(input_id.clone());
306 let h2 = registry.register(input_id);
307
308 registry.resolve_all_terminated("runtime reset");
309
310 for handle in [h1, h2] {
311 match handle.wait().await {
312 CompletionOutcome::RuntimeTerminated(r) => assert_eq!(r, "runtime reset"),
313 other => panic!("Expected RuntimeTerminated, got {other:?}"),
314 }
315 }
316 }
317
318 #[tokio::test]
319 async fn resolve_without_result_nonexistent_returns_false() {
320 let mut registry = CompletionRegistry::new();
321 assert!(!registry.resolve_without_result(&InputId::new()));
322 }
323}