1use crate::protocol::{ProtocolError, ProtocolManifest};
6use arc_swap::ArcSwap;
7use lru::LruCache;
8use std::path::{Path, PathBuf};
9use std::sync::{Arc, Mutex};
10
11pub struct ProtocolLoader {
13 base_path: Option<PathBuf>,
14 hot_reload: bool,
15 validator: crate::protocol::validator::ProtocolValidator,
16 cache: Mutex<LruCache<String, Arc<ProtocolManifest>>>,
17}
18
19impl ProtocolLoader {
20 pub fn new() -> Self {
21 Self {
22 base_path: None,
23 hot_reload: false,
24 validator: crate::protocol::validator::ProtocolValidator::default(),
25 cache: Mutex::new(LruCache::new(
28 std::num::NonZeroUsize::new(100)
29 .expect("Cache size must be non-zero (this should never happen)"),
30 )),
31 }
32 }
33
34 pub fn with_base_path(mut self, path: impl AsRef<Path>) -> Self {
36 self.base_path = Some(path.as_ref().to_path_buf());
37 self
38 }
39
40 pub fn with_hot_reload(mut self, enable: bool) -> Self {
42 self.hot_reload = enable;
43 self
44 }
45
46 pub async fn load_model(&self, model: &str) -> Result<ProtocolManifest, ProtocolError> {
49 {
51 let mut cache = self.cache.lock().map_err(|e| {
52 ProtocolError::Internal(format!(
53 "Failed to acquire cache lock while loading model '{}': {}",
54 model, e
55 ))
56 })?;
57 if let Some(manifest) = cache.get(model) {
58 return Ok(manifest.as_ref().clone());
59 }
60 }
61
62 let parts: Vec<&str> = model.split('/').collect();
63 if parts.len() != 2 {
64 return Err(ProtocolError::NotFound {
65 id: model.to_string(),
66 hint: Some("Ensure the model name follows the 'provider/model' format".to_string()),
67 });
68 }
69
70 let provider = parts[0];
71 let model_name = parts[1];
72
73 let manifest = match self.load_model_config(model_name).await {
77 Ok(model_config) => self.load_provider(&model_config.provider).await?,
78 Err(ProtocolError::NotFound { .. }) => self.load_provider(provider).await?,
79 Err(e) => return Err(e),
80 };
81
82 {
84 let mut cache = self.cache.lock().map_err(|e| {
85 ProtocolError::Internal(format!(
86 "Failed to acquire cache lock while caching model '{}': {}",
87 model, e
88 ))
89 })?;
90 cache.put(model.to_string(), Arc::new(manifest.clone()));
91 }
92
93 Ok(manifest)
94 }
95
96 pub async fn load_provider(
98 &self,
99 provider_id: &str,
100 ) -> Result<ProtocolManifest, ProtocolError> {
101 if let Some(ref base_path) = self.base_path {
108 let provider_path = base_path
109 .join("v1")
110 .join("providers")
111 .join(format!("{}.yaml", provider_id));
112
113 if provider_path.exists() {
114 return self.load_from_file(&provider_path).await;
115 }
116 }
117
118 if let Ok(root) =
120 std::env::var("AI_PROTOCOL_DIR").or_else(|_| std::env::var("AI_PROTOCOL_PATH"))
121 {
122 if root.starts_with("http://") || root.starts_with("https://") {
124 let url = if root.ends_with('/') {
126 format!("{}v1/providers/{}.yaml", root, provider_id)
127 } else {
128 format!("{}/v1/providers/{}.yaml", root, provider_id)
129 };
130 return self.load_from_url(&url).await;
131 }
132 }
133
134 let mut default_paths: Vec<PathBuf> = Vec::new();
139 if let Ok(root) =
140 std::env::var("AI_PROTOCOL_DIR").or_else(|_| std::env::var("AI_PROTOCOL_PATH"))
141 {
142 if !root.starts_with("http://") && !root.starts_with("https://") {
144 let root = PathBuf::from(root);
145 default_paths.push(root.join("v1").join("providers"));
146 }
147 }
148 default_paths.push(PathBuf::from("ai-protocol/v1/providers"));
149 default_paths.push(PathBuf::from("../ai-protocol/v1/providers"));
150 default_paths.push(PathBuf::from("../../ai-protocol/v1/providers"));
151 let win_dev = PathBuf::from("D:\\ai-protocol\\v1\\providers");
152 if win_dev.exists() {
153 default_paths.push(win_dev);
154 }
155
156 for base in default_paths {
157 let provider_path = base.join(format!("{}.yaml", provider_id));
158 if provider_path.exists() {
159 return self.load_from_file(&provider_path).await;
160 }
161 }
162
163 let github_url = format!(
165 "https://raw.githubusercontent.com/hiddenpath/ai-protocol/main/v1/providers/{}.yaml",
166 provider_id
167 );
168 if let Ok(manifest) = self.load_from_url(&github_url).await {
169 return Ok(manifest);
170 }
171
172 Err(ProtocolError::NotFound {
173 id: provider_id.to_string(),
174 hint: Some(format!(
175 "Check if the provider file '{}.yaml' exists in your protocol directory",
176 provider_id
177 )),
178 })
179 }
180
181 async fn load_from_file(&self, path: &Path) -> Result<ProtocolManifest, ProtocolError> {
183 let bytes = tokio::fs::read(path)
185 .await
186 .map_err(|e| ProtocolError::LoadError {
187 path: path.to_string_lossy().to_string(),
188 reason: e.to_string(),
189 hint: Some("Check if the file exists and you have read permissions.".to_string()),
190 })?;
191
192 let content = if bytes.len() >= 2 && bytes[0] == 0xFF && bytes[1] == 0xFE {
194 let utf16_bytes = &bytes[2..];
196 let mut utf16_chars = Vec::new();
198 for i in (0..utf16_bytes.len()).step_by(2) {
199 if i + 1 < utf16_bytes.len() {
200 let code_unit = u16::from_le_bytes([utf16_bytes[i], utf16_bytes[i + 1]]);
201 utf16_chars.push(code_unit);
202 }
203 }
204 String::from_utf16(&utf16_chars).map_err(|e| ProtocolError::LoadError {
205 path: path.to_string_lossy().to_string(),
206 reason: format!(
207 "Invalid UTF-16: {}. Please convert the file to UTF-8 encoding.",
208 e
209 ),
210 hint: Some(
211 "The runtime expects UTF-8 manifests. Try converting the file encoding."
212 .to_string(),
213 ),
214 })?
215 } else if bytes.len() >= 3 && bytes[0] == 0xEF && bytes[1] == 0xBB && bytes[2] == 0xBF {
216 String::from_utf8(bytes[3..].to_vec()).map_err(|e| ProtocolError::LoadError {
218 path: path.to_string_lossy().to_string(),
219 reason: format!("Invalid UTF-8 (after BOM): {}", e),
220 hint: Some(
221 "Remove Byte Order Mark (BOM) and ensure the file is valid UTF-8.".to_string(),
222 ),
223 })?
224 } else {
225 String::from_utf8(bytes).map_err(|e| ProtocolError::LoadError {
227 path: path.to_string_lossy().to_string(),
228 reason: format!(
229 "Invalid UTF-8: {}. Please convert the file to UTF-8 encoding.",
230 e
231 ),
232 hint: Some("Verify the file content is valid UTF-8.".to_string()),
233 })?
234 };
235
236 let manifest: ProtocolManifest = Self::parse_manifest_yaml(&content)?;
237
238 self.validator.validate(&manifest)?;
240
241 Ok(manifest)
242 }
243
244 async fn load_from_url(&self, url: &str) -> Result<ProtocolManifest, ProtocolError> {
246 let client = reqwest::Client::builder()
247 .timeout(std::time::Duration::from_secs(30))
248 .build()
249 .map_err(|e| ProtocolError::Internal(format!("Failed to create HTTP client: {}", e)))?;
250
251 let response = client
252 .get(url)
253 .send()
254 .await
255 .map_err(|e| ProtocolError::LoadError {
256 path: url.to_string(),
257 reason: format!("HTTP request failed: {}", e),
258 hint: Some(
259 "Check your internet connection and verify the URL is accessible.".to_string(),
260 ),
261 })?;
262
263 if !response.status().is_success() {
264 return Err(ProtocolError::LoadError {
265 path: url.to_string(),
266 reason: format!(
267 "HTTP {}: {}",
268 response.status(),
269 response.text().await.unwrap_or_default()
270 ),
271 hint: Some(
272 "Verify the remote registry URL and your API permissions if any.".to_string(),
273 ),
274 });
275 }
276
277 let content = response
278 .text()
279 .await
280 .map_err(|e| ProtocolError::LoadError {
281 path: url.to_string(),
282 reason: format!("Failed to read response: {}", e),
283 hint: None,
284 })?;
285
286 let manifest: ProtocolManifest = Self::parse_manifest_yaml(&content)?;
287
288 self.validator.validate(&manifest)?;
290
291 Ok(manifest)
292 }
293
294 fn parse_manifest_yaml(content: &str) -> Result<ProtocolManifest, ProtocolError> {
300 serde_yaml::from_str::<ProtocolManifest>(content).map_err(|e| {
301 let msg = e.to_string();
302 let looks_structural = msg.contains("missing field")
305 || msg.contains("unknown field")
306 || msg.contains("invalid type")
307 || msg.contains("invalid value")
308 || msg.contains("expected");
309
310 if looks_structural {
311 ProtocolError::ValidationError(format!("Invalid manifest structure: {}", msg))
312 } else {
313 ProtocolError::YamlError(msg)
314 }
315 })
316 }
317
318 async fn load_model_config(&self, model_name: &str) -> Result<ModelConfig, ProtocolError> {
320 let mut model_paths: Vec<PathBuf> = Vec::new();
322 if let Ok(root) =
323 std::env::var("AI_PROTOCOL_DIR").or_else(|_| std::env::var("AI_PROTOCOL_PATH"))
324 {
325 let root = PathBuf::from(root);
326 model_paths.push(root.join("v1").join("models"));
327 }
328 model_paths.push(PathBuf::from("ai-protocol/v1/models"));
329 model_paths.push(PathBuf::from("../ai-protocol/v1/models"));
330 model_paths.push(PathBuf::from("../../ai-protocol/v1/models"));
331 let win_dev = PathBuf::from("D:\\ai-protocol\\v1\\models");
332 if win_dev.exists() {
333 model_paths.push(win_dev);
334 }
335
336 for base in model_paths {
337 if !base.exists() {
338 continue;
339 }
340 let mut rd = match tokio::fs::read_dir(&base).await {
341 Ok(rd) => rd,
342 Err(_) => continue,
343 };
344 while let Ok(Some(entry)) = rd.next_entry().await {
345 let path = entry.path();
346 if path
347 .extension()
348 .and_then(|s| s.to_str())
349 .map(|s| s.eq_ignore_ascii_case("yaml") || s.eq_ignore_ascii_case("yml"))
350 != Some(true)
351 {
352 continue;
353 }
354 if let Ok(config) = self.load_model_registry(&path).await {
355 if let Some(model) = config.models.get(model_name) {
356 return Ok(model.clone());
357 }
358 }
359 }
360 }
361
362 Err(ProtocolError::NotFound {
363 id: model_name.to_string(),
364 hint: Some(
365 "Check if the model is registered in the manifests/v1/models/ directory"
366 .to_string(),
367 ),
368 })
369 }
370
371 async fn load_model_registry(&self, path: &Path) -> Result<ModelRegistry, ProtocolError> {
372 let content =
373 tokio::fs::read_to_string(path)
374 .await
375 .map_err(|e| ProtocolError::LoadError {
376 path: path.to_string_lossy().to_string(),
377 reason: format!("Failed to read model registry: {}", e),
378 hint: None,
379 })?;
380
381 let registry: ModelRegistry = serde_yaml::from_str(&content).map_err(|e| {
382 ProtocolError::YamlError(format!("Failed to parse model registry: {}", e))
383 })?;
384
385 Ok(registry)
386 }
387}
388
389impl Default for ProtocolLoader {
390 fn default() -> Self {
391 Self::new()
392 }
393}
394
395#[derive(Debug, Clone, serde::Deserialize)]
397struct ModelRegistry {
398 models: std::collections::HashMap<String, ModelConfig>,
399}
400
401#[allow(dead_code)]
403#[derive(Debug, Clone, serde::Deserialize)]
404struct ModelConfig {
405 provider: String,
406 #[serde(default)]
407 model_id: Option<String>,
408 #[serde(default)]
409 context_window: Option<u32>,
410 #[serde(default)]
411 capabilities: Vec<String>,
412}
413
414pub struct ProtocolRegistry {
416 manifests: ArcSwap<std::collections::HashMap<String, Arc<ProtocolManifest>>>,
417 loader: ProtocolLoader,
418}
419
420impl ProtocolRegistry {
421 pub fn new() -> Self {
422 Self {
423 manifests: ArcSwap::from_pointee(std::collections::HashMap::new()),
424 loader: ProtocolLoader::new(),
425 }
426 }
427
428 pub async fn get_manifest(
430 &self,
431 provider_id: &str,
432 ) -> Result<Arc<ProtocolManifest>, ProtocolError> {
433 let current = self.manifests.load();
435 if let Some(manifest) = current.get(provider_id) {
436 return Ok(Arc::clone(manifest));
437 }
438
439 let manifest = self.loader.load_provider(provider_id).await?;
441 let manifest_arc = Arc::new(manifest);
442
443 let mut updated_map = std::collections::HashMap::new();
445 for (k, v) in current.iter() {
446 updated_map.insert(k.clone(), v.clone());
447 }
448 updated_map.insert(provider_id.to_string(), manifest_arc.clone());
449 self.manifests.store(Arc::new(updated_map));
450
451 Ok(manifest_arc)
452 }
453}
454
455impl Default for ProtocolRegistry {
456 fn default() -> Self {
457 Self::new()
458 }
459}