1use crate::protocol::{ProtocolError, ProtocolManifest};
6use arc_swap::ArcSwap;
7use lru::LruCache;
8use std::path::{Path, PathBuf};
9use std::sync::{Arc, Mutex};
10
11pub struct ProtocolLoader {
13 base_path: Option<PathBuf>,
14 hot_reload: bool,
15 validator: crate::protocol::validator::ProtocolValidator,
16 cache: Mutex<LruCache<String, Arc<ProtocolManifest>>>,
17}
18
19impl ProtocolLoader {
20 pub fn new() -> Self {
21 Self {
22 base_path: None,
23 hot_reload: false,
24 validator: crate::protocol::validator::ProtocolValidator::default(),
25 cache: Mutex::new(LruCache::new(
28 std::num::NonZeroUsize::new(100)
29 .expect("Cache size must be non-zero (this should never happen)"),
30 )),
31 }
32 }
33
34 pub fn with_base_path(mut self, path: impl AsRef<Path>) -> Self {
36 self.base_path = Some(path.as_ref().to_path_buf());
37 self
38 }
39
40 pub fn with_hot_reload(mut self, enable: bool) -> Self {
42 self.hot_reload = enable;
43 self
44 }
45
46 pub async fn load_model(&self, model: &str) -> Result<ProtocolManifest, ProtocolError> {
49 {
51 let mut cache = self.cache.lock().map_err(|e| {
52 ProtocolError::Internal(format!(
53 "Failed to acquire cache lock while loading model '{}': {}",
54 model, e
55 ))
56 })?;
57 if let Some(manifest) = cache.get(model) {
58 return Ok(manifest.as_ref().clone());
59 }
60 }
61
62 let parts: Vec<&str> = model.split('/').collect();
64 if parts.len() < 2 {
65 return Err(ProtocolError::NotFound {
66 id: model.to_string(),
67 hint: Some("Ensure the model name follows the 'provider/model' format".to_string()),
68 });
69 }
70
71 let provider = parts[0];
72 let model_name = parts[1..].join("/");
73
74 let manifest = match self.load_model_config(&model_name).await {
78 Ok(model_config) => self.load_provider(&model_config.provider).await?,
79 Err(ProtocolError::NotFound { .. }) => self.load_provider(provider).await?,
80 Err(e) => return Err(e),
81 };
82
83 {
85 let mut cache = self.cache.lock().map_err(|e| {
86 ProtocolError::Internal(format!(
87 "Failed to acquire cache lock while caching model '{}': {}",
88 model, e
89 ))
90 })?;
91 cache.put(model.to_string(), Arc::new(manifest.clone()));
92 }
93
94 Ok(manifest)
95 }
96
97 pub async fn load_provider(
99 &self,
100 provider_id: &str,
101 ) -> Result<ProtocolManifest, ProtocolError> {
102 let mut search_locations: Vec<(PathBuf, bool)> = Vec::new(); if let Some(ref base_path) = self.base_path {
113 search_locations.push((base_path.join("dist").join("v1").join("providers"), true));
115 search_locations.push((base_path.join("v1").join("providers"), false));
117 }
118
119 if let Ok(root) =
121 std::env::var("AI_PROTOCOL_DIR").or_else(|_| std::env::var("AI_PROTOCOL_PATH"))
122 {
123 if root.starts_with("http://") || root.starts_with("https://") {
124 let url = if root.ends_with('/') {
128 format!("{}dist/v1/providers/{}.json", root, provider_id)
129 } else {
130 format!("{}/dist/v1/providers/{}.json", root, provider_id)
131 };
132
133 if let Ok(manifest) = self.load_from_json_url(&url).await {
135 return Ok(manifest);
136 }
137
138 let url_yaml = if root.ends_with('/') {
140 format!("{}v1/providers/{}.yaml", root, provider_id)
141 } else {
142 format!("{}/v1/providers/{}.yaml", root, provider_id)
143 };
144 return self.load_from_url(&url_yaml).await;
145 } else {
146 let root = PathBuf::from(root);
148 search_locations.push((root.join("dist").join("v1").join("providers"), true));
149 search_locations.push((root.join("v1").join("providers"), false));
150 }
151 }
152
153 let default_roots = vec![
155 PathBuf::from("ai-protocol"),
156 PathBuf::from("../ai-protocol"),
157 PathBuf::from("../../ai-protocol"),
158 PathBuf::from("D:\\ai-protocol"),
159 ];
160
161 for root in default_roots {
162 search_locations.push((root.join("dist").join("v1").join("providers"), true));
163 search_locations.push((root.join("v1").join("providers"), false));
164 }
165
166 for (base, prefer_json) in search_locations {
168 if prefer_json {
169 let json_path = base.join(format!("{}.json", provider_id));
170 if json_path.exists() {
171 return self.load_from_json_file(&json_path).await;
172 }
173 } else {
174 let yaml_path = base.join(format!("{}.yaml", provider_id));
175 if yaml_path.exists() {
176 return self.load_from_file(&yaml_path).await;
177 }
178 }
179 }
180
181 let github_json = format!(
183 "https://raw.githubusercontent.com/hiddenpath/ai-protocol/main/dist/v1/providers/{}.json",
184 provider_id
185 );
186 if let Ok(manifest) = self.load_from_json_url(&github_json).await {
187 return Ok(manifest);
188 }
189
190 let github_yaml = format!(
192 "https://raw.githubusercontent.com/hiddenpath/ai-protocol/main/v1/providers/{}.yaml",
193 provider_id
194 );
195 if let Ok(manifest) = self.load_from_url(&github_yaml).await {
196 return Ok(manifest);
197 }
198
199 Err(ProtocolError::NotFound {
200 id: provider_id.to_string(),
201 hint: Some(format!(
202 "Check if the provider file '{}.json' or '{}.yaml' exists in your protocol directory",
203 provider_id, provider_id
204 )),
205 })
206 }
207
208 async fn load_from_json_file(&self, path: &Path) -> Result<ProtocolManifest, ProtocolError> {
210 let content = tokio::fs::read(path)
211 .await
212 .map_err(|e| ProtocolError::LoadError {
213 path: path.to_string_lossy().to_string(),
214 reason: e.to_string(),
215 hint: Some("Check file permissions.".to_string()),
216 })?;
217
218 let manifest: ProtocolManifest = serde_json::from_slice(&content)
219 .map_err(|e| ProtocolError::ValidationError(format!("Invalid JSON manifest: {}", e)))?;
220
221 self.validator.validate(&manifest)?;
224
225 Ok(manifest)
226 }
227
228 async fn load_from_file(&self, path: &Path) -> Result<ProtocolManifest, ProtocolError> {
230 let bytes = tokio::fs::read(path)
232 .await
233 .map_err(|e| ProtocolError::LoadError {
234 path: path.to_string_lossy().to_string(),
235 reason: e.to_string(),
236 hint: Some("Check if the file exists and you have read permissions.".to_string()),
237 })?;
238
239 let content = if bytes.len() >= 2 && bytes[0] == 0xFF && bytes[1] == 0xFE {
241 let utf16_bytes = &bytes[2..];
243 let mut utf16_chars = Vec::new();
244 for i in (0..utf16_bytes.len()).step_by(2) {
245 if i + 1 < utf16_bytes.len() {
246 let code_unit = u16::from_le_bytes([utf16_bytes[i], utf16_bytes[i + 1]]);
247 utf16_chars.push(code_unit);
248 }
249 }
250 String::from_utf16(&utf16_chars).map_err(|e| ProtocolError::LoadError {
251 path: path.to_string_lossy().to_string(),
252 reason: format!("Invalid UTF-16: {}", e),
253 hint: None,
254 })?
255 } else if bytes.len() >= 3 && bytes[0] == 0xEF && bytes[1] == 0xBB && bytes[2] == 0xBF {
256 String::from_utf8(bytes[3..].to_vec()).map_err(|e| ProtocolError::LoadError {
257 path: path.to_string_lossy().to_string(),
258 reason: format!("Invalid UTF-8 (after BOM): {}", e),
259 hint: None,
260 })?
261 } else {
262 String::from_utf8(bytes).map_err(|e| ProtocolError::LoadError {
263 path: path.to_string_lossy().to_string(),
264 reason: format!("Invalid UTF-8: {}", e),
265 hint: None,
266 })?
267 };
268
269 let manifest: ProtocolManifest = Self::parse_manifest_yaml(&content)?;
270 self.validator.validate(&manifest)?;
271 Ok(manifest)
272 }
273
274 async fn load_from_json_url(&self, url: &str) -> Result<ProtocolManifest, ProtocolError> {
276 let client = reqwest::Client::builder()
277 .timeout(std::time::Duration::from_secs(30))
278 .build()
279 .map_err(|e| ProtocolError::Internal(format!("Failed to create HTTP client: {}", e)))?;
280
281 let response = client
282 .get(url)
283 .send()
284 .await
285 .map_err(|e| ProtocolError::LoadError {
286 path: url.to_string(),
287 reason: format!("HTTP request failed: {}", e),
288 hint: None,
289 })?;
290
291 if !response.status().is_success() {
292 return Err(ProtocolError::LoadError {
293 path: url.to_string(),
294 reason: format!("HTTP {}", response.status()),
295 hint: None,
296 });
297 }
298
299 let content = response
300 .bytes()
301 .await
302 .map_err(|e| ProtocolError::LoadError {
303 path: url.to_string(),
304 reason: format!("Failed to read bytes: {}", e),
305 hint: None,
306 })?;
307
308 let manifest: ProtocolManifest = serde_json::from_slice(&content).map_err(|e| {
309 ProtocolError::ValidationError(format!("Invalid JSON manifest from URL: {}", e))
310 })?;
311
312 self.validator.validate(&manifest)?;
313 Ok(manifest)
314 }
315
316 async fn load_from_url(&self, url: &str) -> Result<ProtocolManifest, ProtocolError> {
318 let client = reqwest::Client::builder()
319 .timeout(std::time::Duration::from_secs(30))
320 .build()
321 .map_err(|e| ProtocolError::Internal(format!("Failed to create HTTP client: {}", e)))?;
322
323 let response = client
324 .get(url)
325 .send()
326 .await
327 .map_err(|e| ProtocolError::LoadError {
328 path: url.to_string(),
329 reason: format!("HTTP request failed: {}", e),
330 hint: Some(
331 "Check your internet connection and verify the URL is accessible.".to_string(),
332 ),
333 })?;
334
335 if !response.status().is_success() {
336 return Err(ProtocolError::LoadError {
337 path: url.to_string(),
338 reason: format!(
339 "HTTP {}: {}",
340 response.status(),
341 response.text().await.unwrap_or_default()
342 ),
343 hint: Some(
344 "Verify the remote registry URL and your API permissions if any.".to_string(),
345 ),
346 });
347 }
348
349 let content = response
350 .text()
351 .await
352 .map_err(|e| ProtocolError::LoadError {
353 path: url.to_string(),
354 reason: format!("Failed to read response: {}", e),
355 hint: None,
356 })?;
357
358 let manifest: ProtocolManifest = Self::parse_manifest_yaml(&content)?;
359
360 self.validator.validate(&manifest)?;
362
363 Ok(manifest)
364 }
365
366 fn parse_manifest_yaml(content: &str) -> Result<ProtocolManifest, ProtocolError> {
372 serde_yaml::from_str::<ProtocolManifest>(content).map_err(|e| {
373 let msg = e.to_string();
374 let looks_structural = msg.contains("missing field")
377 || msg.contains("unknown field")
378 || msg.contains("invalid type")
379 || msg.contains("invalid value")
380 || msg.contains("expected");
381
382 if looks_structural {
383 ProtocolError::ValidationError(format!("Invalid manifest structure: {}", msg))
384 } else {
385 ProtocolError::YamlError(msg)
386 }
387 })
388 }
389
390 async fn load_model_config(&self, model_name: &str) -> Result<ModelConfig, ProtocolError> {
392 let mut search_locations: Vec<(PathBuf, bool)> = Vec::new(); if let Ok(root) =
399 std::env::var("AI_PROTOCOL_DIR").or_else(|_| std::env::var("AI_PROTOCOL_PATH"))
400 {
401 if !root.starts_with("http://") && !root.starts_with("https://") {
405 let root = PathBuf::from(root);
406 search_locations.push((root.join("dist").join("v1").join("models"), true));
407 search_locations.push((root.join("v1").join("models"), false));
408 }
409 }
410
411 let default_roots = vec![
413 PathBuf::from("ai-protocol"),
414 PathBuf::from("../ai-protocol"),
415 PathBuf::from("../../ai-protocol"),
416 PathBuf::from("D:\\ai-protocol"),
417 ];
418
419 for root in default_roots {
420 search_locations.push((root.join("dist").join("v1").join("models"), true));
421 search_locations.push((root.join("v1").join("models"), false));
422 }
423
424 for (base, prefer_json) in search_locations {
425 if !base.exists() {
426 continue;
427 }
428 let mut rd = match tokio::fs::read_dir(&base).await {
429 Ok(rd) => rd,
430 Err(_) => continue,
431 };
432
433 while let Ok(Some(entry)) = rd.next_entry().await {
434 let path = entry.path();
435 let extension = path.extension().and_then(|s| s.to_str());
436
437 let is_match = if prefer_json {
438 extension.map(|s| s.eq_ignore_ascii_case("json")) == Some(true)
439 } else {
440 extension
441 .map(|s| s.eq_ignore_ascii_case("yaml") || s.eq_ignore_ascii_case("yml"))
442 == Some(true)
443 };
444
445 if !is_match {
446 continue;
447 }
448
449 if prefer_json {
450 if let Ok(config) = self.load_model_registry_json(&path).await {
451 if let Some(model) = config.models.get(model_name) {
452 return Ok(model.clone());
453 }
454 }
455 } else {
456 if let Ok(config) = self.load_model_registry_yaml(&path).await {
457 if let Some(model) = config.models.get(model_name) {
458 return Ok(model.clone());
459 }
460 }
461 }
462 }
463 }
464
465 Err(ProtocolError::NotFound {
466 id: model_name.to_string(),
467 hint: Some(
468 "Check if the model is registered in the manifests/v1/models/ directory"
469 .to_string(),
470 ),
471 })
472 }
473
474 async fn load_model_registry_json(&self, path: &Path) -> Result<ModelRegistry, ProtocolError> {
475 let content = tokio::fs::read(path)
476 .await
477 .map_err(|e| ProtocolError::LoadError {
478 path: path.to_string_lossy().to_string(),
479 reason: e.to_string(),
480 hint: None,
481 })?;
482 let registry: ModelRegistry = serde_json::from_slice(&content).map_err(|e| {
483 ProtocolError::ValidationError(format!("Invalid JSON model registry: {}", e))
484 })?;
485 Ok(registry)
486 }
487
488 async fn load_model_registry_yaml(&self, path: &Path) -> Result<ModelRegistry, ProtocolError> {
489 let content =
490 tokio::fs::read_to_string(path)
491 .await
492 .map_err(|e| ProtocolError::LoadError {
493 path: path.to_string_lossy().to_string(),
494 reason: format!("Failed to read model registry: {}", e),
495 hint: None,
496 })?;
497
498 let registry: ModelRegistry = serde_yaml::from_str(&content).map_err(|e| {
499 ProtocolError::YamlError(format!("Failed to parse model registry: {}", e))
500 })?;
501
502 Ok(registry)
503 }
504}
505
506impl Default for ProtocolLoader {
507 fn default() -> Self {
508 Self::new()
509 }
510}
511
512#[derive(Debug, Clone, serde::Deserialize)]
514struct ModelRegistry {
515 models: std::collections::HashMap<String, ModelConfig>,
516}
517
518#[allow(dead_code)]
520#[derive(Debug, Clone, serde::Deserialize)]
521struct ModelConfig {
522 provider: String,
523 #[serde(default)]
524 model_id: Option<String>,
525 #[serde(default)]
526 context_window: Option<u32>,
527 #[serde(default)]
528 capabilities: Vec<String>,
529}
530
531pub struct ProtocolRegistry {
533 manifests: ArcSwap<std::collections::HashMap<String, Arc<ProtocolManifest>>>,
534 loader: ProtocolLoader,
535}
536
537impl ProtocolRegistry {
538 pub fn new() -> Self {
539 Self {
540 manifests: ArcSwap::from_pointee(std::collections::HashMap::new()),
541 loader: ProtocolLoader::new(),
542 }
543 }
544
545 pub async fn get_manifest(
547 &self,
548 provider_id: &str,
549 ) -> Result<Arc<ProtocolManifest>, ProtocolError> {
550 let current = self.manifests.load();
552 if let Some(manifest) = current.get(provider_id) {
553 return Ok(Arc::clone(manifest));
554 }
555
556 let manifest = self.loader.load_provider(provider_id).await?;
558 let manifest_arc = Arc::new(manifest);
559
560 let mut updated_map = std::collections::HashMap::new();
562 for (k, v) in current.iter() {
563 updated_map.insert(k.clone(), v.clone());
564 }
565 updated_map.insert(provider_id.to_string(), manifest_arc.clone());
566 self.manifests.store(Arc::new(updated_map));
567
568 Ok(manifest_arc)
569 }
570}
571
572impl Default for ProtocolRegistry {
573 fn default() -> Self {
574 Self::new()
575 }
576}