1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use hf_hub::api::tokio::ApiBuilder;
8use serde::{Deserialize, Serialize};
9use serde_json::{json, Value};
10use tokio::sync::{oneshot, Mutex, Notify, RwLock};
11
12use crate::error::Error;
13use crate::formatter::ChatFormatter;
14use crate::ipc::client::IPCClient;
15use crate::model::resolver::{ModelResolver, ResolvedModel};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
19pub enum ModelLoadState {
20 Idle,
22 Downloading,
24 Loading,
26 Activating,
28 Ready,
30 Failed,
32}
33
34impl std::fmt::Display for ModelLoadState {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match self {
37 Self::Idle => write!(f, "IDLE"),
38 Self::Downloading => write!(f, "DOWNLOADING"),
39 Self::Loading => write!(f, "LOADING"),
40 Self::Activating => write!(f, "ACTIVATING"),
41 Self::Ready => write!(f, "READY"),
42 Self::Failed => write!(f, "FAILED"),
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct ModelInfo {
50 pub model_id: String,
51 pub model_path: String,
52 pub formatter: Option<Arc<ChatFormatter>>,
53 pub capabilities: Option<HashMap<String, Vec<i32>>>,
54 pub minimum_memory_bytes: Option<u64>,
55}
56
57impl ModelInfo {
58 pub fn require_formatter(&self) -> std::result::Result<&ChatFormatter, crate::error::Error> {
59 self.formatter.as_deref().ok_or_else(|| {
60 crate::error::Error::ModelNotReady(format!(
61 "Model '{}' does not have a chat formatter",
62 self.model_id
63 ))
64 })
65 }
66}
67
68pub struct ModelEntry {
70 pub state: ModelLoadState,
71 pub info: Option<ModelInfo>,
72 pub error: Option<String>,
73 pub notify: Arc<Notify>,
74 pub resolved: Option<ResolvedModel>,
75 pub bytes_downloaded: Option<u64>,
76 pub bytes_total: Option<u64>,
77 pub activation_waiters: Vec<oneshot::Sender<Result<(), String>>>,
79}
80
81impl Default for ModelEntry {
82 fn default() -> Self {
83 Self {
84 state: ModelLoadState::Idle,
85 info: None,
86 error: None,
87 notify: Arc::new(Notify::new()),
88 resolved: None,
89 bytes_downloaded: None,
90 bytes_total: None,
91 activation_waiters: Vec::new(),
92 }
93 }
94}
95
96#[derive(Debug, Deserialize)]
97struct RemoteRepoFile {
98 rfilename: String,
99 #[serde(default)]
100 size: Option<u64>,
101}
102
103#[derive(Debug, Deserialize)]
104struct RemoteRepoInfo {
105 siblings: Vec<RemoteRepoFile>,
106}
107
108pub struct ModelRegistry {
110 entries: Arc<RwLock<HashMap<String, ModelEntry>>>,
111 resolver: Mutex<ModelResolver>,
112 alias_cache: RwLock<HashMap<String, String>>,
113 ipc_client: RwLock<Option<Arc<IPCClient>>>,
115}
116
117impl ModelRegistry {
118 pub fn new() -> Result<Self, Error> {
120 Ok(Self {
121 entries: Arc::new(RwLock::new(HashMap::new())),
122 resolver: Mutex::new(ModelResolver::new()?),
123 alias_cache: RwLock::new(HashMap::new()),
124 ipc_client: RwLock::new(None),
125 })
126 }
127
128 pub async fn set_ipc_client(&self, client: Arc<IPCClient>) {
130 let mut ipc = self.ipc_client.write().await;
131 *ipc = Some(client);
132 }
133
134 pub async fn ensure_loaded(&self, requested_model_id: &str) -> Result<ModelInfo, Error> {
143 let (_state, canonical_id) = self
144 .schedule_model(requested_model_id, false)
145 .await
146 .map_err(Error::ModelNotReady)?;
147
148 let (state, info, error) = self
150 .await_model(&canonical_id, None)
151 .await
152 .map_err(Error::ModelNotReady)?;
153
154 if state == ModelLoadState::Failed {
155 return Err(Error::ModelNotReady(error.unwrap_or_else(|| {
156 format!("Model '{}' failed to load", canonical_id)
157 })));
158 }
159
160 if state == ModelLoadState::Ready {
161 return info
162 .ok_or_else(|| Error::ModelNotReady("Model ready but info missing".to_string()));
163 }
164
165 let info = info.ok_or_else(|| {
167 Error::ModelNotReady(format!("Model '{}' info missing", canonical_id))
168 })?;
169
170 {
172 let entries = self.entries.read().await;
173 if let Some(entry) = entries.get(&canonical_id) {
174 if entry.state == ModelLoadState::Ready {
175 return entry
176 .info
177 .clone()
178 .ok_or_else(|| Error::ModelNotReady("Ready but no info".to_string()));
179 }
180 }
181 }
182
183 let activation_rx = self
185 .send_load_model_command(requested_model_id, &canonical_id, &info)
186 .await
187 .map_err(Error::ModelNotReady)?;
188
189 match activation_rx.await {
191 Ok(Ok(())) => {
192 self.get_if_ready(&canonical_id).await.ok_or_else(|| {
194 Error::ModelNotReady(format!("Model '{}' failed to activate", canonical_id))
195 })
196 }
197 Ok(Err(e)) => Err(Error::ModelNotReady(e)),
198 Err(_) => Err(Error::ModelNotReady(format!(
199 "Activation channel closed for '{}'",
200 canonical_id
201 ))),
202 }
203 }
204
205 pub async fn resolve_or_download(
207 &self,
208 requested_model_id: &str,
209 ) -> Result<(ModelLoadState, String), String> {
210 self.schedule_model(requested_model_id, false).await
211 }
212
213 async fn send_load_model_command(
215 &self,
216 requested_id: &str,
217 canonical_id: &str,
218 info: &ModelInfo,
219 ) -> Result<oneshot::Receiver<Result<(), String>>, String> {
220 let (tx, rx) = oneshot::channel();
222
223 {
225 let mut entries = self.entries.write().await;
226 let entry = entries
227 .get_mut(canonical_id)
228 .ok_or_else(|| format!("Model '{}' not in registry", canonical_id))?;
229
230 match entry.state {
231 ModelLoadState::Ready => {
232 let _ = tx.send(Ok(()));
233 return Ok(rx);
234 }
235 ModelLoadState::Failed => {
236 let message = entry
237 .error
238 .clone()
239 .unwrap_or_else(|| format!("Model '{}' failed to load", canonical_id));
240 let _ = tx.send(Err(message));
241 return Ok(rx);
242 }
243 ModelLoadState::Activating => {
244 entry.activation_waiters.push(tx);
245 return Ok(rx);
246 }
247 _ => {
248 entry.state = ModelLoadState::Activating;
249 entry.activation_waiters.push(tx);
250 }
251 }
252 }
253
254 let ipc = {
256 let guard = self.ipc_client.read().await;
257 guard
258 .clone()
259 .ok_or_else(|| "IPC client not set".to_string())?
260 };
261
262 let command = json!({
264 "type": "load_model",
265 "requested_id": requested_id,
266 "canonical_id": canonical_id,
267 "model_path": info.model_path,
268 "wait_for_completion": false,
269 });
270
271 let response = ipc
272 .send_management_command_async(command, Duration::from_secs(30))
273 .await
274 .map_err(|e| format!("Failed to send load_model command: {}", e))?;
275
276 let status = response
278 .get("status")
279 .and_then(|v| v.as_str())
280 .unwrap_or("");
281
282 match status {
283 "ok" => {
284 let capabilities = self.parse_capabilities(&response);
286 let minimum_memory_bytes = self.parse_minimum_memory_bytes(&response);
287 self.complete_activation(canonical_id, capabilities, minimum_memory_bytes)
288 .await;
289 }
290 "accepted" => {
291 tracing::debug!(
293 "Model '{}' activation accepted, waiting for model_loaded event",
294 canonical_id
295 );
296 }
297 _ => {
298 let message = response
299 .get("message")
300 .and_then(|v| v.as_str())
301 .unwrap_or("unknown error");
302 self.fail_activation(
303 canonical_id,
304 &format!("Engine rejected load_model: {}", message),
305 )
306 .await;
307 return Err(format!(
308 "Engine rejected load_model for '{}': {}",
309 requested_id, message
310 ));
311 }
312 }
313
314 Ok(rx)
315 }
316
317 fn parse_capabilities(&self, response: &Value) -> Option<HashMap<String, Vec<i32>>> {
319 response
320 .get("data")
321 .and_then(|d| d.get("load_model"))
322 .and_then(|lm| lm.get("capabilities"))
323 .and_then(|c| c.as_object())
324 .map(|obj| {
325 obj.iter()
326 .filter_map(|(k, v)| {
327 let vals: Vec<i32> = if let Some(arr) = v.as_array() {
328 arr.iter()
329 .filter_map(|x| x.as_i64().map(|n| n as i32))
330 .collect()
331 } else if let Some(n) = v.as_i64() {
332 vec![n as i32]
333 } else {
334 return None;
335 };
336 Some((k.clone(), vals))
337 })
338 .collect()
339 })
340 }
341
342 fn parse_minimum_memory_bytes(&self, response: &Value) -> Option<u64> {
343 response
344 .get("data")
345 .and_then(|d| d.get("load_model"))
346 .and_then(|lm| lm.get("minimum_memory_bytes"))
347 .and_then(|value| value.as_u64())
348 }
349
350 async fn complete_activation(
352 &self,
353 model_id: &str,
354 capabilities: Option<HashMap<String, Vec<i32>>>,
355 minimum_memory_bytes: Option<u64>,
356 ) {
357 let mut entries = self.entries.write().await;
358 if let Some(entry) = entries.get_mut(model_id) {
359 if let Some(ref mut info) = entry.info {
360 if let Some(capabilities) = capabilities {
361 info.capabilities = Some(capabilities);
362 }
363 if minimum_memory_bytes.is_some() {
364 info.minimum_memory_bytes = minimum_memory_bytes;
365 }
366 }
367 entry.state = ModelLoadState::Ready;
368 entry.notify.notify_waiters();
369
370 for tx in entry.activation_waiters.drain(..) {
372 let _ = tx.send(Ok(()));
373 }
374 }
375 }
376
377 async fn fail_activation(&self, model_id: &str, error: &str) {
379 let mut entries = self.entries.write().await;
380 if let Some(entry) = entries.get_mut(model_id) {
381 entry.state = ModelLoadState::Failed;
382 entry.error = Some(error.to_string());
383 entry.notify.notify_waiters();
384
385 for tx in entry.activation_waiters.drain(..) {
387 let _ = tx.send(Err(error.to_string()));
388 }
389 }
390 }
391
392 pub async fn schedule_model(
396 &self,
397 requested_model_id: &str,
398 force_reload: bool,
399 ) -> Result<(ModelLoadState, String), String> {
400 let resolved = {
401 let mut resolver = self.resolver.lock().await;
402 resolver
403 .resolve(requested_model_id)
404 .await
405 .map_err(|e| e.to_string())?
406 };
407
408 let canonical_id = resolved.canonical_id.clone();
409
410 {
411 let mut alias_cache = self.alias_cache.write().await;
412 alias_cache.insert(requested_model_id.to_lowercase(), canonical_id.clone());
413 alias_cache
414 .entry(canonical_id.to_lowercase())
415 .or_insert_with(|| canonical_id.clone());
416 }
417
418 let mut entries = self.entries.write().await;
419 let entry = entries
420 .entry(canonical_id.clone())
421 .or_insert_with(ModelEntry::default);
422
423 if entry.state == ModelLoadState::Ready && !force_reload {
424 return Ok((ModelLoadState::Ready, canonical_id));
425 }
426
427 if matches!(
428 entry.state,
429 ModelLoadState::Loading | ModelLoadState::Downloading | ModelLoadState::Activating
430 ) && !force_reload
431 {
432 return Ok((entry.state, canonical_id));
433 }
434
435 if entry.state == ModelLoadState::Failed && !force_reload {
436 return Ok((ModelLoadState::Failed, canonical_id));
437 }
438
439 entry.error = None;
440 entry.info = None;
441 entry.resolved = Some(resolved.clone());
442 entry.bytes_downloaded = None;
443 entry.bytes_total = None;
444 entry.notify = Arc::new(Notify::new());
445 entry.activation_waiters.clear();
446
447 if resolved.source == "local" || resolved.source == "hf_cache" {
448 let formatter = ChatFormatter::new(&resolved.model_path).ok().map(Arc::new);
449 entry.info = Some(ModelInfo {
450 model_id: canonical_id.clone(),
451 model_path: resolved.model_path.to_string_lossy().to_string(),
452 formatter,
453 capabilities: None,
454 minimum_memory_bytes: None,
455 });
456 entry.state = ModelLoadState::Loading;
457 entry.notify.notify_waiters();
458 return Ok((ModelLoadState::Loading, canonical_id));
459 }
460
461 entry.state = ModelLoadState::Downloading;
463 let notify = entry.notify.clone();
464
465 drop(entries);
467
468 let hf_repo = resolved
470 .hf_repo
471 .clone()
472 .unwrap_or_else(|| resolved.canonical_id.clone());
473 let canonical_id_for_task = canonical_id.clone();
474 let entries_ref = self.entries.clone();
475
476 tokio::spawn(async move {
477 let result = Self::download_model(
478 Arc::clone(&entries_ref),
479 canonical_id_for_task.as_str(),
480 hf_repo.as_str(),
481 )
482 .await;
483
484 let mut entries: tokio::sync::RwLockWriteGuard<'_, HashMap<String, ModelEntry>> =
485 entries_ref.write().await;
486 if let Some(entry) = entries.get_mut(&canonical_id_for_task) {
487 match result {
488 Ok(download_path) => {
489 if let Some(ref mut resolved) = entry.resolved {
491 resolved.model_path = download_path.clone();
492 resolved.source = "hf_cache".to_string();
493 }
494
495 let formatter = ChatFormatter::new(&download_path).ok().map(Arc::new);
496 entry.info = Some(ModelInfo {
497 model_id: canonical_id_for_task.clone(),
498 model_path: download_path.to_string_lossy().to_string(),
499 formatter,
500 capabilities: None,
501 minimum_memory_bytes: None,
502 });
503 entry.state = ModelLoadState::Loading;
504 }
505 Err(e) => {
506 entry.error = Some(format!("Download failed: {}", e));
507 entry.state = ModelLoadState::Failed;
508 }
509 }
510 notify.notify_waiters();
511 }
512 });
513
514 Ok((ModelLoadState::Downloading, canonical_id))
515 }
516
517 async fn download_model(
519 entries_ref: Arc<RwLock<HashMap<String, ModelEntry>>>,
520 canonical_id: &str,
521 repo_id: &str,
522 ) -> Result<std::path::PathBuf, String> {
523 tracing::info!("Downloading model from HuggingFace: {}", repo_id);
524
525 let api = ApiBuilder::from_env()
526 .with_progress(false)
527 .build()
528 .map_err(|e| format!("Failed to create HF API: {}", e))?;
529 let repo = api.model(repo_id.to_string());
530
531 let mut repo_info: RemoteRepoInfo = repo
532 .info_request()
533 .query(&[("blobs", "true")])
534 .send()
535 .await
536 .map_err(|e| format!("Failed to query repo info: {}", e))?
537 .json()
538 .await
539 .map_err(|e| format!("Failed to decode repo info: {}", e))?;
540
541 if repo_info.siblings.is_empty() {
542 return Err("Repository has no downloadable files".to_string());
543 }
544
545 if !repo_info
546 .siblings
547 .iter()
548 .any(|file| file.rfilename == "config.json")
549 {
550 return Err("Repository is missing config.json".to_string());
551 }
552
553 repo_info.siblings.sort_by(|left, right| {
554 let left_priority = u8::from(left.rfilename != "config.json");
555 let right_priority = u8::from(right.rfilename != "config.json");
556 left_priority
557 .cmp(&right_priority)
558 .then_with(|| left.rfilename.cmp(&right.rfilename))
559 });
560
561 let total_bytes = repo_info
562 .siblings
563 .iter()
564 .filter_map(|file| file.size)
565 .sum::<u64>();
566 let mut downloaded_bytes = 0u64;
567
568 {
569 let mut entries = entries_ref.write().await;
570 if let Some(entry) = entries.get_mut(canonical_id) {
571 entry.bytes_downloaded = Some(0);
572 entry.bytes_total = Some(total_bytes);
573 }
574 }
575
576 let mut model_dir: Option<std::path::PathBuf> = None;
577
578 for file in repo_info.siblings {
579 if Self::download_cancelled(&entries_ref, canonical_id).await {
580 return Err("Cancelled".to_string());
581 }
582
583 let path = repo
584 .get(file.rfilename.as_str())
585 .await
586 .map_err(|e| format!("Failed to download {}: {}", file.rfilename, e))?;
587
588 if file.rfilename == "config.json" {
589 model_dir = path.parent().map(|parent| parent.to_path_buf());
590 }
591
592 let file_size = tokio::fs::metadata(&path)
593 .await
594 .map(|metadata| metadata.len())
595 .unwrap_or_else(|_| file.size.unwrap_or(0));
596 downloaded_bytes = downloaded_bytes.saturating_add(file_size);
597
598 let completed_bytes = if total_bytes == 0 {
599 downloaded_bytes
600 } else {
601 downloaded_bytes.min(total_bytes)
602 };
603
604 let mut entries = entries_ref.write().await;
605 if let Some(entry) = entries.get_mut(canonical_id) {
606 entry.bytes_downloaded = Some(completed_bytes);
607 entry.bytes_total = Some(total_bytes.max(completed_bytes));
608 }
609 }
610
611 if Self::download_cancelled(&entries_ref, canonical_id).await {
612 return Err("Cancelled".to_string());
613 }
614
615 let model_dir =
616 model_dir.ok_or_else(|| "Downloaded repo is missing config.json".to_string())?;
617
618 tracing::info!("Model downloaded to {:?}", model_dir);
619 Ok(model_dir)
620 }
621
622 pub async fn cancel_download(&self, model_id: &str) -> Result<(), String> {
624 let canonical_id = self.canonicalize(model_id).await?;
625 let mut entries = self.entries.write().await;
626 let entry = entries
627 .get_mut(&canonical_id)
628 .ok_or_else(|| format!("Model '{}' has not been scheduled", canonical_id))?;
629
630 if entry.state != ModelLoadState::Downloading {
631 return Err(format!(
632 "Model '{}' is not downloading (current state: {})",
633 canonical_id, entry.state
634 ));
635 }
636
637 entry.state = ModelLoadState::Failed;
638 entry.error = Some("Cancelled".to_string());
639 entry.notify.notify_waiters();
640 Ok(())
641 }
642
643 pub async fn await_model(
648 &self,
649 model_id: &str,
650 timeout: Option<std::time::Duration>,
651 ) -> Result<(ModelLoadState, Option<ModelInfo>, Option<String>), String> {
652 let canonical_id = self.canonicalize(model_id).await?;
653
654 let _notify: Arc<Notify>;
655 let notified = {
656 let entries = self.entries.read().await;
657 let entry = entries
658 .get(&canonical_id)
659 .ok_or_else(|| format!("Model '{}' has not been scheduled", model_id))?;
660
661 if !matches!(
663 entry.state,
664 ModelLoadState::Downloading | ModelLoadState::Activating
665 ) {
666 return Ok((entry.state, entry.info.clone(), entry.error.clone()));
667 }
668
669 _notify = entry.notify.clone();
670 _notify.notified()
671 };
672
673 match timeout {
675 Some(d) => {
676 let _ = tokio::time::timeout(d, notified).await;
677 }
678 None => notified.await,
679 }
680
681 let entries = self.entries.read().await;
683 let entry = entries
684 .get(&canonical_id)
685 .ok_or_else(|| format!("Model '{}' not found", canonical_id))?;
686 Ok((entry.state, entry.info.clone(), entry.error.clone()))
687 }
688
689 pub async fn get_if_ready(&self, model_id: &str) -> Option<ModelInfo> {
691 let canonical_id = self.canonicalize(model_id).await.ok()?;
692 let entries = self.entries.read().await;
693 let entry = entries.get(&canonical_id)?;
694
695 if entry.state == ModelLoadState::Ready {
696 entry.info.clone()
697 } else {
698 None
699 }
700 }
701
702 pub async fn get_status(
704 &self,
705 model_id: &str,
706 ) -> (ModelLoadState, Option<String>, Option<(u64, u64)>) {
707 let canonical_id = match self.canonicalize(model_id).await {
708 Ok(id) => id,
709 Err(_) => return (ModelLoadState::Idle, None, None),
710 };
711
712 let entries = self.entries.read().await;
713 let entry = match entries.get(&canonical_id) {
714 Some(e) => e,
715 None => return (ModelLoadState::Idle, None, None),
716 };
717
718 let progress = match (entry.bytes_downloaded, entry.bytes_total) {
719 (Some(d), Some(t)) => Some((d, t)),
720 _ => None,
721 };
722
723 (entry.state, entry.error.clone(), progress)
724 }
725
726 pub async fn update_capabilities(
728 &self,
729 model_id: &str,
730 capabilities: HashMap<String, Vec<i32>>,
731 ) {
732 let canonical_id = match self.canonicalize(model_id).await {
733 Ok(id) => id,
734 Err(_) => {
735 tracing::warn!("Received capabilities for unknown model '{}'", model_id);
736 return;
737 }
738 };
739
740 let mut entries = self.entries.write().await;
741 if let Some(entry) = entries.get_mut(&canonical_id) {
742 if let Some(ref mut info) = entry.info {
743 info.capabilities = Some(capabilities);
744 }
745 }
746 }
747
748 pub async fn mark_ready(&self, model_id: &str) {
750 let canonical_id = match self.canonicalize(model_id).await {
751 Ok(id) => id,
752 Err(_) => return,
753 };
754
755 let mut entries = self.entries.write().await;
756 if let Some(entry) = entries.get_mut(&canonical_id) {
757 entry.state = ModelLoadState::Ready;
758 entry.notify.notify_waiters();
759 for tx in entry.activation_waiters.drain(..) {
760 let _ = tx.send(Ok(()));
761 }
762 }
763 }
764
765 pub async fn mark_failed(&self, model_id: &str, error: String) {
767 let canonical_id = match self.canonicalize(model_id).await {
768 Ok(id) => id,
769 Err(_) => return,
770 };
771
772 let mut entries = self.entries.write().await;
773 if let Some(entry) = entries.get_mut(&canonical_id) {
774 entry.state = ModelLoadState::Failed;
775 entry.error = Some(error.clone());
776 entry.notify.notify_waiters();
777
778 for tx in entry.activation_waiters.drain(..) {
780 let _ = tx.send(Err(error.clone()));
781 }
782 }
783 }
784
785 pub async fn handle_model_loaded(&self, payload: &Value) {
789 let model_id = match payload.get("model_id").and_then(|v| v.as_str()) {
790 Some(id) => id,
791 None => {
792 tracing::warn!("Received model_loaded event without model_id");
793 return;
794 }
795 };
796
797 if let Some(caps) = payload.get("capabilities").and_then(|c| c.as_object()) {
799 let capabilities: HashMap<String, Vec<i32>> = caps
800 .iter()
801 .filter_map(|(k, v)| {
802 let vals: Vec<i32> = if let Some(arr) = v.as_array() {
803 arr.iter()
804 .filter_map(|x| x.as_i64().map(|n| n as i32))
805 .collect()
806 } else if let Some(n) = v.as_i64() {
807 vec![n as i32]
808 } else {
809 return None;
810 };
811 Some((k.clone(), vals))
812 })
813 .collect();
814
815 if !capabilities.is_empty() {
816 self.update_capabilities(model_id, capabilities).await;
817 }
818 }
819
820 let canonical_id = match self.canonicalize(model_id).await {
822 Ok(id) => id,
823 Err(_) => {
824 tracing::debug!("Model '{}' not found in alias cache, using as-is", model_id);
826 model_id.to_string()
827 }
828 };
829
830 let minimum_memory_bytes = payload
832 .get("minimum_memory_bytes")
833 .and_then(|value| value.as_u64());
834 self.complete_activation(&canonical_id, None, minimum_memory_bytes)
835 .await;
836 }
837
838 pub async fn handle_model_load_failed(&self, payload: &Value) {
842 let model_id = match payload.get("model_id").and_then(|v| v.as_str()) {
843 Some(id) => id,
844 None => {
845 tracing::warn!("Received model_load_failed event without model_id");
846 return;
847 }
848 };
849
850 let error = match payload.get("error").and_then(|v| v.as_str()) {
851 Some(message) => message,
852 None => {
853 tracing::warn!(
854 "Received model_load_failed event without error for '{}'",
855 model_id
856 );
857 "unknown error"
858 }
859 };
860
861 let canonical_id = match self.canonicalize(model_id).await {
862 Ok(id) => id,
863 Err(_) => {
864 tracing::debug!("Model '{}' not found in alias cache, using as-is", model_id);
865 model_id.to_string()
866 }
867 };
868
869 self.fail_activation(&canonical_id, error).await;
870 }
871
872 pub async fn list_models(&self) -> Vec<HashMap<String, String>> {
874 let entries = self.entries.read().await;
875 let mut catalog = Vec::new();
876
877 for (canonical_id, entry) in entries.iter() {
878 if let Some(ref resolved) = entry.resolved {
879 let mut payload: HashMap<String, String> = resolved.metadata.clone();
880 payload.insert("canonical_id".to_string(), canonical_id.clone());
881 payload.insert(
882 "model_path".to_string(),
883 resolved.model_path.to_string_lossy().to_string(),
884 );
885 payload.insert("source".to_string(), resolved.source.clone());
886 payload.insert(
887 "hf_repo".to_string(),
888 resolved.hf_repo.clone().unwrap_or_default(),
889 );
890 payload.insert("state".to_string(), entry.state.to_string());
891 catalog.push(payload);
892 }
893 }
894
895 catalog
896 }
897
898 async fn canonicalize(&self, model_id: &str) -> Result<String, String> {
899 {
901 let entries = self.entries.read().await;
902 if entries.contains_key(model_id) {
903 return Ok(model_id.to_string());
904 }
905 }
906
907 {
909 let alias_cache = self.alias_cache.read().await;
910 if let Some(canonical) = alias_cache.get(&model_id.to_lowercase()) {
911 return Ok(canonical.clone());
912 }
913 }
914
915 Err(format!("Model '{}' not found in registry", model_id))
916 }
917
918 async fn download_cancelled(
919 entries_ref: &Arc<RwLock<HashMap<String, ModelEntry>>>,
920 canonical_id: &str,
921 ) -> bool {
922 let entries = entries_ref.read().await;
923 matches!(
924 entries.get(canonical_id),
925 Some(entry)
926 if entry.state == ModelLoadState::Failed
927 && entry.error.as_deref() == Some("Cancelled")
928 )
929 }
930}
931
932#[cfg(test)]
933mod tests {
934 use super::*;
935 use serde_json::json;
936 use tokio::sync::oneshot;
937
938 #[tokio::test]
939 async fn test_registry_creation() {
940 let registry = ModelRegistry::new().unwrap();
941 let models = registry.list_models().await;
942 assert!(models.is_empty());
943 }
944
945 #[tokio::test]
946 async fn test_handle_model_load_failed_fails_activation_waiters() {
947 let registry = ModelRegistry::new().unwrap();
948 let canonical_id = "moondream/moondream3-preview".to_string();
949 let requested_id = canonical_id.clone();
950 let error = "Weight shard file not found".to_string();
951 let (tx, rx) = oneshot::channel();
952
953 {
954 let mut alias_cache = registry.alias_cache.write().await;
955 alias_cache.insert(requested_id.to_lowercase(), canonical_id.clone());
956 }
957
958 {
959 let mut entries = registry.entries.write().await;
960 let entry = entries.entry(canonical_id.clone()).or_default();
961 entry.state = ModelLoadState::Activating;
962 entry.activation_waiters.push(tx);
963 }
964
965 registry
966 .handle_model_load_failed(&json!({
967 "model_id": requested_id,
968 "error": error,
969 }))
970 .await;
971
972 assert_eq!(rx.await.unwrap(), Err(error.clone()));
973
974 let entries = registry.entries.read().await;
975 let entry = entries.get(&canonical_id).unwrap();
976 assert_eq!(entry.state, ModelLoadState::Failed);
977 assert_eq!(entry.error.as_deref(), Some(error.as_str()));
978 assert!(entry.activation_waiters.is_empty());
979 }
980
981 #[test]
982 fn test_model_load_state_display() {
983 assert_eq!(ModelLoadState::Idle.to_string(), "IDLE");
984 assert_eq!(ModelLoadState::Ready.to_string(), "READY");
985 assert_eq!(ModelLoadState::Failed.to_string(), "FAILED");
986 }
987}