vorma 0.86.0-pre.3

Vorma framework.
Documentation
use std::future::Future;
use std::sync::Arc;

use vorma_tasks::{CancelToken, ExecCtx};

use super::error::Error;

#[derive(Clone)]
pub(in crate::mux) struct OrderedCancellation {
	later_tokens_by_index: Arc<Vec<Vec<CancelToken>>>,
}

impl OrderedCancellation {
	fn new(later_tokens_by_index: Vec<Vec<CancelToken>>) -> Self {
		Self {
			later_tokens_by_index: Arc::new(later_tokens_by_index),
		}
	}

	pub(in crate::mux) fn cancel_later(&self, index: usize) {
		if let Some(tokens) = self.later_tokens_by_index.get(index) {
			for token in tokens {
				token.cancel();
			}
		}
	}
}

pub(in crate::mux) struct OrderedTaskContexts<E> {
	contexts: Vec<OrderedTaskCtx<E>>,
	cancellation: OrderedCancellation,
}

impl<E> OrderedTaskContexts<E>
where
	E: Send + Sync + 'static,
{
	pub(in crate::mux) fn sibling_children(parent: &ExecCtx<E>, count: usize) -> Self {
		let exec_ctxs = (0..count).map(|_| parent.child()).collect::<Vec<_>>();
		let tokens = exec_ctxs
			.iter()
			.map(|exec_ctx| exec_ctx.cancel_token().clone())
			.collect::<Vec<_>>();
		let later_tokens_by_index = tokens
			.iter()
			.enumerate()
			.map(|(index, _)| tokens.iter().skip(index + 1).cloned().collect())
			.collect::<Vec<Vec<_>>>();
		Self::from_parts(exec_ctxs, later_tokens_by_index)
	}

	pub(in crate::mux) fn descendant_chain(parent: ExecCtx<E>, count: usize) -> Self {
		let mut current_exec_ctx = parent;
		let mut exec_ctxs = Vec::with_capacity(count);
		let mut later_tokens_by_index = Vec::with_capacity(count);
		for index in 0..count {
			let task_exec_ctx = current_exec_ctx.clone();
			let later_tokens = if index < count - 1 {
				let descendant_exec_ctx = current_exec_ctx.child();
				let token = descendant_exec_ctx.cancel_token().clone();
				current_exec_ctx = descendant_exec_ctx;
				vec![token]
			} else {
				Vec::new()
			};
			exec_ctxs.push(task_exec_ctx);
			later_tokens_by_index.push(later_tokens);
		}
		Self::from_parts(exec_ctxs, later_tokens_by_index)
	}

	pub(in crate::mux) fn cancellation(&self) -> OrderedCancellation {
		self.cancellation.clone()
	}

	pub(in crate::mux) fn into_contexts(self) -> Vec<OrderedTaskCtx<E>> {
		self.contexts
	}

	fn from_parts(
		exec_ctxs: Vec<ExecCtx<E>>,
		later_tokens_by_index: Vec<Vec<CancelToken>>,
	) -> Self {
		let cancellation = OrderedCancellation::new(later_tokens_by_index);
		let contexts = exec_ctxs
			.into_iter()
			.enumerate()
			.map(|(index, exec_ctx)| OrderedTaskCtx {
				index,
				exec_ctx,
				cancellation: cancellation.clone(),
			})
			.collect();
		Self {
			contexts,
			cancellation,
		}
	}
}

#[derive(Clone)]
pub(in crate::mux) struct OrderedTaskCtx<E> {
	index: usize,
	exec_ctx: ExecCtx<E>,
	cancellation: OrderedCancellation,
}

impl<E> OrderedTaskCtx<E> {
	pub(in crate::mux) fn index(&self) -> usize {
		self.index
	}

	pub(in crate::mux) fn exec_ctx(&self) -> ExecCtx<E> {
		self.exec_ctx.clone()
	}

	pub(in crate::mux) fn cancel_later(&self) {
		self.cancellation.cancel_later(self.index);
	}
}

pub(in crate::mux) async fn run_ordered_parallel<I, O, E, F, Fut>(
	items: Vec<I>,
	contexts: OrderedTaskContexts<E>,
	run_task: F,
) -> Result<OrderedParallelOutputs<O>, Error>
where
	I: Send + 'static,
	O: Send + 'static,
	E: Send + Sync + 'static,
	F: Fn(I, OrderedTaskCtx<E>) -> Fut + Clone + Send + Sync + 'static,
	Fut: Future<Output = O> + Send + 'static,
{
	let task_contexts = contexts.into_contexts();
	if items.len() != task_contexts.len() {
		return Err(Error::Invariant(format!(
			"ordered parallel item count {} does not match task context count {}",
			items.len(),
			task_contexts.len()
		)));
	}

	let mut handles = Vec::with_capacity(task_contexts.len());
	for (item, task_ctx) in items.into_iter().zip(task_contexts) {
		let run_task = run_task.clone();
		let index = task_ctx.index();
		handles.push(tokio::spawn(async move {
			let output = run_task(item, task_ctx).await;
			(index, output)
		}));
	}

	let mut outputs = Vec::with_capacity(handles.len());
	for handle in handles {
		outputs.push(
			handle
				.await
				.map_err(|error| Error::TaskJoin(error.to_string()))?,
		);
	}
	outputs.sort_by_key(|(index, _)| *index);
	Ok(OrderedParallelOutputs { outputs })
}

pub(in crate::mux) struct OrderedParallelOutputs<O> {
	outputs: Vec<(usize, O)>,
}

impl<O> OrderedParallelOutputs<O> {
	pub(in crate::mux) fn into_vec(self) -> Vec<(usize, O)> {
		self.outputs
	}
}