1use std::sync::{Arc, OnceLock};
31use std::time::Duration;
32
33use tokio::sync::broadcast;
34
35use uni_plugin::PluginRegistry;
36use uni_plugin::circuit_breaker::{BreakerConfig, CircuitBreaker};
37use uni_plugin::plugin::PluginId;
38use uni_plugin::qname::QName;
39use uni_plugin::scheduler::{MemoryPersistence, Scheduler, SchedulerPersistence};
40use uni_plugin::traits::background::{BackgroundJobProvider, JobContext, JobHost};
41use uni_store::storage::manager::StorageManager;
42
43use crate::host::HostCypherExecutor;
44use crate::shutdown::ShutdownHandle;
45
46pub const DEFAULT_TICK_INTERVAL: Duration = Duration::from_millis(100);
50
51#[derive(Debug)]
58pub struct SchedulerHost {
59 scheduler: Arc<Scheduler>,
61 persistence: Arc<dyn SchedulerPersistence>,
63 circuit_breaker: Arc<CircuitBreaker>,
70 job_host: Option<Arc<SchedulerJobHost>>,
74}
75
76pub struct SchedulerJobHost {
83 storage: Arc<StorageManager>,
84 host_executor: OnceLock<Arc<dyn HostCypherExecutor>>,
89}
90
91impl std::fmt::Debug for SchedulerJobHost {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.debug_struct("SchedulerJobHost")
94 .field("host_executor_wired", &self.host_executor.get().is_some())
95 .finish_non_exhaustive()
96 }
97}
98
99impl SchedulerJobHost {
100 #[must_use]
104 pub fn new(storage: Arc<StorageManager>) -> Self {
105 Self {
106 storage,
107 host_executor: OnceLock::new(),
108 }
109 }
110
111 pub fn set_host_executor(&self, exec: Arc<dyn HostCypherExecutor>) {
114 let _ = self.host_executor.set(exec);
115 }
116
117 #[must_use]
119 pub fn storage(&self) -> &Arc<StorageManager> {
120 &self.storage
121 }
122}
123
124impl uni_plugin::scheduler::SchedulerControl for SchedulerHost {
132 fn add_scheduled_job(&self, id: QName, schedule: uni_plugin::traits::background::Schedule) {
133 if let Err(e) = self.persistence.record_scheduled(&id, &schedule) {
139 tracing::warn!(
140 qname = %id,
141 error = %e,
142 "SchedulerHost: record_scheduled failed; in-memory registration continues",
143 );
144 }
145 self.scheduler.add_scheduled_job(id, schedule);
146 }
147
148 fn cancel(&self, id: &QName) -> bool {
149 self.scheduler.cancel(id)
150 }
151
152 fn list(&self) -> Vec<uni_plugin::scheduler::SchedulerJobRecord> {
153 self.scheduler.list()
154 }
155
156 fn submit_cypher(&self, cypher: &str) -> Result<(), uni_plugin::FnError> {
157 let Some(host) = self.job_host.as_ref() else {
158 return Err(uni_plugin::FnError::new(
159 0xD21,
160 "submit_cypher: scheduler host has no JobHost wired",
161 ));
162 };
163 host.execute_write_cypher(cypher)
164 }
165
166 fn flush_checkpoint(&self) -> Result<(), uni_plugin::FnError> {
167 self.persistence
168 .flush_checkpoint()
169 .map_err(|e| uni_plugin::FnError::new(0xD22, format!("flush_checkpoint: {e}")))
170 }
171}
172
173impl JobHost for SchedulerJobHost {
174 fn as_any(&self) -> &dyn std::any::Any {
175 self
176 }
177
178 fn compact_storage(&self) -> Result<(), uni_plugin::FnError> {
179 let storage = Arc::clone(&self.storage);
184 tokio::task::block_in_place(|| {
185 tokio::runtime::Handle::current().block_on(async move { storage.compact().await })
186 })
187 .map(|_stats| ())
188 .map_err(|e| uni_plugin::FnError::new(0xD11, format!("compact_storage: {e}")))
189 }
190
191 fn execute_write_cypher(&self, cypher: &str) -> Result<(), uni_plugin::FnError> {
192 let Some(exec) = self.host_executor.get() else {
200 tracing::debug!("execute_write_cypher: host executor not wired (shutdown race?)",);
201 return Ok(());
202 };
203 exec.execute_write_cypher(cypher)
204 .map_err(|e| uni_plugin::FnError::new(0xD12, format!("execute_write_cypher: {e}")))
205 }
206}
207
208impl SchedulerHost {
209 #[must_use]
224 pub fn spawn(
225 registry: Arc<PluginRegistry>,
226 persistence: Arc<dyn SchedulerPersistence>,
227 shutdown: &ShutdownHandle,
228 tick_interval: Duration,
229 ) -> Arc<Self> {
230 Self::spawn_with_job_host(registry, persistence, shutdown, tick_interval, None)
231 }
232
233 #[must_use]
237 pub fn spawn_with_job_host(
238 registry: Arc<PluginRegistry>,
239 persistence: Arc<dyn SchedulerPersistence>,
240 shutdown: &ShutdownHandle,
241 tick_interval: Duration,
242 job_host: Option<Arc<SchedulerJobHost>>,
243 ) -> Arc<Self> {
244 let scheduler = Arc::new(Scheduler::new());
245
246 match persistence.load_all() {
250 Ok(records) => {
251 for record in records {
252 scheduler.add_scheduled_job(record.id.clone(), record.schedule);
253 }
254 let requeued = scheduler.requeue_orphaned_runs();
257 if requeued > 0 {
258 tracing::info!(
259 requeued,
260 "scheduler: requeued orphaned runs from previous shutdown"
261 );
262 }
263 }
264 Err(e) => tracing::warn!(error = %e, "scheduler: load_all failed; starting empty"),
265 }
266
267 scheduler.resume();
268
269 let circuit_breaker = Arc::new(CircuitBreaker::new(BreakerConfig::default()));
270
271 let host = Arc::new(Self {
272 scheduler: Arc::clone(&scheduler),
273 persistence: Arc::clone(&persistence),
274 circuit_breaker: Arc::clone(&circuit_breaker),
275 job_host: job_host.clone(),
276 });
277
278 let driver_scheduler = Arc::clone(&scheduler);
280 let driver_persistence = Arc::clone(&persistence);
281 let driver_registry = Arc::clone(®istry);
282 let driver_breaker = Arc::clone(&circuit_breaker);
283 let driver_job_host = job_host;
284 let shutdown_rx = shutdown.subscribe();
285 let handle = tokio::spawn(driver_loop(
286 driver_scheduler,
287 driver_persistence,
288 driver_registry,
289 driver_breaker,
290 driver_job_host,
291 shutdown_rx,
292 tick_interval,
293 ));
294 shutdown.track_task(handle);
295
296 host
297 }
298
299 #[must_use]
301 pub fn job_host(&self) -> Option<&Arc<SchedulerJobHost>> {
302 self.job_host.as_ref()
303 }
304
305 #[must_use]
309 pub fn circuit_breaker(&self) -> &Arc<CircuitBreaker> {
310 &self.circuit_breaker
311 }
312
313 #[must_use]
315 pub fn scheduler(&self) -> &Arc<Scheduler> {
316 &self.scheduler
317 }
318
319 #[must_use]
321 pub fn persistence(&self) -> &Arc<dyn SchedulerPersistence> {
322 &self.persistence
323 }
324}
325
326#[must_use]
330pub fn spawn_with_memory_persistence(
331 registry: Arc<PluginRegistry>,
332 shutdown: &ShutdownHandle,
333) -> Arc<SchedulerHost> {
334 SchedulerHost::spawn(
335 registry,
336 Arc::new(MemoryPersistence),
337 shutdown,
338 DEFAULT_TICK_INTERVAL,
339 )
340}
341
342async fn driver_loop(
347 scheduler: Arc<Scheduler>,
348 persistence: Arc<dyn SchedulerPersistence>,
349 registry: Arc<PluginRegistry>,
350 circuit_breaker: Arc<CircuitBreaker>,
351 job_host: Option<Arc<SchedulerJobHost>>,
352 mut shutdown_rx: broadcast::Receiver<()>,
353 tick_interval: Duration,
354) {
355 let mut ticker = tokio::time::interval(tick_interval);
356 ticker.tick().await;
359 loop {
360 tokio::select! {
361 _ = ticker.tick() => {
362 dispatch_one_tick(
363 &scheduler,
364 &persistence,
365 ®istry,
366 &circuit_breaker,
367 job_host.as_ref(),
368 );
369 }
370 _ = shutdown_rx.recv() => {
371 tracing::info!("scheduler driver: shutdown received");
372 break;
373 }
374 }
375 }
376}
377
378fn dispatch_one_tick(
383 scheduler: &Arc<Scheduler>,
384 persistence: &Arc<dyn SchedulerPersistence>,
385 registry: &Arc<PluginRegistry>,
386 circuit_breaker: &Arc<CircuitBreaker>,
387 job_host: Option<&Arc<SchedulerJobHost>>,
388) {
389 let due = scheduler.tick();
390 if due.is_empty() {
391 return;
392 }
393 let providers = registry.background_jobs();
394 let plugin_id = PluginId::new("uni");
398 for id in due {
399 if !circuit_breaker.allow(&plugin_id, &id) {
402 tracing::debug!(
403 job = %id,
404 "scheduler: circuit breaker open; skipping this tick"
405 );
406 scheduler.mark_finished(&id, false);
412 continue;
413 }
414 let Some(provider) = find_provider(&providers, &id) else {
415 tracing::warn!(
416 job = %id,
417 "scheduler: no provider registered; marking finished with failure"
418 );
419 let now = std::time::SystemTime::now();
420 scheduler.mark_finished(&id, false);
421 circuit_breaker.record_failure(&plugin_id, &id);
422 let _ = persistence.record_finished(&id, now, false);
423 continue;
424 };
425 let scheduler_clone = Arc::clone(scheduler);
426 let persistence_clone = Arc::clone(persistence);
427 let breaker_clone = Arc::clone(circuit_breaker);
428 let plugin_id_clone = plugin_id.clone();
429 let job_host_clone = job_host.cloned();
430 let started_at = std::time::SystemTime::now();
431 if let Err(e) = persistence_clone.record_started(&id, started_at) {
432 tracing::warn!(
433 job = %id,
434 error = %e,
435 "scheduler: record_started failed; continuing"
436 );
437 }
438 let cancel = scheduler.cancel_token_for(&id).unwrap_or_default();
445 let cancel_for_select = cancel.clone();
446 let id_for_log = id.clone();
447 let blocking = tokio::task::spawn_blocking(move || {
448 let mut ctx = JobContext::new(cancel, None);
449 if let Some(host) = job_host_clone.as_deref() {
456 ctx = ctx.with_host(host as &dyn JobHost);
457 }
458 provider.execute(ctx)
459 });
460 tokio::spawn(async move {
469 let success = tokio::select! {
470 joined = blocking => {
471 match joined {
472 Ok(outcome) => outcome.is_ok(),
473 Err(join_err) => {
474 tracing::warn!(
475 job = %id_for_log,
476 error = %join_err,
477 "scheduler: blocking dispatch join failed"
478 );
479 false
480 }
481 }
482 }
483 () = cancel_for_select.cancelled() => {
484 tracing::info!(
485 job = %id_for_log,
486 "scheduler: cancellation observed before job completion"
487 );
488 false
489 }
490 };
491 let finished_at = std::time::SystemTime::now();
492 scheduler_clone.mark_finished(&id, success);
493 if success {
494 breaker_clone.record_success(&plugin_id_clone, &id);
495 } else {
496 breaker_clone.record_failure(&plugin_id_clone, &id);
497 }
498 if let Err(e) = persistence_clone.record_finished(&id, finished_at, success) {
499 tracing::warn!(
500 job = %id,
501 error = %e,
502 "scheduler: record_finished failed"
503 );
504 }
505 });
506 }
507}
508
509fn find_provider(
517 providers: &Arc<Vec<Arc<dyn BackgroundJobProvider>>>,
518 id: &QName,
519) -> Option<Arc<dyn BackgroundJobProvider>> {
520 providers
521 .iter()
522 .find(|p| &p.definition().id == id)
523 .map(Arc::clone)
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529 use std::sync::atomic::{AtomicU64, Ordering};
530 use uni_plugin::Capability;
531 use uni_plugin::CapabilitySet;
532 use uni_plugin::PluginRegistrar;
533 use uni_plugin::errors::FnError;
534 use uni_plugin::traits::background::{
535 ConcurrencyLimit, JobDefinition, JobOutcome, RetryPolicy, Schedule,
536 };
537
538 #[derive(Debug)]
540 struct CountingJob {
541 definition: JobDefinition,
542 counter: Arc<AtomicU64>,
543 }
544
545 impl BackgroundJobProvider for CountingJob {
546 fn definition(&self) -> &JobDefinition {
547 &self.definition
548 }
549
550 fn execute(&self, _ctx: JobContext<'_>) -> Result<JobOutcome, FnError> {
551 self.counter.fetch_add(1, Ordering::SeqCst);
552 Ok(JobOutcome::Done)
553 }
554 }
555
556 #[derive(Debug)]
559 struct AlwaysFailJob {
560 definition: JobDefinition,
561 attempts: Arc<AtomicU64>,
562 }
563
564 impl BackgroundJobProvider for AlwaysFailJob {
565 fn definition(&self) -> &JobDefinition {
566 &self.definition
567 }
568
569 fn execute(&self, _ctx: JobContext<'_>) -> Result<JobOutcome, FnError> {
570 self.attempts.fetch_add(1, Ordering::SeqCst);
571 Err(FnError::new(0xC1F, "always fails"))
572 }
573 }
574
575 fn make_registry_with_job(provider: Arc<dyn BackgroundJobProvider>) -> Arc<PluginRegistry> {
576 let registry = Arc::new(PluginRegistry::new());
577 let caps = CapabilitySet::from_iter_of([Capability::BackgroundJob { max_concurrent: 0 }]);
578 let plugin_id = uni_plugin::PluginId::new("test");
579 let mut r = PluginRegistrar::new(plugin_id, &caps, ®istry);
580 r.background_job(provider).expect("background_job register");
581 r.commit_to_registry().expect("commit");
582 registry
583 }
584
585 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
586 async fn driver_fires_periodic_job() {
587 let counter = Arc::new(AtomicU64::new(0));
588 let provider = Arc::new(CountingJob {
589 definition: JobDefinition {
590 id: QName::new("test", "ticker"),
591 schedule: Schedule::Periodic(Duration::from_millis(50)),
592 concurrency: ConcurrencyLimit::Exclusive,
593 timeout: Duration::from_secs(1),
594 retry: RetryPolicy::Never,
595 docs: "test ticker".to_owned(),
596 },
597 counter: Arc::clone(&counter),
598 });
599 let registry = make_registry_with_job(provider);
600 let shutdown = ShutdownHandle::new(Duration::from_secs(5));
601 let host = SchedulerHost::spawn(
602 registry,
603 Arc::new(MemoryPersistence),
604 &shutdown,
605 Duration::from_millis(25),
606 );
607 host.scheduler().add_scheduled_job(
608 QName::new("test", "ticker"),
609 Schedule::Periodic(Duration::from_millis(50)),
610 );
611
612 tokio::time::sleep(Duration::from_millis(400)).await;
614
615 let fires = counter.load(Ordering::SeqCst);
616 assert!(
617 fires >= 2,
618 "expected the periodic job to fire at least twice, got {fires}"
619 );
620
621 let _ = shutdown.shutdown_async().await;
623 }
624
625 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
626 async fn cancel_halts_further_runs() {
627 let counter = Arc::new(AtomicU64::new(0));
628 let provider = Arc::new(CountingJob {
629 definition: JobDefinition {
630 id: QName::new("test", "cancelme"),
631 schedule: Schedule::Periodic(Duration::from_millis(50)),
632 concurrency: ConcurrencyLimit::Exclusive,
633 timeout: Duration::from_secs(1),
634 retry: RetryPolicy::Never,
635 docs: "cancelme".to_owned(),
636 },
637 counter: Arc::clone(&counter),
638 });
639 let registry = make_registry_with_job(provider);
640 let shutdown = ShutdownHandle::new(Duration::from_secs(5));
641 let host = SchedulerHost::spawn(
642 registry,
643 Arc::new(MemoryPersistence),
644 &shutdown,
645 Duration::from_millis(25),
646 );
647 let job_id = QName::new("test", "cancelme");
648 host.scheduler().add_scheduled_job(
649 job_id.clone(),
650 Schedule::Periodic(Duration::from_millis(50)),
651 );
652
653 tokio::time::sleep(Duration::from_millis(150)).await;
655 let pre_cancel = counter.load(Ordering::SeqCst);
656 assert!(pre_cancel >= 1, "expected at least one pre-cancel fire");
657
658 host.scheduler().cancel(&job_id);
659
660 tokio::time::sleep(Duration::from_millis(300)).await;
662 let post_cancel = counter.load(Ordering::SeqCst);
663
664 assert!(
667 post_cancel <= pre_cancel + 1,
668 "expected cancel to halt firing; pre={pre_cancel} post={post_cancel}"
669 );
670
671 let _ = shutdown.shutdown_async().await;
672 }
673
674 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
675 async fn circuit_breaker_opens_after_threshold_failures() {
676 let attempts = Arc::new(AtomicU64::new(0));
677 let provider = Arc::new(AlwaysFailJob {
678 definition: JobDefinition {
679 id: QName::new("test", "flaky"),
680 schedule: Schedule::Periodic(Duration::from_millis(20)),
681 concurrency: ConcurrencyLimit::Exclusive,
682 timeout: Duration::from_secs(1),
683 retry: RetryPolicy::Never,
684 docs: "flaky".to_owned(),
685 },
686 attempts: Arc::clone(&attempts),
687 });
688 let registry = make_registry_with_job(provider);
689 let shutdown = ShutdownHandle::new(Duration::from_secs(5));
690 let host = SchedulerHost::spawn(
691 registry,
692 Arc::new(MemoryPersistence),
693 &shutdown,
694 Duration::from_millis(10),
695 );
696 host.scheduler().add_scheduled_job(
697 QName::new("test", "flaky"),
698 Schedule::Periodic(Duration::from_millis(20)),
699 );
700
701 tokio::time::sleep(Duration::from_millis(500)).await;
706
707 let total_attempts = attempts.load(Ordering::SeqCst);
708 assert!(
709 (10..=20).contains(&total_attempts),
710 "expected the breaker to cap attempts around the failure_threshold (10); \
711 got {total_attempts}"
712 );
713
714 let _ = shutdown.shutdown_async().await;
715 }
716}