use std::{
sync::Arc,
time::{Duration, Instant},
};
use anyhow::Result;
use dashmap::{mapref::entry::Entry, DashMap};
use futures::{stream::BoxStream, Stream, StreamExt};
use serde::{Deserialize, Serialize};
use tokio::{select, task::JoinHandle, try_join};
use tracing::{debug_span, error, instrument, trace, warn, Instrument};
use self::dynamic_channel::{DynamicChannel, DynamicChannelFactory};
use crate::{
acker::{Acker, ComposedAcker},
channel::{
coordinated_channel::coordinated_channel, Channel, ChannelFactory, ChannelType, LeaseGuard,
},
common::get_random_routing_key,
config::Config,
operation::{marker::Marker, FatalStrategy, Operation},
queue::{Publisher, PublisherExt},
serializer::{Serializable, Serializer},
task::{AnyTask, AnyTaskOutput, AnyTaskResult, Task, TaskResult},
};
type Receiver<'a, Item> = Box<dyn Stream<Item = (Item, Box<dyn Acker>)> + Send + Unpin + 'a>;
type Sender<'a, Item> = Box<dyn Publisher<Item> + Send + Unpin + Sync + 'a>;
type CoordinatedTaskChannel<'a, Op, Metadata> = (
String,
Sender<'a, Task<'a, Op, Metadata>>,
LeaseGuard<DynamicChannel, Receiver<'a, TaskResult<Op, Metadata>>>,
);
pub struct Runtime {
channel_factory: DynamicChannelFactory,
task_channel: DynamicChannel,
serializer: Serializer,
worker_emulator: Option<Vec<JoinHandle<Result<()>>>>,
_marker: Marker,
}
const IPC_ROUTING_KEY: &str = "ipc-routing-key";
pub const DEFAULT_ROUTING_KEY: &str = "default";
impl Runtime {
pub async fn from_config(config: &Config, marker: Marker) -> Result<Self> {
let channel_factory = DynamicChannelFactory::from_config(config).await?;
let task_channel = channel_factory
.get(
config
.task_bus_routing_key
.clone()
.unwrap_or_else(|| DEFAULT_ROUTING_KEY.to_string()),
ChannelType::ExactlyOnce,
)
.await?;
let serializer = Serializer::from(config);
let worker_emulator = match config.runtime {
crate::config::Runtime::InMemory => Some(Self::spawn_emulator(
channel_factory.clone(),
task_channel.clone(),
config.num_workers.unwrap_or(3),
)),
_ => None,
};
Ok(Self {
channel_factory,
task_channel,
serializer,
worker_emulator,
_marker: marker,
})
}
pub async fn in_memory() -> Result<Self> {
let config = Config {
runtime: crate::config::Runtime::InMemory,
..Default::default()
};
Self::from_config(&config, Marker).await
}
fn spawn_emulator(
channel_factory: DynamicChannelFactory,
task_channel: DynamicChannel,
num_threads: usize,
) -> Vec<JoinHandle<Result<()>>> {
(0..num_threads)
.map(|_| {
let channel_factory = channel_factory.clone();
let task_channel = task_channel.clone();
tokio::spawn(async move {
let worker_runtime = WorkerRuntime {
channel_factory,
task_channel,
_marker: Marker,
};
worker_runtime.main_loop().await?;
Ok(())
})
})
.collect()
}
pub async fn close(&self) -> Result<()> {
self.task_channel.close().await
}
#[instrument(skip_all, level = "debug")]
async fn get_task_sender<'a, Op: Operation + 'a, Metadata: Serializable + 'a>(
&self,
) -> Result<Sender<'a, Task<'a, Op, Metadata>>> {
let sender = self.task_channel.sender::<AnyTask>().await?;
let serializer = self.serializer;
let transformed_sender =
sender.with(move |task: &Task<'_, Op, Metadata>| task.as_any_task(serializer));
Ok(Box::new(transformed_sender))
}
#[instrument(skip_all, level = "debug")]
pub async fn lease_coordinated_task_channel<
'a,
Op: Operation + 'a,
Metadata: Serializable + 'a,
>(
&self,
) -> Result<CoordinatedTaskChannel<'a, Op, Metadata>> {
let (task_sender, (result_channel_identifier, result_channel)) = try_join!(
self.get_task_sender(),
self.channel_factory.issue(ChannelType::ExactlyOnce)
)?;
let receiver = result_channel
.receiver::<AnyTaskResult>()
.await?
.map(move |(result, acker)| (result.into_task_result::<Op, Metadata>(), acker));
let (sender, receiver) = coordinated_channel(task_sender, receiver);
let ack_composed_receiver =
receiver.map(|((result, original_acker), coordinated_acker)| {
(
result,
Box::new(ComposedAcker::new(original_acker, coordinated_acker))
as Box<dyn Acker>,
)
});
Ok((
result_channel_identifier,
Box::new(sender),
LeaseGuard::new(result_channel, Box::new(ack_composed_receiver)),
))
}
}
impl Drop for Runtime {
fn drop(&mut self) {
if let Some(worker_emulator) = self.worker_emulator.take() {
for handle in worker_emulator {
handle.abort();
}
}
}
}
#[derive(Clone)]
pub struct WorkerRuntime {
channel_factory: DynamicChannelFactory,
task_channel: DynamicChannel,
_marker: Marker,
}
#[derive(Debug)]
pub struct ExecutionOk {
pub routing_key: String,
pub output: AnyTaskOutput,
}
#[derive(Debug)]
pub struct ExecutionErr<E> {
routing_key: String,
err: E,
strategy: FatalStrategy,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WorkerIpc {
ExecutionError { routing_key: String },
}
impl WorkerRuntime {
pub async fn from_config(config: &Config, marker: Marker) -> Result<Self> {
let channel_factory = DynamicChannelFactory::from_config(config).await?;
let task_channel = channel_factory
.get(
config
.task_bus_routing_key
.clone()
.unwrap_or_else(|| DEFAULT_ROUTING_KEY.to_string()),
ChannelType::ExactlyOnce,
)
.await?;
Ok(Self {
channel_factory,
task_channel,
_marker: marker,
})
}
#[instrument(skip(self), level = "trace")]
pub async fn get_result_sender(&self, identifier: String) -> Result<Sender<AnyTaskResult>> {
self.channel_factory
.get(identifier, ChannelType::ExactlyOnce)
.await?
.sender::<AnyTaskResult>()
.await
}
#[instrument(skip(self), level = "trace")]
pub async fn get_ipc_sender(&self) -> Result<Sender<WorkerIpc>> {
self.channel_factory
.get(IPC_ROUTING_KEY.to_string(), ChannelType::Broadcast)
.await?
.sender()
.await
}
#[instrument(skip(self), level = "trace")]
pub async fn get_ipc_receiver(&self) -> Result<BoxStream<'static, WorkerIpc>> {
let s = self
.channel_factory
.get(IPC_ROUTING_KEY.to_string(), ChannelType::Broadcast)
.await?
.receiver::<WorkerIpc>()
.await?;
Ok(s.then(|(message, acker)| async move {
_ = acker.ack().await;
message
})
.boxed())
}
#[instrument(skip_all, level = "trace")]
pub async fn get_task_receiver(&self) -> Result<Receiver<AnyTask>> {
self.task_channel.receiver().await
}
#[instrument(skip(self), level = "trace")]
pub async fn dispatch_fatal<E>(
&self,
ExecutionErr {
routing_key,
err,
strategy,
}: ExecutionErr<E>,
) -> Result<()>
where
E: std::fmt::Display + std::fmt::Debug,
{
match strategy {
FatalStrategy::Ignore => Ok(()),
FatalStrategy::Terminate => {
let (ipc, sender) = try_join!(
self.get_ipc_sender(),
self.get_result_sender(routing_key.clone())
)?;
let ipc_msg = WorkerIpc::ExecutionError { routing_key };
let sender_msg = AnyTaskResult::Err(err.to_string());
try_join!(ipc.publish(&ipc_msg), sender.publish(&sender_msg))?;
try_join!(ipc.close(), sender.close())?;
Ok(())
}
}
}
#[instrument(skip(self), level = "trace")]
pub async fn dispatch_ok<'a>(
&self,
ExecutionOk {
routing_key,
output,
}: ExecutionOk,
) -> Result<()> {
let sender = self.get_result_sender(routing_key).await?;
sender.publish(&AnyTaskResult::Ok(output)).await?;
sender.close().await?;
Ok(())
}
#[instrument(skip(self), level = "trace")]
pub async fn main_loop(&self) -> Result<()> {
let mut task_stream = self.get_task_receiver().await?;
const TERMINATION_CLEAR_INTERVAL: Duration = Duration::from_secs(60);
let terminated_jobs: Arc<DashMap<String, Instant>> = Default::default();
let reaper = tokio::spawn({
let terminated_jobs = terminated_jobs.clone();
async move {
loop {
terminated_jobs.retain(|_, v| v.elapsed() < TERMINATION_CLEAR_INTERVAL);
tokio::time::sleep(TERMINATION_CLEAR_INTERVAL).await;
}
}
});
let identifier: String = get_random_routing_key();
let (ipc_sig_term_tx, ipc_sig_term_rx) = tokio::sync::watch::channel::<String>(identifier);
let mut ipc_receiver = self.get_ipc_receiver().await?;
let remote_ipc_sig_term_handler = tokio::spawn({
let terminated_jobs = terminated_jobs.clone();
async move {
while let Some(ipc) = ipc_receiver.next().await {
match ipc {
WorkerIpc::ExecutionError { routing_key } => {
if mark_terminated(&terminated_jobs, routing_key.clone()) {
warn!(routing_key = %routing_key, "received IPC termination signal");
ipc_sig_term_tx.send_replace(routing_key.clone());
}
}
}
}
}
});
#[inline]
fn mark_terminated(
terminated_jobs: &DashMap<String, Instant>,
routing_key: String,
) -> bool {
if let Entry::Vacant(entry) = terminated_jobs.entry(routing_key.clone()) {
entry.insert(Instant::now());
return true;
}
false
}
while let Some((payload, acker)) = task_stream.next().await {
if terminated_jobs.contains_key(&payload.clone().routing_key) {
trace!(routing_key = %payload.clone().routing_key, "skipping terminated job");
acker.nack().await?;
continue;
}
let routing_key = payload.clone().routing_key;
let routing_key_clone = routing_key.clone();
let span = debug_span!("remote_execute", routing_key = %routing_key_clone);
let execution_task = payload.remote_execute().instrument(span);
let ipc_sig_term = {
let mut ipc_sig_term_rx = ipc_sig_term_rx.clone();
async move {
loop {
ipc_sig_term_rx.changed().await.expect("IPC channel closed");
if *ipc_sig_term_rx.borrow() == routing_key_clone {
return true;
}
}
}
};
select! {
execution = execution_task => {
match execution {
Ok(output) => {
try_join!(
acker.ack(),
self.dispatch_ok(ExecutionOk {
routing_key,
output,
})
)?;
}
Err(err) => {
error!(routing_key = %routing_key, "execution error: {err:?}");
mark_terminated(&terminated_jobs, routing_key.clone());
try_join!(
acker.nack(),
self.dispatch_fatal(ExecutionErr {
routing_key,
strategy: err.fatal_strategy(),
err,
})
)?;
}
}
}
_ = ipc_sig_term => {
warn!(routing_key = %routing_key, "task cancelled via IPC sigterm");
_ = acker.nack().await;
}
}
}
remote_ipc_sig_term_handler.abort();
reaper.abort();
Ok(())
}
}
mod dynamic_channel;