use futures::future::join_all;
use futures::future::try_join_all;
use tokio::sync::RwLock;
use tokio::sync::Barrier;
use anyhow::Result;
use tokio::task::JoinHandle;
use indicatif::{ProgressBar, ProgressStyle, MultiProgress};
use std::{
sync::Arc,
collections::HashMap
};
use crate::{
ActionObject, Id,
ActionOutput,
context::Context
};
#[derive(Clone)]
pub struct Runtime {
ctx: Arc<RwLock<Context>>,
barriers: HashMap<Id, Arc<Barrier>>,
outputs: Arc<RwLock<HashMap<Id, ActionOutput>>>,
progress: Arc<RwLock<MultiProgress>>
}
impl Runtime {
pub async fn run(mut self) -> Result<()> {
let actions = self.ctx.read().await.actions.clone();
let mut dependents: HashMap<Id, usize> = HashMap::new();
for action in actions.iter() {
dependents.insert(action.id, 0);
action.deps()
.iter()
.map(|dep| dep.id())
.for_each(|id| {
let count = dependents.entry(id).or_insert(0);
*count += 1;
});
}
for (id, dependents) in dependents.clone() {
if dependents == 0 {
continue;
}
let barrier = Arc::new(Barrier::new(dependents + 1));
self.barriers.insert(id, barrier);
}
let mut handles: Vec<JoinHandle<Result<()>>> = Vec::new();
let bars = Arc::new(RwLock::new(Vec::new()));
let bars_clone = bars.clone();
let tick_loop = tokio::spawn(async move {
loop {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
bars_clone.write().await.iter().for_each(|bar: &ProgressBar| bar.tick());
}
});
for action in actions {
let runtime_clone = self.clone();
let bars = bars.clone();
let action = action.clone();
let deps = action.deps();
let barriers = deps
.iter()
.map(|dep| dep.id());
let barriers = barriers
.map(|id| self.barriers.get(&id).unwrap().clone())
.collect::<Vec<_>>();
let self_barriers = self.barriers.clone();
handles.push(tokio::spawn(async move {
let self_barrier = self_barriers.get(&action.id).cloned();
for barrier in barriers {
barrier.wait().await;
}
if action.check(runtime_clone.clone()).await? {
return Ok(())
}
let display_name = action.display_name();
let progress = runtime_clone.progress.write().await.add(ProgressBar::new_spinner());
progress.set_style(ProgressStyle::default_spinner().template(" {spinner} [{elapsed_precise}] {wide_msg}")?);
progress.set_message(display_name.clone());
bars.write().await.push(progress.clone());
let output = action.perform(runtime_clone.clone()).await;
if output.is_err() {
progress.finish_with_message(format!("Error: {}", display_name));
return Err(output.err().unwrap())
}
progress.finish_and_clear();
let output = output.unwrap();
if let Some(barrier) = self_barrier {
barrier.wait().await;
}
if let Some(output) = output {
runtime_clone.outputs.write().await.insert(action.id, output);
}
Ok(())
}));
}
let results = try_join_all(handles).await;
tick_loop.abort();
bars.write().await.iter().for_each(|bar: &ProgressBar| bar.finish());
results?;
Ok(())
}
pub async fn get_output(&self, obj: ActionObject) -> Option<ActionOutput> {
self.outputs.read().await.get(&obj.id()).cloned()
}
}
pub struct RuntimeBuilder {
ctx: Context
}
impl RuntimeBuilder {
pub fn new() -> Self {
Self {
ctx: Context::new()
}
}
pub fn add_action(mut self, action: ActionObject) -> Self {
self.ctx.add_action(action);
self
}
pub fn build(self) -> Runtime {
Runtime {
ctx: Arc::new(RwLock::new(self.ctx)),
barriers: HashMap::new(),
outputs: Arc::new(RwLock::new(HashMap::new())),
progress: Arc::new(RwLock::new(MultiProgress::new()))
}
}
}
impl Default for RuntimeBuilder {
fn default() -> Self {
Self::new()
}
}