1use crate::store::Store;
6use std::sync::{Arc, Mutex};
7use std::time::Duration;
8use tokio::sync::mpsc;
9
10pub trait StreamWriterTrait: Send + Sync + 'static {
17 fn emit_custom(&self, node: &str, data: serde_json::Value);
19}
20
21impl StreamWriterTrait for mpsc::UnboundedSender<(String, serde_json::Value)> {
22 fn emit_custom(&self, node: &str, data: serde_json::Value) {
23 let _ = self.send((node.to_string(), data));
24 }
25}
26
27impl std::fmt::Debug for dyn StreamWriterTrait {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("StreamWriterTrait").finish_non_exhaustive()
30 }
31}
32
33#[derive(Clone)]
57pub struct Runtime<C: Clone + Send + Sync + 'static = ()> {
58 pub context: C,
60
61 pub store: Option<Arc<dyn Store>>,
63
64 pub heartbeat: Heartbeat,
66
67 pub previous: Option<serde_json::Value>,
69
70 pub execution_info: Option<ExecutionInfo>,
72
73 pub control: Option<RunControl>,
75
76 pub stream_writer: Option<Arc<dyn StreamWriterTrait>>,
83}
84
85impl<C: Clone + Send + Sync + 'static> std::fmt::Debug for Runtime<C>
86where
87 C: std::fmt::Debug,
88{
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("Runtime")
91 .field("context", &self.context)
92 .field("store", &self.store.as_ref().map(|_| "<Store>"))
93 .field("heartbeat", &self.heartbeat)
94 .field("previous", &self.previous)
95 .field("execution_info", &self.execution_info)
96 .field("control", &self.control)
97 .field("stream_writer", &self.stream_writer)
98 .finish()
99 }
100}
101
102impl<C: Clone + Send + Sync + 'static> Runtime<C> {
103 #[must_use]
105 pub fn new() -> Self
106 where
107 C: Default,
108 {
109 Self {
110 context: C::default(),
111 store: None,
112 heartbeat: Heartbeat::default(),
113 previous: None,
114 execution_info: None,
115 control: None,
116 stream_writer: None,
117 }
118 }
119
120 #[must_use]
122 pub fn with_context(context: C) -> Self {
123 Self {
124 context,
125 store: None,
126 heartbeat: Heartbeat::default(),
127 previous: None,
128 execution_info: None,
129 control: None,
130 stream_writer: None,
131 }
132 }
133
134 pub fn set_execution_info(&mut self, info: ExecutionInfo) {
139 self.execution_info = Some(info);
140 }
141
142 #[must_use]
148 pub fn managed_values(&self) -> ManagedValues {
149 let Some(info) = self.execution_info.as_ref() else {
150 return ManagedValues {
151 is_last_step: false,
152 remaining_steps: 25,
153 };
154 };
155
156 let remaining = info.recursion_limit.saturating_sub(info.step);
157
158 ManagedValues {
159 is_last_step: remaining <= 1,
160 remaining_steps: u32::try_from(remaining).unwrap_or(u32::MAX),
161 }
162 }
163
164 #[must_use]
169 pub const fn heartbeat(&self) -> &Heartbeat {
170 &self.heartbeat
171 }
172}
173
174impl Default for Runtime<()>
175where
176 (): std::fmt::Debug,
177{
178 fn default() -> Self {
179 Self::new()
180 }
181}
182
183pub struct Heartbeat {
201 tx: tokio::sync::mpsc::UnboundedSender<()>,
202 _rx: Option<tokio::sync::mpsc::UnboundedReceiver<()>>,
206}
207
208impl Clone for Heartbeat {
209 fn clone(&self) -> Self {
210 Self {
211 tx: self.tx.clone(),
212 _rx: None,
216 }
217 }
218}
219
220impl std::fmt::Debug for Heartbeat {
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 f.debug_struct("Heartbeat")
223 .field("tx", &"<UnboundedSender>")
224 .finish()
225 }
226}
227
228impl Heartbeat {
229 #[must_use]
231 pub const fn new(tx: tokio::sync::mpsc::UnboundedSender<()>) -> Self {
232 Self { tx, _rx: None }
233 }
234
235 #[must_use]
241 pub fn new_pair() -> (Self, HeartbeatWatcher) {
242 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
243 let watcher = HeartbeatWatcher::new(rx);
244 (Self { tx, _rx: None }, watcher)
245 }
246
247 pub fn ping(&self) -> Result<(), tokio::sync::mpsc::error::SendError<()>> {
253 self.tx.send(())
254 }
255}
256
257impl Default for Heartbeat {
258 fn default() -> Self {
259 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
260 Self { tx, _rx: Some(rx) }
261 }
262}
263
264pub struct HeartbeatWatcher {
286 rx: tokio::sync::mpsc::UnboundedReceiver<()>,
287 last_beat: crate::time::Instant,
288}
289
290impl std::fmt::Debug for HeartbeatWatcher {
291 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292 f.debug_struct("HeartbeatWatcher")
293 .field("last_beat", &self.last_beat)
294 .finish_non_exhaustive()
295 }
296}
297
298impl HeartbeatWatcher {
299 #[must_use]
301 pub fn new(rx: tokio::sync::mpsc::UnboundedReceiver<()>) -> Self {
302 Self {
303 rx,
304 last_beat: crate::time::Instant::now(),
305 }
306 }
307
308 #[must_use]
317 pub fn is_alive(&mut self, idle_timeout: Duration) -> bool {
318 while self.rx.try_recv().is_ok() {
320 self.last_beat = crate::time::Instant::now();
321 }
322 self.last_beat.elapsed() < idle_timeout
323 }
324}
325
326#[derive(Clone, Debug)]
331pub struct ExecutionInfo {
332 pub checkpoint_id: String,
334
335 pub checkpoint_ns: String,
337
338 pub task_id: String,
340
341 pub step: usize,
343
344 pub recursion_limit: usize,
346
347 pub thread_id: Option<String>,
349
350 pub run_id: Option<String>,
352
353 pub node_attempt: u32,
355
356 pub node_first_attempt_time: Option<f64>,
358}
359
360#[derive(Clone, Copy, Debug)]
364pub struct ManagedValues {
365 pub is_last_step: bool,
367
368 pub remaining_steps: u32,
370}
371
372#[derive(Debug)]
377pub struct RunControl {
378 drain_reason: Arc<Mutex<Option<String>>>,
379}
380
381impl Clone for RunControl {
382 fn clone(&self) -> Self {
383 Self {
384 drain_reason: Arc::clone(&self.drain_reason),
385 }
386 }
387}
388
389impl RunControl {
390 #[must_use]
392 pub fn new() -> Self {
393 Self {
394 drain_reason: Arc::new(Mutex::new(None)),
395 }
396 }
397
398 pub fn request_drain(&self, reason: &str) {
408 *self.drain_reason.lock().unwrap() = Some(reason.to_string());
409 }
410
411 #[must_use]
417 pub fn drain_requested(&self) -> bool {
418 self.drain_reason.lock().unwrap().is_some()
419 }
420
421 #[must_use]
427 pub fn drain_reason(&self) -> Option<String> {
428 self.drain_reason.lock().unwrap().clone()
429 }
430}
431
432impl Default for RunControl {
433 fn default() -> Self {
434 Self::new()
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_default_managed_values_no_execution_info() {
444 let runtime = Runtime::<()>::new();
446 let values = runtime.managed_values();
447 assert!(!values.is_last_step, "default should not be last step");
448 assert_eq!(
449 values.remaining_steps, 25,
450 "default remaining steps should be 25"
451 );
452 }
453
454 #[test]
455 fn test_managed_values_early_step() {
456 let mut runtime = Runtime::<()>::new();
458 runtime.set_execution_info(ExecutionInfo {
459 checkpoint_id: "cp-1".to_string(),
460 checkpoint_ns: "default".to_string(),
461 task_id: "task-1".to_string(),
462 step: 3,
463 recursion_limit: 25,
464 thread_id: None,
465 run_id: None,
466 node_attempt: 1,
467 node_first_attempt_time: None,
468 });
469 let values = runtime.managed_values();
470 assert!(!values.is_last_step, "early step should not be last step");
471 assert_eq!(values.remaining_steps, 22, "remaining: 25 - 3 = 22");
472 }
473
474 #[test]
475 fn test_managed_values_last_step() {
476 let mut runtime = Runtime::<()>::new();
478 runtime.set_execution_info(ExecutionInfo {
479 checkpoint_id: "cp-1".to_string(),
480 checkpoint_ns: "default".to_string(),
481 task_id: "task-1".to_string(),
482 step: 24,
483 recursion_limit: 25,
484 thread_id: None,
485 run_id: None,
486 node_attempt: 1,
487 node_first_attempt_time: None,
488 });
489 let values = runtime.managed_values();
490 assert!(values.is_last_step, "step 24 of 25 should be last step");
491 assert_eq!(values.remaining_steps, 1, "remaining: 25 - 24 = 1");
492 }
493
494 #[test]
495 fn test_managed_values_past_recursion_limit() {
496 let mut runtime = Runtime::<()>::new();
498 runtime.set_execution_info(ExecutionInfo {
499 checkpoint_id: "cp-1".to_string(),
500 checkpoint_ns: "default".to_string(),
501 task_id: "task-1".to_string(),
502 step: 25,
503 recursion_limit: 25,
504 thread_id: None,
505 run_id: None,
506 node_attempt: 1,
507 node_first_attempt_time: None,
508 });
509 let values = runtime.managed_values();
510 assert!(
511 values.is_last_step,
512 "step >= recursion_limit should be last step"
513 );
514 assert_eq!(
515 values.remaining_steps, 0,
516 "no remaining steps when at limit"
517 );
518 }
519
520 #[test]
521 fn test_managed_values_custom_recursion_limit() {
522 let mut runtime = Runtime::<()>::new();
524 runtime.set_execution_info(ExecutionInfo {
525 checkpoint_id: "cp-1".to_string(),
526 checkpoint_ns: "default".to_string(),
527 task_id: "task-1".to_string(),
528 step: 8,
529 recursion_limit: 10,
530 thread_id: None,
531 run_id: None,
532 node_attempt: 1,
533 node_first_attempt_time: None,
534 });
535 let values = runtime.managed_values();
536 assert!(!values.is_last_step, "step 8 of 10 should not be last step");
537 assert_eq!(values.remaining_steps, 2, "remaining: 10 - 8 = 2");
538 }
539
540 #[test]
541 fn test_managed_values_exact_countdown() {
542 let mut runtime = Runtime::<()>::new();
544 runtime.set_execution_info(ExecutionInfo {
545 checkpoint_id: "cp-1".to_string(),
546 checkpoint_ns: "default".to_string(),
547 task_id: "task-1".to_string(),
548 step: 9,
549 recursion_limit: 10,
550 thread_id: None,
551 run_id: None,
552 node_attempt: 1,
553 node_first_attempt_time: None,
554 });
555 let values = runtime.managed_values();
556 assert!(values.is_last_step, "step 9 of 10 should be last step");
557 assert_eq!(values.remaining_steps, 1, "remaining: 10 - 9 = 1");
558 }
559}
560
561