use crate::store::Store;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::mpsc;
pub trait StreamWriterTrait: Send + Sync + 'static {
fn emit_custom(&self, node: &str, data: serde_json::Value);
}
impl StreamWriterTrait for mpsc::UnboundedSender<(String, serde_json::Value)> {
fn emit_custom(&self, node: &str, data: serde_json::Value) {
let _ = self.send((node.to_string(), data));
}
}
impl std::fmt::Debug for dyn StreamWriterTrait {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamWriterTrait").finish_non_exhaustive()
}
}
#[derive(Clone)]
pub struct Runtime<C: Clone + Send + Sync + 'static = ()> {
pub context: C,
pub store: Option<Arc<dyn Store>>,
pub heartbeat: Heartbeat,
pub previous: Option<serde_json::Value>,
pub execution_info: Option<ExecutionInfo>,
pub control: Option<RunControl>,
pub stream_writer: Option<Arc<dyn StreamWriterTrait>>,
}
impl<C: Clone + Send + Sync + 'static> std::fmt::Debug for Runtime<C>
where
C: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Runtime")
.field("context", &self.context)
.field("store", &self.store.as_ref().map(|_| "<Store>"))
.field("heartbeat", &self.heartbeat)
.field("previous", &self.previous)
.field("execution_info", &self.execution_info)
.field("control", &self.control)
.field("stream_writer", &self.stream_writer)
.finish()
}
}
impl<C: Clone + Send + Sync + 'static> Runtime<C> {
#[must_use]
pub fn new() -> Self
where
C: Default,
{
Self {
context: C::default(),
store: None,
heartbeat: Heartbeat::default(),
previous: None,
execution_info: None,
control: None,
stream_writer: None,
}
}
#[must_use]
pub fn with_context(context: C) -> Self {
Self {
context,
store: None,
heartbeat: Heartbeat::default(),
previous: None,
execution_info: None,
control: None,
stream_writer: None,
}
}
pub fn set_execution_info(&mut self, info: ExecutionInfo) {
self.execution_info = Some(info);
}
#[must_use]
pub fn managed_values(&self) -> ManagedValues {
let Some(info) = self.execution_info.as_ref() else {
return ManagedValues {
is_last_step: false,
remaining_steps: 25,
};
};
let remaining = info.recursion_limit.saturating_sub(info.step);
ManagedValues {
is_last_step: remaining <= 1,
remaining_steps: u32::try_from(remaining).unwrap_or(u32::MAX),
}
}
#[must_use]
pub const fn heartbeat(&self) -> &Heartbeat {
&self.heartbeat
}
}
impl Default for Runtime<()>
where
(): std::fmt::Debug,
{
fn default() -> Self {
Self::new()
}
}
pub struct Heartbeat {
tx: tokio::sync::mpsc::UnboundedSender<()>,
_rx: Option<tokio::sync::mpsc::UnboundedReceiver<()>>,
}
impl Clone for Heartbeat {
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
_rx: None,
}
}
}
impl std::fmt::Debug for Heartbeat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Heartbeat")
.field("tx", &"<UnboundedSender>")
.finish()
}
}
impl Heartbeat {
#[must_use]
pub const fn new(tx: tokio::sync::mpsc::UnboundedSender<()>) -> Self {
Self { tx, _rx: None }
}
#[must_use]
pub fn new_pair() -> (Self, HeartbeatWatcher) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let watcher = HeartbeatWatcher::new(rx);
(Self { tx, _rx: None }, watcher)
}
pub fn ping(&self) -> Result<(), tokio::sync::mpsc::error::SendError<()>> {
self.tx.send(())
}
}
impl Default for Heartbeat {
fn default() -> Self {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
Self { tx, _rx: Some(rx) }
}
}
pub struct HeartbeatWatcher {
rx: tokio::sync::mpsc::UnboundedReceiver<()>,
last_beat: crate::time::Instant,
}
impl std::fmt::Debug for HeartbeatWatcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HeartbeatWatcher")
.field("last_beat", &self.last_beat)
.finish_non_exhaustive()
}
}
impl HeartbeatWatcher {
#[must_use]
pub fn new(rx: tokio::sync::mpsc::UnboundedReceiver<()>) -> Self {
Self {
rx,
last_beat: crate::time::Instant::now(),
}
}
#[must_use]
pub fn is_alive(&mut self, idle_timeout: Duration) -> bool {
while self.rx.try_recv().is_ok() {
self.last_beat = crate::time::Instant::now();
}
self.last_beat.elapsed() < idle_timeout
}
}
#[derive(Clone, Debug)]
pub struct ExecutionInfo {
pub checkpoint_id: String,
pub checkpoint_ns: String,
pub task_id: String,
pub step: usize,
pub recursion_limit: usize,
pub thread_id: Option<String>,
pub run_id: Option<String>,
pub node_attempt: u32,
pub node_first_attempt_time: Option<f64>,
}
#[derive(Clone, Copy, Debug)]
pub struct ManagedValues {
pub is_last_step: bool,
pub remaining_steps: u32,
}
#[derive(Debug)]
pub struct RunControl {
drain_reason: Arc<Mutex<Option<String>>>,
}
impl Clone for RunControl {
fn clone(&self) -> Self {
Self {
drain_reason: Arc::clone(&self.drain_reason),
}
}
}
impl RunControl {
#[must_use]
pub fn new() -> Self {
Self {
drain_reason: Arc::new(Mutex::new(None)),
}
}
pub fn request_drain(&self, reason: &str) {
*self.drain_reason.lock().unwrap() = Some(reason.to_string());
}
#[must_use]
pub fn drain_requested(&self) -> bool {
self.drain_reason.lock().unwrap().is_some()
}
#[must_use]
pub fn drain_reason(&self) -> Option<String> {
self.drain_reason.lock().unwrap().clone()
}
}
impl Default for RunControl {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_managed_values_no_execution_info() {
let runtime = Runtime::<()>::new();
let values = runtime.managed_values();
assert!(!values.is_last_step, "default should not be last step");
assert_eq!(
values.remaining_steps, 25,
"default remaining steps should be 25"
);
}
#[test]
fn test_managed_values_early_step() {
let mut runtime = Runtime::<()>::new();
runtime.set_execution_info(ExecutionInfo {
checkpoint_id: "cp-1".to_string(),
checkpoint_ns: "default".to_string(),
task_id: "task-1".to_string(),
step: 3,
recursion_limit: 25,
thread_id: None,
run_id: None,
node_attempt: 1,
node_first_attempt_time: None,
});
let values = runtime.managed_values();
assert!(!values.is_last_step, "early step should not be last step");
assert_eq!(values.remaining_steps, 22, "remaining: 25 - 3 = 22");
}
#[test]
fn test_managed_values_last_step() {
let mut runtime = Runtime::<()>::new();
runtime.set_execution_info(ExecutionInfo {
checkpoint_id: "cp-1".to_string(),
checkpoint_ns: "default".to_string(),
task_id: "task-1".to_string(),
step: 24,
recursion_limit: 25,
thread_id: None,
run_id: None,
node_attempt: 1,
node_first_attempt_time: None,
});
let values = runtime.managed_values();
assert!(values.is_last_step, "step 24 of 25 should be last step");
assert_eq!(values.remaining_steps, 1, "remaining: 25 - 24 = 1");
}
#[test]
fn test_managed_values_past_recursion_limit() {
let mut runtime = Runtime::<()>::new();
runtime.set_execution_info(ExecutionInfo {
checkpoint_id: "cp-1".to_string(),
checkpoint_ns: "default".to_string(),
task_id: "task-1".to_string(),
step: 25,
recursion_limit: 25,
thread_id: None,
run_id: None,
node_attempt: 1,
node_first_attempt_time: None,
});
let values = runtime.managed_values();
assert!(
values.is_last_step,
"step >= recursion_limit should be last step"
);
assert_eq!(
values.remaining_steps, 0,
"no remaining steps when at limit"
);
}
#[test]
fn test_managed_values_custom_recursion_limit() {
let mut runtime = Runtime::<()>::new();
runtime.set_execution_info(ExecutionInfo {
checkpoint_id: "cp-1".to_string(),
checkpoint_ns: "default".to_string(),
task_id: "task-1".to_string(),
step: 8,
recursion_limit: 10,
thread_id: None,
run_id: None,
node_attempt: 1,
node_first_attempt_time: None,
});
let values = runtime.managed_values();
assert!(!values.is_last_step, "step 8 of 10 should not be last step");
assert_eq!(values.remaining_steps, 2, "remaining: 10 - 8 = 2");
}
#[test]
fn test_managed_values_exact_countdown() {
let mut runtime = Runtime::<()>::new();
runtime.set_execution_info(ExecutionInfo {
checkpoint_id: "cp-1".to_string(),
checkpoint_ns: "default".to_string(),
task_id: "task-1".to_string(),
step: 9,
recursion_limit: 10,
thread_id: None,
run_id: None,
node_attempt: 1,
node_first_attempt_time: None,
});
let values = runtime.managed_values();
assert!(values.is_last_step, "step 9 of 10 should be last step");
assert_eq!(values.remaining_steps, 1, "remaining: 10 - 9 = 1");
}
}