Skip to main content

modelexpress_server/
services.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::database::ModelDatabase;
5use modelexpress_common::{
6    cache::CacheConfig,
7    constants, download,
8    grpc::{
9        api::{ApiRequest, ApiResponse, api_service_server::ApiService},
10        health::{HealthRequest, HealthResponse, health_service_server::HealthService},
11        model::{
12            FileChunk, ModelDownloadRequest, ModelFileInfo, ModelFileList, ModelFilesRequest,
13            ModelStatusUpdate, model_service_server::ModelService,
14        },
15    },
16    models::{ModelProvider, ModelStatus},
17};
18use std::{
19    collections::HashMap,
20    path::{Path, PathBuf},
21    sync::{Arc, Mutex},
22    time::SystemTime,
23};
24use tokio::io::AsyncReadExt;
25use tokio_stream::wrappers::ReceiverStream;
26use tonic::{Request, Response, Status};
27use tracing::{debug, error, info};
28
29static START_TIME: std::sync::OnceLock<SystemTime> = std::sync::OnceLock::new();
30
31/// Get the configured cache directory for model downloads
32fn get_server_cache_dir() -> Option<std::path::PathBuf> {
33    // Try to get cache configuration
34    if let Ok(config) = CacheConfig::discover() {
35        Some(config.local_path)
36    } else {
37        // Fall back to environment variable
38        std::env::var("HF_HUB_CACHE")
39            .ok()
40            .map(std::path::PathBuf::from)
41    }
42}
43
44/// Convert gRPC provider to internal ModelProvider enum
45///
46/// Falls back to HuggingFace provider if the conversion fails or an invalid
47/// provider value is provided. A warning is logged when fallback occurs.
48fn convert_provider(grpc_provider: i32) -> ModelProvider {
49    match modelexpress_common::grpc::model::ModelProvider::try_from(grpc_provider) {
50        Ok(provider) => provider.into(),
51        Err(_) => {
52            tracing::warn!(
53                "Invalid provider value {}, falling back to HuggingFace",
54                grpc_provider
55            );
56            ModelProvider::HuggingFace
57        }
58    }
59}
60
61/// Health service implementation
62#[derive(Debug, Default)]
63pub struct HealthServiceImpl;
64
65#[tonic::async_trait]
66impl HealthService for HealthServiceImpl {
67    async fn get_health(
68        &self,
69        _request: Request<HealthRequest>,
70    ) -> Result<Response<HealthResponse>, Status> {
71        let start_time = START_TIME.get_or_init(SystemTime::now);
72        let uptime = SystemTime::now()
73            .duration_since(*start_time)
74            .unwrap_or_default()
75            .as_secs();
76
77        let response = HealthResponse {
78            version: env!("CARGO_PKG_VERSION").to_string(),
79            status: "ok".to_string(),
80            uptime,
81        };
82
83        Ok(Response::new(response))
84    }
85}
86
87/// API service implementation
88#[derive(Debug, Default)]
89pub struct ApiServiceImpl;
90
91#[tonic::async_trait]
92impl ApiService for ApiServiceImpl {
93    async fn send_request(
94        &self,
95        request: Request<ApiRequest>,
96    ) -> Result<Response<ApiResponse>, Status> {
97        let api_request = request.into_inner();
98        info!("Received gRPC request: {:?}", api_request);
99
100        // Process the request based on the action
101        if api_request.action.as_str() == "ping" {
102            info!("Processing ping request");
103            let response_data = serde_json::json!({ "message": "pong" });
104            let data_bytes = serde_json::to_vec(&response_data)
105                .map_err(|e| Status::internal(format!("Serialization error: {e}")))?;
106
107            Ok(Response::new(ApiResponse {
108                success: true,
109                data: Some(data_bytes),
110                error: None,
111            }))
112        } else {
113            error!("Unknown action: {}", api_request.action);
114            Ok(Response::new(ApiResponse {
115                success: false,
116                data: None,
117                error: Some(format!("Unknown action: {}", api_request.action)),
118            }))
119        }
120    }
121}
122
123/// Model service implementation
124#[derive(Debug, Default)]
125pub struct ModelServiceImpl;
126
127/// Helper function to collect all files in a model directory recursively
128fn collect_model_files(base_path: &Path, current_path: &Path) -> Vec<(PathBuf, u64)> {
129    let mut files = Vec::new();
130
131    if let Ok(entries) = std::fs::read_dir(current_path) {
132        for entry in entries.flatten() {
133            let path = entry.path();
134            if path.is_file() {
135                if let Ok(metadata) = std::fs::metadata(&path) {
136                    // Get relative path from base_path
137                    if let Ok(relative) = path.strip_prefix(base_path) {
138                        // Validate that the relative path does not contain any '..' components or is absolute
139                        let mut is_safe = true;
140                        for comp in relative.components() {
141                            use std::path::Component;
142                            match comp {
143                                Component::ParentDir
144                                | Component::RootDir
145                                | Component::Prefix(_) => {
146                                    is_safe = false;
147                                    break;
148                                }
149                                _ => {}
150                            }
151                        }
152                        if is_safe {
153                            files.push((relative.to_path_buf(), metadata.len()));
154                        } else {
155                            tracing::warn!(
156                                "Skipping potentially unsafe file path: {:?} (relative: {:?})",
157                                path,
158                                relative
159                            );
160                        }
161                    }
162                }
163            } else if path.is_dir() {
164                files.extend(collect_model_files(base_path, &path));
165            }
166        }
167    }
168
169    files
170}
171
172#[tonic::async_trait]
173impl ModelService for ModelServiceImpl {
174    type EnsureModelDownloadedStream = ReceiverStream<Result<ModelStatusUpdate, Status>>;
175    type StreamModelFilesStream = ReceiverStream<Result<FileChunk, Status>>;
176
177    async fn ensure_model_downloaded(
178        &self,
179        request: Request<ModelDownloadRequest>,
180    ) -> Result<Response<Self::EnsureModelDownloadedStream>, Status> {
181        info!("Starting model download stream");
182        let model_request = request.into_inner();
183        let (tx, rx) = tokio::sync::mpsc::channel(4);
184        let model_name = model_request.model_name.clone();
185
186        // Convert gRPC provider to our enum
187        let provider: ModelProvider =
188            modelexpress_common::grpc::model::ModelProvider::try_from(model_request.provider)
189                .unwrap_or(modelexpress_common::grpc::model::ModelProvider::HuggingFace)
190                .into();
191        let ignore_weights = model_request.ignore_weights;
192
193        // Spawn a task to handle the streaming download updates
194        tokio::spawn(async move {
195            // Check if the model is already downloaded
196            if let Some(status) = MODEL_TRACKER.get_status(&model_name) {
197                let update = ModelStatusUpdate {
198                    model_name: model_name.clone(),
199                    status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
200                    message: match status {
201                        ModelStatus::DOWNLOADED => Some("Model already downloaded".to_string()),
202                        ModelStatus::DOWNLOADING => Some("Model download in progress".to_string()),
203                        ModelStatus::ERROR => {
204                            Some("Previous download failed - retrying".to_string())
205                        }
206                    },
207                    provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
208                        as i32,
209                };
210
211                if tx.send(Ok(update)).await.is_err() {
212                    return; // Client disconnected
213                }
214
215                // If already downloaded, we're done
216                if status == ModelStatus::DOWNLOADED {
217                    return;
218                }
219            }
220
221            // Start or monitor the download process
222            let final_status = MODEL_TRACKER
223                .ensure_model_downloaded(&model_name, provider, &tx, ignore_weights)
224                .await;
225
226            // Send final status update
227            let final_update = ModelStatusUpdate {
228                model_name: model_name.clone(),
229                status: modelexpress_common::grpc::model::ModelStatus::from(final_status) as i32,
230                message: match final_status {
231                    ModelStatus::DOWNLOADED => {
232                        Some("Model download completed successfully".to_string())
233                    }
234                    ModelStatus::ERROR => Some("Model download failed".to_string()),
235                    ModelStatus::DOWNLOADING => Some("Download still in progress".to_string()),
236                },
237                provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
238            };
239
240            let _ = tx.send(Ok(final_update)).await;
241        });
242
243        Ok(Response::new(ReceiverStream::new(rx)))
244    }
245
246    async fn stream_model_files(
247        &self,
248        request: Request<ModelFilesRequest>,
249    ) -> Result<Response<Self::StreamModelFilesStream>, Status> {
250        let files_request = request.into_inner();
251        let model_name = files_request.model_name.clone();
252        let chunk_size = if files_request.chunk_size == 0 {
253            constants::DEFAULT_TRANSFER_CHUNK_SIZE
254        } else {
255            files_request.chunk_size as usize
256        };
257
258        // Convert gRPC provider to our enum
259        let provider = convert_provider(files_request.provider);
260
261        info!(
262            "Starting file stream for model: {} with chunk size: {} bytes",
263            model_name, chunk_size
264        );
265
266        // Get the cache directory
267        let cache_dir = get_server_cache_dir()
268            .ok_or_else(|| Status::internal("Server cache directory not configured"))?;
269
270        // Get the model path using the provider from the request
271        let provider_impl = download::get_provider(provider);
272        let model_path = provider_impl
273            .get_model_path(&model_name, cache_dir.clone())
274            .await
275            .map_err(|e| Status::not_found(format!("Model not found: {e}")))?;
276
277        debug!("Model path resolved to: {:?}", model_path);
278
279        let commit_hash = if provider == ModelProvider::HuggingFace {
280            model_path
281                .file_name()
282                .and_then(|name| name.to_str())
283                .map(String::from)
284        } else {
285            None
286        };
287
288        if provider == ModelProvider::HuggingFace && commit_hash.is_none() {
289            return Err(Status::internal(
290                "Resolved Hugging Face model path did not contain a revision",
291            ));
292        }
293
294        // Collect all files to stream
295        let files = collect_model_files(&model_path, &model_path);
296
297        if files.is_empty() {
298            return Err(Status::not_found("No files found in model directory"));
299        }
300
301        let total_files = files.len();
302        info!(
303            "Found {} files to stream for model {}",
304            total_files, model_name
305        );
306
307        let (tx, rx) = tokio::sync::mpsc::channel(16);
308
309        // Spawn a task to stream files
310        tokio::spawn(async move {
311            // Allocate buffer once and reuse across all files
312            let mut buffer = vec![0u8; chunk_size];
313            let mut is_first_chunk = true;
314
315            for (file_idx, (relative_path, total_size)) in files.iter().enumerate() {
316                let file_path = model_path.join(relative_path);
317                let is_last_file = file_idx == total_files.saturating_sub(1);
318
319                debug!("Streaming file: {:?} ({} bytes)", relative_path, total_size);
320
321                // Open the file
322                let file = match tokio::fs::File::open(&file_path).await {
323                    Ok(f) => f,
324                    Err(e) => {
325                        error!("Failed to open file {:?}: {}", file_path, e);
326                        let _ = tx
327                            .send(Err(Status::internal(format!("Failed to open file: {e}"))))
328                            .await;
329                        return;
330                    }
331                };
332
333                let mut reader = tokio::io::BufReader::new(file);
334                let mut offset: u64 = 0;
335
336                loop {
337                    let bytes_read = match reader.read(&mut buffer).await {
338                        Ok(0) => break, // EOF
339                        Ok(n) => n,
340                        Err(e) => {
341                            error!("Failed to read file {:?}: {}", file_path, e);
342                            let _ = tx
343                                .send(Err(Status::internal(format!("Failed to read file: {e}"))))
344                                .await;
345                            return;
346                        }
347                    };
348
349                    let is_last_chunk = offset.saturating_add(bytes_read as u64) >= *total_size;
350
351                    let first_chunk = std::mem::replace(&mut is_first_chunk, false);
352
353                    let chunk = FileChunk {
354                        relative_path: relative_path.to_string_lossy().to_string(),
355                        data: buffer[..bytes_read].to_vec(),
356                        offset,
357                        total_size: *total_size,
358                        is_last_chunk,
359                        is_last_file: is_last_file && is_last_chunk,
360                        commit_hash: if first_chunk {
361                            commit_hash.clone()
362                        } else {
363                            None
364                        },
365                    };
366
367                    if tx.send(Ok(chunk)).await.is_err() {
368                        debug!("Client disconnected during file stream");
369                        return;
370                    }
371
372                    offset = offset.saturating_add(bytes_read as u64);
373                }
374            }
375
376            info!("File streaming completed for model");
377        });
378
379        Ok(Response::new(ReceiverStream::new(rx)))
380    }
381
382    async fn list_model_files(
383        &self,
384        request: Request<ModelFilesRequest>,
385    ) -> Result<Response<ModelFileList>, Status> {
386        let files_request = request.into_inner();
387        let model_name = files_request.model_name.clone();
388
389        // Convert gRPC provider to our enum
390        let provider = convert_provider(files_request.provider);
391
392        info!("Listing files for model: {}", model_name);
393
394        // Get the cache directory
395        let cache_dir = get_server_cache_dir()
396            .ok_or_else(|| Status::internal("Server cache directory not configured"))?;
397
398        // Get the model path using the provider from the request
399        let provider_impl = download::get_provider(provider);
400        let model_path = provider_impl
401            .get_model_path(&model_name, cache_dir)
402            .await
403            .map_err(|e| Status::not_found(format!("Model not found: {e}")))?;
404
405        // Collect all files
406        let files = collect_model_files(&model_path, &model_path);
407
408        let file_infos: Vec<ModelFileInfo> = files
409            .iter()
410            .map(|(path, size)| ModelFileInfo {
411                relative_path: path.to_string_lossy().to_string(),
412                size: *size,
413            })
414            .collect();
415
416        let total_size: u64 = files.iter().map(|(_, size)| size).sum();
417
418        Ok(Response::new(ModelFileList {
419            model_name,
420            files: file_infos,
421            total_size,
422        }))
423    }
424}
425
426/// Type alias for the complex waiting channels type
427type WaitingChannels =
428    Arc<Mutex<HashMap<String, Vec<tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>>>>>;
429
430/// Tracks the status of model downloads using `SQLite` for persistence
431#[derive(Debug, Clone)]
432pub struct ModelDownloadTracker {
433    /// `SQLite` database for persistent model status tracking
434    database: ModelDatabase,
435    /// Maps model names to list of channels waiting for updates
436    waiting_channels: WaitingChannels,
437}
438
439impl Default for ModelDownloadTracker {
440    fn default() -> Self {
441        Self::new()
442    }
443}
444
445impl ModelDownloadTracker {
446    #[must_use]
447    pub fn new() -> Self {
448        // Initialize database in the current directory
449        let database = match ModelDatabase::new("./models.db") {
450            Ok(db) => db,
451            Err(e) => {
452                error!("Critical error: Could not initialize model database at ./models.db: {e}");
453                panic!("Critical error: Could not initialize model database at ./models.db");
454            }
455        };
456
457        Self {
458            database,
459            waiting_channels: Arc::new(Mutex::new(HashMap::new())),
460        }
461    }
462
463    /// Gets the status of a model from the database
464    /// If the model is not in the database, it returns None
465    pub fn get_status(&self, model_name: &str) -> Option<ModelStatus> {
466        match self.database.get_status(model_name) {
467            Ok(status) => {
468                // Update last_used_at when checking status
469                if status.is_some() {
470                    let _ = self.database.touch_model(model_name);
471                }
472                status
473            }
474            Err(e) => {
475                error!("Failed to get model status from database: {}", e);
476                None
477            }
478        }
479    }
480
481    /// Sets the status of a model and notifies all waiting channels
482    pub fn set_status_and_notify(
483        &self,
484        model_name: String,
485        status: ModelStatus,
486        provider: ModelProvider,
487        message: Option<String>,
488    ) {
489        // Update status in database
490        if let Err(e) = self
491            .database
492            .set_status(&model_name, provider, status, message.clone())
493        {
494            error!("Failed to update model status in database: {}", e);
495            return;
496        }
497
498        // Notify all waiting channels
499        let mut waiting = match self.waiting_channels.lock() {
500            Ok(guard) => guard,
501            Err(poisoned) => {
502                error!("Waiting channels mutex is poisoned, recovering");
503                poisoned.into_inner()
504            }
505        };
506        if let Some(channels) = waiting.get(&model_name) {
507            let update = ModelStatusUpdate {
508                model_name: model_name.clone(),
509                status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
510                message,
511                provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
512            };
513
514            for channel in channels {
515                let _ = channel.try_send(Ok(update.clone()));
516            }
517
518            // If the model is downloaded or errored, remove all waiting channels
519            if status == ModelStatus::DOWNLOADED || status == ModelStatus::ERROR {
520                waiting.remove(&model_name);
521            }
522        }
523    }
524
525    /// Sets the status of a model
526    pub fn set_status(&self, model_name: String, status: ModelStatus, provider: ModelProvider) {
527        self.set_status_and_notify(model_name, status, provider, None);
528    }
529
530    /// Adds a channel to wait for updates on a specific model
531    pub fn add_waiting_channel(
532        &self,
533        model_name: &str,
534        tx: tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
535    ) {
536        let mut waiting = match self.waiting_channels.lock() {
537            Ok(guard) => guard,
538            Err(poisoned) => {
539                error!("Waiting channels mutex is poisoned, recovering");
540                poisoned.into_inner()
541            }
542        };
543        waiting.entry(model_name.to_string()).or_default().push(tx);
544    }
545
546    /// Deletes the status of a model from the database
547    /// This is used when a model is removed from the tracker
548    pub fn delete_status(&self, model_name: &str) {
549        if let Err(e) = self.database.delete_model(model_name) {
550            error!("Failed to delete model from database: {}", e);
551        }
552        let mut waiting = match self.waiting_channels.lock() {
553            Ok(guard) => guard,
554            Err(poisoned) => {
555                error!("Waiting channels mutex is poisoned, recovering");
556                poisoned.into_inner()
557            }
558        };
559        waiting.remove(model_name);
560    }
561
562    /// Initiates a download for a model and streams status updates
563    pub async fn ensure_model_downloaded(
564        &self,
565        model_name: &str,
566        provider: ModelProvider,
567        tx: &tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
568        ignore_weights: bool,
569    ) -> ModelStatus {
570        // Atomically try to claim this model for download using compare-and-swap
571        let status = match self.database.try_claim_for_download(model_name, provider) {
572            Ok(status) => status,
573            Err(e) => {
574                error!("Failed to claim model for download: {}", e);
575                // Send error and return
576                let error_update = ModelStatusUpdate {
577                    model_name: model_name.to_string(),
578                    status: modelexpress_common::grpc::model::ModelStatus::from(ModelStatus::ERROR)
579                        as i32,
580                    message: Some("Database error occurred".to_string()),
581                    provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
582                        as i32,
583                };
584                let _ = tx.send(Ok(error_update)).await;
585                return ModelStatus::ERROR;
586            }
587        };
588
589        // Send current status
590        let update = ModelStatusUpdate {
591            model_name: model_name.to_string(),
592            status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
593            message: match status {
594                ModelStatus::DOWNLOADED => Some("Model already downloaded".to_string()),
595                ModelStatus::DOWNLOADING => Some("Model download in progress".to_string()),
596                ModelStatus::ERROR => Some("Previous download failed - retrying".to_string()),
597            },
598            provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
599        };
600
601        let _ = tx.send(Ok(update)).await;
602
603        // If the model already existed and is downloading, add this channel to wait for updates
604        if status == ModelStatus::DOWNLOADING {
605            self.add_waiting_channel(model_name, tx.clone());
606
607            // Check if we were the ones who just claimed it vs. it was already downloading
608            // If we just claimed it, we need to start the actual download
609            // We can determine this by checking if there are any waiting channels yet
610            let should_start_download = {
611                let waiting = match self.waiting_channels.lock() {
612                    Ok(guard) => guard,
613                    Err(poisoned) => {
614                        error!("Waiting channels mutex is poisoned, recovering");
615                        poisoned.into_inner()
616                    }
617                };
618                waiting
619                    .get(model_name)
620                    .is_none_or(|channels| channels.len() <= 1)
621            };
622
623            if should_start_download {
624                // We claimed the model, so we're responsible for downloading it
625                let tracker = self.clone();
626                let model_name_owned = model_name.to_string();
627
628                // Perform the download in the background
629                tokio::spawn(async move {
630                    let cache_dir = get_server_cache_dir();
631                    match download::download_model(
632                        &model_name_owned,
633                        provider,
634                        cache_dir,
635                        ignore_weights,
636                    )
637                    .await
638                    {
639                        Ok(_path) => {
640                            // Download completed successfully
641                            tracker.set_status_and_notify(
642                                model_name_owned,
643                                ModelStatus::DOWNLOADED,
644                                provider,
645                                Some("Model download completed successfully".to_string()),
646                            );
647                        }
648                        Err(e) => {
649                            // Download failed
650                            error!("Failed to download model {model_name_owned}: {e}");
651                            tracker.set_status_and_notify(
652                                model_name_owned,
653                                ModelStatus::ERROR,
654                                provider,
655                                Some(format!("Download failed: {e}")),
656                            );
657                        }
658                    }
659                });
660            }
661
662            // Wait for completion by monitoring the status
663            loop {
664                tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
665                if let Some(current_status) = self.get_status(model_name)
666                    && current_status != ModelStatus::DOWNLOADING
667                {
668                    return current_status;
669                }
670            }
671        } else if status == ModelStatus::ERROR {
672            // If the model is in ERROR status, try to retry the download
673            // First, reset the status to DOWNLOADING
674            if let Err(e) = self.database.set_status(
675                model_name,
676                provider,
677                ModelStatus::DOWNLOADING,
678                Some("Retrying download...".to_string()),
679            ) {
680                error!("Failed to reset status for retry: {}", e);
681                return ModelStatus::ERROR;
682            }
683
684            // Add this channel to wait for updates
685            self.add_waiting_channel(model_name, tx.clone());
686
687            // Start the download
688            let tracker = self.clone();
689            let model_name_owned = model_name.to_string();
690
691            tokio::spawn(async move {
692                let cache_dir = get_server_cache_dir();
693                match download::download_model(
694                    &model_name_owned,
695                    provider,
696                    cache_dir,
697                    ignore_weights,
698                )
699                .await
700                {
701                    Ok(_path) => {
702                        // Download completed successfully
703                        tracker.set_status_and_notify(
704                            model_name_owned,
705                            ModelStatus::DOWNLOADED,
706                            provider,
707                            Some("Model download completed successfully".to_string()),
708                        );
709                    }
710                    Err(e) => {
711                        // Download failed again
712                        error!("Failed to download model {model_name_owned} on retry: {e}");
713                        tracker.set_status_and_notify(
714                            model_name_owned,
715                            ModelStatus::ERROR,
716                            provider,
717                            Some(format!("Download failed on retry: {e}")),
718                        );
719                    }
720                }
721            });
722
723            // Wait for completion by monitoring the status
724            loop {
725                tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
726                if let Some(current_status) = self.get_status(model_name)
727                    && current_status != ModelStatus::DOWNLOADING
728                {
729                    return current_status;
730                }
731            }
732        }
733
734        status
735    }
736}
737
738/// Global model download tracker
739pub static MODEL_TRACKER: std::sync::LazyLock<ModelDownloadTracker> =
740    std::sync::LazyLock::new(ModelDownloadTracker::new);
741
742#[cfg(test)]
743#[allow(clippy::expect_used)]
744mod tests {
745    use super::*;
746    use modelexpress_common::grpc::{
747        api::ApiRequest, health::HealthRequest, model::ModelDownloadRequest,
748    };
749    use modelexpress_common::test_support::{EnvVarGuard, acquire_env_mutex};
750    use tempfile::TempDir;
751    use tokio_stream::StreamExt;
752    use tonic::Request;
753
754    #[tokio::test]
755    async fn test_health_service() {
756        let service = HealthServiceImpl;
757        let request = Request::new(HealthRequest {});
758
759        let response = service.get_health(request).await;
760        assert!(response.is_ok());
761
762        let health_response = response.expect("Health response should be ok").into_inner();
763        assert_eq!(health_response.version, env!("CARGO_PKG_VERSION"));
764        assert_eq!(health_response.status, "ok");
765        // uptime is u64, always >= 0, so just verify it exists
766        let _uptime = health_response.uptime;
767    }
768
769    #[tokio::test]
770    async fn test_api_service_ping() {
771        let service = ApiServiceImpl;
772        let request = Request::new(ApiRequest {
773            id: "test-id".to_string(),
774            action: "ping".to_string(),
775            payload: None,
776        });
777
778        let response = service.send_request(request).await;
779        assert!(response.is_ok());
780
781        let api_response = response.expect("API response should be ok").into_inner();
782        assert!(api_response.success);
783        assert!(api_response.data.is_some());
784        assert!(api_response.error.is_none());
785
786        // Check that the response data contains "pong"
787        let data_bytes = api_response.data.expect("Data should be present");
788        let data: serde_json::Value =
789            serde_json::from_slice(&data_bytes).expect("Data should be valid JSON");
790        assert_eq!(data["message"], "pong");
791    }
792
793    #[tokio::test]
794    async fn test_api_service_unknown_action() {
795        let service = ApiServiceImpl;
796        let request = Request::new(ApiRequest {
797            id: "test-id".to_string(),
798            action: "unknown-action".to_string(),
799            payload: None,
800        });
801
802        let response = service.send_request(request).await;
803        assert!(response.is_ok());
804
805        let api_response = response.expect("API response should be ok").into_inner();
806        assert!(!api_response.success);
807        assert!(api_response.data.is_none());
808        assert!(api_response.error.is_some());
809
810        let error_message = api_response.error.expect("Error should be present");
811        assert!(error_message.contains("Unknown action"));
812    }
813
814    #[test]
815    fn test_model_download_tracker_new() {
816        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
817        let tracker = ModelDownloadTracker::new();
818
819        // Test that we can get status for a non-existent model
820        let status = tracker.get_status("non-existent-model");
821        assert!(status.is_none());
822    }
823
824    #[test]
825    fn test_model_download_tracker_set_and_get_status() {
826        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
827        let tracker = ModelDownloadTracker::new();
828
829        // Use a unique model name based on current time to avoid conflicts
830        let timestamp = std::time::SystemTime::now()
831            .duration_since(std::time::UNIX_EPOCH)
832            .expect("Time went backwards")
833            .as_nanos();
834        let model_name = format!("test-model-{timestamp}");
835        let provider = ModelProvider::HuggingFace;
836
837        // Initially should return None
838        assert!(tracker.get_status(&model_name).is_none());
839
840        // Set status
841        tracker.set_status(model_name.clone(), ModelStatus::DOWNLOADING, provider);
842
843        // Should now return the status
844        let status = tracker.get_status(&model_name);
845        assert!(status.is_some());
846        assert_eq!(
847            status.expect("Status should be present"),
848            ModelStatus::DOWNLOADING
849        );
850
851        // Cleanup
852        tracker.delete_status(&model_name);
853    }
854
855    #[test]
856    fn test_model_download_tracker_delete_status() {
857        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
858        let tracker = ModelDownloadTracker::new();
859        let timestamp = std::time::SystemTime::now()
860            .duration_since(std::time::UNIX_EPOCH)
861            .expect("Time went backwards")
862            .as_nanos();
863        let model_name = format!("test-delete-model-{timestamp}");
864        let provider = ModelProvider::HuggingFace;
865
866        // Set status
867        tracker.set_status(model_name.clone(), ModelStatus::DOWNLOADED, provider);
868        assert!(tracker.get_status(&model_name).is_some());
869
870        // Delete status
871        tracker.delete_status(&model_name);
872        assert!(tracker.get_status(&model_name).is_none());
873    }
874
875    #[tokio::test]
876    async fn test_model_service_already_downloaded() {
877        let service = ModelServiceImpl;
878        let timestamp = std::time::SystemTime::now()
879            .duration_since(std::time::UNIX_EPOCH)
880            .expect("Time went backwards")
881            .as_nanos();
882        let model_name = format!("test-already-downloaded-model-{timestamp}");
883
884        // Pre-populate the model as downloaded
885        MODEL_TRACKER.set_status(
886            model_name.clone(),
887            ModelStatus::DOWNLOADED,
888            ModelProvider::HuggingFace,
889        );
890
891        let request = Request::new(ModelDownloadRequest {
892            model_name: model_name.clone(),
893            provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
894            ignore_weights: false,
895        });
896
897        let response = service.ensure_model_downloaded(request).await;
898        assert!(response.is_ok());
899
900        let mut stream = response.expect("Response should be ok").into_inner();
901
902        // Should get at least one update indicating it's already downloaded
903        let update = stream.next().await;
904        assert!(update.is_some());
905
906        let update = update.expect("Update should be present");
907        assert!(update.is_ok());
908
909        let status_update = update.expect("Status update should be ok");
910        assert_eq!(status_update.model_name, model_name);
911        assert_eq!(
912            status_update.status,
913            modelexpress_common::grpc::model::ModelStatus::Downloaded as i32
914        );
915
916        // Cleanup
917        MODEL_TRACKER.delete_status(&model_name);
918    }
919
920    #[test]
921    fn test_model_download_tracker_set_status_and_notify() {
922        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
923        let tracker = ModelDownloadTracker::new();
924        let model_name = "test-notify-model".to_string();
925        let provider = ModelProvider::HuggingFace;
926
927        // Test set_status_and_notify doesn't panic
928        tracker.set_status_and_notify(
929            model_name.clone(),
930            ModelStatus::DOWNLOADED,
931            provider,
932            Some("Download completed".to_string()),
933        );
934
935        // Verify status was set
936        let status = tracker.get_status(&model_name);
937        assert!(status.is_some());
938        assert_eq!(
939            status.expect("Status should be present"),
940            ModelStatus::DOWNLOADED
941        );
942    }
943
944    #[test]
945    fn test_waiting_channels_management() {
946        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
947        let tracker = ModelDownloadTracker::new();
948        let model_name = "test-channels-model";
949
950        let (tx, _rx) = tokio::sync::mpsc::channel(4);
951
952        // Add a waiting channel
953        tracker.add_waiting_channel(model_name, tx);
954
955        // Verify the channel was added by checking internal state
956        let waiting_count = {
957            let waiting = match tracker.waiting_channels.lock() {
958                Ok(guard) => guard,
959                Err(poisoned) => poisoned.into_inner(),
960            };
961            waiting.get(model_name).map_or(0, std::vec::Vec::len)
962        };
963        assert_eq!(waiting_count, 1);
964
965        // Clean up by setting final status
966        tracker.set_status_and_notify(
967            model_name.to_string(),
968            ModelStatus::DOWNLOADED,
969            ModelProvider::HuggingFace,
970            None,
971        );
972
973        // Channels should be cleared for final statuses
974        let waiting_count_after = {
975            let waiting = match tracker.waiting_channels.lock() {
976                Ok(guard) => guard,
977                Err(poisoned) => poisoned.into_inner(),
978            };
979            waiting.get(model_name).map_or(0, std::vec::Vec::len)
980        };
981        assert_eq!(waiting_count_after, 0);
982    }
983
984    #[tokio::test]
985    async fn test_model_service_stream_closes_properly() {
986        let service = ModelServiceImpl;
987        let model_name = "test-stream-model";
988
989        let request = Request::new(ModelDownloadRequest {
990            model_name: model_name.to_string(),
991            provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
992            ignore_weights: false,
993        });
994
995        let response = service.ensure_model_downloaded(request).await;
996        assert!(response.is_ok());
997
998        let mut stream = response.expect("Response should be ok").into_inner();
999
1000        // Read a few updates (may include initial status and progress)
1001        let mut update_count = 0;
1002        while let Some(update) = stream.next().await {
1003            assert!(update.is_ok());
1004            update_count += 1;
1005
1006            // Prevent infinite loop in case of issues
1007            if update_count > 10 {
1008                break;
1009            }
1010        }
1011
1012        assert!(update_count > 0);
1013
1014        // Cleanup
1015        MODEL_TRACKER.delete_status(model_name);
1016    }
1017
1018    #[tokio::test]
1019    async fn test_concurrent_model_download_no_race_condition() {
1020        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
1021        let tracker = ModelDownloadTracker::new();
1022        let model_name = "test-concurrent-model";
1023        let provider = ModelProvider::HuggingFace;
1024
1025        // Test that the compare-and-swap mechanism works
1026        // First attempt should claim the model
1027        let status1 = tracker
1028            .database
1029            .try_claim_for_download(model_name, provider)
1030            .expect("Failed to claim for download 1");
1031        assert_eq!(status1, ModelStatus::DOWNLOADING);
1032
1033        // Second attempt should see it's already claimed
1034        let status2 = tracker
1035            .database
1036            .try_claim_for_download(model_name, provider)
1037            .expect("Failed to claim for download 2");
1038        assert_eq!(status2, ModelStatus::DOWNLOADING);
1039
1040        // Verify only one record exists
1041        let record = tracker
1042            .database
1043            .get_model_record(model_name)
1044            .expect("Failed to get model record")
1045            .expect("Record should exist");
1046        assert_eq!(record.status, ModelStatus::DOWNLOADING);
1047
1048        // Cleanup
1049        tracker.delete_status(model_name);
1050    }
1051
1052    #[test]
1053    fn test_collect_model_files_empty_dir() {
1054        let temp_dir = TempDir::new().expect("Failed to create temp dir");
1055        let files = collect_model_files(temp_dir.path(), temp_dir.path());
1056        assert!(files.is_empty());
1057    }
1058
1059    #[test]
1060    fn test_collect_model_files_with_files() {
1061        let temp_dir = TempDir::new().expect("Failed to create temp dir");
1062
1063        // Create some test files
1064        let file1_path = temp_dir.path().join("config.json");
1065        std::fs::write(&file1_path, r#"{"test": "data"}"#).expect("Failed to write file1");
1066
1067        let file2_path = temp_dir.path().join("model.bin");
1068        std::fs::write(&file2_path, vec![0u8; 100]).expect("Failed to write file2");
1069
1070        let files = collect_model_files(temp_dir.path(), temp_dir.path());
1071
1072        assert_eq!(files.len(), 2);
1073
1074        // Check file sizes
1075        let total_size: u64 = files.iter().map(|(_, size)| size).sum();
1076        assert!(total_size > 0);
1077
1078        // Check that relative paths are correct
1079        let paths: Vec<_> = files
1080            .iter()
1081            .map(|(p, _)| p.to_string_lossy().to_string())
1082            .collect();
1083        assert!(paths.contains(&"config.json".to_string()));
1084        assert!(paths.contains(&"model.bin".to_string()));
1085    }
1086
1087    #[test]
1088    fn test_collect_model_files_nested() {
1089        let temp_dir = TempDir::new().expect("Failed to create temp dir");
1090
1091        // Create nested directory structure
1092        let subdir = temp_dir.path().join("subdir");
1093        std::fs::create_dir(&subdir).expect("Failed to create subdir");
1094
1095        let file1_path = temp_dir.path().join("root_file.txt");
1096        std::fs::write(&file1_path, "root content").expect("Failed to write file1");
1097
1098        let file2_path = subdir.join("nested_file.txt");
1099        std::fs::write(&file2_path, "nested content").expect("Failed to write file2");
1100
1101        let files = collect_model_files(temp_dir.path(), temp_dir.path());
1102
1103        assert_eq!(files.len(), 2);
1104
1105        // Check that nested path is correct
1106        let paths: Vec<_> = files
1107            .iter()
1108            .map(|(p, _)| p.to_string_lossy().to_string())
1109            .collect();
1110        assert!(paths.iter().any(|p| p.contains("nested_file")));
1111    }
1112
1113    #[tokio::test]
1114    async fn test_list_model_files_not_found() {
1115        let service = ModelServiceImpl;
1116
1117        let request = Request::new(ModelFilesRequest {
1118            model_name: "non-existent-model-12345".to_string(),
1119            provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1120            chunk_size: 0,
1121        });
1122
1123        let result = service.list_model_files(request).await;
1124        assert!(result.is_err());
1125        let status = result.expect_err("Should return error");
1126        assert_eq!(status.code(), tonic::Code::NotFound);
1127    }
1128
1129    #[tokio::test]
1130    async fn test_stream_model_files_not_found() {
1131        let service = ModelServiceImpl;
1132
1133        let request = Request::new(ModelFilesRequest {
1134            model_name: "non-existent-model-12345".to_string(),
1135            provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1136            chunk_size: 1024,
1137        });
1138
1139        let result = service.stream_model_files(request).await;
1140        assert!(result.is_err());
1141        let status = result.expect_err("Should return error");
1142        assert_eq!(status.code(), tonic::Code::NotFound);
1143    }
1144
1145    #[tokio::test]
1146    #[allow(clippy::await_holding_lock)]
1147    async fn test_stream_model_files_hf_first_chunk_includes_commit_hash() {
1148        let env_lock = acquire_env_mutex();
1149        let temp_dir = TempDir::new().expect("Failed to create temp dir");
1150        let _cache_dir_guard = EnvVarGuard::set(
1151            &env_lock,
1152            "MODEL_EXPRESS_CACHE_DIRECTORY",
1153            temp_dir.path().to_str().expect("Expected temp dir path"),
1154        );
1155        let _offline_guard = EnvVarGuard::set(&env_lock, "HF_HUB_OFFLINE", "1");
1156
1157        let model_dir = temp_dir.path().join("models--test--model/snapshots/abc123");
1158        std::fs::create_dir_all(&model_dir).expect("Failed to create model dir");
1159        std::fs::write(model_dir.join("config.json"), br#"{"model":"test"}"#)
1160            .expect("Failed to write model file");
1161
1162        let service = ModelServiceImpl;
1163        let request = Request::new(ModelFilesRequest {
1164            model_name: "test/model".to_string(),
1165            provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1166            chunk_size: 1024,
1167        });
1168
1169        let response = service
1170            .stream_model_files(request)
1171            .await
1172            .expect("Expected stream response");
1173        let mut stream = response.into_inner();
1174        let first_chunk = stream
1175            .next()
1176            .await
1177            .expect("Expected stream item")
1178            .expect("Expected first chunk");
1179
1180        assert_eq!(first_chunk.relative_path, "config.json");
1181        assert_eq!(first_chunk.commit_hash.as_deref(), Some("abc123"));
1182        assert!(first_chunk.is_last_chunk);
1183        assert!(first_chunk.is_last_file);
1184    }
1185}