use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
use super::erased::{ErasedPipeline, SpawnedTask};
use super::store::{TaskError, TaskStore};
pub struct Runner<S: TaskStore> {
store: Arc<S>,
pipelines: HashMap<&'static str, Arc<dyn ErasedPipeline>>,
poll_interval: Duration,
max_concurrent: usize,
pipeline_semaphores: HashMap<&'static str, Arc<tokio::sync::Semaphore>>,
}
impl<S: TaskStore + 'static> Runner<S> {
pub async fn submit<T: serde::Serialize>(
&self,
pipeline: &str,
input: T,
) -> Result<super::store::TaskId, TaskError> {
let json = serde_json::to_value(input)
.map_err(|e| TaskError::SerializationError(e.to_string()))?;
self.store.enqueue(pipeline, json).await
}
pub async fn run(&self) -> ! {
let _ = self.store.recover_orphans().await;
let global_semaphore = Arc::new(Semaphore::new(self.max_concurrent));
loop {
let mut tasks_to_spawn = Vec::new();
for (pipeline_name, pipeline_sem) in &self.pipeline_semaphores {
if let Ok(permit) = pipeline_sem.clone().try_acquire_owned() {
if let Ok(mut tasks) = self.store.claim_for_pipeline(pipeline_name, 1).await {
if let Some(task) = tasks.pop() {
tasks_to_spawn.push((task, Some(permit), None));
}
}
}
}
let global_available = global_semaphore.available_permits();
if global_available > 0 {
let excluded: Vec<&str> = self.pipeline_semaphores.keys().copied().collect();
if let Ok(tasks) = self
.store
.claim_excluding(global_available, &excluded)
.await
{
for task in tasks {
if let Ok(permit) = global_semaphore.clone().try_acquire_owned() {
tasks_to_spawn.push((task, None, Some(permit)));
}
}
}
}
for (task, pipeline_permit, global_permit) in tasks_to_spawn {
let store = self.store.clone();
let pipelines = self.pipelines.clone();
tokio::spawn(async move {
let _pipeline_permit = pipeline_permit;
let _global_permit = global_permit;
let result = Self::execute_task(&pipelines, store.as_ref(), &task).await;
match result {
Ok(spawned) => {
for sp in spawned {
let _ = store.enqueue(sp.pipeline, sp.input).await;
}
let _ = store.complete(task.id).await;
}
Err(e) => {
let _ = store.fail(task.id, &e.to_string()).await;
}
}
});
}
tokio::time::sleep(self.poll_interval).await;
}
}
async fn execute_task(
pipelines: &HashMap<&'static str, Arc<dyn ErasedPipeline>>,
_store: &S,
task: &super::store::StoredTask,
) -> Result<Vec<SpawnedTask>, TaskError> {
let pipeline = pipelines
.get(task.pipeline.as_str())
.ok_or_else(|| TaskError::UnknownPipeline(task.pipeline.clone()))?;
pipeline.run_erased(task.input.clone()).await
}
}
pub struct RunnerBuilder<S: TaskStore> {
store: S,
pipelines: HashMap<&'static str, Arc<dyn ErasedPipeline>>,
poll_interval: Duration,
max_concurrent: usize,
pipeline_concurrency: HashMap<&'static str, usize>,
}
impl<S: TaskStore + 'static> RunnerBuilder<S> {
pub fn new(store: S) -> Self {
Self {
store,
pipelines: HashMap::new(),
poll_interval: Duration::from_secs(1),
max_concurrent: 1,
pipeline_concurrency: HashMap::new(),
}
}
pub fn pipeline(mut self, pipeline: impl ErasedPipeline + 'static) -> Self {
let name = pipeline.name();
self.pipelines.insert(name, Arc::new(pipeline));
self
}
pub fn pipeline_with_concurrency(
mut self,
pipeline: impl ErasedPipeline + 'static,
max_concurrent: usize,
) -> Self {
assert!(
max_concurrent > 0,
"pipeline concurrency must be at least 1"
);
let name = pipeline.name();
self.pipeline_concurrency.insert(name, max_concurrent);
self.pipelines.insert(name, Arc::new(pipeline));
self
}
pub fn poll_interval(mut self, interval: Duration) -> Self {
self.poll_interval = interval;
self
}
pub fn max_concurrent(mut self, n: usize) -> Self {
self.max_concurrent = n;
self
}
pub fn get_pipeline_graphs(&self) -> Vec<crate::pipeline::PipelineGraph> {
self.pipelines.values().map(|p| p.to_graph()).collect()
}
pub fn build(self) -> Runner<S> {
let pipeline_semaphores = self
.pipeline_concurrency
.into_iter()
.map(|(name, limit)| (name, Arc::new(tokio::sync::Semaphore::new(limit))))
.collect();
Runner {
store: Arc::new(self.store),
pipelines: self.pipelines,
poll_interval: self.poll_interval,
max_concurrent: self.max_concurrent,
pipeline_semaphores,
}
}
}