claude_agent/common/
provider.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4
5use super::{Named, SourceType};
6
7#[async_trait]
8pub trait Provider<T: Named + Clone + Send + Sync>: Send + Sync {
9    fn provider_name(&self) -> &str;
10    fn priority(&self) -> i32 {
11        0
12    }
13    fn source_type(&self) -> SourceType {
14        SourceType::User
15    }
16
17    async fn list(&self) -> crate::Result<Vec<String>>;
18    async fn get(&self, name: &str) -> crate::Result<Option<T>>;
19    async fn load_all(&self) -> crate::Result<Vec<T>>;
20}
21
22#[derive(Debug, Clone)]
23pub struct InMemoryProvider<T> {
24    items: HashMap<String, T>,
25    priority: i32,
26    source_type: SourceType,
27}
28
29impl<T: Named + Clone + Send + Sync> Default for InMemoryProvider<T> {
30    fn default() -> Self {
31        Self::new()
32    }
33}
34
35impl<T: Named + Clone + Send + Sync> InMemoryProvider<T> {
36    pub fn new() -> Self {
37        Self {
38            items: HashMap::new(),
39            priority: 0,
40            source_type: SourceType::User,
41        }
42    }
43
44    pub fn with_item(mut self, item: T) -> Self {
45        self.add(item);
46        self
47    }
48
49    pub fn add(&mut self, item: T) {
50        self.items.insert(item.name().to_string(), item);
51    }
52
53    pub fn with_items(mut self, items: impl IntoIterator<Item = T>) -> Self {
54        for item in items {
55            self.add(item);
56        }
57        self
58    }
59
60    pub fn with_priority(mut self, priority: i32) -> Self {
61        self.priority = priority;
62        self
63    }
64
65    pub fn with_source_type(mut self, source_type: SourceType) -> Self {
66        self.source_type = source_type;
67        self
68    }
69
70    pub fn len(&self) -> usize {
71        self.items.len()
72    }
73
74    pub fn is_empty(&self) -> bool {
75        self.items.is_empty()
76    }
77}
78
79#[async_trait]
80impl<T: Named + Clone + Send + Sync + 'static> Provider<T> for InMemoryProvider<T> {
81    fn provider_name(&self) -> &str {
82        "in-memory"
83    }
84
85    fn priority(&self) -> i32 {
86        self.priority
87    }
88
89    fn source_type(&self) -> SourceType {
90        self.source_type
91    }
92
93    async fn list(&self) -> crate::Result<Vec<String>> {
94        Ok(self.items.keys().cloned().collect())
95    }
96
97    async fn get(&self, name: &str) -> crate::Result<Option<T>> {
98        Ok(self.items.get(name).cloned())
99    }
100
101    async fn load_all(&self) -> crate::Result<Vec<T>> {
102        Ok(self.items.values().cloned().collect())
103    }
104}
105
106#[derive(Default)]
107pub struct ChainProvider<T: Named + Clone + Send + Sync + 'static> {
108    providers: Vec<Box<dyn Provider<T>>>,
109}
110
111impl<T: Named + Clone + Send + Sync + 'static> ChainProvider<T> {
112    pub fn new() -> Self {
113        Self {
114            providers: Vec::new(),
115        }
116    }
117
118    pub fn with(mut self, provider: impl Provider<T> + 'static) -> Self {
119        self.providers.push(Box::new(provider));
120        self
121    }
122
123    pub fn add(&mut self, provider: impl Provider<T> + 'static) {
124        self.providers.push(Box::new(provider));
125    }
126}
127
128#[async_trait]
129impl<T: Named + Clone + Send + Sync + 'static> Provider<T> for ChainProvider<T> {
130    fn provider_name(&self) -> &str {
131        "chain"
132    }
133
134    fn priority(&self) -> i32 {
135        self.providers
136            .iter()
137            .map(|p| p.priority())
138            .max()
139            .unwrap_or(0)
140    }
141
142    async fn list(&self) -> crate::Result<Vec<String>> {
143        let mut all = Vec::new();
144        for p in &self.providers {
145            all.extend(p.list().await?);
146        }
147        all.sort();
148        all.dedup();
149        Ok(all)
150    }
151
152    async fn get(&self, name: &str) -> crate::Result<Option<T>> {
153        let mut sorted: Vec<_> = self.providers.iter().collect();
154        sorted.sort_by_key(|b| std::cmp::Reverse(b.priority()));
155
156        for provider in sorted {
157            if let Some(item) = provider.get(name).await? {
158                return Ok(Some(item));
159            }
160        }
161        Ok(None)
162    }
163
164    async fn load_all(&self) -> crate::Result<Vec<T>> {
165        let mut map: HashMap<String, T> = HashMap::new();
166
167        let mut sorted: Vec<_> = self.providers.iter().collect();
168        sorted.sort_by_key(|p| std::cmp::Reverse(p.priority()));
169
170        for provider in sorted {
171            for item in provider.load_all().await? {
172                map.entry(item.name().to_string()).or_insert(item);
173            }
174        }
175
176        Ok(map.into_values().collect())
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[derive(Debug, Clone, PartialEq)]
185    struct TestItem {
186        name: String,
187        value: i32,
188    }
189
190    impl Named for TestItem {
191        fn name(&self) -> &str {
192            &self.name
193        }
194    }
195
196    #[tokio::test]
197    async fn test_in_memory_provider() {
198        let provider = InMemoryProvider::new()
199            .with_item(TestItem {
200                name: "a".into(),
201                value: 1,
202            })
203            .with_item(TestItem {
204                name: "b".into(),
205                value: 2,
206            });
207
208        assert_eq!(provider.len(), 2);
209
210        let names = provider.list().await.unwrap();
211        assert!(names.contains(&"a".to_string()));
212        assert!(names.contains(&"b".to_string()));
213
214        let item = provider.get("a").await.unwrap().unwrap();
215        assert_eq!(item.value, 1);
216
217        assert!(provider.get("nonexistent").await.unwrap().is_none());
218    }
219
220    #[tokio::test]
221    async fn test_chain_provider_priority() {
222        let low = InMemoryProvider::new()
223            .with_item(TestItem {
224                name: "shared".into(),
225                value: 1,
226            })
227            .with_priority(0);
228
229        let high = InMemoryProvider::new()
230            .with_item(TestItem {
231                name: "shared".into(),
232                value: 10,
233            })
234            .with_priority(10);
235
236        let chain = ChainProvider::new().with(low).with(high);
237
238        let item = chain.get("shared").await.unwrap().unwrap();
239        assert_eq!(item.value, 10);
240    }
241
242    #[tokio::test]
243    async fn test_chain_provider_load_all() {
244        let p1 = InMemoryProvider::new()
245            .with_item(TestItem {
246                name: "a".into(),
247                value: 1,
248            })
249            .with_priority(0);
250
251        let p2 = InMemoryProvider::new()
252            .with_item(TestItem {
253                name: "b".into(),
254                value: 2,
255            })
256            .with_priority(10);
257
258        let chain = ChainProvider::new().with(p1).with(p2);
259
260        let items = chain.load_all().await.unwrap();
261        assert_eq!(items.len(), 2);
262    }
263
264    #[tokio::test]
265    async fn test_chain_provider_load_all_priority_order() {
266        let low = InMemoryProvider::new()
267            .with_item(TestItem {
268                name: "shared".into(),
269                value: 1,
270            })
271            .with_priority(0);
272
273        let high = InMemoryProvider::new()
274            .with_item(TestItem {
275                name: "shared".into(),
276                value: 100,
277            })
278            .with_priority(10);
279
280        let chain = ChainProvider::new().with(low).with(high);
281
282        let items = chain.load_all().await.unwrap();
283        assert_eq!(items.len(), 1);
284
285        let item = items.into_iter().find(|i| i.name == "shared").unwrap();
286        assert_eq!(item.value, 100, "High priority item should be kept");
287    }
288}