use async_trait::async_trait;
use daggy::Walker;
use daggy::{petgraph::visit::Topo, Dag, NodeIndex};
use log::{error, info, warn};
use parking_lot::Mutex;
use std::borrow::Borrow;
use std::sync::Arc;
use std::sync::Weak;
use std::time::Duration;
use tokio::sync::watch;
#[cfg(unix)]
use crate::server::ListenFds;
use crate::server::ShutdownWatch;
pub mod background;
pub mod listening;
pub struct ServiceReadyNotifier {
sender: watch::Sender<bool>,
}
impl Drop for ServiceReadyNotifier {
fn drop(&mut self) {
let _ = self.sender.send(true);
}
}
impl ServiceReadyNotifier {
pub fn new(sender: watch::Sender<bool>) -> Self {
Self { sender }
}
pub fn notify_ready(self) {
drop(self);
}
}
pub type ServiceReadyWatch = watch::Receiver<bool>;
#[derive(Debug, Clone)]
pub struct ServiceHandle {
pub(crate) id: NodeIndex,
name: String,
ready_watch: ServiceReadyWatch,
dependencies: Weak<Mutex<DependencyGraph>>,
}
#[derive(Debug, Clone)]
pub(crate) struct ServiceDependency {
pub name: String,
pub ready_watch: ServiceReadyWatch,
}
impl ServiceHandle {
pub(crate) fn new(
id: NodeIndex,
name: String,
ready_watch: ServiceReadyWatch,
dependencies: &Arc<Mutex<DependencyGraph>>,
) -> Self {
Self {
id,
name,
ready_watch,
dependencies: Arc::downgrade(dependencies),
}
}
#[cfg(test)]
fn get_dependencies(&self) -> Vec<ServiceDependency> {
let Some(deps_lock) = self.dependencies.upgrade() else {
return Vec::new();
};
let deps = deps_lock.lock();
deps.get_dependencies(self.id)
}
pub fn name(&self) -> &str {
&self.name
}
#[allow(dead_code)]
pub(crate) fn ready_watch(&self) -> ServiceReadyWatch {
self.ready_watch.clone()
}
pub fn add_dependency(&self, dependency: impl Borrow<ServiceHandle>) {
let Some(deps_lock) = self.dependencies.upgrade() else {
warn!("Attempted to add a dependency after the dependency tree was dropped");
return;
};
let mut deps = deps_lock.lock();
if let Err(e) = deps.add_dependency(self.id, dependency.borrow().id) {
error!("Error creating dependency edge: {e}");
}
}
pub fn add_dependencies<'a, D>(&self, dependencies: impl IntoIterator<Item = D>)
where
D: Borrow<ServiceHandle> + 'a,
{
for dependency in dependencies {
self.add_dependency(dependency);
}
}
}
pub(crate) struct DependencyGraph {
dag: Dag<ServiceDependency, ()>,
}
impl DependencyGraph {
pub(crate) fn new() -> Self {
Self { dag: Dag::new() }
}
pub(crate) fn add_node(&mut self, name: String, ready_watch: ServiceReadyWatch) -> NodeIndex {
self.dag.add_node(ServiceDependency { name, ready_watch })
}
pub(crate) fn add_dependency(
&mut self,
dependent_service_node_idx: NodeIndex,
dependency_service_node_idx: NodeIndex,
) -> Result<(), String> {
if let Err(cycle) =
self.dag
.add_edge(dependency_service_node_idx, dependent_service_node_idx, ())
{
return Err(format!(
"Circular service dependency detected between {} and {} creating cycle: {cycle}",
self.dag[dependency_service_node_idx].name,
self.dag[dependent_service_node_idx].name
));
}
Ok(())
}
pub(crate) fn topological_sort(&self) -> Result<Vec<(NodeIndex, ServiceDependency)>, String> {
let mut sorted = Vec::new();
let mut topo = Topo::new(&self.dag);
while let Some(service_id) = topo.next(&self.dag) {
sorted.push((service_id, self.dag[service_id].clone()));
}
Ok(sorted)
}
pub(crate) fn get_dependencies(&self, service_id: NodeIndex) -> Vec<ServiceDependency> {
self.dag
.parents(service_id)
.iter(&self.dag)
.map(|(_, n)| self.dag[n].clone())
.collect()
}
}
impl Default for DependencyGraph {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
pub trait ServiceWithDependents: Send + Sync {
async fn start_service(
&mut self,
#[cfg(unix)] fds: Option<ListenFds>,
shutdown: ShutdownWatch,
listeners_per_fd: usize,
ready_notifier: ServiceReadyNotifier,
);
fn name(&self) -> &str;
fn threads(&self) -> Option<usize> {
None
}
fn on_startup_delay(&self, time_waited: Duration) {
info!(
"Service {} spent {}ms waiting on dependencies",
self.name(),
time_waited.as_millis()
);
}
}
#[async_trait]
impl<S> ServiceWithDependents for S
where
S: Service,
{
async fn start_service(
&mut self,
#[cfg(unix)] fds: Option<ListenFds>,
shutdown: ShutdownWatch,
listeners_per_fd: usize,
ready_notifier: ServiceReadyNotifier,
) {
ready_notifier.notify_ready();
S::start_service(
self,
#[cfg(unix)]
fds,
shutdown,
listeners_per_fd,
)
.await
}
fn name(&self) -> &str {
S::name(self)
}
fn threads(&self) -> Option<usize> {
S::threads(self)
}
fn on_startup_delay(&self, time_waited: Duration) {
S::on_startup_delay(self, time_waited)
}
}
#[async_trait]
pub trait Service: Sync + Send {
async fn start_service(
&mut self,
#[cfg(unix)] _fds: Option<ListenFds>,
_shutdown: ShutdownWatch,
_listeners_per_fd: usize,
) {
}
fn name(&self) -> &str;
fn threads(&self) -> Option<usize> {
None
}
fn on_startup_delay(&self, time_waited: Duration) {
info!(
"Service {} spent {}ms waiting on dependencies",
self.name(),
time_waited.as_millis()
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_service_handle_creation() {
let deps: Arc<Mutex<DependencyGraph>> = Arc::new(Mutex::new(DependencyGraph::new()));
let (tx, rx) = watch::channel(false);
let service_id = ServiceHandle::new(0.into(), "test_service".to_string(), rx, &deps);
assert_eq!(service_id.id, 0.into());
assert_eq!(service_id.name(), "test_service");
let watch_clone = service_id.ready_watch();
assert!(!*watch_clone.borrow());
tx.send(true).ok();
assert!(*watch_clone.borrow());
}
#[test]
fn test_service_handle_add_dependency() {
let graph: Arc<Mutex<DependencyGraph>> = Arc::new(Mutex::new(DependencyGraph::new()));
let (tx1, rx1) = watch::channel(false);
let (tx1_clone, rx1_clone) = (tx1.clone(), rx1.clone());
let (_tx2, rx2) = watch::channel(false);
let (_tx2_clone, rx2_clone) = (_tx2.clone(), rx2.clone());
let dep_node = {
let mut g = graph.lock();
g.add_node("dependency".to_string(), rx1)
};
let main_node = {
let mut g = graph.lock();
g.add_node("main".to_string(), rx2)
};
let dep_service = ServiceHandle::new(dep_node, "dependency".to_string(), rx1_clone, &graph);
let main_service = ServiceHandle::new(main_node, "main".to_string(), rx2_clone, &graph);
main_service.add_dependency(&dep_service);
let deps = main_service.get_dependencies();
assert_eq!(deps.len(), 1);
assert_eq!(deps[0].name, "dependency");
assert!(!*deps[0].ready_watch.borrow());
tx1_clone.send(true).ok();
assert!(*deps[0].ready_watch.borrow());
}
#[test]
fn test_service_handle_multiple_dependencies() {
let graph: Arc<Mutex<DependencyGraph>> = Arc::new(Mutex::new(DependencyGraph::new()));
let (_tx1, rx1) = watch::channel(false);
let rx1_clone = rx1.clone();
let (_tx2, rx2) = watch::channel(false);
let rx2_clone = rx2.clone();
let (_tx3, rx3) = watch::channel(false);
let rx3_clone = rx3.clone();
let dep1_node = {
let mut g = graph.lock();
g.add_node("dep1".to_string(), rx1)
};
let dep2_node = {
let mut g = graph.lock();
g.add_node("dep2".to_string(), rx2)
};
let main_node = {
let mut g = graph.lock();
g.add_node("main".to_string(), rx3)
};
let dep1 = ServiceHandle::new(dep1_node, "dep1".to_string(), rx1_clone, &graph);
let dep2 = ServiceHandle::new(dep2_node, "dep2".to_string(), rx2_clone, &graph);
let main_service = ServiceHandle::new(main_node, "main".to_string(), rx3_clone, &graph);
main_service.add_dependency(&dep1);
main_service.add_dependency(&dep2);
let deps = main_service.get_dependencies();
assert_eq!(deps.len(), 2);
let dep_names: Vec<&str> = deps.iter().map(|d| d.name.as_str()).collect();
assert!(dep_names.contains(&"dep1"));
assert!(dep_names.contains(&"dep2"));
}
#[test]
fn test_single_service_no_dependencies() {
let mut graph = DependencyGraph::new();
let (_tx, rx) = watch::channel(false);
let _node = graph.add_node("service1".to_string(), rx);
let order = graph.topological_sort().unwrap();
assert_eq!(order.len(), 1);
assert_eq!(order[0].1.name, "service1");
}
#[test]
fn test_simple_dependency_chain() {
let mut graph = DependencyGraph::new();
let (_tx1, rx1) = watch::channel(false);
let (_tx2, rx2) = watch::channel(false);
let (_tx3, rx3) = watch::channel(false);
let node1 = graph.add_node("service1".to_string(), rx1);
let node2 = graph.add_node("service2".to_string(), rx2);
let node3 = graph.add_node("service3".to_string(), rx3);
graph.add_dependency(node2, node1).unwrap();
graph.add_dependency(node3, node2).unwrap();
let order = graph.topological_sort().unwrap();
assert_eq!(order.len(), 3);
assert_eq!(order[0].1.name, "service1");
assert_eq!(order[1].1.name, "service2");
assert_eq!(order[2].1.name, "service3");
}
#[test]
fn test_diamond_dependency() {
let mut graph = DependencyGraph::new();
let (_tx1, rx1) = watch::channel(false);
let (_tx2, rx2) = watch::channel(false);
let (_tx3, rx3) = watch::channel(false);
let db = graph.add_node("db".to_string(), rx1);
let cache = graph.add_node("cache".to_string(), rx2);
let api = graph.add_node("api".to_string(), rx3);
graph.add_dependency(api, db).unwrap();
graph.add_dependency(api, cache).unwrap();
let order = graph.topological_sort().unwrap();
assert_eq!(order.len(), 3);
assert_eq!(order[2].1.name, "api");
let first_two: Vec<&str> = order[0..2].iter().map(|(_, d)| d.name.as_str()).collect();
assert!(first_two.contains(&"db"));
assert!(first_two.contains(&"cache"));
}
#[test]
#[should_panic(expected = "node indices out of bounds")]
fn test_missing_dependency() {
let mut graph = DependencyGraph::new();
let (_tx1, rx1) = watch::channel(false);
let node1 = graph.add_node("service1".to_string(), rx1);
let nonexistent = NodeIndex::new(999);
let _ = graph.add_dependency(node1, nonexistent);
}
#[test]
fn test_circular_dependency_self() {
let mut graph = DependencyGraph::new();
let (_tx1, rx1) = watch::channel(false);
let node1 = graph.add_node("service1".to_string(), rx1);
let result = graph.add_dependency(node1, node1);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Circular"));
}
#[test]
fn test_circular_dependency_two_services() {
let mut graph = DependencyGraph::new();
let (_tx1, rx1) = watch::channel(false);
let (_tx2, rx2) = watch::channel(false);
let node1 = graph.add_node("service1".to_string(), rx1);
let node2 = graph.add_node("service2".to_string(), rx2);
graph.add_dependency(node1, node2).unwrap();
let result = graph.add_dependency(node2, node1);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Circular"));
}
#[test]
fn test_circular_dependency_three_services() {
let mut graph = DependencyGraph::new();
let (_tx1, rx1) = watch::channel(false);
let (_tx2, rx2) = watch::channel(false);
let (_tx3, rx3) = watch::channel(false);
let node1 = graph.add_node("service1".to_string(), rx1);
let node2 = graph.add_node("service2".to_string(), rx2);
let node3 = graph.add_node("service3".to_string(), rx3);
graph.add_dependency(node1, node2).unwrap();
graph.add_dependency(node2, node3).unwrap();
let result = graph.add_dependency(node3, node1);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Circular"));
}
#[test]
fn test_complex_valid_graph() {
let mut graph = DependencyGraph::new();
let (_tx1, rx1) = watch::channel(false);
let (_tx2, rx2) = watch::channel(false);
let (_tx3, rx3) = watch::channel(false);
let (_tx4, rx4) = watch::channel(false);
let (_tx5, rx5) = watch::channel(false);
let db = graph.add_node("db".to_string(), rx1);
let cache = graph.add_node("cache".to_string(), rx2);
let auth = graph.add_node("auth".to_string(), rx3);
let api = graph.add_node("api".to_string(), rx4);
let frontend = graph.add_node("frontend".to_string(), rx5);
graph.add_dependency(auth, db).unwrap();
graph.add_dependency(api, db).unwrap();
graph.add_dependency(api, cache).unwrap();
graph.add_dependency(api, auth).unwrap();
graph.add_dependency(frontend, api).unwrap();
let order = graph.topological_sort().unwrap();
let db_pos = order.iter().position(|(_, d)| d.name == "db").unwrap();
let cache_pos = order.iter().position(|(_, d)| d.name == "cache").unwrap();
let auth_pos = order.iter().position(|(_, d)| d.name == "auth").unwrap();
let api_pos = order.iter().position(|(_, d)| d.name == "api").unwrap();
let frontend_pos = order
.iter()
.position(|(_, d)| d.name == "frontend")
.unwrap();
assert!(db_pos < auth_pos);
assert!(auth_pos < api_pos);
assert!(db_pos < api_pos);
assert!(cache_pos < api_pos);
assert!(api_pos < frontend_pos);
}
}