Skip to main content

agent_io/tools/
depends.rs

1//! Dependency injection system
2
3use std::any::Any;
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8
9/// Dependency resolver trait
10pub trait Dependency: Clone + Send + Sync + 'static {}
11
12impl<T: Clone + Send + Sync + 'static> Dependency for T {}
13
14/// A dependency that can be resolved at runtime
15#[derive(Clone)]
16pub struct Depends<T>
17where
18    T: Dependency,
19{
20    factory: Arc<dyn Fn() -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync>,
21}
22
23impl<T: Dependency> Depends<T> {
24    /// Create a new dependency with a factory function
25    pub fn new<F, Fut>(factory: F) -> Self
26    where
27        F: Fn() -> Fut + Send + Sync + 'static,
28        Fut: Future<Output = T> + Send + 'static,
29    {
30        Self {
31            factory: Arc::new(move || Box::pin(factory())),
32        }
33    }
34
35    /// Create a dependency with a static value
36    pub fn with_value(value: T) -> Self {
37        Self::new(move || {
38            let v = value.clone();
39            async move { v }
40        })
41    }
42
43    /// Resolve the dependency
44    pub async fn resolve(&self) -> T {
45        (self.factory)().await
46    }
47}
48
49/// Dependency container for managing shared dependencies
50#[derive(Default)]
51pub struct DependencyContainer {
52    dependencies: HashMap<std::any::TypeId, Box<dyn Any + Send + Sync>>,
53}
54
55impl DependencyContainer {
56    pub fn new() -> Self {
57        Self {
58            dependencies: HashMap::new(),
59        }
60    }
61
62    /// Register a dependency
63    pub fn register<T: 'static + Send + Sync>(&mut self, value: T) {
64        self.dependencies
65            .insert(std::any::TypeId::of::<T>(), Box::new(value));
66    }
67
68    /// Get a dependency
69    pub fn get<T: 'static + Clone + Send + Sync>(&self) -> Option<T> {
70        self.dependencies
71            .get(&std::any::TypeId::of::<T>())
72            .and_then(|v| v.downcast_ref::<T>())
73            .cloned()
74    }
75
76    /// Check if a dependency exists
77    pub fn contains<T: 'static>(&self) -> bool {
78        self.dependencies.contains_key(&std::any::TypeId::of::<T>())
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    #[derive(Clone, Debug, PartialEq)]
87    struct Database {
88        url: String,
89    }
90
91    #[tokio::test]
92    async fn test_depends_with_value() {
93        let db = Database {
94            url: "postgresql://localhost".to_string(),
95        };
96        let depends = Depends::with_value(db.clone());
97
98        let resolved = depends.resolve().await;
99        assert_eq!(resolved, db);
100    }
101
102    #[tokio::test]
103    async fn test_depends_with_factory() {
104        let depends = Depends::new(|| async {
105            Database {
106                url: "postgresql://localhost".to_string(),
107            }
108        });
109
110        let resolved = depends.resolve().await;
111        assert_eq!(resolved.url, "postgresql://localhost");
112    }
113
114    #[test]
115    fn test_dependency_container() {
116        let mut container = DependencyContainer::new();
117        let db = Database {
118            url: "postgresql://localhost".to_string(),
119        };
120
121        container.register(db.clone());
122
123        assert!(container.contains::<Database>());
124        assert_eq!(container.get::<Database>(), Some(db));
125    }
126}