Skip to main content

modelexpress_server/
services.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::registry::backend::ClaimOutcome;
5use crate::registry::state::RegistryManager;
6use modelexpress_common::{
7    cache::{CacheConfig, resolve_model_path},
8    constants, download,
9    grpc::{
10        api::{ApiRequest, ApiResponse, api_service_server::ApiService},
11        health::{HealthRequest, HealthResponse, health_service_server::HealthService},
12        model::{
13            DeleteModelRequest, DeleteModelResponse, FileChunk, ModelDownloadRequest,
14            ModelFileInfo, ModelFileList, ModelFileSelector, ModelFilesRequest,
15            ModelProvider as GrpcModelProvider, ModelStatusUpdate,
16            model_service_server::ModelService,
17        },
18    },
19    models::{ModelProvider, ModelStatus},
20};
21use std::{
22    collections::HashMap,
23    path::{Path, PathBuf},
24    sync::{Arc, Mutex},
25    time::SystemTime,
26};
27use tokio::io::AsyncReadExt;
28use tokio_stream::wrappers::ReceiverStream;
29use tonic::{Request, Response, Status};
30use tracing::{debug, error, info};
31
32static START_TIME: std::sync::OnceLock<SystemTime> = std::sync::OnceLock::new();
33
34/// Get the configured cache directory for model downloads
35fn get_server_cache_dir() -> Option<std::path::PathBuf> {
36    // Try to get cache configuration
37    if let Ok(config) = CacheConfig::discover() {
38        Some(config.local_path)
39    } else {
40        // Fall back to environment variable
41        std::env::var("HF_HUB_CACHE")
42            .ok()
43            .map(std::path::PathBuf::from)
44    }
45}
46
47/// Returns true if the model's files are present in the given cache directory. Used to
48/// guard against stale `DOWNLOADED` registry records that point at a cache entry which no
49/// longer exists on disk (e.g. left behind by a client-side `model clear`). When no cache
50/// directory is configured we cannot verify, so we assume the files are present to preserve
51/// existing behavior rather than loop re-downloading forever.
52async fn model_files_present(
53    cache_dir: Option<std::path::PathBuf>,
54    model_name: &str,
55    provider: ModelProvider,
56) -> bool {
57    let Some(cache_dir) = cache_dir else {
58        return true;
59    };
60    download::get_provider(provider)
61        .get_model_path(model_name, cache_dir)
62        .await
63        .is_ok()
64}
65
66/// Health service implementation
67#[derive(Debug, Default)]
68pub struct HealthServiceImpl;
69
70#[tonic::async_trait]
71impl HealthService for HealthServiceImpl {
72    async fn get_health(
73        &self,
74        _request: Request<HealthRequest>,
75    ) -> Result<Response<HealthResponse>, Status> {
76        let start_time = START_TIME.get_or_init(SystemTime::now);
77        let uptime = SystemTime::now()
78            .duration_since(*start_time)
79            .unwrap_or_default()
80            .as_secs();
81
82        let response = HealthResponse {
83            version: env!("CARGO_PKG_VERSION").to_string(),
84            status: "ok".to_string(),
85            uptime,
86        };
87
88        Ok(Response::new(response))
89    }
90}
91
92/// API service implementation
93#[derive(Debug, Default)]
94pub struct ApiServiceImpl;
95
96#[tonic::async_trait]
97impl ApiService for ApiServiceImpl {
98    async fn send_request(
99        &self,
100        request: Request<ApiRequest>,
101    ) -> Result<Response<ApiResponse>, Status> {
102        let api_request = request.into_inner();
103        info!("Received gRPC request: {:?}", api_request);
104
105        // Process the request based on the action
106        if api_request.action.as_str() == "ping" {
107            info!("Processing ping request");
108            let response_data = serde_json::json!({ "message": "pong" });
109            let data_bytes = serde_json::to_vec(&response_data)
110                .map_err(|e| Status::internal(format!("Serialization error: {e}")))?;
111
112            Ok(Response::new(ApiResponse {
113                success: true,
114                data: Some(data_bytes),
115                error: None,
116            }))
117        } else {
118            error!("Unknown action: {}", api_request.action);
119            Ok(Response::new(ApiResponse {
120                success: false,
121                data: None,
122                error: Some(format!("Unknown action: {}", api_request.action)),
123            }))
124        }
125    }
126}
127
128/// Model service implementation
129#[derive(Debug, Default)]
130pub struct ModelServiceImpl;
131
132/// Helper function to collect all files in a model directory recursively
133fn collect_model_files(
134    base_path: &Path,
135    current_path: &Path,
136    file_selector: Option<&ModelFileSelector>,
137) -> Vec<(PathBuf, u64)> {
138    let mut files = Vec::new();
139
140    if let Ok(entries) = std::fs::read_dir(current_path) {
141        for entry in entries.flatten() {
142            let path = entry.path();
143            if path.is_file() {
144                if let Ok(metadata) = std::fs::metadata(&path) {
145                    // Get relative path from base_path
146                    if let Ok(relative) = path.strip_prefix(base_path) {
147                        // Validate that the relative path does not contain any '..' components or is absolute
148                        let mut is_safe = true;
149                        for comp in relative.components() {
150                            use std::path::Component;
151                            match comp {
152                                Component::ParentDir
153                                | Component::RootDir
154                                | Component::Prefix(_) => {
155                                    is_safe = false;
156                                    break;
157                                }
158                                _ => {}
159                            }
160                        }
161                        if !is_safe {
162                            tracing::warn!(
163                                "Skipping potentially unsafe file path: {:?} (relative: {:?})",
164                                path,
165                                relative
166                            );
167                        } else if file_selector.is_none_or(|selector| {
168                            selector
169                                .paths
170                                .iter()
171                                .any(|selector_path| Path::new(selector_path) == relative)
172                        }) {
173                            files.push((relative.to_path_buf(), metadata.len()));
174                        }
175                    }
176                }
177            } else if path.is_dir() {
178                files.extend(collect_model_files(base_path, &path, file_selector));
179            }
180        }
181    }
182
183    files
184}
185
186fn ensure_selected_files_exist(
187    files: &[(PathBuf, u64)],
188    file_selector: Option<&ModelFileSelector>,
189) -> Result<(), String> {
190    let Some(selector) = file_selector else {
191        return Ok(());
192    };
193
194    if let Some(missing_path) = selector.paths.iter().find(|selector_path| {
195        !files
196            .iter()
197            .any(|(path, _)| Path::new(selector_path) == path.as_path())
198    }) {
199        Err(format!(
200            "Selected file not found in model directory: {missing_path}"
201        ))
202    } else {
203        Ok(())
204    }
205}
206
207#[tonic::async_trait]
208impl ModelService for ModelServiceImpl {
209    type EnsureModelDownloadedStream = ReceiverStream<Result<ModelStatusUpdate, Status>>;
210    type StreamModelFilesStream = ReceiverStream<Result<FileChunk, Status>>;
211
212    async fn ensure_model_downloaded(
213        &self,
214        request: Request<ModelDownloadRequest>,
215    ) -> Result<Response<Self::EnsureModelDownloadedStream>, Status> {
216        info!("Starting model download stream");
217        let model_request = request.into_inner();
218        let (tx, rx) = tokio::sync::mpsc::channel(4);
219
220        // Convert gRPC provider to our enum
221        let grpc_provider = GrpcModelProvider::try_from(model_request.provider).map_err(|_| {
222            Status::invalid_argument(format!(
223                "Invalid provider value: {}",
224                model_request.provider
225            ))
226        })?;
227        let provider = ModelProvider::from(grpc_provider);
228        let model_name = download::canonical_model_name(&model_request.model_name, provider)
229            .map_err(|e| Status::invalid_argument(e.to_string()))?;
230        let ignore_weights = model_request.ignore_weights;
231
232        // Spawn a task to handle the streaming download updates
233        tokio::spawn(async move {
234            let Some(tracker) = model_tracker() else {
235                let _ = tx
236                    .send(Err(Status::unavailable(
237                        "server startup incomplete: model tracker not initialized",
238                    )))
239                    .await;
240                return;
241            };
242            // Run the full claim + wait + retry flow. `ensure_model_downloaded` sends
243            // its own initial status update (based on the `ClaimOutcome` returned by the
244            // registry), so we don't do a pre-check here — a pre-check would either
245            // duplicate that update or, worse, emit `status=ERROR` on a model we're
246            // about to retry and trip the client-lib's terminal-error bailout before
247            // the retry completion broadcast arrives.
248            let final_status = tracker
249                .ensure_model_downloaded(&model_name, provider, &tx, ignore_weights)
250                .await;
251
252            // Send final status update
253            let final_update = ModelStatusUpdate {
254                model_name: model_name.clone(),
255                status: modelexpress_common::grpc::model::ModelStatus::from(final_status) as i32,
256                message: match final_status {
257                    ModelStatus::DOWNLOADED => {
258                        Some("Model download completed successfully".to_string())
259                    }
260                    ModelStatus::ERROR => Some("Model download failed".to_string()),
261                    ModelStatus::DOWNLOADING => Some("Download still in progress".to_string()),
262                },
263                provider: grpc_provider as i32,
264            };
265
266            let _ = tx.send(Ok(final_update)).await;
267        });
268
269        Ok(Response::new(ReceiverStream::new(rx)))
270    }
271
272    async fn stream_model_files(
273        &self,
274        request: Request<ModelFilesRequest>,
275    ) -> Result<Response<Self::StreamModelFilesStream>, Status> {
276        let files_request = request.into_inner();
277        let chunk_size = if files_request.chunk_size == 0 {
278            constants::DEFAULT_TRANSFER_CHUNK_SIZE
279        } else {
280            files_request.chunk_size as usize
281        };
282
283        // Convert gRPC provider to our enum
284        let grpc_provider = GrpcModelProvider::try_from(files_request.provider).map_err(|_| {
285            Status::invalid_argument(format!(
286                "Invalid provider value: {}",
287                files_request.provider
288            ))
289        })?;
290        let provider = ModelProvider::from(grpc_provider);
291        let model_name = download::canonical_model_name(&files_request.model_name, provider)
292            .map_err(|e| Status::invalid_argument(e.to_string()))?;
293        let provider_impl = download::get_provider(provider);
294
295        info!(
296            "Starting file stream for model: {} with chunk size: {} bytes",
297            model_name, chunk_size
298        );
299
300        // Get the cache directory
301        let cache_dir = get_server_cache_dir()
302            .ok_or_else(|| Status::internal("Server cache directory not configured"))?;
303
304        // Get the model path using the provider from the request
305        let model_path = provider_impl
306            .get_model_path(&model_name, cache_dir.clone())
307            .await
308            .map_err(|e| Status::not_found(format!("Model not found: {e}")))?;
309
310        debug!("Model path resolved to: {:?}", model_path);
311
312        let commit_hash = if provider == ModelProvider::HuggingFace {
313            model_path
314                .file_name()
315                .and_then(|name| name.to_str())
316                .map(String::from)
317        } else {
318            None
319        };
320
321        if provider == ModelProvider::HuggingFace && commit_hash.is_none() {
322            return Err(Status::internal(
323                "Resolved Hugging Face model path did not contain a revision",
324            ));
325        }
326
327        let expected_model_path =
328            resolve_model_path(&cache_dir, provider, &model_name, commit_hash.as_deref()).map_err(
329                |e| Status::internal(format!("Failed to resolve expected cache layout: {e}")),
330            )?;
331
332        if model_path != expected_model_path {
333            error!(
334                "Resolved model path '{}' does not match expected cache layout '{}' for model '{}'",
335                model_path.display(),
336                expected_model_path.display(),
337                model_name
338            );
339            return Err(Status::internal(
340                "Resolved model path does not match expected cache layout",
341            ));
342        }
343
344        // Collect all files to stream
345        let files = collect_model_files(
346            &model_path,
347            &model_path,
348            files_request.file_selector.as_ref(),
349        );
350        ensure_selected_files_exist(&files, files_request.file_selector.as_ref())
351            .map_err(Status::not_found)?;
352
353        if files.is_empty() {
354            return Err(Status::not_found("No files found in model directory"));
355        }
356
357        let total_files = files.len();
358        info!(
359            "Found {} files to stream for model {}",
360            total_files, model_name
361        );
362
363        let (tx, rx) = tokio::sync::mpsc::channel(16);
364
365        // Spawn a task to stream files
366        tokio::spawn(async move {
367            // Allocate buffer once and reuse across all files
368            let mut buffer = vec![0u8; chunk_size];
369            let mut is_first_chunk = true;
370
371            for (file_idx, (relative_path, total_size)) in files.iter().enumerate() {
372                let file_path = model_path.join(relative_path);
373                let is_last_file = file_idx == total_files.saturating_sub(1);
374
375                debug!("Streaming file: {:?} ({} bytes)", relative_path, total_size);
376
377                // Open the file
378                let file = match tokio::fs::File::open(&file_path).await {
379                    Ok(f) => f,
380                    Err(e) => {
381                        error!("Failed to open file {:?}: {}", file_path, e);
382                        let _ = tx
383                            .send(Err(Status::internal(format!("Failed to open file: {e}"))))
384                            .await;
385                        return;
386                    }
387                };
388
389                let mut reader = tokio::io::BufReader::new(file);
390                let mut offset: u64 = 0;
391
392                if *total_size == 0 {
393                    let first_chunk = std::mem::replace(&mut is_first_chunk, false);
394                    let chunk = FileChunk {
395                        relative_path: relative_path.to_string_lossy().to_string(),
396                        data: Vec::new(),
397                        offset: 0,
398                        total_size: 0,
399                        is_last_chunk: true,
400                        is_last_file,
401                        commit_hash: if first_chunk {
402                            commit_hash.clone()
403                        } else {
404                            None
405                        },
406                    };
407
408                    if tx.send(Ok(chunk)).await.is_err() {
409                        debug!("Client disconnected during file stream");
410                        return;
411                    }
412
413                    continue;
414                }
415
416                loop {
417                    let bytes_read = match reader.read(&mut buffer).await {
418                        Ok(0) => break, // EOF
419                        Ok(n) => n,
420                        Err(e) => {
421                            error!("Failed to read file {:?}: {}", file_path, e);
422                            let _ = tx
423                                .send(Err(Status::internal(format!("Failed to read file: {e}"))))
424                                .await;
425                            return;
426                        }
427                    };
428
429                    let is_last_chunk = offset.saturating_add(bytes_read as u64) >= *total_size;
430
431                    let first_chunk = std::mem::replace(&mut is_first_chunk, false);
432
433                    let chunk = FileChunk {
434                        relative_path: relative_path.to_string_lossy().to_string(),
435                        data: buffer[..bytes_read].to_vec(),
436                        offset,
437                        total_size: *total_size,
438                        is_last_chunk,
439                        is_last_file: is_last_file && is_last_chunk,
440                        commit_hash: if first_chunk {
441                            commit_hash.clone()
442                        } else {
443                            None
444                        },
445                    };
446
447                    if tx.send(Ok(chunk)).await.is_err() {
448                        debug!("Client disconnected during file stream");
449                        return;
450                    }
451
452                    offset = offset.saturating_add(bytes_read as u64);
453                }
454            }
455
456            info!("File streaming completed for model");
457        });
458
459        Ok(Response::new(ReceiverStream::new(rx)))
460    }
461
462    async fn list_model_files(
463        &self,
464        request: Request<ModelFilesRequest>,
465    ) -> Result<Response<ModelFileList>, Status> {
466        let files_request = request.into_inner();
467
468        // Convert gRPC provider to our enum
469        let grpc_provider = GrpcModelProvider::try_from(files_request.provider).map_err(|_| {
470            Status::invalid_argument(format!(
471                "Invalid provider value: {}",
472                files_request.provider
473            ))
474        })?;
475        let provider = ModelProvider::from(grpc_provider);
476        let model_name = download::canonical_model_name(&files_request.model_name, provider)
477            .map_err(|e| Status::invalid_argument(e.to_string()))?;
478        let provider_impl = download::get_provider(provider);
479
480        info!("Listing files for model: {}", model_name);
481
482        // Get the cache directory
483        let cache_dir = get_server_cache_dir()
484            .ok_or_else(|| Status::internal("Server cache directory not configured"))?;
485
486        // Get the model path using the provider from the request
487        let model_path = provider_impl
488            .get_model_path(&model_name, cache_dir)
489            .await
490            .map_err(|e| Status::not_found(format!("Model not found: {e}")))?;
491
492        // Collect all files
493        let files = collect_model_files(
494            &model_path,
495            &model_path,
496            files_request.file_selector.as_ref(),
497        );
498        ensure_selected_files_exist(&files, files_request.file_selector.as_ref())
499            .map_err(Status::not_found)?;
500
501        let file_infos: Vec<ModelFileInfo> = files
502            .iter()
503            .map(|(path, size)| ModelFileInfo {
504                relative_path: path.to_string_lossy().to_string(),
505                size: *size,
506            })
507            .collect();
508
509        let total_size: u64 = files.iter().map(|(_, size)| size).sum();
510
511        Ok(Response::new(ModelFileList {
512            model_name,
513            files: file_infos,
514            total_size,
515        }))
516    }
517
518    async fn delete_model(
519        &self,
520        request: Request<DeleteModelRequest>,
521    ) -> Result<Response<DeleteModelResponse>, Status> {
522        let delete_request = request.into_inner();
523
524        let grpc_provider = GrpcModelProvider::try_from(delete_request.provider).map_err(|_| {
525            Status::invalid_argument(format!(
526                "Invalid provider value: {}",
527                delete_request.provider
528            ))
529        })?;
530        let provider = ModelProvider::from(grpc_provider);
531        let model_name = download::canonical_model_name(&delete_request.model_name, provider)
532            .map_err(|e| Status::invalid_argument(e.to_string()))?;
533
534        let Some(tracker) = model_tracker() else {
535            return Err(Status::unavailable(
536                "server startup incomplete: model tracker not initialized",
537            ));
538        };
539
540        tracker.delete_status(&model_name).await;
541        info!("Deleted registry record for model '{model_name}'");
542
543        Ok(Response::new(DeleteModelResponse {
544            success: true,
545            message: Some(format!("Model '{model_name}' removed from registry")),
546        }))
547    }
548}
549
550/// Type alias for the complex waiting channels type
551type WaitingChannels =
552    Arc<Mutex<HashMap<String, Vec<tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>>>>>;
553
554/// Tracks the status of model downloads through the distributed registry backend.
555#[derive(Clone)]
556pub struct ModelDownloadTracker {
557    /// Distributed registry (Redis today, K8s CRDs in a follow-up).
558    registry: Arc<RegistryManager>,
559    /// Maps model names to list of channels waiting for updates on this server replica.
560    waiting_channels: WaitingChannels,
561}
562
563impl ModelDownloadTracker {
564    pub fn new(registry: Arc<RegistryManager>) -> Self {
565        Self {
566            registry,
567            waiting_channels: Arc::new(Mutex::new(HashMap::new())),
568        }
569    }
570
571    /// Gets the status of a model from the registry, bumping `last_used_at` on hit.
572    /// Returns None on lookup failure (error logged) or unknown model.
573    pub async fn get_status(&self, model_name: &str) -> Option<ModelStatus> {
574        match self.registry.get_status(model_name).await {
575            Ok(Some(status)) => {
576                if let Err(e) = self.registry.touch_model(model_name).await {
577                    error!("Failed to touch model {model_name}: {e}");
578                }
579                Some(status)
580            }
581            Ok(None) => None,
582            Err(e) => {
583                error!("Failed to get model status from registry: {e}");
584                None
585            }
586        }
587    }
588
589    /// Sets the status of a model and notifies all waiting channels on this replica.
590    pub async fn set_status_and_notify(
591        &self,
592        model_name: String,
593        status: ModelStatus,
594        provider: ModelProvider,
595        message: Option<String>,
596    ) {
597        if let Err(e) = self
598            .registry
599            .set_status(&model_name, provider, status, message.clone())
600            .await
601        {
602            error!("Failed to update model status in registry: {e}");
603            return;
604        }
605
606        let mut waiting = match self.waiting_channels.lock() {
607            Ok(guard) => guard,
608            Err(poisoned) => {
609                error!("Waiting channels mutex is poisoned, recovering");
610                poisoned.into_inner()
611            }
612        };
613        if let Some(channels) = waiting.get(&model_name) {
614            let update = ModelStatusUpdate {
615                model_name: model_name.clone(),
616                status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
617                message,
618                provider: GrpcModelProvider::from(provider) as i32,
619            };
620            for channel in channels {
621                let _ = channel.try_send(Ok(update.clone()));
622            }
623            if status == ModelStatus::DOWNLOADED || status == ModelStatus::ERROR {
624                waiting.remove(&model_name);
625            }
626        }
627    }
628
629    /// Sets the status of a model (no message), notifying waiters.
630    pub async fn set_status(
631        &self,
632        model_name: String,
633        status: ModelStatus,
634        provider: ModelProvider,
635    ) {
636        self.set_status_and_notify(model_name, status, provider, None)
637            .await;
638    }
639
640    /// Adds a channel that wants updates on a specific model (server-replica-local).
641    pub fn add_waiting_channel(
642        &self,
643        model_name: &str,
644        tx: tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
645    ) {
646        let mut waiting = match self.waiting_channels.lock() {
647            Ok(guard) => guard,
648            Err(poisoned) => {
649                error!("Waiting channels mutex is poisoned, recovering");
650                poisoned.into_inner()
651            }
652        };
653        waiting.entry(model_name.to_string()).or_default().push(tx);
654    }
655
656    /// Deletes a model record from the registry and clears local waiters.
657    pub async fn delete_status(&self, model_name: &str) {
658        if let Err(e) = self.registry.delete_model(model_name).await {
659            error!("Failed to delete model from registry: {e}");
660        }
661        let mut waiting = match self.waiting_channels.lock() {
662            Ok(guard) => guard,
663            Err(poisoned) => {
664                error!("Waiting channels mutex is poisoned, recovering");
665                poisoned.into_inner()
666            }
667        };
668        waiting.remove(model_name);
669    }
670
671    /// Spawn a background task that actually downloads the model, updating the tracker on
672    /// success or failure. Extracted here so the claim and retry paths share the code.
673    fn spawn_download_task(
674        &self,
675        model_name: String,
676        provider: ModelProvider,
677        ignore_weights: bool,
678        retry: bool,
679    ) {
680        let tracker = self.clone();
681        tokio::spawn(async move {
682            let cache_dir = get_server_cache_dir();
683            match download::download_model(&model_name, provider, cache_dir, ignore_weights).await {
684                Ok(_path) => {
685                    tracker
686                        .set_status_and_notify(
687                            model_name,
688                            ModelStatus::DOWNLOADED,
689                            provider,
690                            Some("Model download completed successfully".to_string()),
691                        )
692                        .await;
693                }
694                Err(e) => {
695                    if retry {
696                        error!("Failed to download model {model_name} on retry: {e}");
697                    } else {
698                        error!("Failed to download model {model_name}: {e}");
699                    }
700                    let msg = if retry {
701                        format!("Download failed on retry: {e}")
702                    } else {
703                        format!("Download failed: {e}")
704                    };
705                    tracker
706                        .set_status_and_notify(model_name, ModelStatus::ERROR, provider, Some(msg))
707                        .await;
708                }
709            }
710        });
711    }
712
713    /// Initiates a download for a model and streams status updates.
714    pub async fn ensure_model_downloaded(
715        &self,
716        model_name: &str,
717        provider: ModelProvider,
718        tx: &tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
719        ignore_weights: bool,
720    ) -> ModelStatus {
721        // Atomically try to claim this model for download. The `ClaimOutcome` tells us
722        // whether THIS replica won the claim or is observing someone else's claim —
723        // status alone (`DOWNLOADING`) can't distinguish those cases across replicas.
724        // A claim may report an existing `DOWNLOADED` record whose files no longer exist
725        // on disk (e.g. after a client-side `model clear` that only removed local files).
726        // When that happens we drop the stale record and re-claim once, so the download
727        // path runs instead of returning a false success. Bounded to two attempts to
728        // avoid looping if the delete or a concurrent re-claim keeps the record around.
729        const MAX_CLAIM_ATTEMPTS: usize = 2;
730        let mut attempt: usize = 0;
731        let (status, is_owner) = loop {
732            attempt = attempt.saturating_add(1);
733            match self
734                .registry
735                .try_claim_for_download(model_name, provider)
736                .await
737            {
738                Ok(ClaimOutcome::Claimed) => break (ModelStatus::DOWNLOADING, true),
739                Ok(ClaimOutcome::AlreadyExists(existing)) => {
740                    if existing == ModelStatus::DOWNLOADED
741                        && attempt < MAX_CLAIM_ATTEMPTS
742                        && !model_files_present(get_server_cache_dir(), model_name, provider).await
743                    {
744                        error!(
745                            "Registry reports model '{model_name}' as DOWNLOADED but its files \
746                             are missing from the cache; clearing the stale record and \
747                             re-downloading"
748                        );
749                        self.delete_status(model_name).await;
750                        continue;
751                    }
752                    break (existing, false);
753                }
754                Err(e) => {
755                    error!("Failed to claim model for download: {e}");
756                    let error_update = ModelStatusUpdate {
757                        model_name: model_name.to_string(),
758                        status: modelexpress_common::grpc::model::ModelStatus::from(
759                            ModelStatus::ERROR,
760                        ) as i32,
761                        message: Some("Registry error occurred".to_string()),
762                        provider: GrpcModelProvider::from(provider) as i32,
763                    };
764                    let _ = tx.send(Ok(error_update)).await;
765                    return ModelStatus::ERROR;
766                }
767            }
768        };
769
770        // If we observed a previous ERROR, attempt the ERROR -> DOWNLOADING CAS up front.
771        // Only the CAS winner spawns the retry download; observers fall through to the
772        // wait loop. Doing this *before* the initial stream update keeps the reported
773        // status honest: after this block, the record is DOWNLOADING (the record may
774        // briefly have been ERROR, but the client should wait, not bail).
775        let (effective_status, is_retry_owner) = if status == ModelStatus::ERROR {
776            let won = match self
777                .registry
778                .try_reset_error_for_retry(model_name, provider)
779                .await
780            {
781                Ok(won) => won,
782                Err(e) => {
783                    error!("Failed to CAS status for retry: {e}");
784                    let _ = tx
785                        .send(Ok(ModelStatusUpdate {
786                            model_name: model_name.to_string(),
787                            status: modelexpress_common::grpc::model::ModelStatus::from(
788                                ModelStatus::ERROR,
789                            ) as i32,
790                            message: Some("Registry error occurred during retry".to_string()),
791                            provider: GrpcModelProvider::from(provider) as i32,
792                        }))
793                        .await;
794                    return ModelStatus::ERROR;
795                }
796            };
797            (ModelStatus::DOWNLOADING, won)
798        } else {
799            (status, false)
800        };
801
802        let update = ModelStatusUpdate {
803            model_name: model_name.to_string(),
804            status: modelexpress_common::grpc::model::ModelStatus::from(effective_status) as i32,
805            message: match (status, effective_status) {
806                (_, ModelStatus::DOWNLOADED) => Some("Model already downloaded".to_string()),
807                (ModelStatus::ERROR, _) => Some("Previous download failed, retrying".to_string()),
808                (_, ModelStatus::DOWNLOADING) => Some("Model download in progress".to_string()),
809                // effective can never be ERROR: ERROR observations are CAS'd above.
810                (_, ModelStatus::ERROR) => Some("Download error".to_string()),
811            },
812            provider: GrpcModelProvider::from(provider) as i32,
813        };
814        let _ = tx.send(Ok(update)).await;
815
816        if effective_status == ModelStatus::DOWNLOADING {
817            // Every caller is a waiter — whether we own the download or not, we still
818            // need a channel so the completion broadcast reaches this stream.
819            self.add_waiting_channel(model_name, tx.clone());
820
821            // Spawn the download only on the replica that won the claim (fresh
822            // download) or won the ERROR-retry CAS. Everyone else waits.
823            if is_owner || is_retry_owner {
824                let retry = status == ModelStatus::ERROR;
825                self.spawn_download_task(model_name.to_string(), provider, ignore_weights, retry);
826            }
827
828            // Wait for completion by polling the registry.
829            loop {
830                tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
831                if let Some(current_status) = self.get_status(model_name).await
832                    && current_status != ModelStatus::DOWNLOADING
833                {
834                    return current_status;
835                }
836            }
837        }
838
839        effective_status
840    }
841}
842
843/// Global model download tracker. Initialized by `main.rs` after the registry is
844/// connected, then read via `model_tracker()` from gRPC service handlers.
845static MODEL_TRACKER: std::sync::OnceLock<Arc<ModelDownloadTracker>> = std::sync::OnceLock::new();
846
847/// Initialize the process-wide tracker. Called exactly once during server startup.
848/// Returns an error if called twice — main.rs propagates it as a startup failure.
849pub fn init_model_tracker(
850    tracker: Arc<ModelDownloadTracker>,
851) -> Result<Arc<ModelDownloadTracker>, &'static str> {
852    MODEL_TRACKER
853        .set(tracker.clone())
854        .map(|()| tracker)
855        .map_err(|_| "init_model_tracker called more than once")
856}
857
858/// Read the process-wide tracker. Returns `None` if `init_model_tracker` hasn't run yet,
859/// letting gRPC handlers return `Status::unavailable` instead of crashing the server.
860pub fn model_tracker() -> Option<&'static Arc<ModelDownloadTracker>> {
861    MODEL_TRACKER.get()
862}
863
864#[cfg(test)]
865#[allow(clippy::expect_used)]
866mod tests {
867    use super::*;
868    use modelexpress_common::grpc::{api::ApiRequest, health::HealthRequest};
869    use modelexpress_common::test_support::{EnvVarGuard, acquire_env_mutex};
870    use tempfile::TempDir;
871    use tokio_stream::StreamExt;
872    use tonic::Request;
873
874    #[tokio::test]
875    async fn test_health_service() {
876        let service = HealthServiceImpl;
877        let request = Request::new(HealthRequest {});
878
879        let response = service.get_health(request).await;
880        assert!(response.is_ok());
881
882        let health_response = response.expect("Health response should be ok").into_inner();
883        assert_eq!(health_response.version, env!("CARGO_PKG_VERSION"));
884        assert_eq!(health_response.status, "ok");
885        // uptime is u64, always >= 0, so just verify it exists
886        let _uptime = health_response.uptime;
887    }
888
889    #[tokio::test]
890    async fn test_api_service_ping() {
891        let service = ApiServiceImpl;
892        let request = Request::new(ApiRequest {
893            id: "test-id".to_string(),
894            action: "ping".to_string(),
895            payload: None,
896        });
897
898        let response = service.send_request(request).await;
899        assert!(response.is_ok());
900
901        let api_response = response.expect("API response should be ok").into_inner();
902        assert!(api_response.success);
903        assert!(api_response.data.is_some());
904        assert!(api_response.error.is_none());
905
906        // Check that the response data contains "pong"
907        let data_bytes = api_response.data.expect("Data should be present");
908        let data: serde_json::Value =
909            serde_json::from_slice(&data_bytes).expect("Data should be valid JSON");
910        assert_eq!(data["message"], "pong");
911    }
912
913    #[tokio::test]
914    async fn test_api_service_unknown_action() {
915        let service = ApiServiceImpl;
916        let request = Request::new(ApiRequest {
917            id: "test-id".to_string(),
918            action: "unknown-action".to_string(),
919            payload: None,
920        });
921
922        let response = service.send_request(request).await;
923        assert!(response.is_ok());
924
925        let api_response = response.expect("API response should be ok").into_inner();
926        assert!(!api_response.success);
927        assert!(api_response.data.is_none());
928        assert!(api_response.error.is_some());
929
930        let error_message = api_response.error.expect("Error should be present");
931        assert!(error_message.contains("Unknown action"));
932    }
933
934    // Tracker tests exercise the ModelDownloadTracker's interaction with a mocked
935    // RegistryBackend. The full backend semantics (claim atomicity, LRU ordering, etc.) are
936    // covered by the per-backend unit tests in modelexpress_server::registry and by the
937    // testcontainers-based integration tests.
938    fn tracker_with_mock(
939        mock: crate::registry::backend::MockRegistryBackend,
940    ) -> ModelDownloadTracker {
941        let registry = Arc::new(RegistryManager::with_backend(Arc::new(mock)));
942        ModelDownloadTracker::new(registry)
943    }
944
945    #[tokio::test]
946    async fn test_tracker_get_status_missing_returns_none() {
947        let mut mock = crate::registry::backend::MockRegistryBackend::new();
948        mock.expect_get_status().once().returning(|_| Ok(None));
949        // touch is NOT called when status is missing
950        let tracker = tracker_with_mock(mock);
951        assert!(tracker.get_status("unknown").await.is_none());
952    }
953
954    #[tokio::test]
955    async fn test_tracker_get_status_hit_bumps_last_used_at() {
956        let mut mock = crate::registry::backend::MockRegistryBackend::new();
957        mock.expect_get_status()
958            .once()
959            .returning(|_| Ok(Some(ModelStatus::DOWNLOADED)));
960        mock.expect_touch_model().once().returning(|_| Ok(()));
961        let tracker = tracker_with_mock(mock);
962        assert_eq!(tracker.get_status("m").await, Some(ModelStatus::DOWNLOADED));
963    }
964
965    #[tokio::test]
966    async fn test_tracker_set_status_notifies_waiting_channel() {
967        let mut mock = crate::registry::backend::MockRegistryBackend::new();
968        mock.expect_set_status()
969            .once()
970            .returning(|_, _, _, _| Ok(()));
971        let tracker = tracker_with_mock(mock);
972
973        let (tx, mut rx) = tokio::sync::mpsc::channel(4);
974        tracker.add_waiting_channel("m", tx);
975
976        tracker
977            .set_status_and_notify(
978                "m".to_string(),
979                ModelStatus::DOWNLOADED,
980                ModelProvider::HuggingFace,
981                Some("done".to_string()),
982            )
983            .await;
984
985        let update = rx.recv().await.expect("waiter should receive update");
986        let update = update.expect("notify should send Ok");
987        assert_eq!(update.model_name, "m");
988        assert_eq!(
989            update.status,
990            modelexpress_common::grpc::model::ModelStatus::Downloaded as i32
991        );
992        assert_eq!(update.message.as_deref(), Some("done"));
993
994        // Terminal status removes waiters.
995        let waiters = tracker
996            .waiting_channels
997            .lock()
998            .expect("waiters lock")
999            .get("m")
1000            .map_or(0, std::vec::Vec::len);
1001        assert_eq!(waiters, 0);
1002    }
1003
1004    #[tokio::test]
1005    async fn test_tracker_delete_status_clears_backend_and_waiters() {
1006        let mut mock = crate::registry::backend::MockRegistryBackend::new();
1007        mock.expect_delete_model().once().returning(|_| Ok(()));
1008        let tracker = tracker_with_mock(mock);
1009
1010        let (tx, _rx) = tokio::sync::mpsc::channel(1);
1011        tracker.add_waiting_channel("m", tx);
1012        tracker.delete_status("m").await;
1013
1014        let waiters = tracker
1015            .waiting_channels
1016            .lock()
1017            .expect("waiters lock")
1018            .contains_key("m");
1019        assert!(!waiters);
1020    }
1021
1022    #[tokio::test]
1023    async fn test_tracker_set_status_delegates_without_message() {
1024        let mut mock = crate::registry::backend::MockRegistryBackend::new();
1025        mock.expect_set_status()
1026            .withf(|_, _, status, msg| *status == ModelStatus::DOWNLOADING && msg.is_none())
1027            .once()
1028            .returning(|_, _, _, _| Ok(()));
1029        let tracker = tracker_with_mock(mock);
1030        tracker
1031            .set_status(
1032                "m".to_string(),
1033                ModelStatus::DOWNLOADING,
1034                ModelProvider::HuggingFace,
1035            )
1036            .await;
1037    }
1038
1039    #[tokio::test]
1040    async fn test_tracker_error_status_clears_waiters() {
1041        let mut mock = crate::registry::backend::MockRegistryBackend::new();
1042        mock.expect_set_status()
1043            .once()
1044            .returning(|_, _, _, _| Ok(()));
1045        let tracker = tracker_with_mock(mock);
1046        let (tx, _rx) = tokio::sync::mpsc::channel(1);
1047        tracker.add_waiting_channel("m", tx);
1048        tracker
1049            .set_status_and_notify(
1050                "m".to_string(),
1051                ModelStatus::ERROR,
1052                ModelProvider::HuggingFace,
1053                Some("fail".to_string()),
1054            )
1055            .await;
1056        let waiters = tracker
1057            .waiting_channels
1058            .lock()
1059            .expect("waiters lock")
1060            .get("m")
1061            .map_or(0, std::vec::Vec::len);
1062        assert_eq!(waiters, 0, "ERROR is terminal, waiters must be cleared");
1063    }
1064
1065    #[tokio::test]
1066    async fn test_tracker_downloading_status_keeps_waiters() {
1067        let mut mock = crate::registry::backend::MockRegistryBackend::new();
1068        mock.expect_set_status()
1069            .once()
1070            .returning(|_, _, _, _| Ok(()));
1071        let tracker = tracker_with_mock(mock);
1072        let (tx, _rx) = tokio::sync::mpsc::channel(1);
1073        tracker.add_waiting_channel("m", tx);
1074        tracker
1075            .set_status_and_notify(
1076                "m".to_string(),
1077                ModelStatus::DOWNLOADING,
1078                ModelProvider::HuggingFace,
1079                None,
1080            )
1081            .await;
1082        let waiters = tracker
1083            .waiting_channels
1084            .lock()
1085            .expect("waiters lock")
1086            .get("m")
1087            .map_or(0, std::vec::Vec::len);
1088        assert_eq!(
1089            waiters, 1,
1090            "DOWNLOADING is non-terminal, waiter must remain"
1091        );
1092    }
1093
1094    #[tokio::test]
1095    async fn test_tracker_set_status_swallows_backend_error() {
1096        let mut mock = crate::registry::backend::MockRegistryBackend::new();
1097        mock.expect_set_status()
1098            .once()
1099            .returning(|_, _, _, _| Err("redis down".into()));
1100        let tracker = tracker_with_mock(mock);
1101        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
1102        tracker.add_waiting_channel("m", tx);
1103        // Error is logged but set_status_and_notify returns ()
1104        tracker
1105            .set_status_and_notify(
1106                "m".to_string(),
1107                ModelStatus::DOWNLOADED,
1108                ModelProvider::HuggingFace,
1109                None,
1110            )
1111            .await;
1112        // Nothing should be notified on the channel because set_status failed early.
1113        assert!(
1114            rx.try_recv().is_err(),
1115            "waiter shouldn't receive on backend error"
1116        );
1117    }
1118
1119    #[tokio::test]
1120    async fn test_model_tracker_uninitialized_returns_none() {
1121        // The process-wide MODEL_TRACKER hasn't been init'd in tests, so
1122        // model_tracker() returns None — the service layer uses this to respond with
1123        // Status::unavailable rather than panic.
1124        assert!(model_tracker().is_none());
1125    }
1126
1127    #[test]
1128    fn test_collect_model_files_empty_dir() {
1129        let temp_dir = TempDir::new().expect("Failed to create temp dir");
1130        let files = collect_model_files(temp_dir.path(), temp_dir.path(), None);
1131        assert!(files.is_empty());
1132    }
1133
1134    #[test]
1135    fn test_collect_model_files_with_files() {
1136        let temp_dir = TempDir::new().expect("Failed to create temp dir");
1137
1138        // Create some test files
1139        let file1_path = temp_dir.path().join("config.json");
1140        std::fs::write(&file1_path, r#"{"test": "data"}"#).expect("Failed to write file1");
1141
1142        let file2_path = temp_dir.path().join("model.bin");
1143        std::fs::write(&file2_path, vec![0u8; 100]).expect("Failed to write file2");
1144
1145        let files = collect_model_files(temp_dir.path(), temp_dir.path(), None);
1146
1147        assert_eq!(files.len(), 2);
1148
1149        // Check file sizes
1150        let total_size: u64 = files.iter().map(|(_, size)| size).sum();
1151        assert!(total_size > 0);
1152
1153        // Check that relative paths are correct
1154        let paths: Vec<_> = files
1155            .iter()
1156            .map(|(p, _)| p.to_string_lossy().to_string())
1157            .collect();
1158        assert!(paths.contains(&"config.json".to_string()));
1159        assert!(paths.contains(&"model.bin".to_string()));
1160    }
1161
1162    #[test]
1163    fn test_collect_model_files_nested() {
1164        let temp_dir = TempDir::new().expect("Failed to create temp dir");
1165
1166        // Create nested directory structure
1167        let subdir = temp_dir.path().join("subdir");
1168        std::fs::create_dir(&subdir).expect("Failed to create subdir");
1169
1170        let file1_path = temp_dir.path().join("root_file.txt");
1171        std::fs::write(&file1_path, "root content").expect("Failed to write file1");
1172
1173        let file2_path = subdir.join("nested_file.txt");
1174        std::fs::write(&file2_path, "nested content").expect("Failed to write file2");
1175
1176        let files = collect_model_files(temp_dir.path(), temp_dir.path(), None);
1177
1178        assert_eq!(files.len(), 2);
1179
1180        // Check that nested path is correct
1181        let paths: Vec<_> = files
1182            .iter()
1183            .map(|(p, _)| p.to_string_lossy().to_string())
1184            .collect();
1185        assert!(paths.iter().any(|p| p.contains("nested_file")));
1186    }
1187
1188    #[test]
1189    fn test_collect_model_files_with_selector_filters_exact_paths() {
1190        let temp_dir = TempDir::new().expect("Failed to create temp dir");
1191        let subdir = temp_dir.path().join("subdir");
1192        std::fs::create_dir(&subdir).expect("Failed to create subdir");
1193        std::fs::write(temp_dir.path().join("config.json"), "{}").expect("Failed to write config");
1194        std::fs::write(temp_dir.path().join("model.bin"), vec![0u8; 100])
1195            .expect("Failed to write model");
1196        std::fs::write(temp_dir.path().join("ignored.txt"), "ignore")
1197            .expect("Failed to write ignored");
1198        std::fs::write(subdir.join("nested.txt"), "nested").expect("Failed to write nested");
1199
1200        let selector = ModelFileSelector {
1201            paths: vec!["config.json".to_string(), "subdir/nested.txt".to_string()],
1202        };
1203        let files = collect_model_files(temp_dir.path(), temp_dir.path(), Some(&selector));
1204
1205        let mut paths: Vec<_> = files
1206            .iter()
1207            .map(|(p, _)| p.to_string_lossy().to_string())
1208            .collect();
1209        paths.sort();
1210        assert_eq!(
1211            paths,
1212            vec!["config.json".to_string(), "subdir/nested.txt".to_string()]
1213        );
1214    }
1215
1216    #[test]
1217    fn test_collect_model_files_with_selector_empty_and_nonmatching_paths() {
1218        let temp_dir = TempDir::new().expect("Failed to create temp dir");
1219        std::fs::write(temp_dir.path().join("config.json"), "{}").expect("Failed to write config");
1220
1221        let empty_selector = ModelFileSelector { paths: vec![] };
1222        assert!(
1223            collect_model_files(temp_dir.path(), temp_dir.path(), Some(&empty_selector)).is_empty()
1224        );
1225
1226        let nonmatching_selector = ModelFileSelector {
1227            paths: vec!["missing.json".to_string(), "../config.json".to_string()],
1228        };
1229        assert!(
1230            collect_model_files(
1231                temp_dir.path(),
1232                temp_dir.path(),
1233                Some(&nonmatching_selector)
1234            )
1235            .is_empty()
1236        );
1237    }
1238
1239    #[tokio::test]
1240    #[allow(clippy::await_holding_lock)]
1241    async fn test_list_model_files_hf_honors_file_selector() {
1242        let env_lock = acquire_env_mutex();
1243        let temp_dir = TempDir::new().expect("Failed to create temp dir");
1244        let _cache_dir_guard = EnvVarGuard::set(
1245            &env_lock,
1246            "MODEL_EXPRESS_CACHE_DIRECTORY",
1247            temp_dir.path().to_str().expect("Expected temp dir path"),
1248        );
1249        let _offline_guard = EnvVarGuard::set(&env_lock, "HF_HUB_OFFLINE", "1");
1250
1251        let model_dir = temp_dir.path().join("models--test--model/snapshots/abc123");
1252        std::fs::create_dir_all(model_dir.join("subdir")).expect("Failed to create model dir");
1253        std::fs::write(model_dir.join("config.json"), br#"{"model":"test"}"#)
1254            .expect("Failed to write config");
1255        std::fs::write(model_dir.join("model.bin"), vec![0u8; 100]).expect("Failed to write model");
1256        std::fs::write(model_dir.join("subdir/nested.txt"), b"nested")
1257            .expect("Failed to write nested");
1258
1259        let service = ModelServiceImpl;
1260        let request = Request::new(ModelFilesRequest {
1261            model_name: "test/model".to_string(),
1262            provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1263            chunk_size: 0,
1264            file_selector: Some(ModelFileSelector {
1265                paths: vec!["config.json".to_string(), "subdir/nested.txt".to_string()],
1266            }),
1267        });
1268
1269        let response = service
1270            .list_model_files(request)
1271            .await
1272            .expect("Expected file list")
1273            .into_inner();
1274        let mut paths: Vec<_> = response
1275            .files
1276            .iter()
1277            .map(|file| file.relative_path.clone())
1278            .collect();
1279        paths.sort();
1280
1281        assert_eq!(
1282            paths,
1283            vec!["config.json".to_string(), "subdir/nested.txt".to_string()]
1284        );
1285        assert_eq!(
1286            response.total_size,
1287            br#"{"model":"test"}"#.len() as u64 + b"nested".len() as u64
1288        );
1289    }
1290
1291    #[tokio::test]
1292    #[allow(clippy::await_holding_lock)]
1293    async fn test_model_files_present_reflects_disk_state() {
1294        let env_lock = acquire_env_mutex();
1295        let _offline_guard = EnvVarGuard::set(&env_lock, "HF_HUB_OFFLINE", "1");
1296        let temp_dir = TempDir::new().expect("Failed to create temp dir");
1297        let cache_dir = temp_dir.path().to_path_buf();
1298
1299        // No files on disk: a stale DOWNLOADED record must not be honored.
1300        assert!(
1301            !model_files_present(
1302                Some(cache_dir.clone()),
1303                "test/model",
1304                ModelProvider::HuggingFace
1305            )
1306            .await
1307        );
1308
1309        // Once the snapshot exists, the cache hit is real.
1310        let model_dir = cache_dir.join("models--test--model/snapshots/abc123");
1311        std::fs::create_dir_all(&model_dir).expect("Failed to create model dir");
1312        std::fs::write(model_dir.join("config.json"), b"{}").expect("Failed to write config");
1313        assert!(
1314            model_files_present(Some(cache_dir), "test/model", ModelProvider::HuggingFace).await
1315        );
1316    }
1317
1318    #[tokio::test]
1319    async fn test_model_files_present_assumes_present_without_cache_dir() {
1320        // With no configured cache directory we cannot verify, so we must not force a
1321        // re-download loop: assume the files are present.
1322        assert!(model_files_present(None, "test/model", ModelProvider::HuggingFace).await);
1323    }
1324
1325    #[tokio::test]
1326    #[allow(clippy::await_holding_lock)]
1327    async fn test_stream_model_files_hf_honors_file_selector() {
1328        let env_lock = acquire_env_mutex();
1329        let temp_dir = TempDir::new().expect("Failed to create temp dir");
1330        let _cache_dir_guard = EnvVarGuard::set(
1331            &env_lock,
1332            "MODEL_EXPRESS_CACHE_DIRECTORY",
1333            temp_dir.path().to_str().expect("Expected temp dir path"),
1334        );
1335        let _offline_guard = EnvVarGuard::set(&env_lock, "HF_HUB_OFFLINE", "1");
1336
1337        let model_dir = temp_dir.path().join("models--test--model/snapshots/abc123");
1338        std::fs::create_dir_all(&model_dir).expect("Failed to create model dir");
1339        std::fs::write(model_dir.join("config.json"), br#"{"model":"test"}"#)
1340            .expect("Failed to write config");
1341        std::fs::write(model_dir.join("model.bin"), vec![0u8; 100]).expect("Failed to write model");
1342
1343        let service = ModelServiceImpl;
1344        let request = Request::new(ModelFilesRequest {
1345            model_name: "test/model".to_string(),
1346            provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1347            chunk_size: 1024,
1348            file_selector: Some(ModelFileSelector {
1349                paths: vec!["config.json".to_string()],
1350            }),
1351        });
1352
1353        let response = service
1354            .stream_model_files(request)
1355            .await
1356            .expect("Expected stream response");
1357        let chunks: Vec<_> = response
1358            .into_inner()
1359            .map(|chunk| chunk.expect("Expected chunk"))
1360            .collect()
1361            .await;
1362
1363        assert_eq!(chunks.len(), 1);
1364        assert_eq!(chunks[0].relative_path, "config.json");
1365        assert_eq!(chunks[0].commit_hash.as_deref(), Some("abc123"));
1366        assert!(chunks[0].is_last_file);
1367    }
1368
1369    #[tokio::test]
1370    #[allow(clippy::await_holding_lock)]
1371    async fn test_stream_model_files_hf_returns_not_found_for_missing_selector_path() {
1372        let env_lock = acquire_env_mutex();
1373        let temp_dir = TempDir::new().expect("Failed to create temp dir");
1374        let _cache_dir_guard = EnvVarGuard::set(
1375            &env_lock,
1376            "MODEL_EXPRESS_CACHE_DIRECTORY",
1377            temp_dir.path().to_str().expect("Expected temp dir path"),
1378        );
1379        let _offline_guard = EnvVarGuard::set(&env_lock, "HF_HUB_OFFLINE", "1");
1380
1381        let model_dir = temp_dir.path().join("models--test--model/snapshots/abc123");
1382        std::fs::create_dir_all(&model_dir).expect("Failed to create model dir");
1383        std::fs::write(model_dir.join("config.json"), br#"{"model":"test"}"#)
1384            .expect("Failed to write config");
1385
1386        let service = ModelServiceImpl;
1387        let request = Request::new(ModelFilesRequest {
1388            model_name: "test/model".to_string(),
1389            provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1390            chunk_size: 1024,
1391            file_selector: Some(ModelFileSelector {
1392                paths: vec!["config.json".to_string(), "missing.json".to_string()],
1393            }),
1394        });
1395
1396        let result = service.stream_model_files(request).await;
1397        let status = result.expect_err("Expected not found");
1398        assert_eq!(status.code(), tonic::Code::NotFound);
1399        assert_eq!(
1400            status.message(),
1401            "Selected file not found in model directory: missing.json"
1402        );
1403    }
1404
1405    #[tokio::test]
1406    async fn test_list_model_files_not_found() {
1407        let service = ModelServiceImpl;
1408
1409        let request = Request::new(ModelFilesRequest {
1410            model_name: "non-existent-model-12345".to_string(),
1411            provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1412            chunk_size: 0,
1413            file_selector: None,
1414        });
1415
1416        let result = service.list_model_files(request).await;
1417        assert!(result.is_err());
1418        let status = result.expect_err("Should return error");
1419        assert_eq!(status.code(), tonic::Code::NotFound);
1420    }
1421
1422    #[tokio::test]
1423    async fn test_stream_model_files_not_found() {
1424        let service = ModelServiceImpl;
1425
1426        let request = Request::new(ModelFilesRequest {
1427            model_name: "non-existent-model-12345".to_string(),
1428            provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1429            chunk_size: 1024,
1430            file_selector: None,
1431        });
1432
1433        let result = service.stream_model_files(request).await;
1434        assert!(result.is_err());
1435        let status = result.expect_err("Should return error");
1436        assert_eq!(status.code(), tonic::Code::NotFound);
1437    }
1438
1439    #[tokio::test]
1440    async fn test_ensure_model_downloaded_rejects_invalid_provider() {
1441        let service = ModelServiceImpl;
1442
1443        let request = Request::new(ModelDownloadRequest {
1444            model_name: "test/model".to_string(),
1445            provider: 99,
1446            ignore_weights: false,
1447        });
1448
1449        let result = service.ensure_model_downloaded(request).await;
1450        assert!(result.is_err());
1451        let status = result.expect_err("Should return error");
1452        assert_eq!(status.code(), tonic::Code::InvalidArgument);
1453        assert!(status.message().contains("Invalid provider value"));
1454    }
1455
1456    #[tokio::test]
1457    async fn test_list_model_files_rejects_invalid_provider() {
1458        let service = ModelServiceImpl;
1459
1460        let request = Request::new(ModelFilesRequest {
1461            model_name: "test/model".to_string(),
1462            provider: 99,
1463            chunk_size: 0,
1464            file_selector: None,
1465        });
1466
1467        let result = service.list_model_files(request).await;
1468        assert!(result.is_err());
1469        let status = result.expect_err("Should return error");
1470        assert_eq!(status.code(), tonic::Code::InvalidArgument);
1471        assert!(status.message().contains("Invalid provider value"));
1472    }
1473
1474    #[tokio::test]
1475    async fn test_stream_model_files_rejects_invalid_provider() {
1476        let service = ModelServiceImpl;
1477
1478        let request = Request::new(ModelFilesRequest {
1479            model_name: "test/model".to_string(),
1480            provider: 99,
1481            chunk_size: 1024,
1482            file_selector: None,
1483        });
1484
1485        let result = service.stream_model_files(request).await;
1486        assert!(result.is_err());
1487        let status = result.expect_err("Should return error");
1488        assert_eq!(status.code(), tonic::Code::InvalidArgument);
1489        assert!(status.message().contains("Invalid provider value"));
1490    }
1491
1492    #[tokio::test]
1493    #[allow(clippy::await_holding_lock)]
1494    async fn test_stream_model_files_hf_first_chunk_includes_commit_hash() {
1495        let env_lock = acquire_env_mutex();
1496        let temp_dir = TempDir::new().expect("Failed to create temp dir");
1497        let _cache_dir_guard = EnvVarGuard::set(
1498            &env_lock,
1499            "MODEL_EXPRESS_CACHE_DIRECTORY",
1500            temp_dir.path().to_str().expect("Expected temp dir path"),
1501        );
1502        let _offline_guard = EnvVarGuard::set(&env_lock, "HF_HUB_OFFLINE", "1");
1503
1504        let model_dir = temp_dir.path().join("models--test--model/snapshots/abc123");
1505        std::fs::create_dir_all(&model_dir).expect("Failed to create model dir");
1506        std::fs::write(model_dir.join("config.json"), br#"{"model":"test"}"#)
1507            .expect("Failed to write model file");
1508
1509        let service = ModelServiceImpl;
1510        let request = Request::new(ModelFilesRequest {
1511            model_name: "test/model".to_string(),
1512            provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1513            chunk_size: 1024,
1514            file_selector: None,
1515        });
1516
1517        let response = service
1518            .stream_model_files(request)
1519            .await
1520            .expect("Expected stream response");
1521        let mut stream = response.into_inner();
1522        let first_chunk = stream
1523            .next()
1524            .await
1525            .expect("Expected stream item")
1526            .expect("Expected first chunk");
1527
1528        assert_eq!(first_chunk.relative_path, "config.json");
1529        assert_eq!(first_chunk.commit_hash.as_deref(), Some("abc123"));
1530        assert!(first_chunk.is_last_chunk);
1531        assert!(first_chunk.is_last_file);
1532    }
1533
1534    #[tokio::test]
1535    #[allow(clippy::await_holding_lock)]
1536    async fn test_stream_model_files_hf_emits_chunk_for_zero_byte_file() {
1537        let env_lock = acquire_env_mutex();
1538        let temp_dir = TempDir::new().expect("Failed to create temp dir");
1539        let _cache_dir_guard = EnvVarGuard::set(
1540            &env_lock,
1541            "MODEL_EXPRESS_CACHE_DIRECTORY",
1542            temp_dir.path().to_str().expect("Expected temp dir path"),
1543        );
1544        let _offline_guard = EnvVarGuard::set(&env_lock, "HF_HUB_OFFLINE", "1");
1545
1546        let model_dir = temp_dir.path().join("models--test--model/snapshots/abc123");
1547        std::fs::create_dir_all(&model_dir).expect("Failed to create model dir");
1548        std::fs::write(model_dir.join("empty.bin"), []).expect("Failed to write empty file");
1549
1550        let service = ModelServiceImpl;
1551        let request = Request::new(ModelFilesRequest {
1552            model_name: "test/model".to_string(),
1553            provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1554            chunk_size: 1024,
1555            file_selector: None,
1556        });
1557
1558        let response = service
1559            .stream_model_files(request)
1560            .await
1561            .expect("Expected stream response");
1562        let mut stream = response.into_inner();
1563        let first_chunk = stream
1564            .next()
1565            .await
1566            .expect("Expected stream item")
1567            .expect("Expected first chunk");
1568
1569        assert_eq!(first_chunk.relative_path, "empty.bin");
1570        assert_eq!(first_chunk.total_size, 0);
1571        assert_eq!(first_chunk.data.len(), 0);
1572        assert_eq!(first_chunk.offset, 0);
1573        assert_eq!(first_chunk.commit_hash.as_deref(), Some("abc123"));
1574        assert!(first_chunk.is_last_chunk);
1575        assert!(first_chunk.is_last_file);
1576    }
1577}