1use std::collections::HashMap;
14use std::sync::{Arc, RwLock};
15
16use super::provider_trait::{ContextProvider, ProviderParams};
17use super::ProviderResult;
18use crate::core::bm25_index::ChunkKind;
19use crate::core::content_chunk::ContentChunk;
20
21pub struct ProviderRegistry {
23 providers: RwLock<HashMap<String, Arc<dyn ContextProvider>>>,
24}
25
26impl ProviderRegistry {
27 pub fn new() -> Self {
28 Self {
29 providers: RwLock::new(HashMap::new()),
30 }
31 }
32
33 pub fn register(&self, provider: Arc<dyn ContextProvider>) {
34 let id = provider.id().to_string();
35 if let Ok(mut map) = self.providers.write() {
36 map.insert(id, provider);
37 }
38 }
39
40 pub fn get(&self, id: &str) -> Option<Arc<dyn ContextProvider>> {
41 self.providers
42 .read()
43 .ok()
44 .and_then(|map| map.get(id).cloned())
45 }
46
47 pub fn execute(
48 &self,
49 provider_id: &str,
50 action: &str,
51 params: &ProviderParams,
52 ) -> Result<ProviderResult, String> {
53 let provider = self
54 .get(provider_id)
55 .ok_or_else(|| format!("Provider '{provider_id}' not registered"))?;
56
57 if !provider.is_available() {
58 return Err(format!(
59 "Provider '{provider_id}' is not available (check config/auth)"
60 ));
61 }
62
63 if !provider.supported_actions().contains(&action) {
64 return Err(format!(
65 "Provider '{provider_id}' does not support action '{action}'. Supported: {:?}",
66 provider.supported_actions()
67 ));
68 }
69
70 provider.execute(action, params)
71 }
72
73 pub fn execute_as_chunks(
75 &self,
76 provider_id: &str,
77 action: &str,
78 params: &ProviderParams,
79 ) -> Result<Vec<ContentChunk>, String> {
80 let result = self.execute(provider_id, action, params)?;
81 Ok(result_to_chunks(&result))
82 }
83
84 pub fn discover(&self) -> Vec<ProviderInfo> {
86 let Ok(map) = self.providers.read() else {
87 return Vec::new();
88 };
89
90 let mut infos: Vec<ProviderInfo> = map
91 .values()
92 .map(|p| ProviderInfo {
93 id: p.id().to_string(),
94 display_name: p.display_name().to_string(),
95 available: p.is_available(),
96 requires_auth: p.requires_auth(),
97 actions: p
98 .supported_actions()
99 .iter()
100 .map(std::string::ToString::to_string)
101 .collect(),
102 cache_ttl_secs: p.cache_ttl_secs(),
103 })
104 .collect();
105
106 infos.sort_by(|a, b| a.id.cmp(&b.id));
107 infos
108 }
109
110 pub fn provider_count(&self) -> usize {
111 self.providers.read().map_or(0, |m| m.len())
112 }
113
114 pub fn available_provider_ids(&self) -> Vec<String> {
115 self.providers
116 .read()
117 .map(|m| {
118 m.values()
119 .filter(|p| p.is_available())
120 .map(|p| p.id().to_string())
121 .collect()
122 })
123 .unwrap_or_default()
124 }
125}
126
127impl Default for ProviderRegistry {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133#[derive(Debug, Clone, serde::Serialize)]
135pub struct ProviderInfo {
136 pub id: String,
137 pub display_name: String,
138 pub available: bool,
139 pub requires_auth: bool,
140 pub actions: Vec<String>,
141 pub cache_ttl_secs: u64,
142}
143
144fn action_to_chunk_kind(resource_type: &str) -> ChunkKind {
149 match resource_type {
150 "issues" => ChunkKind::Issue,
151 "merge_requests" | "pull_requests" | "prs" => ChunkKind::PullRequest,
152 "wikis" | "pages" => ChunkKind::WikiPage,
153 "schemas" | "tables" => ChunkKind::DbSchema,
154 "endpoints" | "routes" => ChunkKind::ApiEndpoint,
155 "tickets" => ChunkKind::Ticket,
156 _ => ChunkKind::ExternalOther,
157 }
158}
159
160pub fn result_to_chunks(result: &ProviderResult) -> Vec<ContentChunk> {
162 let kind = action_to_chunk_kind(&result.resource_type);
163
164 result
165 .items
166 .iter()
167 .map(|item| {
168 let body = item.body.as_deref().unwrap_or("");
169 let content = format!(
170 "#{} {}{}\n{}",
171 item.id,
172 item.title,
173 item.state
174 .as_ref()
175 .map(|s| format!(" [{s}]"))
176 .unwrap_or_default(),
177 body,
178 );
179
180 let references = crate::core::content_chunk::extract_file_references(&content);
181
182 let metadata = serde_json::json!({
183 "state": item.state,
184 "author": item.author,
185 "created_at": item.created_at,
186 "updated_at": item.updated_at,
187 "url": item.url,
188 "labels": item.labels,
189 });
190
191 ContentChunk::from_provider(
192 &result.provider,
193 &result.resource_type,
194 &item.id,
195 &item.title,
196 kind.clone(),
197 content,
198 references,
199 Some(metadata),
200 )
201 })
202 .collect()
203}
204
205static GLOBAL_REGISTRY: std::sync::LazyLock<ProviderRegistry> =
210 std::sync::LazyLock::new(ProviderRegistry::new);
211
212pub fn global_registry() -> &'static ProviderRegistry {
213 &GLOBAL_REGISTRY
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::core::providers::{ProviderItem, ProviderResult};
220
221 #[test]
222 fn result_to_chunks_preserves_provider_id() {
223 let result = ProviderResult {
224 provider: "github".into(),
225 resource_type: "issues".into(),
226 items: vec![ProviderItem {
227 id: "42".into(),
228 title: "Auth bug".into(),
229 state: Some("open".into()),
230 author: Some("dev".into()),
231 created_at: None,
232 updated_at: None,
233 url: Some("https://github.com/o/r/issues/42".into()),
234 labels: vec!["bug".into()],
235 body: Some("Fix in src/auth/handler.rs".into()),
236 }],
237 total_count: Some(1),
238 truncated: false,
239 };
240
241 let chunks = result_to_chunks(&result);
242 assert_eq!(chunks.len(), 1);
243 let c = &chunks[0];
244 assert_eq!(c.provider_id(), Some("github"));
245 assert_eq!(c.kind, ChunkKind::Issue);
246 assert!(c.file_path.contains("github://issues/42"));
247 assert!(c.references.contains(&"src/auth/handler.rs".to_string()));
248 }
249
250 #[test]
251 fn action_maps_to_correct_kind() {
252 assert_eq!(action_to_chunk_kind("issues"), ChunkKind::Issue);
253 assert_eq!(
254 action_to_chunk_kind("pull_requests"),
255 ChunkKind::PullRequest
256 );
257 assert_eq!(
258 action_to_chunk_kind("merge_requests"),
259 ChunkKind::PullRequest
260 );
261 assert_eq!(action_to_chunk_kind("wikis"), ChunkKind::WikiPage);
262 assert_eq!(action_to_chunk_kind("schemas"), ChunkKind::DbSchema);
263 assert_eq!(action_to_chunk_kind("endpoints"), ChunkKind::ApiEndpoint);
264 assert_eq!(action_to_chunk_kind("tickets"), ChunkKind::Ticket);
265 assert_eq!(action_to_chunk_kind("unknown"), ChunkKind::ExternalOther);
266 }
267
268 #[test]
269 fn registry_discover_returns_sorted() {
270 let reg = ProviderRegistry::new();
271 let infos = reg.discover();
272 assert!(infos.is_empty());
273 }
274
275 #[test]
276 fn registry_execute_unknown_provider_errors() {
277 let reg = ProviderRegistry::new();
278 let result = reg.execute("nonexistent", "issues", &ProviderParams::default());
279 assert!(result.is_err());
280 assert!(result.unwrap_err().contains("not registered"));
281 }
282}