claude_agent/common/
provider.rs1use std::collections::HashMap;
2
3use async_trait::async_trait;
4
5use super::{Named, SourceType};
6
7#[async_trait]
8pub trait Provider<T: Named + Clone + Send + Sync>: Send + Sync {
9 fn provider_name(&self) -> &str;
10 fn priority(&self) -> i32 {
11 0
12 }
13 fn source_type(&self) -> SourceType {
14 SourceType::User
15 }
16
17 async fn list(&self) -> crate::Result<Vec<String>>;
18 async fn get(&self, name: &str) -> crate::Result<Option<T>>;
19 async fn load_all(&self) -> crate::Result<Vec<T>>;
20}
21
22#[derive(Debug, Clone)]
23pub struct InMemoryProvider<T> {
24 items: HashMap<String, T>,
25 priority: i32,
26 source_type: SourceType,
27}
28
29impl<T: Named + Clone + Send + Sync> Default for InMemoryProvider<T> {
30 fn default() -> Self {
31 Self::new()
32 }
33}
34
35impl<T: Named + Clone + Send + Sync> InMemoryProvider<T> {
36 pub fn new() -> Self {
37 Self {
38 items: HashMap::new(),
39 priority: 0,
40 source_type: SourceType::User,
41 }
42 }
43
44 pub fn with_item(mut self, item: T) -> Self {
45 self.add(item);
46 self
47 }
48
49 pub fn add(&mut self, item: T) {
50 self.items.insert(item.name().to_string(), item);
51 }
52
53 pub fn with_items(mut self, items: impl IntoIterator<Item = T>) -> Self {
54 for item in items {
55 self.add(item);
56 }
57 self
58 }
59
60 pub fn with_priority(mut self, priority: i32) -> Self {
61 self.priority = priority;
62 self
63 }
64
65 pub fn with_source_type(mut self, source_type: SourceType) -> Self {
66 self.source_type = source_type;
67 self
68 }
69
70 pub fn len(&self) -> usize {
71 self.items.len()
72 }
73
74 pub fn is_empty(&self) -> bool {
75 self.items.is_empty()
76 }
77}
78
79#[async_trait]
80impl<T: Named + Clone + Send + Sync + 'static> Provider<T> for InMemoryProvider<T> {
81 fn provider_name(&self) -> &str {
82 "in-memory"
83 }
84
85 fn priority(&self) -> i32 {
86 self.priority
87 }
88
89 fn source_type(&self) -> SourceType {
90 self.source_type
91 }
92
93 async fn list(&self) -> crate::Result<Vec<String>> {
94 Ok(self.items.keys().cloned().collect())
95 }
96
97 async fn get(&self, name: &str) -> crate::Result<Option<T>> {
98 Ok(self.items.get(name).cloned())
99 }
100
101 async fn load_all(&self) -> crate::Result<Vec<T>> {
102 Ok(self.items.values().cloned().collect())
103 }
104}
105
106#[derive(Default)]
107pub struct ChainProvider<T: Named + Clone + Send + Sync + 'static> {
108 providers: Vec<Box<dyn Provider<T>>>,
109}
110
111impl<T: Named + Clone + Send + Sync + 'static> ChainProvider<T> {
112 pub fn new() -> Self {
113 Self {
114 providers: Vec::new(),
115 }
116 }
117
118 pub fn with(mut self, provider: impl Provider<T> + 'static) -> Self {
119 self.providers.push(Box::new(provider));
120 self
121 }
122
123 pub fn add(&mut self, provider: impl Provider<T> + 'static) {
124 self.providers.push(Box::new(provider));
125 }
126}
127
128#[async_trait]
129impl<T: Named + Clone + Send + Sync + 'static> Provider<T> for ChainProvider<T> {
130 fn provider_name(&self) -> &str {
131 "chain"
132 }
133
134 fn priority(&self) -> i32 {
135 self.providers
136 .iter()
137 .map(|p| p.priority())
138 .max()
139 .unwrap_or(0)
140 }
141
142 async fn list(&self) -> crate::Result<Vec<String>> {
143 let mut all = Vec::new();
144 for p in &self.providers {
145 all.extend(p.list().await?);
146 }
147 all.sort();
148 all.dedup();
149 Ok(all)
150 }
151
152 async fn get(&self, name: &str) -> crate::Result<Option<T>> {
153 let mut sorted: Vec<_> = self.providers.iter().collect();
154 sorted.sort_by_key(|b| std::cmp::Reverse(b.priority()));
155
156 for provider in sorted {
157 if let Some(item) = provider.get(name).await? {
158 return Ok(Some(item));
159 }
160 }
161 Ok(None)
162 }
163
164 async fn load_all(&self) -> crate::Result<Vec<T>> {
165 let mut map: HashMap<String, T> = HashMap::new();
166
167 let mut sorted: Vec<_> = self.providers.iter().collect();
168 sorted.sort_by_key(|p| std::cmp::Reverse(p.priority()));
169
170 for provider in sorted {
171 for item in provider.load_all().await? {
172 map.entry(item.name().to_string()).or_insert(item);
173 }
174 }
175
176 Ok(map.into_values().collect())
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[derive(Debug, Clone, PartialEq)]
185 struct TestItem {
186 name: String,
187 value: i32,
188 }
189
190 impl Named for TestItem {
191 fn name(&self) -> &str {
192 &self.name
193 }
194 }
195
196 #[tokio::test]
197 async fn test_in_memory_provider() {
198 let provider = InMemoryProvider::new()
199 .with_item(TestItem {
200 name: "a".into(),
201 value: 1,
202 })
203 .with_item(TestItem {
204 name: "b".into(),
205 value: 2,
206 });
207
208 assert_eq!(provider.len(), 2);
209
210 let names = provider.list().await.unwrap();
211 assert!(names.contains(&"a".to_string()));
212 assert!(names.contains(&"b".to_string()));
213
214 let item = provider.get("a").await.unwrap().unwrap();
215 assert_eq!(item.value, 1);
216
217 assert!(provider.get("nonexistent").await.unwrap().is_none());
218 }
219
220 #[tokio::test]
221 async fn test_chain_provider_priority() {
222 let low = InMemoryProvider::new()
223 .with_item(TestItem {
224 name: "shared".into(),
225 value: 1,
226 })
227 .with_priority(0);
228
229 let high = InMemoryProvider::new()
230 .with_item(TestItem {
231 name: "shared".into(),
232 value: 10,
233 })
234 .with_priority(10);
235
236 let chain = ChainProvider::new().with(low).with(high);
237
238 let item = chain.get("shared").await.unwrap().unwrap();
239 assert_eq!(item.value, 10);
240 }
241
242 #[tokio::test]
243 async fn test_chain_provider_load_all() {
244 let p1 = InMemoryProvider::new()
245 .with_item(TestItem {
246 name: "a".into(),
247 value: 1,
248 })
249 .with_priority(0);
250
251 let p2 = InMemoryProvider::new()
252 .with_item(TestItem {
253 name: "b".into(),
254 value: 2,
255 })
256 .with_priority(10);
257
258 let chain = ChainProvider::new().with(p1).with(p2);
259
260 let items = chain.load_all().await.unwrap();
261 assert_eq!(items.len(), 2);
262 }
263
264 #[tokio::test]
265 async fn test_chain_provider_load_all_priority_order() {
266 let low = InMemoryProvider::new()
267 .with_item(TestItem {
268 name: "shared".into(),
269 value: 1,
270 })
271 .with_priority(0);
272
273 let high = InMemoryProvider::new()
274 .with_item(TestItem {
275 name: "shared".into(),
276 value: 100,
277 })
278 .with_priority(10);
279
280 let chain = ChainProvider::new().with(low).with(high);
281
282 let items = chain.load_all().await.unwrap();
283 assert_eq!(items.len(), 1);
284
285 let item = items.into_iter().find(|i| i.name == "shared").unwrap();
286 assert_eq!(item.value, 100, "High priority item should be kept");
287 }
288}