agent_io/tools/
depends.rs1use std::any::Any;
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8
9pub trait Dependency: Clone + Send + Sync + 'static {}
11
12impl<T: Clone + Send + Sync + 'static> Dependency for T {}
13
14#[derive(Clone)]
16pub struct Depends<T>
17where
18 T: Dependency,
19{
20 factory: Arc<dyn Fn() -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync>,
21}
22
23impl<T: Dependency> Depends<T> {
24 pub fn new<F, Fut>(factory: F) -> Self
26 where
27 F: Fn() -> Fut + Send + Sync + 'static,
28 Fut: Future<Output = T> + Send + 'static,
29 {
30 Self {
31 factory: Arc::new(move || Box::pin(factory())),
32 }
33 }
34
35 pub fn with_value(value: T) -> Self {
37 Self::new(move || {
38 let v = value.clone();
39 async move { v }
40 })
41 }
42
43 pub async fn resolve(&self) -> T {
45 (self.factory)().await
46 }
47}
48
49#[derive(Default)]
51pub struct DependencyContainer {
52 dependencies: HashMap<std::any::TypeId, Box<dyn Any + Send + Sync>>,
53}
54
55impl DependencyContainer {
56 pub fn new() -> Self {
57 Self {
58 dependencies: HashMap::new(),
59 }
60 }
61
62 pub fn register<T: 'static + Send + Sync>(&mut self, value: T) {
64 self.dependencies
65 .insert(std::any::TypeId::of::<T>(), Box::new(value));
66 }
67
68 pub fn get<T: 'static + Clone + Send + Sync>(&self) -> Option<T> {
70 self.dependencies
71 .get(&std::any::TypeId::of::<T>())
72 .and_then(|v| v.downcast_ref::<T>())
73 .cloned()
74 }
75
76 pub fn contains<T: 'static>(&self) -> bool {
78 self.dependencies.contains_key(&std::any::TypeId::of::<T>())
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85
86 #[derive(Clone, Debug, PartialEq)]
87 struct Database {
88 url: String,
89 }
90
91 #[tokio::test]
92 async fn test_depends_with_value() {
93 let db = Database {
94 url: "postgresql://localhost".to_string(),
95 };
96 let depends = Depends::with_value(db.clone());
97
98 let resolved = depends.resolve().await;
99 assert_eq!(resolved, db);
100 }
101
102 #[tokio::test]
103 async fn test_depends_with_factory() {
104 let depends = Depends::new(|| async {
105 Database {
106 url: "postgresql://localhost".to_string(),
107 }
108 });
109
110 let resolved = depends.resolve().await;
111 assert_eq!(resolved.url, "postgresql://localhost");
112 }
113
114 #[test]
115 fn test_dependency_container() {
116 let mut container = DependencyContainer::new();
117 let db = Database {
118 url: "postgresql://localhost".to_string(),
119 };
120
121 container.register(db.clone());
122
123 assert!(container.contains::<Database>());
124 assert_eq!(container.get::<Database>(), Some(db));
125 }
126}