Skip to main content

agent_io/tools/
depends.rs

1//! Dependency injection system
2
3use std::any::Any;
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8
9/// Dependency resolver trait
10pub trait Dependency: Clone + Send + Sync + 'static {}
11
12impl<T: Clone + Send + Sync + 'static> Dependency for T {}
13
14/// A dependency that can be resolved at runtime
15#[derive(Clone)]
16pub struct Depends<T>
17where
18    T: Dependency,
19{
20    factory: Arc<dyn Fn() -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync>,
21}
22
23impl<T: Dependency> Depends<T> {
24    /// Create a new dependency with a factory function
25    pub fn new<F, Fut>(factory: F) -> Self
26    where
27        F: Fn() -> Fut + Send + Sync + 'static,
28        Fut: Future<Output = T> + Send + 'static,
29    {
30        Self {
31            factory: Arc::new(move || Box::pin(factory())),
32        }
33    }
34
35    /// Create a dependency with a static value
36    pub fn with_value(value: T) -> Self {
37        Self::new(move || {
38            let v = value.clone();
39            async move { v }
40        })
41    }
42
43    /// Resolve the dependency
44    pub async fn resolve(&self) -> T {
45        (self.factory)().await
46    }
47}
48
49/// Dependency container for managing shared dependencies
50#[derive(Default)]
51pub struct DependencyContainer {
52    dependencies: HashMap<std::any::TypeId, Box<dyn Any + Send + Sync>>,
53}
54
55impl DependencyContainer {
56    pub fn new() -> Self {
57        Self {
58            dependencies: HashMap::new(),
59        }
60    }
61
62    /// Register a dependency
63    pub fn register<T: 'static + Send + Sync>(&mut self, value: T) {
64        self.dependencies
65            .insert(std::any::TypeId::of::<T>(), Box::new(value));
66    }
67
68    /// Get a dependency
69    pub fn get<T: 'static + Clone + Send + Sync>(&self) -> Option<T> {
70        self.dependencies
71            .get(&std::any::TypeId::of::<T>())
72            .and_then(|v| v.downcast_ref::<T>())
73            .cloned()
74    }
75
76    /// Check if a dependency exists
77    pub fn contains<T: 'static>(&self) -> bool {
78        self.dependencies.contains_key(&std::any::TypeId::of::<T>())
79    }
80}
81
82/// Wrapper for dependency overrides
83#[derive(Default)]
84pub struct DependencyOverrides {
85    inner: HashMap<String, Box<dyn Any + Send + Sync>>,
86}
87
88impl DependencyOverrides {
89    pub fn new() -> Self {
90        Self {
91            inner: HashMap::new(),
92        }
93    }
94
95    pub fn insert<T: 'static + Send + Sync>(&mut self, key: &str, value: T) {
96        self.inner.insert(key.to_string(), Box::new(value));
97    }
98
99    pub fn get<T: 'static + Clone + Send + Sync>(&self, key: &str) -> Option<T> {
100        self.inner
101            .get(key)
102            .and_then(|v| v.downcast_ref::<T>())
103            .cloned()
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[derive(Clone, Debug, PartialEq)]
112    struct Database {
113        url: String,
114    }
115
116    #[tokio::test]
117    async fn test_depends_with_value() {
118        let db = Database {
119            url: "postgresql://localhost".to_string(),
120        };
121        let depends = Depends::with_value(db.clone());
122
123        let resolved = depends.resolve().await;
124        assert_eq!(resolved, db);
125    }
126
127    #[tokio::test]
128    async fn test_depends_with_factory() {
129        let depends = Depends::new(|| async {
130            Database {
131                url: "postgresql://localhost".to_string(),
132            }
133        });
134
135        let resolved = depends.resolve().await;
136        assert_eq!(resolved.url, "postgresql://localhost");
137    }
138
139    #[test]
140    fn test_dependency_container() {
141        let mut container = DependencyContainer::new();
142        let db = Database {
143            url: "postgresql://localhost".to_string(),
144        };
145
146        container.register(db.clone());
147
148        assert!(container.contains::<Database>());
149        assert_eq!(container.get::<Database>(), Some(db));
150    }
151}