rusty_commit/providers/
registry.rs1use crate::config::Config;
8use anyhow::{Context, Result};
9use std::collections::HashMap;
10use std::sync::RwLock;
11
12#[derive(thiserror::Error, Debug)]
14#[error("Registry lock error")]
15pub struct LockError;
16
17macro_rules! read_lock {
18 ($lock:expr, $field:ident) => {
19 $lock.read().map_err(|_| {
20 tracing::error!("{} lock is poisoned", stringify!($field));
21 LockError
22 })
23 };
24}
25
26macro_rules! write_lock {
27 ($lock:expr, $field:ident) => {
28 $lock.write().map_err(|_| {
29 tracing::error!("{} lock is poisoned", stringify!($field));
30 LockError
31 })
32 };
33}
34
35pub trait ProviderBuilder: Send + Sync {
37 fn name(&self) -> &'static str;
39
40 fn aliases(&self) -> Vec<&'static str> {
42 vec![]
43 }
44
45 fn category(&self) -> ProviderCategory {
47 ProviderCategory::Standard
48 }
49
50 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>>;
52
53 fn requires_api_key(&self) -> bool {
55 true
56 }
57
58 fn default_model(&self) -> Option<&'static str> {
60 None
61 }
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66#[allow(dead_code)]
67pub enum ProviderCategory {
68 Standard,
70 OpenAICompatible,
72 Local,
74 Cloud,
76}
77
78#[derive(Clone)]
80pub struct ProviderEntry {
81 pub name: &'static str,
82 pub aliases: Vec<&'static str>,
83 pub category: ProviderCategory,
84 #[allow(dead_code)]
85 pub requires_api_key: bool,
86 #[allow(dead_code)]
87 pub default_model: Option<&'static str>,
88}
89
90impl ProviderEntry {
91 pub fn from_builder(builder: &dyn ProviderBuilder) -> Self {
92 Self {
93 name: builder.name(),
94 aliases: builder.aliases(),
95 category: builder.category(),
96 requires_api_key: builder.requires_api_key(),
97 default_model: builder.default_model(),
98 }
99 }
100
101 #[allow(dead_code)]
103 pub fn matches(&self, provider: &str) -> bool {
104 let lower = provider.to_lowercase();
105 self.name.eq_ignore_ascii_case(&lower)
106 || self.aliases.iter().any(|&a| a.eq_ignore_ascii_case(&lower))
107 }
108}
109
110pub struct ProviderRegistry {
112 entries: RwLock<HashMap<&'static str, ProviderEntry>>,
113 builders: RwLock<HashMap<&'static str, Box<dyn ProviderBuilder>>>,
114 by_alias: RwLock<HashMap<&'static str, &'static str>>,
115}
116
117impl ProviderRegistry {
118 pub fn new() -> Self {
120 Self {
121 entries: RwLock::new(HashMap::new()),
122 builders: RwLock::new(HashMap::new()),
123 by_alias: RwLock::new(HashMap::new()),
124 }
125 }
126
127 pub fn register(&self, builder: Box<dyn ProviderBuilder>) -> Result<()> {
129 let name = builder.name();
130 let entry = ProviderEntry::from_builder(&*builder);
131
132 write_lock!(self.entries, entries)?.insert(name, entry.clone());
134 write_lock!(self.builders, builders)?.insert(name, builder);
135
136 for &alias in &entry.aliases {
138 write_lock!(self.by_alias, by_alias)?.insert(alias, name);
139 }
140
141 Ok(())
142 }
143
144 #[allow(dead_code)]
146 pub fn get(&self, provider: &str) -> Option<ProviderEntry> {
147 let lower = provider.to_lowercase();
148
149 let entries = read_lock!(self.entries, entries).ok()?;
151 if let Some(entry) = entries.get(lower.as_str()) {
152 return Some(entry.clone());
153 }
154
155 let by_alias = read_lock!(self.by_alias, by_alias).ok()?;
157 if let Some(&primary) = by_alias.get(lower.as_str()) {
158 return entries.get(primary).cloned();
159 }
160
161 None
162 }
163
164 pub fn all(&self) -> Option<Vec<ProviderEntry>> {
166 let entries = read_lock!(self.entries, entries).ok()?;
167 Some(entries.values().cloned().collect())
168 }
169
170 pub fn by_category(&self, category: ProviderCategory) -> Option<Vec<ProviderEntry>> {
172 let entries = read_lock!(self.entries, entries).ok()?;
173 Some(
174 entries
175 .values()
176 .filter(|e| e.category == category)
177 .cloned()
178 .collect(),
179 )
180 }
181
182 #[allow(dead_code)]
184 pub fn is_empty(&self) -> bool {
185 match read_lock!(self.entries, entries) {
186 Ok(entries) => entries.is_empty(),
187 Err(_) => true,
188 }
189 }
190
191 #[allow(dead_code)]
193 pub fn len(&self) -> usize {
194 match read_lock!(self.entries, entries) {
195 Ok(entries) => entries.len(),
196 Err(_) => 0,
197 }
198 }
199
200 pub fn create(
202 &self,
203 name: &str,
204 config: &Config,
205 ) -> Result<Option<Box<dyn super::AIProvider>>> {
206 let lower = name.to_lowercase();
207
208 let builders = read_lock!(self.builders, builders).context("Failed to read builders")?;
209 let by_alias = read_lock!(self.by_alias, by_alias).context("Failed to read aliases")?;
210
211 if let Some(builder) = builders.get(lower.as_str()) {
213 return Ok(Some(builder.create(config)?));
214 }
215
216 if let Some(&primary) = by_alias.get(lower.as_str()) {
218 if let Some(builder) = builders.get(primary) {
219 return Ok(Some(builder.create(config)?));
220 }
221 }
222
223 Ok(None)
224 }
225}
226
227impl Default for ProviderRegistry {
228 fn default() -> Self {
229 Self::new()
230 }
231}