use std::{
any::TypeId,
collections::{HashMap, HashSet},
sync::Arc,
time::{Duration, Instant},
};
use serde_json::Value;
use tokio::{
sync::Mutex,
task::{JoinHandle, JoinSet},
};
use tracing::{error, info, warn};
use crate::{Context, Error, Job, Output, State};
enum Mode<S: State> {
Init { job: Job<S>, state: S },
Running(Instant, JoinHandle<Output>),
Done(Output),
}
#[derive(Clone)]
pub struct Worker<S: State> {
out: Arc<Mutex<HashMap<&'static str, NodeState>>>,
mode: Arc<Mutex<Option<Mode<S>>>>, }
impl<S: State> Worker<S> {
pub fn new(job: Job<S>, state: S) -> Self {
Self {
out: Arc::new(Mutex::new(HashMap::new())),
mode: Arc::new(Mutex::new(Some(Mode::Init { job, state }))),
}
}
#[allow(clippy::missing_panics_doc)]
pub async fn run(&mut self) -> Result<(), &'static str> {
let mut mode = self.mode.lock().await;
let Mode::Init { job, state } = std::mem::take(&mut *mode).unwrap() else {
return Err("Has already been started");
};
let t0 = Instant::now();
let fut = run_job(job, state, self.out.clone(), t0);
let handle = tokio::spawn(fut);
*mode = Some(Mode::Running(t0, handle));
Ok(())
}
#[allow(clippy::missing_panics_doc)]
pub async fn stop(&mut self) {
let mut mode = self.mode.lock().await;
let Mode::Running(t0, _) = mode.as_ref().unwrap() else {
return;
};
let duration = t0.elapsed();
*mode = Some(Mode::Done(Output::Stopped { duration }));
}
#[allow(clippy::missing_panics_doc)]
pub async fn get_output(&mut self) -> Result<Output, &'static str> {
let mut mode = self.mode.lock().await;
match mode.as_ref().unwrap() {
Mode::Done(output) => return Ok(output.clone()),
Mode::Init { .. } => return Err("Not running"),
Mode::Running(_, _) => {}
}
let Mode::Running(_, handle) = std::mem::take(&mut *mode).unwrap() else {
unreachable!();
};
let output = handle.await.expect("run_job should not be able to panic");
*mode = Some(Mode::Done(output.clone()));
Ok(output)
}
pub async fn data(&self) -> HashMap<String, Value> {
let mut data = HashMap::new();
for (&name, state) in self.out.lock().await.iter() {
if let NodeState::Provided { value } | NodeState::Done { value, .. } = state {
data.insert(name.to_string(), value.clone());
}
}
data
}
pub async fn status(&self) -> HashMap<&'static str, NodeState> {
self.out.lock().await.clone()
}
}
#[derive(Debug, Clone)]
pub enum NodeState {
Provided { value: Value },
Running {
start: Duration,
},
Done {
duration: Duration,
retries: u32,
value: Value,
},
Retrying {
start: Duration,
retries: u32,
},
Failed {
duration: Duration,
retries: u32,
error: Error,
},
}
#[allow(clippy::too_many_lines)] async fn run_job<S: State, T: ::std::hash::BuildHasher>(
job: Job<S>,
state: S,
out: Arc<Mutex<HashMap<&'static str, NodeState, T>>>,
t0: Instant,
) -> Output {
enum Node {
Done(TypeId, u32, Duration, Result<Value, Error>),
Retry(TypeId, u32),
}
let nodes = job.nodes;
let adj = job.adj;
let mut results = HashMap::new();
let mut handles = JoinSet::new();
let mut abort_handles = HashMap::new();
let mut pending: HashSet<TypeId> = nodes.keys().copied().collect();
let mut o = out.lock().await;
for (id, (name, data)) in job.provided {
info!(name, "Provided");
o.insert(
name,
NodeState::Provided {
value: data.clone(),
},
);
results.insert(id, data);
}
drop(o);
let ctx = |retry, start| Context {
retry,
start,
state: state.clone(),
};
loop {
let is_done = |i| results.contains_key(i);
let get_payloads = |id| adj[&id].iter().map(|id| results[id].clone()).collect();
let ready = pending.extract_if(|id| adj[id].iter().all(is_done));
for id in ready {
let payloads = get_payloads(id);
let node = &nodes[&id];
let producer = node.producer.clone();
let start = t0.elapsed();
let context = ctx(0, start);
let state = NodeState::Running { start };
out.lock().await.insert(node.name, state);
info!(name = node.name, "Node start");
let abort_handle = handles.spawn(async move {
let result = producer(context, payloads).await;
Node::Done(id, 0, t0.elapsed(), result)
});
abort_handles.insert(abort_handle.id(), id);
}
let result = handles.join_next().await;
let Some(result) = result else {
let duration = t0.elapsed();
info!(?duration, "Job done");
return Output::Done { duration };
};
let result = match result {
Ok(result) => result,
Err(e) => {
let duration = t0.elapsed();
let id = abort_handles[&e.id()];
let name = nodes[&id].name;
error!(name, "Node panicked");
return Output::NodePanic {
duration,
name,
error: format!("{e:?}"),
};
}
};
match result {
Node::Done(id, retry, time, Ok(payload)) => {
results.insert(id, payload.clone());
let name = nodes[&id].name;
let state = NodeState::Done {
duration: time,
retries: retry,
value: payload,
};
out.lock().await.insert(name, state);
info!(name, "Node done");
}
Node::Done(id, retry, time, Err(e)) => {
let name = nodes[&id].name;
if let Some(retry_in) = e.retry_in {
warn!(name, retry, error = e.message, ?retry_in, "Node failed");
handles.spawn(async move {
tokio::time::sleep(time + retry_in).await;
Node::Retry(id, retry)
});
} else {
let duration = t0.elapsed();
let msg = format!("Node {name} failed ({retry} retries): {}", e.message);
let state = NodeState::Failed {
duration: time,
retries: retry,
error: e,
};
out.lock().await.insert(name, state);
error!(name, "Node failed");
return Output::NodeFailed {
duration,
name,
error: msg,
};
}
}
Node::Retry(id, mut retry) => {
retry += 1;
let payloads = get_payloads(id);
let producer = nodes[&id].producer.clone();
let start = t0.elapsed();
let context = ctx(retry, start);
let name = nodes[&id].name;
let state = NodeState::Retrying {
start,
retries: retry,
};
out.lock().await.insert(name, state);
info!(name, retry, "Node retrying");
handles.spawn(async move {
let result = producer(context, payloads).await;
Node::Done(id, retry, t0.elapsed(), result)
});
}
}
}
}