use apalis_core::backend::BackendExt;
use apalis_core::backend::codec::Codec;
use apalis_core::task::builder::TaskBuilder;
use apalis_core::task::metadata::Meta;
use apalis_core::task::status::Status;
use apalis_core::{
backend::WaitForCompletion,
error::BoxDynError,
task::{Task, metadata::MetadataExt, task_id::TaskId},
};
use futures::future::BoxFuture;
use futures::{FutureExt, Sink, SinkExt, StreamExt};
use petgraph::Direction;
use petgraph::graph::NodeIndex;
use std::collections::HashMap;
use std::fmt::Debug;
use tower::Service;
use crate::DagExecutor;
use crate::dag::context::DagFlowContext;
use crate::dag::error::{DagFlowError, DagServiceError};
use crate::dag::response::DagExecutionResponse;
use crate::id_generator::GenerateId;
pub struct RootDagService<B>
where
B: BackendExt,
{
executor: DagExecutor<B>,
backend: B,
}
impl<B> std::fmt::Debug for RootDagService<B>
where
B: BackendExt,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RootDagService")
.field("executor", &"<DagExecutor>")
.field("backend", &"<Backend>")
.finish()
}
}
impl<B> RootDagService<B>
where
B: BackendExt,
{
pub(crate) fn new(executor: DagExecutor<B>, backend: B) -> Self {
Self { executor, backend }
}
}
impl<B> Clone for RootDagService<B>
where
B: BackendExt + Clone,
{
fn clone(&self) -> Self {
Self {
executor: self.executor.clone(),
backend: self.backend.clone(),
}
}
}
fn find_designated_fan_in_handler(
incoming_nodes: &[NodeIndex],
) -> Result<&NodeIndex, DagFlowError> {
let designated_handler = incoming_nodes.iter().max_by_key(|n| n.index());
designated_handler.ok_or(DagFlowError::Service(DagServiceError::MissingFaninHandler))
}
impl<B, Err, CdcErr, MetaError, IdType> Service<Task<B::Compact, B::Context, B::IdType>>
for RootDagService<B>
where
B: BackendExt<Error = Err, IdType = IdType>
+ Send
+ Sync
+ 'static
+ Clone
+ WaitForCompletion<DagExecutionResponse<B::Compact, IdType>>,
IdType: GenerateId + Send + Sync + 'static + PartialEq + Debug + Clone,
B::Compact: Send + Sync + 'static + Clone,
B::Context:
Send + Sync + Default + MetadataExt<DagFlowContext<B::IdType>, Error = MetaError> + 'static,
Err: std::error::Error + Send + Sync + 'static,
B: Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err> + Unpin,
B::Codec: Codec<Vec<B::Compact>, Compact = B::Compact, Error = CdcErr>
+ 'static
+ Codec<DagExecutionResponse<B::Compact, B::IdType>, Compact = B::Compact, Error = CdcErr>,
CdcErr: Into<BoxDynError>,
MetaError: Into<BoxDynError> + Send + Sync + 'static,
{
type Response = DagExecutionResponse<B::Compact, B::IdType>;
type Error = DagFlowError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.executor.poll_ready(cx)
}
fn call(&mut self, mut req: Task<B::Compact, B::Context, B::IdType>) -> Self::Future {
let mut executor = self.executor.clone();
let mut backend = self.backend.clone();
let start_nodes = executor.start_nodes.clone();
let end_nodes = executor.end_nodes.clone();
async move {
let ctx = req.extract::<Meta<DagFlowContext<B::IdType>>>().await;
let (response, context) = if let Ok(Meta(context)) = ctx {
#[cfg(feature = "tracing")]
tracing::debug!(
node = ?context.current_node,
"Extracted DagFlowContext for task"
);
let incoming_nodes = executor
.graph
.neighbors_directed(context.current_node, Direction::Incoming)
.collect::<Vec<_>>();
match incoming_nodes.len() {
0 if start_nodes.len() == 1 => {
let response = executor.call(req).await?;
(response, context)
}
0 if start_nodes.len() > 1 => {
let response = executor.call(req).await?;
(response, context)
}
1 => {
let response = executor.call(req).await?;
(response, context)
}
_ => {
let dependency_task_ids = context.get_dependency_task_ids(&incoming_nodes);
#[cfg(feature = "tracing")]
tracing::debug!(
prev_node = ?context.prev_node,
node = ?context.current_node,
deps = ?dependency_task_ids,
"Fanning in from multiple dependencies",
);
let prev_node = context
.prev_node
.ok_or(DagFlowError::Service(DagServiceError::MissingPreviousNode))?;
if *find_designated_fan_in_handler(&incoming_nodes)? != prev_node {
return Ok(DagExecutionResponse::WaitingForDependencies {
pending_dependencies: dependency_task_ids,
});
}
let results = backend
.wait_for(dependency_task_ids.values().cloned().collect::<Vec<_>>())
.collect::<Vec<_>>()
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()
.map_err(|e| DagFlowError::Backend(e.into()))?;
if results.iter().all(|s| matches!(s.status, Status::Done)) {
let sorted_results = {
let res = incoming_nodes
.iter()
.rev()
.map(|node_index| {
let task_id = context
.node_task_ids
.iter()
.find(|(n, _)| *n == node_index)
.map(|(_, task_id)| task_id)
.ok_or(DagFlowError::Service(
DagServiceError::MissingIncomingTaskId,
))?;
let task_result = results
.iter()
.find(|r| &r.task_id == task_id)
.ok_or(DagFlowError::Service(
DagServiceError::MissingTaskIdResult(format!(
"{:?}",
task_id.inner()
)),
))?;
Ok(task_result)
})
.collect::<Result<Vec<_>, DagFlowError>>();
match res {
Ok(v) => v,
Err(_) => {
return Ok(DagExecutionResponse::WaitingForDependencies {
pending_dependencies: dependency_task_ids,
});
}
}
};
let res = sorted_results
.iter()
.map(|s| match &s.result {
Ok(val) => match val {
DagExecutionResponse::FanOut { response, .. } => {
Ok(response.clone())
}
DagExecutionResponse::EnqueuedNext { result }
| DagExecutionResponse::Complete { result } => {
Ok(result.clone())
}
_ => Err(DagFlowError::Service(
DagServiceError::InvalidFanInDependencyResult,
)),
},
Err(e) => Err(DagFlowError::Service(
DagServiceError::DependencyTaskFailed(e.as_str().into()),
)),
})
.collect::<Result<Vec<_>, _>>()?;
let encoded_input = B::Codec::encode(&res)
.map_err(|e| DagFlowError::Codec(e.into()))?;
let req = req.map(|_| encoded_input); let response = executor.call(req).await?;
(response, context)
} else {
return Err(DagFlowError::Service(
DagServiceError::DependencyTaskFailed(
"An adjacent node failed. Terminating".into(),
),
));
}
}
}
} else {
#[cfg(feature = "tracing")]
tracing::debug!("Extracting DagFlowContext for task without meta");
if start_nodes.len() == 1 {
#[cfg(feature = "tracing")]
tracing::debug!("Single start node detected, proceeding with execution");
let context = DagFlowContext::new(req.parts.task_id.clone());
req.parts
.ctx
.inject(context.clone())
.map_err(|e| DagFlowError::Metadata(e.into()))?;
let response = executor.call(req).await?;
#[cfg(feature = "tracing")]
tracing::debug!(node = ?context.current_node, "Execution complete at node");
(response, context)
} else {
let new_node_task_ids = fan_out_entry_nodes(
&executor,
&backend,
&DagFlowContext::new(req.parts.task_id.clone()),
&req.args,
)
.await?;
return Ok(DagExecutionResponse::EntryFanOut {
node_task_ids: new_node_task_ids,
});
}
};
let current_node = context.current_node;
let outgoing_nodes = executor
.graph
.neighbors_directed(current_node, Direction::Outgoing)
.collect::<Vec<_>>();
match outgoing_nodes.len() {
0 => {
assert!(
end_nodes.contains(¤t_node),
"Current node is not an end node"
);
return Ok(DagExecutionResponse::Complete { result: response });
}
1 => {
let next_node = outgoing_nodes[0];
let mut new_context = context.clone();
new_context.prev_node = Some(current_node);
new_context.current_node = next_node;
new_context.current_position += 1;
new_context.is_initial = false;
let task = TaskBuilder::new(response.clone())
.with_task_id(TaskId::new(B::IdType::generate()))
.meta(new_context)
.build();
backend
.send(task)
.await
.map_err(|e| DagFlowError::Backend(e.into()))?;
}
_ => {
let mut new_context = context.clone();
new_context.prev_node = Some(current_node);
new_context.current_position += 1;
new_context.is_initial = false;
let next_task_ids = fan_out_next_nodes(
&executor,
outgoing_nodes,
&backend,
&new_context,
&response,
)
.await?;
return Ok(DagExecutionResponse::FanOut {
response,
node_task_ids: next_task_ids,
});
}
}
Ok(DagExecutionResponse::EnqueuedNext { result: response })
}
.boxed()
}
}
async fn fan_out_next_nodes<B, Err, CdcErr>(
_executor: &DagExecutor<B>,
outgoing_nodes: Vec<NodeIndex>,
backend: &B,
context: &DagFlowContext<B::IdType>,
input: &B::Compact,
) -> Result<HashMap<NodeIndex, TaskId<B::IdType>>, DagFlowError>
where
B::IdType: GenerateId + Send + Sync + 'static + PartialEq,
B::Compact: Send + Sync + 'static + Clone,
B::Context: Send + Sync + Default + MetadataExt<DagFlowContext<B::IdType>> + 'static,
B: Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err> + Unpin,
Err: std::error::Error + Send + Sync + 'static,
B: BackendExt<Error = Err> + Send + Sync + 'static + Clone,
B::Codec: Codec<Vec<B::Compact>, Compact = B::Compact, Error = CdcErr>,
CdcErr: Into<BoxDynError>,
{
let mut enqueue_futures = vec![];
let next_nodes = outgoing_nodes
.iter()
.map(|node| (*node, TaskId::new(B::IdType::generate())))
.collect::<HashMap<NodeIndex, TaskId<B::IdType>>>();
let mut node_task_ids = next_nodes.clone();
node_task_ids.extend(context.node_task_ids.clone());
for outgoing_node in outgoing_nodes.into_iter() {
let task_id = next_nodes
.get(&outgoing_node)
.ok_or(DagFlowError::Service(DagServiceError::MissingNextNode))?
.clone();
let task = TaskBuilder::new(input.clone())
.with_task_id(task_id)
.meta(DagFlowContext {
prev_node: context.prev_node,
current_node: outgoing_node,
completed_nodes: context.completed_nodes.clone(),
node_task_ids: node_task_ids.clone(),
current_position: context.current_position + 1,
is_initial: context.is_initial,
root_task_id: context.root_task_id.clone(),
})
.build();
let mut b = backend.clone();
enqueue_futures.push(
async move {
b.send(task)
.await
.map_err(|e| DagFlowError::Backend(e.into()))?;
Ok::<(), DagFlowError>(())
}
.boxed(),
);
}
futures::future::try_join_all(enqueue_futures).await?;
Ok(next_nodes)
}
async fn fan_out_entry_nodes<B, Err, CdcErr>(
executor: &DagExecutor<B>,
backend: &B,
context: &DagFlowContext<B::IdType>,
input: &B::Compact,
) -> Result<HashMap<NodeIndex, TaskId<B::IdType>>, DagFlowError>
where
B::IdType: GenerateId + Send + Sync + 'static + PartialEq + Debug,
B::Compact: Send + Sync + 'static + Clone,
B::Context: Send + Sync + Default + MetadataExt<DagFlowContext<B::IdType>> + 'static,
B: Sink<Task<B::Compact, B::Context, B::IdType>, Error = Err> + Unpin,
Err: std::error::Error + Send + Sync + 'static,
B: BackendExt<Error = Err> + Send + Sync + 'static + Clone,
B::Codec: Codec<Vec<B::Compact>, Compact = B::Compact, Error = CdcErr>,
CdcErr: Into<BoxDynError>,
{
let values: Vec<B::Compact> =
B::Codec::decode(input).map_err(|e: CdcErr| DagFlowError::Codec(e.into()))?;
let start_nodes = executor.start_nodes.clone();
if values.len() != start_nodes.len() {
return Err(DagFlowError::InputCountMismatch {
expected: start_nodes.len(),
actual: values.len(),
});
}
let mut enqueue_futures = vec![];
let next_nodes = start_nodes
.iter()
.map(|node| (*node, TaskId::new(B::IdType::generate())))
.collect::<HashMap<NodeIndex, TaskId<B::IdType>>>();
let mut node_task_ids = next_nodes.clone();
node_task_ids.extend(context.node_task_ids.clone());
for (outgoing_node, input) in start_nodes.into_iter().zip(values) {
let task_id = next_nodes
.get(&outgoing_node)
.ok_or(DagFlowError::Service(DagServiceError::MissingNextNode))?;
let task = TaskBuilder::new(input)
.with_task_id(task_id.clone())
.meta(DagFlowContext {
prev_node: None,
current_node: outgoing_node,
completed_nodes: Default::default(),
node_task_ids: node_task_ids.clone(),
current_position: context.current_position,
is_initial: true,
root_task_id: context.root_task_id.clone(),
})
.build();
let mut b = backend.clone();
enqueue_futures.push(
async move {
b.send(task)
.await
.map_err(|e| DagFlowError::Backend(BoxDynError::from(e)))?;
Ok::<(), DagFlowError>(())
}
.boxed(),
);
}
futures::future::try_join_all(enqueue_futures).await?;
Ok(next_nodes)
}