use tokio::sync::RwLock;
use tokio::sync::Barrier;
use tokio::task::JoinSet;
use std::any::{Any, TypeId};
use tracing::{debug, info, error};
use std::{
sync::Arc,
collections::HashMap
};
use crate::Operation;
use crate::{
ActionObject, Id,
ActionOutput,
ActionError,
context::Context
};
#[derive(Clone)]
pub struct Runtime {
ctx: Context,
barriers: HashMap<Id, Arc<Barrier>>,
outputs: Arc<RwLock<HashMap<Id, ActionOutput>>>,
state: HashMap<TypeId, Arc<dyn Any + Send + Sync>>
}
impl Runtime {
pub async fn perform(mut self) -> Result<(), ActionError> {
let actions = self.ctx.actions.clone();
let mut dependents: HashMap<Id, usize> = HashMap::new();
for action in actions.iter() {
dependents.insert(action.id, 0);
action.deps()
.iter()
.map(|dep| dep.id())
.for_each(|id| {
let count = dependents.entry(id).or_insert(0);
*count += 1;
});
}
for (id, dependents) in dependents.clone() {
if dependents == 0 {
continue;
}
let barrier = Arc::new(Barrier::new(dependents + 1));
self.barriers.insert(id, barrier);
}
let mut join_set: JoinSet<Result<(), ActionError>> = JoinSet::new();
debug!("Starting actions");
for action in actions {
let runtime_clone = self.clone();
let action = action.clone();
let deps = action.deps();
let barriers = deps
.iter()
.map(|dep| dep.id());
let barriers = barriers
.map(|id| self.barriers.get(&id).unwrap().clone())
.collect::<Vec<_>>();
let self_barriers = self.barriers.clone();
join_set.spawn(async move {
let self_barrier = self_barriers.get(&action.id).cloned();
for barrier in barriers {
barrier.wait().await;
}
let probe = action.probe(runtime_clone.clone()).await?;
if !probe.needs_run {
return Ok(())
}
let display_name = action.display_name();
info!("Starting action: {}", display_name);
let output = action.run(runtime_clone.clone(), Operation::Perform).await;
if let Err(err) = &output {
error!("Action failed: {}", display_name);
error!("Error: {}", err);
return Err(err.clone())
} else {
info!("Action finished: {}", display_name);
}
let output = output.unwrap();
if let Some(barrier) = self_barrier {
barrier.wait().await;
}
if let Some(output) = output {
runtime_clone.outputs.write().await.insert(action.id, output);
}
Ok(())
});
}
while let Some(result) = join_set.join_next().await {
match result {
Ok(Ok(())) => {},
Ok(Err(err)) => {
join_set.abort_all();
if let ActionError::ActionFailed(_, long) = err.clone() {
println!("{}", long);
}
return Err(err)
},
Err(_) => {
join_set.abort_all();
return Err(ActionError::InternalError("JOIN_SET_ERROR"))
}
}
}
Ok(())
}
pub async fn rollback(self) -> Result<(), ActionError> {
let actions = self.ctx.actions.clone();
let mut dependencies: HashMap<Id, Vec<Id>> = HashMap::new();
for action in actions.iter() {
if !action.probe(self.clone()).await?.can_rollback {
return Err(ActionError::InternalError("NO_ROLLBACK"))
}
}
for action in actions.iter() {
dependencies.insert(action.id, Vec::new());
action.deps()
.iter()
.map(|dep| dep.id())
.for_each(|id| {
let deps = dependencies.entry(id).or_insert(Vec::new());
deps.push(action.id);
});
}
let mut actions = actions;
actions.sort_by(|a, b| {
let a_deps = dependencies.get(&a.id).unwrap();
let b_deps = dependencies.get(&b.id).unwrap();
if a_deps.contains(&b.id) {
return std::cmp::Ordering::Greater
}
if b_deps.contains(&a.id) {
return std::cmp::Ordering::Less
}
std::cmp::Ordering::Equal
});
let mut join_set: JoinSet<Result<(), ActionError>> = JoinSet::new();
for action in actions {
let runtime_clone = self.clone();
join_set.spawn(async move {
action.run(runtime_clone.clone(), Operation::Rollback).await?;
Ok(())
});
}
while let Some(result) = join_set.join_next().await {
match result {
Ok(Ok(())) => {},
Ok(Err(err)) => {
join_set.abort_all();
if let ActionError::ActionFailed(_, long) = err.clone() {
println!("{}", long);
}
return Err(err)
},
Err(_) => {
join_set.abort_all();
return Err(ActionError::InternalError("JOIN_SET_ERROR"))
}
}
}
Ok(())
}
pub async fn get_output(&self, obj: ActionObject) -> Option<ActionOutput> {
self.outputs.read().await.get(&obj.id()).cloned()
}
pub fn get_state<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
self.state.get(&TypeId::of::<T>()).cloned().map(|state| {
state.downcast::<T>().unwrap()
})
}
}
pub struct RuntimeBuilder {
ctx: Context,
state: HashMap<TypeId, Arc<dyn Any + Send + Sync>>
}
impl RuntimeBuilder {
pub fn new() -> Self {
Self {
ctx: Context::new(),
state: HashMap::new()
}
}
pub async fn add_action(mut self, action: ActionObject) -> Self {
action.load_state(&mut self).await;
self.ctx.add_action(action);
self
}
pub fn build(self) -> Runtime {
Runtime {
ctx: self.ctx,
barriers: HashMap::new(),
outputs: Arc::new(RwLock::new(HashMap::new())),
state: self.state
}
}
pub fn add_state<T: Send + Sync + 'static>(&mut self, state: T) -> &mut Self {
self.state.insert(TypeId::of::<T>(), Arc::new(state));
self
}
}
impl Default for RuntimeBuilder {
fn default() -> Self {
Self::new()
}
}