Skip to main content

orchard/model/
registry.rs

1//! Model registry for tracking loaded models and their state.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use hf_hub::api::tokio::ApiBuilder;
8use serde::{Deserialize, Serialize};
9use serde_json::{json, Value};
10use tokio::sync::{oneshot, Mutex, Notify, RwLock};
11
12use crate::error::Error;
13use crate::formatter::ChatFormatter;
14use crate::ipc::client::IPCClient;
15use crate::model::resolver::{ModelResolver, ResolvedModel};
16
17/// Model load state machine.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
19pub enum ModelLoadState {
20    /// Model not yet requested
21    Idle,
22    /// Downloading from HuggingFace
23    Downloading,
24    /// Loading weights into engine
25    Loading,
26    /// Waiting for engine activation
27    Activating,
28    /// Ready for inference
29    Ready,
30    /// Failed to load
31    Failed,
32}
33
34impl std::fmt::Display for ModelLoadState {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            Self::Idle => write!(f, "IDLE"),
38            Self::Downloading => write!(f, "DOWNLOADING"),
39            Self::Loading => write!(f, "LOADING"),
40            Self::Activating => write!(f, "ACTIVATING"),
41            Self::Ready => write!(f, "READY"),
42            Self::Failed => write!(f, "FAILED"),
43        }
44    }
45}
46
47/// Information about a loaded model.
48#[derive(Debug, Clone)]
49pub struct ModelInfo {
50    pub model_id: String,
51    pub model_path: String,
52    pub formatter: Option<Arc<ChatFormatter>>,
53    pub capabilities: Option<HashMap<String, Vec<i32>>>,
54    pub minimum_memory_bytes: Option<u64>,
55}
56
57impl ModelInfo {
58    pub fn require_formatter(&self) -> std::result::Result<&ChatFormatter, crate::error::Error> {
59        self.formatter.as_deref().ok_or_else(|| {
60            crate::error::Error::ModelNotReady(format!(
61                "Model '{}' does not have a chat formatter",
62                self.model_id
63            ))
64        })
65    }
66}
67
68/// Entry in the model registry tracking a model's state.
69pub struct ModelEntry {
70    pub state: ModelLoadState,
71    pub info: Option<ModelInfo>,
72    pub error: Option<String>,
73    pub notify: Arc<Notify>,
74    pub resolved: Option<ResolvedModel>,
75    pub bytes_downloaded: Option<u64>,
76    pub bytes_total: Option<u64>,
77    /// Activation completion waiters (set when ACTIVATING)
78    pub activation_waiters: Vec<oneshot::Sender<Result<(), String>>>,
79}
80
81impl Default for ModelEntry {
82    fn default() -> Self {
83        Self {
84            state: ModelLoadState::Idle,
85            info: None,
86            error: None,
87            notify: Arc::new(Notify::new()),
88            resolved: None,
89            bytes_downloaded: None,
90            bytes_total: None,
91            activation_waiters: Vec::new(),
92        }
93    }
94}
95
96#[derive(Debug, Deserialize)]
97struct RemoteRepoFile {
98    rfilename: String,
99    #[serde(default)]
100    size: Option<u64>,
101}
102
103#[derive(Debug, Deserialize)]
104struct RemoteRepoInfo {
105    siblings: Vec<RemoteRepoFile>,
106}
107
108/// Registry of loaded models.
109pub struct ModelRegistry {
110    entries: Arc<RwLock<HashMap<String, ModelEntry>>>,
111    resolver: Mutex<ModelResolver>,
112    alias_cache: RwLock<HashMap<String, String>>,
113    /// IPC client for sending management commands to PIE
114    ipc_client: RwLock<Option<Arc<IPCClient>>>,
115}
116
117impl ModelRegistry {
118    /// Create a new model registry.
119    pub fn new() -> Result<Self, Error> {
120        Ok(Self {
121            entries: Arc::new(RwLock::new(HashMap::new())),
122            resolver: Mutex::new(ModelResolver::new()?),
123            alias_cache: RwLock::new(HashMap::new()),
124            ipc_client: RwLock::new(None),
125        })
126    }
127
128    /// Set the IPC client for sending management commands to PIE.
129    pub async fn set_ipc_client(&self, client: Arc<IPCClient>) {
130        let mut ipc = self.ipc_client.write().await;
131        *ipc = Some(client);
132    }
133
134    /// Ensure a model is loaded and ready.
135    ///
136    /// This will:
137    /// 1. Resolve the model identifier
138    /// 2. Download if needed
139    /// 3. Load the formatter
140    /// 4. Send load_model command to PIE
141    /// 5. Wait for engine activation
142    pub async fn ensure_loaded(&self, requested_model_id: &str) -> Result<ModelInfo, Error> {
143        let (_state, canonical_id) = self
144            .schedule_model(requested_model_id, false)
145            .await
146            .map_err(Error::ModelNotReady)?;
147
148        // Wait for local readiness (download + formatter)
149        let (state, info, error) = self
150            .await_model(&canonical_id, None)
151            .await
152            .map_err(Error::ModelNotReady)?;
153
154        if state == ModelLoadState::Failed {
155            return Err(Error::ModelNotReady(error.unwrap_or_else(|| {
156                format!("Model '{}' failed to load", canonical_id)
157            })));
158        }
159
160        if state == ModelLoadState::Ready {
161            return info
162                .ok_or_else(|| Error::ModelNotReady("Model ready but info missing".to_string()));
163        }
164
165        // At this point we have LOADING state with info - need to activate on PIE
166        let info = info.ok_or_else(|| {
167            Error::ModelNotReady(format!("Model '{}' info missing", canonical_id))
168        })?;
169
170        // Check if already activating or ready
171        {
172            let entries = self.entries.read().await;
173            if let Some(entry) = entries.get(&canonical_id) {
174                if entry.state == ModelLoadState::Ready {
175                    return entry
176                        .info
177                        .clone()
178                        .ok_or_else(|| Error::ModelNotReady("Ready but no info".to_string()));
179                }
180            }
181        }
182
183        // Send load_model command and wait for activation
184        let activation_rx = self
185            .send_load_model_command(requested_model_id, &canonical_id, &info)
186            .await
187            .map_err(Error::ModelNotReady)?;
188
189        // Wait for activation to complete
190        match activation_rx.await {
191            Ok(Ok(())) => {
192                // Activation succeeded, get the ready info
193                self.get_if_ready(&canonical_id).await.ok_or_else(|| {
194                    Error::ModelNotReady(format!("Model '{}' failed to activate", canonical_id))
195                })
196            }
197            Ok(Err(e)) => Err(Error::ModelNotReady(e)),
198            Err(_) => Err(Error::ModelNotReady(format!(
199                "Activation channel closed for '{}'",
200                canonical_id
201            ))),
202        }
203    }
204
205    /// Resolve a model identifier and schedule any required local download work.
206    pub async fn resolve_or_download(
207        &self,
208        requested_model_id: &str,
209    ) -> Result<(ModelLoadState, String), String> {
210        self.schedule_model(requested_model_id, false).await
211    }
212
213    /// Send the load_model command to PIE.
214    async fn send_load_model_command(
215        &self,
216        requested_id: &str,
217        canonical_id: &str,
218        info: &ModelInfo,
219    ) -> Result<oneshot::Receiver<Result<(), String>>, String> {
220        // Create activation channel
221        let (tx, rx) = oneshot::channel();
222
223        // Set up activation state
224        {
225            let mut entries = self.entries.write().await;
226            let entry = entries
227                .get_mut(canonical_id)
228                .ok_or_else(|| format!("Model '{}' not in registry", canonical_id))?;
229
230            match entry.state {
231                ModelLoadState::Ready => {
232                    let _ = tx.send(Ok(()));
233                    return Ok(rx);
234                }
235                ModelLoadState::Failed => {
236                    let message = entry
237                        .error
238                        .clone()
239                        .unwrap_or_else(|| format!("Model '{}' failed to load", canonical_id));
240                    let _ = tx.send(Err(message));
241                    return Ok(rx);
242                }
243                ModelLoadState::Activating => {
244                    entry.activation_waiters.push(tx);
245                    return Ok(rx);
246                }
247                _ => {
248                    entry.state = ModelLoadState::Activating;
249                    entry.activation_waiters.push(tx);
250                }
251            }
252        }
253
254        // Get IPC client
255        let ipc = {
256            let guard = self.ipc_client.read().await;
257            guard
258                .clone()
259                .ok_or_else(|| "IPC client not set".to_string())?
260        };
261
262        // Build and send the command
263        let command = json!({
264            "type": "load_model",
265            "requested_id": requested_id,
266            "canonical_id": canonical_id,
267            "model_path": info.model_path,
268            "wait_for_completion": false,
269        });
270
271        let response = ipc
272            .send_management_command_async(command, Duration::from_secs(30))
273            .await
274            .map_err(|e| format!("Failed to send load_model command: {}", e))?;
275
276        // Check response status
277        let status = response
278            .get("status")
279            .and_then(|v| v.as_str())
280            .unwrap_or("");
281
282        match status {
283            "ok" => {
284                // Immediate success - extract capabilities and mark ready
285                let capabilities = self.parse_capabilities(&response);
286                let minimum_memory_bytes = self.parse_minimum_memory_bytes(&response);
287                self.complete_activation(canonical_id, capabilities, minimum_memory_bytes)
288                    .await;
289            }
290            "accepted" => {
291                // Async activation - wait for model_loaded event
292                tracing::debug!(
293                    "Model '{}' activation accepted, waiting for model_loaded event",
294                    canonical_id
295                );
296            }
297            _ => {
298                let message = response
299                    .get("message")
300                    .and_then(|v| v.as_str())
301                    .unwrap_or("unknown error");
302                self.fail_activation(
303                    canonical_id,
304                    &format!("Engine rejected load_model: {}", message),
305                )
306                .await;
307                return Err(format!(
308                    "Engine rejected load_model for '{}': {}",
309                    requested_id, message
310                ));
311            }
312        }
313
314        Ok(rx)
315    }
316
317    /// Parse capabilities from management response.
318    fn parse_capabilities(&self, response: &Value) -> Option<HashMap<String, Vec<i32>>> {
319        response
320            .get("data")
321            .and_then(|d| d.get("load_model"))
322            .and_then(|lm| lm.get("capabilities"))
323            .and_then(|c| c.as_object())
324            .map(|obj| {
325                obj.iter()
326                    .filter_map(|(k, v)| {
327                        let vals: Vec<i32> = if let Some(arr) = v.as_array() {
328                            arr.iter()
329                                .filter_map(|x| x.as_i64().map(|n| n as i32))
330                                .collect()
331                        } else if let Some(n) = v.as_i64() {
332                            vec![n as i32]
333                        } else {
334                            return None;
335                        };
336                        Some((k.clone(), vals))
337                    })
338                    .collect()
339            })
340    }
341
342    fn parse_minimum_memory_bytes(&self, response: &Value) -> Option<u64> {
343        response
344            .get("data")
345            .and_then(|d| d.get("load_model"))
346            .and_then(|lm| lm.get("minimum_memory_bytes"))
347            .and_then(|value| value.as_u64())
348    }
349
350    /// Complete activation successfully.
351    async fn complete_activation(
352        &self,
353        model_id: &str,
354        capabilities: Option<HashMap<String, Vec<i32>>>,
355        minimum_memory_bytes: Option<u64>,
356    ) {
357        let mut entries = self.entries.write().await;
358        if let Some(entry) = entries.get_mut(model_id) {
359            if let Some(ref mut info) = entry.info {
360                if let Some(capabilities) = capabilities {
361                    info.capabilities = Some(capabilities);
362                }
363                if minimum_memory_bytes.is_some() {
364                    info.minimum_memory_bytes = minimum_memory_bytes;
365                }
366            }
367            entry.state = ModelLoadState::Ready;
368            entry.notify.notify_waiters();
369
370            // Signal activation complete
371            for tx in entry.activation_waiters.drain(..) {
372                let _ = tx.send(Ok(()));
373            }
374        }
375    }
376
377    /// Fail activation with error.
378    async fn fail_activation(&self, model_id: &str, error: &str) {
379        let mut entries = self.entries.write().await;
380        if let Some(entry) = entries.get_mut(model_id) {
381            entry.state = ModelLoadState::Failed;
382            entry.error = Some(error.to_string());
383            entry.notify.notify_waiters();
384
385            // Signal activation failed
386            for tx in entry.activation_waiters.drain(..) {
387                let _ = tx.send(Err(error.to_string()));
388            }
389        }
390    }
391
392    /// Schedule a model for loading.
393    ///
394    /// Returns the current state and canonical ID.
395    pub async fn schedule_model(
396        &self,
397        requested_model_id: &str,
398        force_reload: bool,
399    ) -> Result<(ModelLoadState, String), String> {
400        let resolved = {
401            let mut resolver = self.resolver.lock().await;
402            resolver
403                .resolve(requested_model_id)
404                .await
405                .map_err(|e| e.to_string())?
406        };
407
408        let canonical_id = resolved.canonical_id.clone();
409
410        {
411            let mut alias_cache = self.alias_cache.write().await;
412            alias_cache.insert(requested_model_id.to_lowercase(), canonical_id.clone());
413            alias_cache
414                .entry(canonical_id.to_lowercase())
415                .or_insert_with(|| canonical_id.clone());
416        }
417
418        let mut entries = self.entries.write().await;
419        let entry = entries
420            .entry(canonical_id.clone())
421            .or_insert_with(ModelEntry::default);
422
423        if entry.state == ModelLoadState::Ready && !force_reload {
424            return Ok((ModelLoadState::Ready, canonical_id));
425        }
426
427        if matches!(
428            entry.state,
429            ModelLoadState::Loading | ModelLoadState::Downloading | ModelLoadState::Activating
430        ) && !force_reload
431        {
432            return Ok((entry.state, canonical_id));
433        }
434
435        if entry.state == ModelLoadState::Failed && !force_reload {
436            return Ok((ModelLoadState::Failed, canonical_id));
437        }
438
439        entry.error = None;
440        entry.info = None;
441        entry.resolved = Some(resolved.clone());
442        entry.bytes_downloaded = None;
443        entry.bytes_total = None;
444        entry.notify = Arc::new(Notify::new());
445        entry.activation_waiters.clear();
446
447        if resolved.source == "local" || resolved.source == "hf_cache" {
448            let formatter = ChatFormatter::new(&resolved.model_path).ok().map(Arc::new);
449            entry.info = Some(ModelInfo {
450                model_id: canonical_id.clone(),
451                model_path: resolved.model_path.to_string_lossy().to_string(),
452                formatter,
453                capabilities: None,
454                minimum_memory_bytes: None,
455            });
456            entry.state = ModelLoadState::Loading;
457            entry.notify.notify_waiters();
458            return Ok((ModelLoadState::Loading, canonical_id));
459        }
460
461        // Model needs to be downloaded from HuggingFace
462        entry.state = ModelLoadState::Downloading;
463        let notify = entry.notify.clone();
464
465        // Drop entries lock before spawning to avoid deadlock
466        drop(entries);
467
468        // Spawn download task
469        let hf_repo = resolved
470            .hf_repo
471            .clone()
472            .unwrap_or_else(|| resolved.canonical_id.clone());
473        let canonical_id_for_task = canonical_id.clone();
474        let entries_ref = self.entries.clone();
475
476        tokio::spawn(async move {
477            let result = Self::download_model(
478                Arc::clone(&entries_ref),
479                canonical_id_for_task.as_str(),
480                hf_repo.as_str(),
481            )
482            .await;
483
484            let mut entries: tokio::sync::RwLockWriteGuard<'_, HashMap<String, ModelEntry>> =
485                entries_ref.write().await;
486            if let Some(entry) = entries.get_mut(&canonical_id_for_task) {
487                match result {
488                    Ok(download_path) => {
489                        // Update resolved path
490                        if let Some(ref mut resolved) = entry.resolved {
491                            resolved.model_path = download_path.clone();
492                            resolved.source = "hf_cache".to_string();
493                        }
494
495                        let formatter = ChatFormatter::new(&download_path).ok().map(Arc::new);
496                        entry.info = Some(ModelInfo {
497                            model_id: canonical_id_for_task.clone(),
498                            model_path: download_path.to_string_lossy().to_string(),
499                            formatter,
500                            capabilities: None,
501                            minimum_memory_bytes: None,
502                        });
503                        entry.state = ModelLoadState::Loading;
504                    }
505                    Err(e) => {
506                        entry.error = Some(format!("Download failed: {}", e));
507                        entry.state = ModelLoadState::Failed;
508                    }
509                }
510                notify.notify_waiters();
511            }
512        });
513
514        Ok((ModelLoadState::Downloading, canonical_id))
515    }
516
517    /// Download a model from HuggingFace Hub.
518    async fn download_model(
519        entries_ref: Arc<RwLock<HashMap<String, ModelEntry>>>,
520        canonical_id: &str,
521        repo_id: &str,
522    ) -> Result<std::path::PathBuf, String> {
523        tracing::info!("Downloading model from HuggingFace: {}", repo_id);
524
525        let api = ApiBuilder::from_env()
526            .with_progress(false)
527            .build()
528            .map_err(|e| format!("Failed to create HF API: {}", e))?;
529        let repo = api.model(repo_id.to_string());
530
531        let mut repo_info: RemoteRepoInfo = repo
532            .info_request()
533            .query(&[("blobs", "true")])
534            .send()
535            .await
536            .map_err(|e| format!("Failed to query repo info: {}", e))?
537            .json()
538            .await
539            .map_err(|e| format!("Failed to decode repo info: {}", e))?;
540
541        if repo_info.siblings.is_empty() {
542            return Err("Repository has no downloadable files".to_string());
543        }
544
545        if !repo_info
546            .siblings
547            .iter()
548            .any(|file| file.rfilename == "config.json")
549        {
550            return Err("Repository is missing config.json".to_string());
551        }
552
553        repo_info.siblings.sort_by(|left, right| {
554            let left_priority = u8::from(left.rfilename != "config.json");
555            let right_priority = u8::from(right.rfilename != "config.json");
556            left_priority
557                .cmp(&right_priority)
558                .then_with(|| left.rfilename.cmp(&right.rfilename))
559        });
560
561        let total_bytes = repo_info
562            .siblings
563            .iter()
564            .filter_map(|file| file.size)
565            .sum::<u64>();
566        let mut downloaded_bytes = 0u64;
567
568        {
569            let mut entries = entries_ref.write().await;
570            if let Some(entry) = entries.get_mut(canonical_id) {
571                entry.bytes_downloaded = Some(0);
572                entry.bytes_total = Some(total_bytes);
573            }
574        }
575
576        let mut model_dir: Option<std::path::PathBuf> = None;
577
578        for file in repo_info.siblings {
579            if Self::download_cancelled(&entries_ref, canonical_id).await {
580                return Err("Cancelled".to_string());
581            }
582
583            let path = repo
584                .get(file.rfilename.as_str())
585                .await
586                .map_err(|e| format!("Failed to download {}: {}", file.rfilename, e))?;
587
588            if file.rfilename == "config.json" {
589                model_dir = path.parent().map(|parent| parent.to_path_buf());
590            }
591
592            let file_size = tokio::fs::metadata(&path)
593                .await
594                .map(|metadata| metadata.len())
595                .unwrap_or_else(|_| file.size.unwrap_or(0));
596            downloaded_bytes = downloaded_bytes.saturating_add(file_size);
597
598            let completed_bytes = if total_bytes == 0 {
599                downloaded_bytes
600            } else {
601                downloaded_bytes.min(total_bytes)
602            };
603
604            let mut entries = entries_ref.write().await;
605            if let Some(entry) = entries.get_mut(canonical_id) {
606                entry.bytes_downloaded = Some(completed_bytes);
607                entry.bytes_total = Some(total_bytes.max(completed_bytes));
608            }
609        }
610
611        if Self::download_cancelled(&entries_ref, canonical_id).await {
612            return Err("Cancelled".to_string());
613        }
614
615        let model_dir =
616            model_dir.ok_or_else(|| "Downloaded repo is missing config.json".to_string())?;
617
618        tracing::info!("Model downloaded to {:?}", model_dir);
619        Ok(model_dir)
620    }
621
622    /// Cancel an in-progress download.
623    pub async fn cancel_download(&self, model_id: &str) -> Result<(), String> {
624        let canonical_id = self.canonicalize(model_id).await?;
625        let mut entries = self.entries.write().await;
626        let entry = entries
627            .get_mut(&canonical_id)
628            .ok_or_else(|| format!("Model '{}' has not been scheduled", canonical_id))?;
629
630        if entry.state != ModelLoadState::Downloading {
631            return Err(format!(
632                "Model '{}' is not downloading (current state: {})",
633                canonical_id, entry.state
634            ));
635        }
636
637        entry.state = ModelLoadState::Failed;
638        entry.error = Some("Cancelled".to_string());
639        entry.notify.notify_waiters();
640        Ok(())
641    }
642
643    /// Wait for a model to finish loading.
644    ///
645    /// Returns immediately if state is already terminal (Loading/Ready/Failed).
646    /// Only blocks if state is Downloading or Activating.
647    pub async fn await_model(
648        &self,
649        model_id: &str,
650        timeout: Option<std::time::Duration>,
651    ) -> Result<(ModelLoadState, Option<ModelInfo>, Option<String>), String> {
652        let canonical_id = self.canonicalize(model_id).await?;
653
654        let _notify: Arc<Notify>;
655        let notified = {
656            let entries = self.entries.read().await;
657            let entry = entries
658                .get(&canonical_id)
659                .ok_or_else(|| format!("Model '{}' has not been scheduled", model_id))?;
660
661            // Notify is edge-triggered: arm the waiter while holding the lock.
662            if !matches!(
663                entry.state,
664                ModelLoadState::Downloading | ModelLoadState::Activating
665            ) {
666                return Ok((entry.state, entry.info.clone(), entry.error.clone()));
667            }
668
669            _notify = entry.notify.clone();
670            _notify.notified()
671        };
672
673        // Wait for state transition
674        match timeout {
675            Some(d) => {
676                let _ = tokio::time::timeout(d, notified).await;
677            }
678            None => notified.await,
679        }
680
681        // Re-read final state
682        let entries = self.entries.read().await;
683        let entry = entries
684            .get(&canonical_id)
685            .ok_or_else(|| format!("Model '{}' not found", canonical_id))?;
686        Ok((entry.state, entry.info.clone(), entry.error.clone()))
687    }
688
689    /// Get model info if ready.
690    pub async fn get_if_ready(&self, model_id: &str) -> Option<ModelInfo> {
691        let canonical_id = self.canonicalize(model_id).await.ok()?;
692        let entries = self.entries.read().await;
693        let entry = entries.get(&canonical_id)?;
694
695        if entry.state == ModelLoadState::Ready {
696            entry.info.clone()
697        } else {
698            None
699        }
700    }
701
702    /// Get model status.
703    pub async fn get_status(
704        &self,
705        model_id: &str,
706    ) -> (ModelLoadState, Option<String>, Option<(u64, u64)>) {
707        let canonical_id = match self.canonicalize(model_id).await {
708            Ok(id) => id,
709            Err(_) => return (ModelLoadState::Idle, None, None),
710        };
711
712        let entries = self.entries.read().await;
713        let entry = match entries.get(&canonical_id) {
714            Some(e) => e,
715            None => return (ModelLoadState::Idle, None, None),
716        };
717
718        let progress = match (entry.bytes_downloaded, entry.bytes_total) {
719            (Some(d), Some(t)) => Some((d, t)),
720            _ => None,
721        };
722
723        (entry.state, entry.error.clone(), progress)
724    }
725
726    /// Update capabilities for a loaded model.
727    pub async fn update_capabilities(
728        &self,
729        model_id: &str,
730        capabilities: HashMap<String, Vec<i32>>,
731    ) {
732        let canonical_id = match self.canonicalize(model_id).await {
733            Ok(id) => id,
734            Err(_) => {
735                tracing::warn!("Received capabilities for unknown model '{}'", model_id);
736                return;
737            }
738        };
739
740        let mut entries = self.entries.write().await;
741        if let Some(entry) = entries.get_mut(&canonical_id) {
742            if let Some(ref mut info) = entry.info {
743                info.capabilities = Some(capabilities);
744            }
745        }
746    }
747
748    /// Mark a model as ready (called when engine confirms activation).
749    pub async fn mark_ready(&self, model_id: &str) {
750        let canonical_id = match self.canonicalize(model_id).await {
751            Ok(id) => id,
752            Err(_) => return,
753        };
754
755        let mut entries = self.entries.write().await;
756        if let Some(entry) = entries.get_mut(&canonical_id) {
757            entry.state = ModelLoadState::Ready;
758            entry.notify.notify_waiters();
759            for tx in entry.activation_waiters.drain(..) {
760                let _ = tx.send(Ok(()));
761            }
762        }
763    }
764
765    /// Mark a model as failed.
766    pub async fn mark_failed(&self, model_id: &str, error: String) {
767        let canonical_id = match self.canonicalize(model_id).await {
768            Ok(id) => id,
769            Err(_) => return,
770        };
771
772        let mut entries = self.entries.write().await;
773        if let Some(entry) = entries.get_mut(&canonical_id) {
774            entry.state = ModelLoadState::Failed;
775            entry.error = Some(error.clone());
776            entry.notify.notify_waiters();
777
778            // Signal activation failed if waiting
779            for tx in entry.activation_waiters.drain(..) {
780                let _ = tx.send(Err(error.clone()));
781            }
782        }
783    }
784
785    /// Handle model_loaded event from PIE.
786    ///
787    /// Called by the event callback when a model_loaded event is received.
788    pub async fn handle_model_loaded(&self, payload: &Value) {
789        let model_id = match payload.get("model_id").and_then(|v| v.as_str()) {
790            Some(id) => id,
791            None => {
792                tracing::warn!("Received model_loaded event without model_id");
793                return;
794            }
795        };
796
797        // Extract and update capabilities
798        if let Some(caps) = payload.get("capabilities").and_then(|c| c.as_object()) {
799            let capabilities: HashMap<String, Vec<i32>> = caps
800                .iter()
801                .filter_map(|(k, v)| {
802                    let vals: Vec<i32> = if let Some(arr) = v.as_array() {
803                        arr.iter()
804                            .filter_map(|x| x.as_i64().map(|n| n as i32))
805                            .collect()
806                    } else if let Some(n) = v.as_i64() {
807                        vec![n as i32]
808                    } else {
809                        return None;
810                    };
811                    Some((k.clone(), vals))
812                })
813                .collect();
814
815            if !capabilities.is_empty() {
816                self.update_capabilities(model_id, capabilities).await;
817            }
818        }
819
820        // Canonicalize the model_id first (fix for issue #4)
821        let canonical_id = match self.canonicalize(model_id).await {
822            Ok(id) => id,
823            Err(_) => {
824                // Try the raw model_id as fallback
825                tracing::debug!("Model '{}' not found in alias cache, using as-is", model_id);
826                model_id.to_string()
827            }
828        };
829
830        // Complete the activation
831        let minimum_memory_bytes = payload
832            .get("minimum_memory_bytes")
833            .and_then(|value| value.as_u64());
834        self.complete_activation(&canonical_id, None, minimum_memory_bytes)
835            .await;
836    }
837
838    /// Handle model_load_failed event from PIE.
839    ///
840    /// Called by the event callback when a model_load_failed event is received.
841    pub async fn handle_model_load_failed(&self, payload: &Value) {
842        let model_id = match payload.get("model_id").and_then(|v| v.as_str()) {
843            Some(id) => id,
844            None => {
845                tracing::warn!("Received model_load_failed event without model_id");
846                return;
847            }
848        };
849
850        let error = match payload.get("error").and_then(|v| v.as_str()) {
851            Some(message) => message,
852            None => {
853                tracing::warn!(
854                    "Received model_load_failed event without error for '{}'",
855                    model_id
856                );
857                "unknown error"
858            }
859        };
860
861        let canonical_id = match self.canonicalize(model_id).await {
862            Ok(id) => id,
863            Err(_) => {
864                tracing::debug!("Model '{}' not found in alias cache, using as-is", model_id);
865                model_id.to_string()
866            }
867        };
868
869        self.fail_activation(&canonical_id, error).await;
870    }
871
872    /// List all registered models.
873    pub async fn list_models(&self) -> Vec<HashMap<String, String>> {
874        let entries = self.entries.read().await;
875        let mut catalog = Vec::new();
876
877        for (canonical_id, entry) in entries.iter() {
878            if let Some(ref resolved) = entry.resolved {
879                let mut payload: HashMap<String, String> = resolved.metadata.clone();
880                payload.insert("canonical_id".to_string(), canonical_id.clone());
881                payload.insert(
882                    "model_path".to_string(),
883                    resolved.model_path.to_string_lossy().to_string(),
884                );
885                payload.insert("source".to_string(), resolved.source.clone());
886                payload.insert(
887                    "hf_repo".to_string(),
888                    resolved.hf_repo.clone().unwrap_or_default(),
889                );
890                payload.insert("state".to_string(), entry.state.to_string());
891                catalog.push(payload);
892            }
893        }
894
895        catalog
896    }
897
898    async fn canonicalize(&self, model_id: &str) -> Result<String, String> {
899        // Check if it's already a canonical ID
900        {
901            let entries = self.entries.read().await;
902            if entries.contains_key(model_id) {
903                return Ok(model_id.to_string());
904            }
905        }
906
907        // Check alias cache
908        {
909            let alias_cache = self.alias_cache.read().await;
910            if let Some(canonical) = alias_cache.get(&model_id.to_lowercase()) {
911                return Ok(canonical.clone());
912            }
913        }
914
915        Err(format!("Model '{}' not found in registry", model_id))
916    }
917
918    async fn download_cancelled(
919        entries_ref: &Arc<RwLock<HashMap<String, ModelEntry>>>,
920        canonical_id: &str,
921    ) -> bool {
922        let entries = entries_ref.read().await;
923        matches!(
924            entries.get(canonical_id),
925            Some(entry)
926                if entry.state == ModelLoadState::Failed
927                    && entry.error.as_deref() == Some("Cancelled")
928        )
929    }
930}
931
932#[cfg(test)]
933mod tests {
934    use super::*;
935    use serde_json::json;
936    use tokio::sync::oneshot;
937
938    #[tokio::test]
939    async fn test_registry_creation() {
940        let registry = ModelRegistry::new().unwrap();
941        let models = registry.list_models().await;
942        assert!(models.is_empty());
943    }
944
945    #[tokio::test]
946    async fn test_handle_model_load_failed_fails_activation_waiters() {
947        let registry = ModelRegistry::new().unwrap();
948        let canonical_id = "moondream/moondream3-preview".to_string();
949        let requested_id = canonical_id.clone();
950        let error = "Weight shard file not found".to_string();
951        let (tx, rx) = oneshot::channel();
952
953        {
954            let mut alias_cache = registry.alias_cache.write().await;
955            alias_cache.insert(requested_id.to_lowercase(), canonical_id.clone());
956        }
957
958        {
959            let mut entries = registry.entries.write().await;
960            let entry = entries.entry(canonical_id.clone()).or_default();
961            entry.state = ModelLoadState::Activating;
962            entry.activation_waiters.push(tx);
963        }
964
965        registry
966            .handle_model_load_failed(&json!({
967                "model_id": requested_id,
968                "error": error,
969            }))
970            .await;
971
972        assert_eq!(rx.await.unwrap(), Err(error.clone()));
973
974        let entries = registry.entries.read().await;
975        let entry = entries.get(&canonical_id).unwrap();
976        assert_eq!(entry.state, ModelLoadState::Failed);
977        assert_eq!(entry.error.as_deref(), Some(error.as_str()));
978        assert!(entry.activation_waiters.is_empty());
979    }
980
981    #[test]
982    fn test_model_load_state_display() {
983        assert_eq!(ModelLoadState::Idle.to_string(), "IDLE");
984        assert_eq!(ModelLoadState::Ready.to_string(), "READY");
985        assert_eq!(ModelLoadState::Failed.to_string(), "FAILED");
986    }
987}