use crate::{
protos::temporal::api::common::v1::WorkflowExecution,
protos::temporal::api::history::v1::History,
workflow::{NextWfActivation, Result, WorkflowError, WorkflowManager},
};
use crossbeam::channel::{bounded, unbounded, Receiver, Select, Sender, TryRecvError};
use dashmap::DashMap;
use parking_lot::Mutex;
use std::{
fmt::Debug,
thread::{self, JoinHandle},
};
use tracing::Span;
pub(crate) struct WorkflowConcurrencyManager {
machines: DashMap<String, MachineMutationSender>,
wf_thread: Mutex<Option<JoinHandle<()>>>,
machine_creator: Sender<MachineCreatorMsg>,
shutdown_chan: Sender<bool>,
}
type MachineMutationSender = Sender<Box<dyn FnOnce(&mut WorkflowManager) + Send>>;
type MachineMutationReceiver = Receiver<Box<dyn FnOnce(&mut WorkflowManager) + Send>>;
struct MachineCreatorMsg {
history: History,
workflow_execution: WorkflowExecution,
resp_chan: Sender<MachineCreatorResponseMsg>,
span: Span,
}
type MachineCreatorResponseMsg = Result<(NextWfActivation, MachineMutationSender)>;
impl WorkflowConcurrencyManager {
pub fn new() -> Self {
let (machine_creator, create_rcv) = unbounded::<MachineCreatorMsg>();
let (shutdown_chan, shutdown_rx) = bounded(1);
let wf_thread = thread::spawn(move || {
WorkflowConcurrencyManager::workflow_thread(create_rcv, shutdown_rx)
});
Self {
machines: Default::default(),
wf_thread: Mutex::new(Some(wf_thread)),
machine_creator,
shutdown_chan,
}
}
pub fn exists(&self, run_id: &str) -> bool {
self.machines.contains_key(run_id)
}
pub fn create_or_update(
&self,
run_id: &str,
history: History,
workflow_execution: WorkflowExecution,
) -> Result<Option<NextWfActivation>> {
let span = debug_span!("create_or_update machines", %run_id);
if self.exists(run_id) {
let activation = self.access(run_id, move |wfm: &mut WorkflowManager| {
let _enter = span.enter();
wfm.feed_history_from_server(history)
})?;
Ok(activation)
} else {
let (resp_send, resp_rcv) = bounded(1);
self.machine_creator
.send(MachineCreatorMsg {
history,
workflow_execution,
resp_chan: resp_send,
span,
})
.expect("wfm creation channel can't be dropped if we are inside this method");
let (activation, machine_sender) = resp_rcv
.recv()
.expect("wfm create resp channel can't be dropped, it is in this stackframe")?;
self.machines.insert(run_id.to_string(), machine_sender);
Ok(Some(activation))
}
}
pub fn access<F, Fout>(&self, run_id: &str, mutator: F) -> Result<Fout>
where
F: FnOnce(&mut WorkflowManager) -> Result<Fout> + Send + 'static,
Fout: Send + Debug + 'static,
{
let m = self
.machines
.get(run_id)
.ok_or_else(|| WorkflowError::MissingMachine {
run_id: run_id.to_string(),
})?;
let (resp_tx, resp_rx) = bounded(1);
let f = move |x: &mut WorkflowManager| {
let _ = resp_tx.send(mutator(x));
};
m.send(Box::new(f))
.expect("wfm mutation send can't fail, if it does a wfm is missing from their thread");
resp_rx
.recv()
.expect("wfm access resp channel can't be dropped, it is in this stackframe")
}
pub fn shutdown(&self) {
let mut wf_thread = self.wf_thread.lock();
if wf_thread.is_none() {
return;
}
let _ = self.shutdown_chan.send(true);
wf_thread
.take()
.unwrap()
.join()
.expect("Workflow manager thread should shut down cleanly");
}
pub fn evict(&self, run_id: &str) {
self.machines.remove(run_id);
}
fn workflow_thread(create_rcv: Receiver<MachineCreatorMsg>, shutdown_rx: Receiver<bool>) {
let mut machine_rcvs: Vec<(MachineMutationReceiver, WorkflowManager)> = vec![];
loop {
let mut sel = Select::new();
sel.recv(&shutdown_rx);
sel.recv(&create_rcv);
for (rcv, _) in machine_rcvs.iter() {
sel.recv(rcv);
}
let index = sel.ready();
if index == 0 {
break;
} else if index == 1 {
let maybe_create_chan_msg = create_rcv.try_recv();
let should_break = WorkflowConcurrencyManager::handle_creation_message(
&mut machine_rcvs,
maybe_create_chan_msg,
);
if should_break {
break;
}
} else {
WorkflowConcurrencyManager::handle_access_msg(index - 2, &mut machine_rcvs)
}
}
}
fn handle_creation_message(
machine_rcvs: &mut Vec<(MachineMutationReceiver, WorkflowManager)>,
maybe_create_chan_msg: Result<MachineCreatorMsg, TryRecvError>,
) -> bool {
match maybe_create_chan_msg {
Ok(MachineCreatorMsg {
history,
workflow_execution,
resp_chan,
span,
}) => {
let _e = span.enter();
let send_this = match WorkflowManager::new(history, workflow_execution)
.map_err(Into::into)
.and_then(|mut wfm| Ok((wfm.get_next_activation()?, wfm)))
{
Ok((Some(activation), wfm)) => {
let (machine_sender, machine_rcv) = unbounded();
machine_rcvs.push((machine_rcv, wfm));
Ok((activation, machine_sender))
}
Ok((None, wfm)) => Err(WorkflowError::MachineWasCreatedWithNoActivations {
run_id: wfm.machines.run_id,
}),
Err(e) => Err(e),
};
resp_chan
.send(send_this)
.expect("wfm create resp rx side can't be dropped");
}
Err(TryRecvError::Disconnected) => {
warn!(
"Sending side of workflow machine creator was dropped. Likely the \
WorkflowConcurrencyManager was dropped. This indicates a failure to call \
shutdown."
);
return true;
}
Err(TryRecvError::Empty) => {}
}
false
}
fn handle_access_msg(
index: usize,
machine_rcvs: &mut Vec<(MachineMutationReceiver, WorkflowManager)>,
) {
match machine_rcvs[index].0.try_recv() {
Ok(func) => {
func(&mut machine_rcvs[index].1);
}
Err(TryRecvError::Disconnected) => {
let wfid = &machine_rcvs[index].1.machines.workflow_id;
debug!(wfid = %wfid, "Workflow manager thread done",);
machine_rcvs.remove(index);
}
Err(TryRecvError::Empty) => {}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::machines::test_help::TestHistoryBuilder;
use crate::protos::temporal::api::common::v1::WorkflowExecution;
use crate::protos::temporal::api::enums::v1::EventType;
use crate::protos::temporal::api::history::v1::History;
#[test]
fn can_shutdown_after_creating_machine() {
let mgr = WorkflowConcurrencyManager::new();
let mut t = TestHistoryBuilder::default();
t.add_by_type(EventType::WorkflowExecutionStarted);
t.add_full_wf_task();
let activation = mgr
.create_or_update(
"some_run_id",
History {
events: t.get_history_info(1).unwrap().events().to_vec(),
},
WorkflowExecution {
workflow_id: "wid".to_string(),
run_id: "rid".to_string(),
},
)
.unwrap();
assert!(activation.is_some());
mgr.shutdown();
}
#[test]
fn returns_errors_on_creation() {
let mgr = WorkflowConcurrencyManager::new();
let res = mgr.create_or_update("some_run_id", Default::default(), Default::default());
assert_matches!(res.unwrap_err(), WorkflowError::HistoryError(_))
}
}