use std::sync::{Arc, OnceLock};
use std::time::Duration;
use tokio::sync::broadcast;
use uni_plugin::PluginRegistry;
use uni_plugin::circuit_breaker::{BreakerConfig, CircuitBreaker};
use uni_plugin::plugin::PluginId;
use uni_plugin::qname::QName;
use uni_plugin::scheduler::{MemoryPersistence, Scheduler, SchedulerPersistence};
use uni_plugin::traits::background::{BackgroundJobProvider, JobContext, JobHost};
use uni_store::storage::manager::StorageManager;
use crate::host::HostCypherExecutor;
use crate::shutdown::ShutdownHandle;
pub const DEFAULT_TICK_INTERVAL: Duration = Duration::from_millis(100);
#[derive(Debug)]
pub struct SchedulerHost {
scheduler: Arc<Scheduler>,
persistence: Arc<dyn SchedulerPersistence>,
circuit_breaker: Arc<CircuitBreaker>,
job_host: Option<Arc<SchedulerJobHost>>,
}
pub struct SchedulerJobHost {
storage: Arc<StorageManager>,
host_executor: OnceLock<Arc<dyn HostCypherExecutor>>,
}
impl std::fmt::Debug for SchedulerJobHost {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SchedulerJobHost")
.field("host_executor_wired", &self.host_executor.get().is_some())
.finish_non_exhaustive()
}
}
impl SchedulerJobHost {
#[must_use]
pub fn new(storage: Arc<StorageManager>) -> Self {
Self {
storage,
host_executor: OnceLock::new(),
}
}
pub fn set_host_executor(&self, exec: Arc<dyn HostCypherExecutor>) {
let _ = self.host_executor.set(exec);
}
#[must_use]
pub fn storage(&self) -> &Arc<StorageManager> {
&self.storage
}
}
impl uni_plugin::scheduler::SchedulerControl for SchedulerHost {
fn add_scheduled_job(&self, id: QName, schedule: uni_plugin::traits::background::Schedule) {
if let Err(e) = self.persistence.record_scheduled(&id, &schedule) {
tracing::warn!(
qname = %id,
error = %e,
"SchedulerHost: record_scheduled failed; in-memory registration continues",
);
}
self.scheduler.add_scheduled_job(id, schedule);
}
fn cancel(&self, id: &QName) -> bool {
self.scheduler.cancel(id)
}
fn list(&self) -> Vec<uni_plugin::scheduler::SchedulerJobRecord> {
self.scheduler.list()
}
fn submit_cypher(&self, cypher: &str) -> Result<(), uni_plugin::FnError> {
let Some(host) = self.job_host.as_ref() else {
return Err(uni_plugin::FnError::new(
0xD21,
"submit_cypher: scheduler host has no JobHost wired",
));
};
host.execute_write_cypher(cypher)
}
fn flush_checkpoint(&self) -> Result<(), uni_plugin::FnError> {
self.persistence
.flush_checkpoint()
.map_err(|e| uni_plugin::FnError::new(0xD22, format!("flush_checkpoint: {e}")))
}
}
impl JobHost for SchedulerJobHost {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn compact_storage(&self) -> Result<(), uni_plugin::FnError> {
let storage = Arc::clone(&self.storage);
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async move { storage.compact().await })
})
.map(|_stats| ())
.map_err(|e| uni_plugin::FnError::new(0xD11, format!("compact_storage: {e}")))
}
fn execute_write_cypher(&self, cypher: &str) -> Result<(), uni_plugin::FnError> {
let Some(exec) = self.host_executor.get() else {
tracing::debug!("execute_write_cypher: host executor not wired (shutdown race?)",);
return Ok(());
};
exec.execute_write_cypher(cypher)
.map_err(|e| uni_plugin::FnError::new(0xD12, format!("execute_write_cypher: {e}")))
}
}
impl SchedulerHost {
#[must_use]
pub fn spawn(
registry: Arc<PluginRegistry>,
persistence: Arc<dyn SchedulerPersistence>,
shutdown: &ShutdownHandle,
tick_interval: Duration,
) -> Arc<Self> {
Self::spawn_with_job_host(registry, persistence, shutdown, tick_interval, None)
}
#[must_use]
pub fn spawn_with_job_host(
registry: Arc<PluginRegistry>,
persistence: Arc<dyn SchedulerPersistence>,
shutdown: &ShutdownHandle,
tick_interval: Duration,
job_host: Option<Arc<SchedulerJobHost>>,
) -> Arc<Self> {
let scheduler = Arc::new(Scheduler::new());
match persistence.load_all() {
Ok(records) => {
for record in records {
scheduler.add_scheduled_job(record.id.clone(), record.schedule);
}
let requeued = scheduler.requeue_orphaned_runs();
if requeued > 0 {
tracing::info!(
requeued,
"scheduler: requeued orphaned runs from previous shutdown"
);
}
}
Err(e) => tracing::warn!(error = %e, "scheduler: load_all failed; starting empty"),
}
scheduler.resume();
let circuit_breaker = Arc::new(CircuitBreaker::new(BreakerConfig::default()));
let host = Arc::new(Self {
scheduler: Arc::clone(&scheduler),
persistence: Arc::clone(&persistence),
circuit_breaker: Arc::clone(&circuit_breaker),
job_host: job_host.clone(),
});
let driver_scheduler = Arc::clone(&scheduler);
let driver_persistence = Arc::clone(&persistence);
let driver_registry = Arc::clone(®istry);
let driver_breaker = Arc::clone(&circuit_breaker);
let driver_job_host = job_host;
let shutdown_rx = shutdown.subscribe();
let handle = tokio::spawn(driver_loop(
driver_scheduler,
driver_persistence,
driver_registry,
driver_breaker,
driver_job_host,
shutdown_rx,
tick_interval,
));
shutdown.track_task(handle);
host
}
#[must_use]
pub fn job_host(&self) -> Option<&Arc<SchedulerJobHost>> {
self.job_host.as_ref()
}
#[must_use]
pub fn circuit_breaker(&self) -> &Arc<CircuitBreaker> {
&self.circuit_breaker
}
#[must_use]
pub fn scheduler(&self) -> &Arc<Scheduler> {
&self.scheduler
}
#[must_use]
pub fn persistence(&self) -> &Arc<dyn SchedulerPersistence> {
&self.persistence
}
}
#[must_use]
pub fn spawn_with_memory_persistence(
registry: Arc<PluginRegistry>,
shutdown: &ShutdownHandle,
) -> Arc<SchedulerHost> {
SchedulerHost::spawn(
registry,
Arc::new(MemoryPersistence),
shutdown,
DEFAULT_TICK_INTERVAL,
)
}
async fn driver_loop(
scheduler: Arc<Scheduler>,
persistence: Arc<dyn SchedulerPersistence>,
registry: Arc<PluginRegistry>,
circuit_breaker: Arc<CircuitBreaker>,
job_host: Option<Arc<SchedulerJobHost>>,
mut shutdown_rx: broadcast::Receiver<()>,
tick_interval: Duration,
) {
let mut ticker = tokio::time::interval(tick_interval);
ticker.tick().await;
loop {
tokio::select! {
_ = ticker.tick() => {
dispatch_one_tick(
&scheduler,
&persistence,
®istry,
&circuit_breaker,
job_host.as_ref(),
);
}
_ = shutdown_rx.recv() => {
tracing::info!("scheduler driver: shutdown received");
break;
}
}
}
}
fn dispatch_one_tick(
scheduler: &Arc<Scheduler>,
persistence: &Arc<dyn SchedulerPersistence>,
registry: &Arc<PluginRegistry>,
circuit_breaker: &Arc<CircuitBreaker>,
job_host: Option<&Arc<SchedulerJobHost>>,
) {
let due = scheduler.tick();
if due.is_empty() {
return;
}
let providers = registry.background_jobs();
let plugin_id = PluginId::new("uni");
for id in due {
if !circuit_breaker.allow(&plugin_id, &id) {
tracing::debug!(
job = %id,
"scheduler: circuit breaker open; skipping this tick"
);
scheduler.mark_finished(&id, false);
continue;
}
let Some(provider) = find_provider(&providers, &id) else {
tracing::warn!(
job = %id,
"scheduler: no provider registered; marking finished with failure"
);
let now = std::time::SystemTime::now();
scheduler.mark_finished(&id, false);
circuit_breaker.record_failure(&plugin_id, &id);
let _ = persistence.record_finished(&id, now, false);
continue;
};
let scheduler_clone = Arc::clone(scheduler);
let persistence_clone = Arc::clone(persistence);
let breaker_clone = Arc::clone(circuit_breaker);
let plugin_id_clone = plugin_id.clone();
let job_host_clone = job_host.cloned();
let started_at = std::time::SystemTime::now();
if let Err(e) = persistence_clone.record_started(&id, started_at) {
tracing::warn!(
job = %id,
error = %e,
"scheduler: record_started failed; continuing"
);
}
let cancel = scheduler.cancel_token_for(&id).unwrap_or_default();
let cancel_for_select = cancel.clone();
let id_for_log = id.clone();
let blocking = tokio::task::spawn_blocking(move || {
let mut ctx = JobContext::new(cancel, None);
if let Some(host) = job_host_clone.as_deref() {
ctx = ctx.with_host(host as &dyn JobHost);
}
provider.execute(ctx)
});
tokio::spawn(async move {
let success = tokio::select! {
joined = blocking => {
match joined {
Ok(outcome) => outcome.is_ok(),
Err(join_err) => {
tracing::warn!(
job = %id_for_log,
error = %join_err,
"scheduler: blocking dispatch join failed"
);
false
}
}
}
() = cancel_for_select.cancelled() => {
tracing::info!(
job = %id_for_log,
"scheduler: cancellation observed before job completion"
);
false
}
};
let finished_at = std::time::SystemTime::now();
scheduler_clone.mark_finished(&id, success);
if success {
breaker_clone.record_success(&plugin_id_clone, &id);
} else {
breaker_clone.record_failure(&plugin_id_clone, &id);
}
if let Err(e) = persistence_clone.record_finished(&id, finished_at, success) {
tracing::warn!(
job = %id,
error = %e,
"scheduler: record_finished failed"
);
}
});
}
}
fn find_provider(
providers: &Arc<Vec<Arc<dyn BackgroundJobProvider>>>,
id: &QName,
) -> Option<Arc<dyn BackgroundJobProvider>> {
providers
.iter()
.find(|p| &p.definition().id == id)
.map(Arc::clone)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU64, Ordering};
use uni_plugin::Capability;
use uni_plugin::CapabilitySet;
use uni_plugin::PluginRegistrar;
use uni_plugin::errors::FnError;
use uni_plugin::traits::background::{
ConcurrencyLimit, JobDefinition, JobOutcome, RetryPolicy, Schedule,
};
#[derive(Debug)]
struct CountingJob {
definition: JobDefinition,
counter: Arc<AtomicU64>,
}
impl BackgroundJobProvider for CountingJob {
fn definition(&self) -> &JobDefinition {
&self.definition
}
fn execute(&self, _ctx: JobContext<'_>) -> Result<JobOutcome, FnError> {
self.counter.fetch_add(1, Ordering::SeqCst);
Ok(JobOutcome::Done)
}
}
#[derive(Debug)]
struct AlwaysFailJob {
definition: JobDefinition,
attempts: Arc<AtomicU64>,
}
impl BackgroundJobProvider for AlwaysFailJob {
fn definition(&self) -> &JobDefinition {
&self.definition
}
fn execute(&self, _ctx: JobContext<'_>) -> Result<JobOutcome, FnError> {
self.attempts.fetch_add(1, Ordering::SeqCst);
Err(FnError::new(0xC1F, "always fails"))
}
}
fn make_registry_with_job(provider: Arc<dyn BackgroundJobProvider>) -> Arc<PluginRegistry> {
let registry = Arc::new(PluginRegistry::new());
let caps = CapabilitySet::from_iter_of([Capability::BackgroundJob { max_concurrent: 0 }]);
let plugin_id = uni_plugin::PluginId::new("test");
let mut r = PluginRegistrar::new(plugin_id, &caps, ®istry);
r.background_job(provider).expect("background_job register");
r.commit_to_registry().expect("commit");
registry
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn driver_fires_periodic_job() {
let counter = Arc::new(AtomicU64::new(0));
let provider = Arc::new(CountingJob {
definition: JobDefinition {
id: QName::new("test", "ticker"),
schedule: Schedule::Periodic(Duration::from_millis(50)),
concurrency: ConcurrencyLimit::Exclusive,
timeout: Duration::from_secs(1),
retry: RetryPolicy::Never,
docs: "test ticker".to_owned(),
},
counter: Arc::clone(&counter),
});
let registry = make_registry_with_job(provider);
let shutdown = ShutdownHandle::new(Duration::from_secs(5));
let host = SchedulerHost::spawn(
registry,
Arc::new(MemoryPersistence),
&shutdown,
Duration::from_millis(25),
);
host.scheduler().add_scheduled_job(
QName::new("test", "ticker"),
Schedule::Periodic(Duration::from_millis(50)),
);
tokio::time::sleep(Duration::from_millis(400)).await;
let fires = counter.load(Ordering::SeqCst);
assert!(
fires >= 2,
"expected the periodic job to fire at least twice, got {fires}"
);
let _ = shutdown.shutdown_async().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn cancel_halts_further_runs() {
let counter = Arc::new(AtomicU64::new(0));
let provider = Arc::new(CountingJob {
definition: JobDefinition {
id: QName::new("test", "cancelme"),
schedule: Schedule::Periodic(Duration::from_millis(50)),
concurrency: ConcurrencyLimit::Exclusive,
timeout: Duration::from_secs(1),
retry: RetryPolicy::Never,
docs: "cancelme".to_owned(),
},
counter: Arc::clone(&counter),
});
let registry = make_registry_with_job(provider);
let shutdown = ShutdownHandle::new(Duration::from_secs(5));
let host = SchedulerHost::spawn(
registry,
Arc::new(MemoryPersistence),
&shutdown,
Duration::from_millis(25),
);
let job_id = QName::new("test", "cancelme");
host.scheduler().add_scheduled_job(
job_id.clone(),
Schedule::Periodic(Duration::from_millis(50)),
);
tokio::time::sleep(Duration::from_millis(150)).await;
let pre_cancel = counter.load(Ordering::SeqCst);
assert!(pre_cancel >= 1, "expected at least one pre-cancel fire");
host.scheduler().cancel(&job_id);
tokio::time::sleep(Duration::from_millis(300)).await;
let post_cancel = counter.load(Ordering::SeqCst);
assert!(
post_cancel <= pre_cancel + 1,
"expected cancel to halt firing; pre={pre_cancel} post={post_cancel}"
);
let _ = shutdown.shutdown_async().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn circuit_breaker_opens_after_threshold_failures() {
let attempts = Arc::new(AtomicU64::new(0));
let provider = Arc::new(AlwaysFailJob {
definition: JobDefinition {
id: QName::new("test", "flaky"),
schedule: Schedule::Periodic(Duration::from_millis(20)),
concurrency: ConcurrencyLimit::Exclusive,
timeout: Duration::from_secs(1),
retry: RetryPolicy::Never,
docs: "flaky".to_owned(),
},
attempts: Arc::clone(&attempts),
});
let registry = make_registry_with_job(provider);
let shutdown = ShutdownHandle::new(Duration::from_secs(5));
let host = SchedulerHost::spawn(
registry,
Arc::new(MemoryPersistence),
&shutdown,
Duration::from_millis(10),
);
host.scheduler().add_scheduled_job(
QName::new("test", "flaky"),
Schedule::Periodic(Duration::from_millis(20)),
);
tokio::time::sleep(Duration::from_millis(500)).await;
let total_attempts = attempts.load(Ordering::SeqCst);
assert!(
(10..=20).contains(&total_attempts),
"expected the breaker to cap attempts around the failure_threshold (10); \
got {total_attempts}"
);
let _ = shutdown.shutdown_async().await;
}
}