use std::{
any::Any,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
thread,
};
use dashmap::DashMap;
use mpsc::Sender;
use reifydb_core::{
interface::version::{ComponentType, HasVersion, SystemVersion},
util::ioc::IocContainer,
};
use reifydb_engine::engine::StandardEngine;
use reifydb_runtime::SharedRuntime;
use reifydb_sub_api::subsystem::{HealthStatus, Subsystem};
use reifydb_value::Result;
use tokio::{sync::mpsc, task::JoinHandle};
use tracing::{info, instrument};
use crate::{
coordinator,
coordinator::TaskCoordinatorMessage,
handle::TaskHandle,
registry::{TaskEntry, TaskRegistry},
task::ScheduledTask,
};
pub struct TaskSubsystem {
running: AtomicBool,
handle: Option<TaskHandle>,
coordinator_tx: Option<Sender<TaskCoordinatorMessage>>,
coordinator_handle: Option<JoinHandle<()>>,
runtime: SharedRuntime,
engine: StandardEngine,
registry: TaskRegistry,
initial_tasks: Vec<ScheduledTask>,
}
impl TaskSubsystem {
#[instrument(name = "task::subsystem::new", level = "debug", skip(ioc, initial_tasks))]
pub fn new(ioc: &IocContainer, initial_tasks: Vec<ScheduledTask>) -> Self {
let runtime = ioc.resolve::<SharedRuntime>().expect("SharedRuntime not registered in IoC");
let engine = ioc.resolve::<StandardEngine>().expect("StandardEngine not registered in IoC");
let registry = Arc::new(DashMap::new());
Self {
running: AtomicBool::new(false),
handle: None,
coordinator_tx: None,
coordinator_handle: None,
runtime,
engine,
registry,
initial_tasks,
}
}
pub fn handle(&self) -> Option<TaskHandle> {
self.handle.clone()
}
}
impl Subsystem for TaskSubsystem {
fn name(&self) -> &'static str {
"sub-task"
}
#[instrument(name = "task::subsystem::start", level = "debug", skip(self))]
fn start(&mut self) -> Result<()> {
if self.running.load(Ordering::Acquire) {
return Ok(());
}
info!("Starting task subsystem");
let (coordinator_tx, coordinator_rx) = mpsc::channel(100);
for task in self.initial_tasks.drain(..) {
let next_execution = self.runtime.clock().instant() + task.schedule.initial_delay();
self.registry.insert(
task.id,
TaskEntry {
task: Arc::new(task),
next_execution,
},
);
}
let handle = TaskHandle::new(self.registry.clone(), coordinator_tx.clone());
let registry = self.registry.clone();
let runtime = self.runtime.clone();
let engine = self.engine.clone();
let join_handle = self.runtime.spawn(async move {
coordinator::run_coordinator(registry, coordinator_rx, runtime, engine).await;
});
self.handle = Some(handle);
self.coordinator_tx = Some(coordinator_tx);
self.coordinator_handle = Some(join_handle);
self.running.store(true, Ordering::Release);
info!("Task subsystem started");
Ok(())
}
#[instrument(name = "task::subsystem::shutdown", level = "debug", skip(self))]
fn shutdown(&mut self) -> Result<()> {
if self.running.compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire).is_err() {
return Ok(());
}
info!("Shutting down task subsystem");
let coordinator_tx = self.coordinator_tx.take();
let coordinator_handle = self.coordinator_handle.take();
let runtime = self.runtime.clone();
let worker = thread::spawn(move || {
if let Some(coordinator_tx) = coordinator_tx {
let _ = coordinator_tx.blocking_send(TaskCoordinatorMessage::Shutdown);
}
if let Some(join_handle) = coordinator_handle {
let _ = runtime.block_on(join_handle);
}
});
let _ = worker.join();
self.handle = None;
info!("Task subsystem shut down");
Ok(())
}
#[instrument(name = "task::subsystem::is_running", level = "trace", skip(self))]
fn is_running(&self) -> bool {
self.running.load(Ordering::Acquire)
}
#[instrument(name = "task::subsystem::health_status", level = "debug", skip(self))]
fn health_status(&self) -> HealthStatus {
if self.is_running() {
HealthStatus::Healthy
} else {
HealthStatus::Unknown
}
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}
impl HasVersion for TaskSubsystem {
fn version(&self) -> SystemVersion {
SystemVersion {
name: env!("CARGO_PKG_NAME")
.strip_prefix("reifydb-")
.unwrap_or(env!("CARGO_PKG_NAME"))
.to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
description: "Periodic task scheduler subsystem".to_string(),
r#type: ComponentType::Subsystem,
}
}
}