agent_io/tools/
depends.rs1use std::any::Any;
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8
9pub trait Dependency: Clone + Send + Sync + 'static {}
11
12impl<T: Clone + Send + Sync + 'static> Dependency for T {}
13
14#[derive(Clone)]
16pub struct Depends<T>
17where
18 T: Dependency,
19{
20 factory: Arc<dyn Fn() -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync>,
21}
22
23impl<T: Dependency> Depends<T> {
24 pub fn new<F, Fut>(factory: F) -> Self
26 where
27 F: Fn() -> Fut + Send + Sync + 'static,
28 Fut: Future<Output = T> + Send + 'static,
29 {
30 Self {
31 factory: Arc::new(move || Box::pin(factory())),
32 }
33 }
34
35 pub fn with_value(value: T) -> Self {
37 Self::new(move || {
38 let v = value.clone();
39 async move { v }
40 })
41 }
42
43 pub async fn resolve(&self) -> T {
45 (self.factory)().await
46 }
47}
48
49#[derive(Default)]
51pub struct DependencyContainer {
52 dependencies: HashMap<std::any::TypeId, Box<dyn Any + Send + Sync>>,
53}
54
55impl DependencyContainer {
56 pub fn new() -> Self {
57 Self {
58 dependencies: HashMap::new(),
59 }
60 }
61
62 pub fn register<T: 'static + Send + Sync>(&mut self, value: T) {
64 self.dependencies
65 .insert(std::any::TypeId::of::<T>(), Box::new(value));
66 }
67
68 pub fn get<T: 'static + Clone + Send + Sync>(&self) -> Option<T> {
70 self.dependencies
71 .get(&std::any::TypeId::of::<T>())
72 .and_then(|v| v.downcast_ref::<T>())
73 .cloned()
74 }
75
76 pub fn contains<T: 'static>(&self) -> bool {
78 self.dependencies.contains_key(&std::any::TypeId::of::<T>())
79 }
80}
81
82#[derive(Default)]
84pub struct DependencyOverrides {
85 inner: HashMap<String, Box<dyn Any + Send + Sync>>,
86}
87
88impl DependencyOverrides {
89 pub fn new() -> Self {
90 Self {
91 inner: HashMap::new(),
92 }
93 }
94
95 pub fn insert<T: 'static + Send + Sync>(&mut self, key: &str, value: T) {
96 self.inner.insert(key.to_string(), Box::new(value));
97 }
98
99 pub fn get<T: 'static + Clone + Send + Sync>(&self, key: &str) -> Option<T> {
100 self.inner
101 .get(key)
102 .and_then(|v| v.downcast_ref::<T>())
103 .cloned()
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[derive(Clone, Debug, PartialEq)]
112 struct Database {
113 url: String,
114 }
115
116 #[tokio::test]
117 async fn test_depends_with_value() {
118 let db = Database {
119 url: "postgresql://localhost".to_string(),
120 };
121 let depends = Depends::with_value(db.clone());
122
123 let resolved = depends.resolve().await;
124 assert_eq!(resolved, db);
125 }
126
127 #[tokio::test]
128 async fn test_depends_with_factory() {
129 let depends = Depends::new(|| async {
130 Database {
131 url: "postgresql://localhost".to_string(),
132 }
133 });
134
135 let resolved = depends.resolve().await;
136 assert_eq!(resolved.url, "postgresql://localhost");
137 }
138
139 #[test]
140 fn test_dependency_container() {
141 let mut container = DependencyContainer::new();
142 let db = Database {
143 url: "postgresql://localhost".to_string(),
144 };
145
146 container.register(db.clone());
147
148 assert!(container.contains::<Database>());
149 assert_eq!(container.get::<Database>(), Some(db));
150 }
151}