use std::sync::{Arc, Mutex};
use crate::task::{TaskHandle, TaskNode, TaskWork, TaskId};
use crate::condition::{ConditionalHandle, BranchId};
use std::collections::HashMap;
pub struct Taskflow {
graph: Arc<Mutex<Vec<TaskNode>>>,
next_id: TaskId,
conditional_branches: Arc<Mutex<HashMap<TaskId, HashMap<BranchId, Vec<TaskId>>>>>,
}
impl Taskflow {
pub fn new() -> Self {
Self {
graph: Arc::new(Mutex::new(Vec::new())),
next_id: 0,
conditional_branches: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn emplace<F>(&mut self, work: F) -> TaskHandle
where
F: FnOnce() + Send + 'static,
{
let id = self.next_id;
self.next_id += 1;
let node = TaskNode::new(id, TaskWork::Static(Box::new(work)));
self.graph.lock().unwrap().push(node);
TaskHandle::new(id, Arc::clone(&self.graph))
}
pub fn emplace_subflow<F>(&mut self, work: F) -> TaskHandle
where
F: FnOnce(&mut crate::Subflow) + Send + 'static,
{
let id = self.next_id;
self.next_id += 1;
let node = TaskNode::new(id, TaskWork::Subflow(Box::new(work)));
self.graph.lock().unwrap().push(node);
TaskHandle::new(id, Arc::clone(&self.graph))
}
pub fn emplace_condition<F>(&mut self, condition: F) -> TaskHandle
where
F: FnOnce() -> usize + Send + 'static,
{
let id = self.next_id;
self.next_id += 1;
let node = TaskNode::new(id, TaskWork::Condition(Box::new(condition)));
self.graph.lock().unwrap().push(node);
TaskHandle::new(id, Arc::clone(&self.graph))
}
pub fn emplace_conditional<F>(&mut self, condition: F) -> ConditionalHandle
where
F: FnOnce() -> usize + Send + 'static,
{
let task = self.emplace_condition(condition);
ConditionalHandle::new(task)
}
pub fn register_branches(&mut self, cond_handle: &ConditionalHandle) {
let mut branches_map = self.conditional_branches.lock().unwrap();
let task_id = cond_handle.task_id();
let mut branch_successors = HashMap::new();
for (branch_id, successors) in &cond_handle.branches {
let successor_ids: Vec<TaskId> = successors.iter().map(|h| h.id).collect();
branch_successors.insert(*branch_id, successor_ids);
for successor in successors {
successor.succeed(&cond_handle.task);
}
}
branches_map.insert(task_id, branch_successors);
}
pub(crate) fn get_conditional_branches(&self) -> Arc<Mutex<HashMap<TaskId, HashMap<BranchId, Vec<TaskId>>>>> {
Arc::clone(&self.conditional_branches)
}
#[cfg(feature = "async")]
pub fn emplace_async<F, Fut>(&mut self, work: F) -> TaskHandle
where
F: FnOnce() -> Fut + Send + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let id = self.next_id;
self.next_id += 1;
let work_boxed: Box<dyn FnOnce() -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + 'static> =
Box::new(move || Box::pin(work()));
let node = TaskNode::new(id, TaskWork::Async(work_boxed));
self.graph.lock().unwrap().push(node);
TaskHandle::new(id, Arc::clone(&self.graph))
}
pub(crate) fn get_graph(&self) -> Arc<Mutex<Vec<TaskNode>>> {
Arc::clone(&self.graph)
}
#[allow(dead_code)]
pub(crate) fn next_id(&mut self) -> TaskId {
let id = self.next_id;
self.next_id += 1;
id
}
pub fn dump(&self) -> String {
let graph = self.graph.lock().unwrap();
let mut dot = String::from("digraph Taskflow {\n");
for node in graph.iter() {
dot.push_str(&format!(" {} [label=\"{}\"];\n", node.id, node.name));
for succ in &node.successors {
dot.push_str(&format!(" {} -> {};\n", node.id, succ));
}
}
dot.push_str("}\n");
dot
}
pub fn size(&self) -> usize {
self.graph.lock().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.graph.lock().unwrap().is_empty()
}
}
impl Default for Taskflow {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_taskflow_creation() {
let mut tf = Taskflow::new();
let a = tf.emplace(|| println!("A")).name("A");
let b = tf.emplace(|| println!("B")).name("B");
a.precede(&b);
assert_eq!(tf.size(), 2);
let dot = tf.dump();
assert!(dot.contains("A"));
assert!(dot.contains("B"));
assert!(dot.contains("->"));
}
}