use std::collections::HashMap;
use std::mem;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use crate::entities::error::AtexError;
use crate::entities::task::{
NewPeriodic, NewRepoTask, NewTask, Task, TaskId, TaskStarted, TaskStats, TaskStatus,
};
use crate::proto::{
DynFut, ILogger, IPool, IPoolFactory, ITaskManager, ITaskRepo, IWrappedTaskExecutor,
};
use crate::utils::now;
use crate::ITaskController;
pub(crate) struct TaskManager<
P: IPoolFactory<TaskId, JoinHandle<TaskStatus>>,
R: ITaskRepo,
L: ILogger,
> {
data: Arc<TaskManagerData<P, R, L>>,
}
impl<P: IPoolFactory<TaskId, JoinHandle<TaskStatus>>, R: ITaskRepo, L: ILogger> Clone
for TaskManager<P, R, L>
{
fn clone(&self) -> Self {
Self {
data: self.data.clone(),
}
}
}
struct TaskManagerData<P: IPoolFactory<TaskId, JoinHandle<TaskStatus>>, R: ITaskRepo, L: ILogger> {
executors: HashMap<&'static str, Box<dyn IWrappedTaskExecutor + 'static>>,
pool_factory: P,
repo: Arc<Mutex<R>>,
logger: Arc<L>,
sleep: Duration,
last_report: Arc<Mutex<Vec<Option<TaskStarted>>>>,
}
struct MutCtx<P: IPool<TaskId, JoinHandle<TaskStatus>>> {
sender: Sender<usize>,
receiver: Receiver<usize>,
pool: P,
}
unsafe impl<P: IPoolFactory<TaskId, JoinHandle<TaskStatus>>, R: ITaskRepo, L: ILogger> Send
for TaskManagerData<P, R, L>
{
}
unsafe impl<P: IPoolFactory<TaskId, JoinHandle<TaskStatus>>, R: ITaskRepo, L: ILogger> Sync
for TaskManagerData<P, R, L>
{
}
impl<P: IPoolFactory<TaskId, JoinHandle<TaskStatus>>, R: ITaskRepo, L: ILogger>
TaskManager<P, R, L>
{
pub(crate) fn new(
executors: HashMap<&'static str, Box<dyn IWrappedTaskExecutor + 'static>>,
pool_factory: P,
repo: Arc<Mutex<R>>,
logger: Arc<L>,
sleep: Duration,
) -> Self {
let data = TaskManagerData {
executors,
pool_factory,
repo,
logger,
sleep,
last_report: Arc::new(Mutex::new(Vec::new())),
};
Self {
data: Arc::new(data),
}
}
}
impl<
P: IPoolFactory<TaskId, JoinHandle<TaskStatus>> + 'static,
R: ITaskRepo + 'static,
L: ILogger + 'static,
> TaskManager<P, R, L>
{
async fn a_create_task(self, task: NewTask) -> Result<TaskId, AtexError> {
let lock = self.gen_lock(&task)?;
let db_task = NewRepoTask {
periodic_interval: None,
executor: task.executor,
payload: task.payload,
created: now(),
updated: now(),
execute_after: 0,
lock,
};
let mut db = self.data.repo.lock().await;
db.create(db_task)
}
fn gen_lock(&self, task: &NewTask) -> Result<String, AtexError> {
let ex_name: &str = &task.executor;
let executor = match self.data.executors.get(&ex_name) {
None => return Err(AtexError::InvalidExecutor),
Some(val) => val,
};
Ok(executor.lock_key(&task.payload))
}
async fn create_periodic(&self, task: NewPeriodic) -> Result<TaskId, AtexError> {
let lock = task.executor.to_string();
let db_task = NewRepoTask {
periodic_interval: Some(task.interval),
executor: task.executor,
payload: "".to_string(),
created: now(),
updated: now(),
execute_after: now(),
lock,
};
let mut db = self.data.repo.lock().await;
if let Some(task) = db.get_by_executor(&db_task.executor)? {
Ok(task.id)
} else {
Ok(db.create(db_task)?)
}
}
async fn a_get(self, id: TaskId) -> Result<Option<Task>, AtexError> {
let db = self.data.repo.lock().await;
db.get_by_id(id)
}
fn prepare(&self) -> MutCtx<P::Pool> {
let len = self.data.pool_factory.len();
let (sender, receiver) = mpsc::channel::<usize>(len);
let pool = self.data.pool_factory.produce();
MutCtx {
sender,
receiver,
pool,
}
}
async fn init_periodics(&self) -> Result<(), AtexError> {
let mut periodics = Vec::new();
for executor in self.data.executors.values() {
if let Some(ts) = executor.periodic_interval() {
periodics.push(NewPeriodic {
executor: executor.name(),
interval: ts,
});
}
}
for periodic in periodics {
self.create_periodic(periodic).await?;
}
Ok(())
}
async fn a_run(self) -> Result<(), AtexError> {
self.init_periodics().await?;
let mut ctx = self.prepare();
{
let mut mtx = self.data.last_report.lock().await;
*mtx = (0..ctx.pool.len()).map(|_| None).collect();
}
loop {
let started = self.loop_iteration(&mut ctx).await?;
if started == 0 {
tokio::time::sleep(self.data.sleep).await
}
}
}
async fn loop_iteration(&self, ctx: &mut MutCtx<P::Pool>) -> Result<usize, AtexError> {
let mut placed = Vec::new();
let mut removed = None;
let result = match ctx.receiver.try_recv() {
Ok(id) => {
removed = Some(id);
self.clear_task_in_cell(ctx, id).await?;
self.fill_empty_cells(ctx, &mut placed).await?
}
Err(TryRecvError::Empty) => self.fill_empty_cells(ctx, &mut placed).await?,
Err(TryRecvError::Disconnected) => return Err(AtexError::TaskQueueIsBroken),
};
if removed.is_none() && placed.is_empty() {
return Ok(result);
}
let mut report = self.data.last_report.lock().await;
if let Some(i) = removed {
report[i] = None;
}
let started = now();
for (i, executor) in placed {
report[i] = Some(TaskStarted { executor, started });
}
Ok(result)
}
async fn clear_task_in_cell(
&self,
ctx: &mut MutCtx<P::Pool>,
id: usize,
) -> Result<(), AtexError> {
let (task_id, handler) = ctx.pool.get_cell(id)?;
let status = match handler.await {
Ok(val) => val,
Err(err) => {
self.data.logger.log(&format!("{:?}", err));
return Err(AtexError::TaskJoinMechanismFailed);
}
};
let mut db = self.data.repo.lock().await;
let task = match db.get_by_id(task_id)? {
None => {
self.data.logger.log("Task has beed removed");
return Ok(());
}
Some(val) => val,
};
let log_msg = if let TaskStatus::Error(ref err) = status {
Some(format!(
"task (id={}, executor={:?}) failed: {:?}",
task.id, task.executor, err,
))
} else {
None
};
if let Some(interval) = task.periodic_interval {
let ts = now() + interval;
db.set_execute_after(task_id, ts)?;
} else {
db.set_status(task_id, status)?
}
mem::drop(db);
if let Some(msg) = log_msg {
self.data.logger.log(&msg)
}
Ok(())
}
async fn fill_empty_cells(
&self,
ctx: &mut MutCtx<P::Pool>,
placed: &mut Vec<(usize, &'static str)>,
) -> Result<usize, AtexError> {
let exclude = ctx.pool.get_keys();
let free_cells = ctx.pool.free_cells();
if free_cells.is_empty() {
return Ok(0);
}
let db = self.data.repo.lock().await;
let tasks = db.get_ready(&exclude, free_cells.len())?;
mem::drop(db);
let started = tasks.len();
for (i, task) in tasks.into_iter().enumerate() {
let cell = free_cells[i];
let id = task.id;
let executor = task.executor.clone();
let handler = match self.start_task(task, ctx.sender.clone(), cell) {
Ok(val) => val,
Err(AtexError::ExecutorNotFound) => {
let status = TaskStatus::Error("Can not find executor".to_string());
self.data.repo.lock().await.set_status(id, status)?;
continue;
}
Err(err) => return Err(err),
};
ctx.pool.put_cell(cell, id, handler)?;
placed.push((cell, executor));
}
Ok(started)
}
fn start_task(
&self,
task: Task,
signal: Sender<usize>,
cell_id: usize,
) -> Result<JoinHandle<TaskStatus>, AtexError> {
let ex_name = task.executor.clone();
let executor = match self.data.executors.get(&(&ex_name as &str)) {
Some(val) => val,
None => {
self.data
.logger
.log(&format!("Can not find executor: {}", task.executor));
return Err(AtexError::ExecutorNotFound);
}
};
let ctrl = Box::new(Self {
data: self.data.clone(),
});
let task_future = executor.execute(ctrl, task);
let future = async move {
let status = task_future.await;
if let Err(err) = signal.send(cell_id).await {
panic!("executor failed with channel error: {:?}", err)
}
status
};
Ok(tokio::task::spawn(future))
}
async fn a_set_progress(self, id: TaskId, progress: u8) -> Result<(), AtexError> {
let mut repo = self.data.repo.lock().await;
repo.set_progress(id, progress)?;
Ok(())
}
async fn a_execute_periodic_now(self, executor: &'static str) -> Result<(), AtexError> {
let mut repo = self.data.repo.lock().await;
let task = match repo.get_by_executor(executor)? {
None => return Err(AtexError::InvalidName),
Some(val) => val,
};
repo.set_execute_after(task.id, now() - 1)?;
Ok(())
}
async fn a_manual_execute_periodic(
self,
executor: &'static str,
) -> Result<TaskStatus, AtexError> {
self.init_periodics().await?;
let opt_task = self.data.repo.lock().await.get_by_executor(executor)?;
match opt_task {
None => return Err(AtexError::InvalidExecutor),
Some(task) => self.manual_exec_task(task).await,
}
}
async fn a_manual_execute_task(self, task_id: TaskId) -> Result<TaskStatus, AtexError> {
let opt_task = self.data.repo.lock().await.get_by_id(task_id)?;
match opt_task {
None => return Err(AtexError::InvalidId),
Some(task) => self.manual_exec_task(task).await,
}
}
async fn manual_exec_task(&self, task: Task) -> Result<TaskStatus, AtexError> {
let ex_name = task.executor.clone();
let executor = match self.data.executors.get(&(&ex_name as &str)) {
Some(val) => val,
None => {
self.data
.logger
.log(&format!("Can not find executor: {}", task.executor));
return Err(AtexError::ExecutorNotFound);
}
};
let ctrl = Box::new(Self {
data: self.data.clone(),
});
Ok(executor.execute(ctrl, task).await)
}
fn name_to_task_key(&self, name: &str) -> Option<&'static str> {
let exec = self.data.executors.get(name)?;
Some(exec.name())
}
async fn a_get_stats(self) -> Vec<Option<TaskStats>> {
let now_time = now();
let started = self.data.last_report.lock().await;
let mut result = Vec::with_capacity(started.len());
for opt_task in started.iter() {
if let Some(task) = opt_task {
let stats = TaskStats {
executor: task.executor,
active_for: now_time - task.started,
};
result.push(Some(stats))
} else {
result.push(None)
}
}
result
}
}
impl<
P: IPoolFactory<TaskId, JoinHandle<TaskStatus>> + 'static,
R: ITaskRepo + 'static,
L: ILogger + 'static,
> ITaskManager for TaskManager<P, R, L>
{
fn create_task(&self, task: NewTask) -> DynFut<Result<TaskId, AtexError>> {
Box::pin(self.clone().a_create_task(task))
}
fn schedule_periodic_now(&self, name: &str) -> DynFut<Result<(), AtexError>> {
let key = match self.name_to_task_key(name) {
Some(val) => val,
None => return Box::pin(async move { Err(AtexError::InvalidName) }),
};
Box::pin(self.clone().a_execute_periodic_now(key))
}
fn manual_execute_periodic(&self, name: &str) -> DynFut<Result<TaskStatus, AtexError>> {
let key = match self.name_to_task_key(name) {
Some(val) => val,
None => return Box::pin(async move { Err(AtexError::InvalidName) }),
};
Box::pin(self.clone().a_manual_execute_periodic(key))
}
fn manual_execute_task(&self, task_id: TaskId) -> DynFut<Result<TaskStatus, AtexError>> {
Box::pin(self.clone().a_manual_execute_task(task_id))
}
fn get(&self, id: TaskId) -> DynFut<Result<Option<Task>, AtexError>> {
Box::pin(self.clone().a_get(id))
}
fn run(&self) -> DynFut<Result<(), AtexError>> {
Box::pin(self.clone().a_run())
}
fn get_stats(&self) -> DynFut<Vec<Option<TaskStats>>> {
Box::pin(self.clone().a_get_stats())
}
}
impl<
P: IPoolFactory<TaskId, JoinHandle<TaskStatus>> + 'static,
R: ITaskRepo + 'static,
L: ILogger + 'static,
> ITaskController for TaskManager<P, R, L>
{
fn set_progress(&self, id: TaskId, progress: u8) -> DynFut<Result<(), AtexError>> {
Box::pin(self.clone().a_set_progress(id, progress))
}
fn create_task(&self, task: NewTask) -> DynFut<Result<TaskId, AtexError>> {
Box::pin(self.clone().a_create_task(task))
}
}