1use async_trait::async_trait;
25use daggy::Walker;
26use daggy::{petgraph::visit::Topo, Dag, NodeIndex};
27use log::{error, info, warn};
28use parking_lot::Mutex;
29use std::borrow::Borrow;
30use std::sync::Arc;
31use std::sync::Weak;
32use std::time::Duration;
33use tokio::sync::watch;
34
35#[cfg(unix)]
36use crate::server::ListenFds;
37use crate::server::ShutdownWatch;
38
39pub mod background;
40pub mod listening;
41
42pub struct ServiceReadyNotifier {
62 sender: watch::Sender<bool>,
63}
64
65impl Drop for ServiceReadyNotifier {
66 fn drop(&mut self) {
69 let _ = self.sender.send(true);
71 }
72}
73
74impl ServiceReadyNotifier {
75 pub fn new(sender: watch::Sender<bool>) -> Self {
79 Self { sender }
80 }
81
82 pub fn notify_ready(self) {
86 drop(self);
88 }
89}
90
91pub type ServiceReadyWatch = watch::Receiver<bool>;
93
94#[derive(Debug, Clone)]
110pub struct ServiceHandle {
111 pub(crate) id: NodeIndex,
112 name: String,
113 ready_watch: ServiceReadyWatch,
114 dependencies: Weak<Mutex<DependencyGraph>>,
115}
116
117#[derive(Debug, Clone)]
119pub(crate) struct ServiceDependency {
120 pub name: String,
121 pub ready_watch: ServiceReadyWatch,
122}
123
124impl ServiceHandle {
125 pub(crate) fn new(
127 id: NodeIndex,
128 name: String,
129 ready_watch: ServiceReadyWatch,
130 dependencies: &Arc<Mutex<DependencyGraph>>,
131 ) -> Self {
132 Self {
133 id,
134 name,
135 ready_watch,
136 dependencies: Arc::downgrade(dependencies),
137 }
138 }
139
140 #[cfg(test)]
141 fn get_dependencies(&self) -> Vec<ServiceDependency> {
142 let Some(deps_lock) = self.dependencies.upgrade() else {
143 return Vec::new();
144 };
145
146 let deps = deps_lock.lock();
147 deps.get_dependencies(self.id)
148 }
149
150 pub fn name(&self) -> &str {
152 &self.name
153 }
154
155 #[allow(dead_code)]
157 pub(crate) fn ready_watch(&self) -> ServiceReadyWatch {
158 self.ready_watch.clone()
159 }
160
161 pub fn add_dependency(&self, dependency: impl Borrow<ServiceHandle>) {
176 let Some(deps_lock) = self.dependencies.upgrade() else {
177 warn!("Attempted to add a dependency after the dependency tree was dropped");
178 return;
179 };
180
181 let mut deps = deps_lock.lock();
182 if let Err(e) = deps.add_dependency(self.id, dependency.borrow().id) {
183 error!("Error creating dependency edge: {e}");
184 }
185 }
186
187 pub fn add_dependencies<'a, D>(&self, dependencies: impl IntoIterator<Item = D>)
203 where
204 D: Borrow<ServiceHandle> + 'a,
205 {
206 for dependency in dependencies {
207 self.add_dependency(dependency);
208 }
209 }
210}
211
212pub(crate) struct DependencyGraph {
214 dag: Dag<ServiceDependency, ()>,
216}
217
218impl DependencyGraph {
219 pub(crate) fn new() -> Self {
221 Self { dag: Dag::new() }
222 }
223
224 pub(crate) fn add_node(&mut self, name: String, ready_watch: ServiceReadyWatch) -> NodeIndex {
228 self.dag.add_node(ServiceDependency { name, ready_watch })
229 }
230 pub(crate) fn add_dependency(
235 &mut self,
236 dependent_service_node_idx: NodeIndex,
237 dependency_service_node_idx: NodeIndex,
238 ) -> Result<(), String> {
239 if let Err(cycle) =
242 self.dag
243 .add_edge(dependency_service_node_idx, dependent_service_node_idx, ())
244 {
245 return Err(format!(
246 "Circular service dependency detected between {} and {} creating cycle: {cycle}",
247 self.dag[dependency_service_node_idx].name,
248 self.dag[dependent_service_node_idx].name
249 ));
250 }
251
252 Ok(())
253 }
254
255 pub(crate) fn topological_sort(&self) -> Result<Vec<(NodeIndex, ServiceDependency)>, String> {
260 let mut sorted = Vec::new();
262 let mut topo = Topo::new(&self.dag);
263
264 while let Some(service_id) = topo.next(&self.dag) {
265 sorted.push((service_id, self.dag[service_id].clone()));
266 }
267
268 Ok(sorted)
269 }
270
271 pub(crate) fn get_dependencies(&self, service_id: NodeIndex) -> Vec<ServiceDependency> {
272 self.dag
273 .parents(service_id)
274 .iter(&self.dag)
275 .map(|(_, n)| self.dag[n].clone())
276 .collect()
277 }
278}
279
280impl Default for DependencyGraph {
281 fn default() -> Self {
282 Self::new()
283 }
284}
285
286#[async_trait]
287pub trait ServiceWithDependents: Send + Sync {
288 async fn start_service(
305 &mut self,
306 #[cfg(unix)] fds: Option<ListenFds>,
307 shutdown: ShutdownWatch,
308 listeners_per_fd: usize,
309 ready_notifier: ServiceReadyNotifier,
310 );
311
312 fn name(&self) -> &str;
316
317 fn threads(&self) -> Option<usize> {
321 None
322 }
323
324 fn on_startup_delay(&self, time_waited: Duration) {
333 info!(
334 "Service {} spent {}ms waiting on dependencies",
335 self.name(),
336 time_waited.as_millis()
337 );
338 }
339}
340
341#[async_trait]
342impl<S> ServiceWithDependents for S
343where
344 S: Service,
345{
346 async fn start_service(
347 &mut self,
348 #[cfg(unix)] fds: Option<ListenFds>,
349 shutdown: ShutdownWatch,
350 listeners_per_fd: usize,
351 ready_notifier: ServiceReadyNotifier,
352 ) {
353 ready_notifier.notify_ready();
355
356 S::start_service(
357 self,
358 #[cfg(unix)]
359 fds,
360 shutdown,
361 listeners_per_fd,
362 )
363 .await
364 }
365
366 fn name(&self) -> &str {
367 S::name(self)
368 }
369
370 fn threads(&self) -> Option<usize> {
371 S::threads(self)
372 }
373
374 fn on_startup_delay(&self, time_waited: Duration) {
375 S::on_startup_delay(self, time_waited)
376 }
377}
378
379#[async_trait]
381pub trait Service: Sync + Send {
382 async fn start_service(
395 &mut self,
396 #[cfg(unix)] _fds: Option<ListenFds>,
397 _shutdown: ShutdownWatch,
398 _listeners_per_fd: usize,
399 ) {
400 }
402
403 fn name(&self) -> &str;
407
408 fn threads(&self) -> Option<usize> {
412 None
413 }
414
415 fn on_startup_delay(&self, time_waited: Duration) {
424 info!(
425 "Service {} spent {}ms waiting on dependencies",
426 self.name(),
427 time_waited.as_millis()
428 );
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435
436 #[test]
437 fn test_service_handle_creation() {
438 let deps: Arc<Mutex<DependencyGraph>> = Arc::new(Mutex::new(DependencyGraph::new()));
439 let (tx, rx) = watch::channel(false);
440 let service_id = ServiceHandle::new(0.into(), "test_service".to_string(), rx, &deps);
441
442 assert_eq!(service_id.id, 0.into());
443 assert_eq!(service_id.name(), "test_service");
444
445 let watch_clone = service_id.ready_watch();
447 assert!(!*watch_clone.borrow());
448
449 tx.send(true).ok();
451 assert!(*watch_clone.borrow());
452 }
453
454 #[test]
455 fn test_service_handle_add_dependency() {
456 let graph: Arc<Mutex<DependencyGraph>> = Arc::new(Mutex::new(DependencyGraph::new()));
457 let (tx1, rx1) = watch::channel(false);
458 let (tx1_clone, rx1_clone) = (tx1.clone(), rx1.clone());
459 let (_tx2, rx2) = watch::channel(false);
460 let (_tx2_clone, rx2_clone) = (_tx2.clone(), rx2.clone());
461
462 let dep_node = {
464 let mut g = graph.lock();
465 g.add_node("dependency".to_string(), rx1)
466 };
467 let main_node = {
468 let mut g = graph.lock();
469 g.add_node("main".to_string(), rx2)
470 };
471
472 let dep_service = ServiceHandle::new(dep_node, "dependency".to_string(), rx1_clone, &graph);
473 let main_service = ServiceHandle::new(main_node, "main".to_string(), rx2_clone, &graph);
474
475 main_service.add_dependency(&dep_service);
477
478 let deps = main_service.get_dependencies();
480 assert_eq!(deps.len(), 1);
481 assert_eq!(deps[0].name, "dependency");
482
483 assert!(!*deps[0].ready_watch.borrow());
485 tx1_clone.send(true).ok();
486 assert!(*deps[0].ready_watch.borrow());
487 }
488
489 #[test]
490 fn test_service_handle_multiple_dependencies() {
491 let graph: Arc<Mutex<DependencyGraph>> = Arc::new(Mutex::new(DependencyGraph::new()));
492 let (_tx1, rx1) = watch::channel(false);
493 let rx1_clone = rx1.clone();
494 let (_tx2, rx2) = watch::channel(false);
495 let rx2_clone = rx2.clone();
496 let (_tx3, rx3) = watch::channel(false);
497 let rx3_clone = rx3.clone();
498
499 let dep1_node = {
501 let mut g = graph.lock();
502 g.add_node("dep1".to_string(), rx1)
503 };
504 let dep2_node = {
505 let mut g = graph.lock();
506 g.add_node("dep2".to_string(), rx2)
507 };
508 let main_node = {
509 let mut g = graph.lock();
510 g.add_node("main".to_string(), rx3)
511 };
512
513 let dep1 = ServiceHandle::new(dep1_node, "dep1".to_string(), rx1_clone, &graph);
514 let dep2 = ServiceHandle::new(dep2_node, "dep2".to_string(), rx2_clone, &graph);
515 let main_service = ServiceHandle::new(main_node, "main".to_string(), rx3_clone, &graph);
516
517 main_service.add_dependency(&dep1);
519 main_service.add_dependency(&dep2);
520
521 let deps = main_service.get_dependencies();
523 assert_eq!(deps.len(), 2);
524
525 let dep_names: Vec<&str> = deps.iter().map(|d| d.name.as_str()).collect();
526 assert!(dep_names.contains(&"dep1"));
527 assert!(dep_names.contains(&"dep2"));
528 }
529
530 #[test]
531 fn test_single_service_no_dependencies() {
532 let mut graph = DependencyGraph::new();
533 let (_tx, rx) = watch::channel(false);
534 let _node = graph.add_node("service1".to_string(), rx);
535
536 let order = graph.topological_sort().unwrap();
537 assert_eq!(order.len(), 1);
538 assert_eq!(order[0].1.name, "service1");
539 }
540
541 #[test]
542 fn test_simple_dependency_chain() {
543 let mut graph = DependencyGraph::new();
544 let (_tx1, rx1) = watch::channel(false);
545 let (_tx2, rx2) = watch::channel(false);
546 let (_tx3, rx3) = watch::channel(false);
547
548 let node1 = graph.add_node("service1".to_string(), rx1);
549 let node2 = graph.add_node("service2".to_string(), rx2);
550 let node3 = graph.add_node("service3".to_string(), rx3);
551
552 graph.add_dependency(node2, node1).unwrap();
554 graph.add_dependency(node3, node2).unwrap();
555
556 let order = graph.topological_sort().unwrap();
557 assert_eq!(order.len(), 3);
558 assert_eq!(order[0].1.name, "service1");
560 assert_eq!(order[1].1.name, "service2");
561 assert_eq!(order[2].1.name, "service3");
562 }
563
564 #[test]
565 fn test_diamond_dependency() {
566 let mut graph = DependencyGraph::new();
567 let (_tx1, rx1) = watch::channel(false);
568 let (_tx2, rx2) = watch::channel(false);
569 let (_tx3, rx3) = watch::channel(false);
570
571 let db = graph.add_node("db".to_string(), rx1);
572 let cache = graph.add_node("cache".to_string(), rx2);
573 let api = graph.add_node("api".to_string(), rx3);
574
575 graph.add_dependency(api, db).unwrap();
577 graph.add_dependency(api, cache).unwrap();
578
579 let order = graph.topological_sort().unwrap();
580 assert_eq!(order.len(), 3);
582 assert_eq!(order[2].1.name, "api");
583 let first_two: Vec<&str> = order[0..2].iter().map(|(_, d)| d.name.as_str()).collect();
584 assert!(first_two.contains(&"db"));
585 assert!(first_two.contains(&"cache"));
586 }
587
588 #[test]
589 #[should_panic(expected = "node indices out of bounds")]
590 fn test_missing_dependency() {
591 let mut graph = DependencyGraph::new();
592 let (_tx1, rx1) = watch::channel(false);
593
594 let node1 = graph.add_node("service1".to_string(), rx1);
595 let nonexistent = NodeIndex::new(999);
596
597 let _ = graph.add_dependency(node1, nonexistent);
599 }
600
601 #[test]
602 fn test_circular_dependency_self() {
603 let mut graph = DependencyGraph::new();
604 let (_tx1, rx1) = watch::channel(false);
605
606 let node1 = graph.add_node("service1".to_string(), rx1);
607
608 let result = graph.add_dependency(node1, node1);
610
611 assert!(result.is_err());
612 assert!(result.unwrap_err().contains("Circular"));
613 }
614
615 #[test]
616 fn test_circular_dependency_two_services() {
617 let mut graph = DependencyGraph::new();
618 let (_tx1, rx1) = watch::channel(false);
619 let (_tx2, rx2) = watch::channel(false);
620
621 let node1 = graph.add_node("service1".to_string(), rx1);
623 let node2 = graph.add_node("service2".to_string(), rx2);
624
625 graph.add_dependency(node1, node2).unwrap();
627 let result = graph.add_dependency(node2, node1);
628
629 assert!(result.is_err());
630 assert!(result.unwrap_err().contains("Circular"));
631 }
632
633 #[test]
634 fn test_circular_dependency_three_services() {
635 let mut graph = DependencyGraph::new();
636 let (_tx1, rx1) = watch::channel(false);
637 let (_tx2, rx2) = watch::channel(false);
638 let (_tx3, rx3) = watch::channel(false);
639
640 let node1 = graph.add_node("service1".to_string(), rx1);
642 let node2 = graph.add_node("service2".to_string(), rx2);
643 let node3 = graph.add_node("service3".to_string(), rx3);
644
645 graph.add_dependency(node1, node2).unwrap();
647 graph.add_dependency(node2, node3).unwrap();
648 let result = graph.add_dependency(node3, node1);
649
650 assert!(result.is_err());
651 assert!(result.unwrap_err().contains("Circular"));
652 }
653
654 #[test]
655 fn test_complex_valid_graph() {
656 let mut graph = DependencyGraph::new();
657 let (_tx1, rx1) = watch::channel(false);
658 let (_tx2, rx2) = watch::channel(false);
659 let (_tx3, rx3) = watch::channel(false);
660 let (_tx4, rx4) = watch::channel(false);
661 let (_tx5, rx5) = watch::channel(false);
662
663 let db = graph.add_node("db".to_string(), rx1);
669 let cache = graph.add_node("cache".to_string(), rx2);
670 let auth = graph.add_node("auth".to_string(), rx3);
671 let api = graph.add_node("api".to_string(), rx4);
672 let frontend = graph.add_node("frontend".to_string(), rx5);
673
674 graph.add_dependency(auth, db).unwrap();
675 graph.add_dependency(api, db).unwrap();
676 graph.add_dependency(api, cache).unwrap();
677 graph.add_dependency(api, auth).unwrap();
678 graph.add_dependency(frontend, api).unwrap();
679
680 let order = graph.topological_sort().unwrap();
681
682 let db_pos = order.iter().position(|(_, d)| d.name == "db").unwrap();
684 let cache_pos = order.iter().position(|(_, d)| d.name == "cache").unwrap();
685 let auth_pos = order.iter().position(|(_, d)| d.name == "auth").unwrap();
686 let api_pos = order.iter().position(|(_, d)| d.name == "api").unwrap();
687 let frontend_pos = order
688 .iter()
689 .position(|(_, d)| d.name == "frontend")
690 .unwrap();
691
692 assert!(db_pos < auth_pos);
693 assert!(auth_pos < api_pos);
694 assert!(db_pos < api_pos);
695 assert!(cache_pos < api_pos);
696 assert!(api_pos < frontend_pos);
697 }
698}