modelexpress_server/
services.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::database::ModelDatabase;
5use modelexpress_common::{
6    cache::CacheConfig,
7    download,
8    grpc::{
9        api::{ApiRequest, ApiResponse, api_service_server::ApiService},
10        health::{HealthRequest, HealthResponse, health_service_server::HealthService},
11        model::{ModelDownloadRequest, ModelStatusUpdate, model_service_server::ModelService},
12    },
13    models::{ModelProvider, ModelStatus},
14};
15use std::{
16    collections::HashMap,
17    sync::{Arc, Mutex},
18    time::SystemTime,
19};
20use tokio_stream::wrappers::ReceiverStream;
21use tonic::{Request, Response, Status};
22use tracing::{error, info};
23
24static START_TIME: std::sync::OnceLock<SystemTime> = std::sync::OnceLock::new();
25
26/// Get the configured cache directory for model downloads
27fn get_server_cache_dir() -> Option<std::path::PathBuf> {
28    // Try to get cache configuration
29    if let Ok(config) = CacheConfig::discover() {
30        Some(config.local_path)
31    } else {
32        // Fall back to environment variable
33        std::env::var("HF_HUB_CACHE")
34            .ok()
35            .map(std::path::PathBuf::from)
36    }
37}
38
39/// Health service implementation
40#[derive(Debug, Default)]
41pub struct HealthServiceImpl;
42
43#[tonic::async_trait]
44impl HealthService for HealthServiceImpl {
45    async fn get_health(
46        &self,
47        _request: Request<HealthRequest>,
48    ) -> Result<Response<HealthResponse>, Status> {
49        let start_time = START_TIME.get_or_init(SystemTime::now);
50        let uptime = SystemTime::now()
51            .duration_since(*start_time)
52            .unwrap_or_default()
53            .as_secs();
54
55        let response = HealthResponse {
56            version: env!("CARGO_PKG_VERSION").to_string(),
57            status: "ok".to_string(),
58            uptime,
59        };
60
61        Ok(Response::new(response))
62    }
63}
64
65/// API service implementation
66#[derive(Debug, Default)]
67pub struct ApiServiceImpl;
68
69#[tonic::async_trait]
70impl ApiService for ApiServiceImpl {
71    async fn send_request(
72        &self,
73        request: Request<ApiRequest>,
74    ) -> Result<Response<ApiResponse>, Status> {
75        let api_request = request.into_inner();
76        info!("Received gRPC request: {:?}", api_request);
77
78        // Process the request based on the action
79        if api_request.action.as_str() == "ping" {
80            info!("Processing ping request");
81            let response_data = serde_json::json!({ "message": "pong" });
82            let data_bytes = serde_json::to_vec(&response_data)
83                .map_err(|e| Status::internal(format!("Serialization error: {e}")))?;
84
85            Ok(Response::new(ApiResponse {
86                success: true,
87                data: Some(data_bytes),
88                error: None,
89            }))
90        } else {
91            error!("Unknown action: {}", api_request.action);
92            Ok(Response::new(ApiResponse {
93                success: false,
94                data: None,
95                error: Some(format!("Unknown action: {}", api_request.action)),
96            }))
97        }
98    }
99}
100
101/// Model service implementation
102#[derive(Debug, Default)]
103pub struct ModelServiceImpl;
104
105#[tonic::async_trait]
106impl ModelService for ModelServiceImpl {
107    type EnsureModelDownloadedStream = ReceiverStream<Result<ModelStatusUpdate, Status>>;
108
109    async fn ensure_model_downloaded(
110        &self,
111        request: Request<ModelDownloadRequest>,
112    ) -> Result<Response<Self::EnsureModelDownloadedStream>, Status> {
113        let model_request = request.into_inner();
114        info!(
115            "Starting model download stream for: {} from provider: {:?}",
116            model_request.model_name, model_request.provider
117        );
118
119        let (tx, rx) = tokio::sync::mpsc::channel(4);
120        let model_name = model_request.model_name.clone();
121
122        // Convert gRPC provider to our enum
123        let provider: ModelProvider =
124            modelexpress_common::grpc::model::ModelProvider::try_from(model_request.provider)
125                .unwrap_or(modelexpress_common::grpc::model::ModelProvider::HuggingFace)
126                .into();
127        let ignore_weights = model_request.ignore_weights;
128
129        // Spawn a task to handle the streaming download updates
130        tokio::spawn(async move {
131            // Check if the model is already downloaded
132            if let Some(status) = MODEL_TRACKER.get_status(&model_name) {
133                let update = ModelStatusUpdate {
134                    model_name: model_name.clone(),
135                    status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
136                    message: match status {
137                        ModelStatus::DOWNLOADED => Some("Model already downloaded".to_string()),
138                        ModelStatus::DOWNLOADING => Some("Model download in progress".to_string()),
139                        ModelStatus::ERROR => {
140                            Some("Previous download failed - retrying".to_string())
141                        }
142                    },
143                    provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
144                        as i32,
145                };
146
147                if tx.send(Ok(update)).await.is_err() {
148                    return; // Client disconnected
149                }
150
151                // If already downloaded, we're done
152                if status == ModelStatus::DOWNLOADED {
153                    return;
154                }
155            }
156
157            // Start or monitor the download process
158            let final_status = MODEL_TRACKER
159                .ensure_model_downloaded(&model_name, provider, &tx, ignore_weights)
160                .await;
161
162            // Send final status update
163            let final_update = ModelStatusUpdate {
164                model_name: model_name.clone(),
165                status: modelexpress_common::grpc::model::ModelStatus::from(final_status) as i32,
166                message: match final_status {
167                    ModelStatus::DOWNLOADED => {
168                        Some("Model download completed successfully".to_string())
169                    }
170                    ModelStatus::ERROR => Some("Model download failed".to_string()),
171                    ModelStatus::DOWNLOADING => Some("Download still in progress".to_string()),
172                },
173                provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
174            };
175
176            let _ = tx.send(Ok(final_update)).await;
177        });
178
179        Ok(Response::new(ReceiverStream::new(rx)))
180    }
181}
182
183/// Type alias for the complex waiting channels type
184type WaitingChannels =
185    Arc<Mutex<HashMap<String, Vec<tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>>>>>;
186
187/// Tracks the status of model downloads using `SQLite` for persistence
188#[derive(Debug, Clone)]
189pub struct ModelDownloadTracker {
190    /// `SQLite` database for persistent model status tracking
191    database: ModelDatabase,
192    /// Maps model names to list of channels waiting for updates
193    waiting_channels: WaitingChannels,
194}
195
196impl Default for ModelDownloadTracker {
197    fn default() -> Self {
198        Self::new()
199    }
200}
201
202impl ModelDownloadTracker {
203    #[must_use]
204    pub fn new() -> Self {
205        // Initialize database in the current directory
206        let database = match ModelDatabase::new("./models.db") {
207            Ok(db) => db,
208            Err(e) => {
209                error!("Critical error: Could not initialize model database at ./models.db: {e}");
210                panic!("Critical error: Could not initialize model database at ./models.db");
211            }
212        };
213
214        Self {
215            database,
216            waiting_channels: Arc::new(Mutex::new(HashMap::new())),
217        }
218    }
219
220    /// Gets the status of a model from the database
221    /// If the model is not in the database, it returns None
222    pub fn get_status(&self, model_name: &str) -> Option<ModelStatus> {
223        match self.database.get_status(model_name) {
224            Ok(status) => {
225                // Update last_used_at when checking status
226                if status.is_some() {
227                    let _ = self.database.touch_model(model_name);
228                }
229                status
230            }
231            Err(e) => {
232                error!("Failed to get model status from database: {}", e);
233                None
234            }
235        }
236    }
237
238    /// Sets the status of a model and notifies all waiting channels
239    pub fn set_status_and_notify(
240        &self,
241        model_name: String,
242        status: ModelStatus,
243        provider: ModelProvider,
244        message: Option<String>,
245    ) {
246        // Update status in database
247        if let Err(e) = self
248            .database
249            .set_status(&model_name, provider, status, message.clone())
250        {
251            error!("Failed to update model status in database: {}", e);
252            return;
253        }
254
255        // Notify all waiting channels
256        let mut waiting = match self.waiting_channels.lock() {
257            Ok(guard) => guard,
258            Err(poisoned) => {
259                error!("Waiting channels mutex is poisoned, recovering");
260                poisoned.into_inner()
261            }
262        };
263        if let Some(channels) = waiting.get(&model_name) {
264            let update = ModelStatusUpdate {
265                model_name: model_name.clone(),
266                status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
267                message,
268                provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
269            };
270
271            for channel in channels {
272                let _ = channel.try_send(Ok(update.clone()));
273            }
274
275            // If the model is downloaded or errored, remove all waiting channels
276            if status == ModelStatus::DOWNLOADED || status == ModelStatus::ERROR {
277                waiting.remove(&model_name);
278            }
279        }
280    }
281
282    /// Sets the status of a model
283    pub fn set_status(&self, model_name: String, status: ModelStatus, provider: ModelProvider) {
284        self.set_status_and_notify(model_name, status, provider, None);
285    }
286
287    /// Adds a channel to wait for updates on a specific model
288    pub fn add_waiting_channel(
289        &self,
290        model_name: &str,
291        tx: tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
292    ) {
293        let mut waiting = match self.waiting_channels.lock() {
294            Ok(guard) => guard,
295            Err(poisoned) => {
296                error!("Waiting channels mutex is poisoned, recovering");
297                poisoned.into_inner()
298            }
299        };
300        waiting.entry(model_name.to_string()).or_default().push(tx);
301    }
302
303    /// Deletes the status of a model from the database
304    /// This is used when a model is removed from the tracker
305    pub fn delete_status(&self, model_name: &str) {
306        if let Err(e) = self.database.delete_model(model_name) {
307            error!("Failed to delete model from database: {}", e);
308        }
309        let mut waiting = match self.waiting_channels.lock() {
310            Ok(guard) => guard,
311            Err(poisoned) => {
312                error!("Waiting channels mutex is poisoned, recovering");
313                poisoned.into_inner()
314            }
315        };
316        waiting.remove(model_name);
317    }
318
319    /// Initiates a download for a model and streams status updates
320    pub async fn ensure_model_downloaded(
321        &self,
322        model_name: &str,
323        provider: ModelProvider,
324        tx: &tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
325        ignore_weights: bool,
326    ) -> ModelStatus {
327        // Atomically try to claim this model for download using compare-and-swap
328        let status = match self.database.try_claim_for_download(model_name, provider) {
329            Ok(status) => status,
330            Err(e) => {
331                error!("Failed to claim model for download: {}", e);
332                // Send error and return
333                let error_update = ModelStatusUpdate {
334                    model_name: model_name.to_string(),
335                    status: modelexpress_common::grpc::model::ModelStatus::from(ModelStatus::ERROR)
336                        as i32,
337                    message: Some("Database error occurred".to_string()),
338                    provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
339                        as i32,
340                };
341                let _ = tx.send(Ok(error_update)).await;
342                return ModelStatus::ERROR;
343            }
344        };
345
346        // Send current status
347        let update = ModelStatusUpdate {
348            model_name: model_name.to_string(),
349            status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
350            message: match status {
351                ModelStatus::DOWNLOADED => Some("Model already downloaded".to_string()),
352                ModelStatus::DOWNLOADING => Some("Model download in progress".to_string()),
353                ModelStatus::ERROR => Some("Previous download failed - retrying".to_string()),
354            },
355            provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
356        };
357
358        let _ = tx.send(Ok(update)).await;
359
360        // If the model already existed and is downloading, add this channel to wait for updates
361        if status == ModelStatus::DOWNLOADING {
362            self.add_waiting_channel(model_name, tx.clone());
363
364            // Check if we were the ones who just claimed it vs. it was already downloading
365            // If we just claimed it, we need to start the actual download
366            // We can determine this by checking if there are any waiting channels yet
367            let should_start_download = {
368                let waiting = match self.waiting_channels.lock() {
369                    Ok(guard) => guard,
370                    Err(poisoned) => {
371                        error!("Waiting channels mutex is poisoned, recovering");
372                        poisoned.into_inner()
373                    }
374                };
375                waiting
376                    .get(model_name)
377                    .is_none_or(|channels| channels.len() <= 1)
378            };
379
380            if should_start_download {
381                // We claimed the model, so we're responsible for downloading it
382                let tracker = self.clone();
383                let model_name_owned = model_name.to_string();
384
385                // Perform the download in the background
386                tokio::spawn(async move {
387                    let cache_dir = get_server_cache_dir();
388                    match download::download_model(
389                        &model_name_owned,
390                        provider,
391                        cache_dir,
392                        ignore_weights,
393                    )
394                    .await
395                    {
396                        Ok(_path) => {
397                            // Download completed successfully
398                            tracker.set_status_and_notify(
399                                model_name_owned,
400                                ModelStatus::DOWNLOADED,
401                                provider,
402                                Some("Model download completed successfully".to_string()),
403                            );
404                        }
405                        Err(e) => {
406                            // Download failed
407                            error!("Failed to download model {model_name_owned}: {e}");
408                            tracker.set_status_and_notify(
409                                model_name_owned,
410                                ModelStatus::ERROR,
411                                provider,
412                                Some(format!("Download failed: {e}")),
413                            );
414                        }
415                    }
416                });
417            }
418
419            // Wait for completion by monitoring the status
420            loop {
421                tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
422                if let Some(current_status) = self.get_status(model_name) {
423                    if current_status != ModelStatus::DOWNLOADING {
424                        return current_status;
425                    }
426                }
427            }
428        } else if status == ModelStatus::ERROR {
429            // If the model is in ERROR status, try to retry the download
430            // First, reset the status to DOWNLOADING
431            if let Err(e) = self.database.set_status(
432                model_name,
433                provider,
434                ModelStatus::DOWNLOADING,
435                Some("Retrying download...".to_string()),
436            ) {
437                error!("Failed to reset status for retry: {}", e);
438                return ModelStatus::ERROR;
439            }
440
441            // Add this channel to wait for updates
442            self.add_waiting_channel(model_name, tx.clone());
443
444            // Start the download
445            let tracker = self.clone();
446            let model_name_owned = model_name.to_string();
447
448            tokio::spawn(async move {
449                let cache_dir = get_server_cache_dir();
450                match download::download_model(
451                    &model_name_owned,
452                    provider,
453                    cache_dir,
454                    ignore_weights,
455                )
456                .await
457                {
458                    Ok(_path) => {
459                        // Download completed successfully
460                        tracker.set_status_and_notify(
461                            model_name_owned,
462                            ModelStatus::DOWNLOADED,
463                            provider,
464                            Some("Model download completed successfully".to_string()),
465                        );
466                    }
467                    Err(e) => {
468                        // Download failed again
469                        error!("Failed to download model {model_name_owned} on retry: {e}");
470                        tracker.set_status_and_notify(
471                            model_name_owned,
472                            ModelStatus::ERROR,
473                            provider,
474                            Some(format!("Download failed on retry: {e}")),
475                        );
476                    }
477                }
478            });
479
480            // Wait for completion by monitoring the status
481            loop {
482                tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
483                if let Some(current_status) = self.get_status(model_name) {
484                    if current_status != ModelStatus::DOWNLOADING {
485                        return current_status;
486                    }
487                }
488            }
489        }
490
491        status
492    }
493}
494
495/// Global model download tracker
496pub static MODEL_TRACKER: std::sync::LazyLock<ModelDownloadTracker> =
497    std::sync::LazyLock::new(ModelDownloadTracker::new);
498
499#[cfg(test)]
500#[allow(clippy::expect_used)]
501mod tests {
502    use super::*;
503    use modelexpress_common::grpc::{
504        api::ApiRequest, health::HealthRequest, model::ModelDownloadRequest,
505    };
506    use tempfile::TempDir;
507    use tokio_stream::StreamExt;
508    use tonic::Request;
509
510    #[tokio::test]
511    async fn test_health_service() {
512        let service = HealthServiceImpl;
513        let request = Request::new(HealthRequest {});
514
515        let response = service.get_health(request).await;
516        assert!(response.is_ok());
517
518        let health_response = response.expect("Health response should be ok").into_inner();
519        assert_eq!(health_response.version, env!("CARGO_PKG_VERSION"));
520        assert_eq!(health_response.status, "ok");
521        // uptime is u64, always >= 0, so just verify it exists
522        let _uptime = health_response.uptime;
523    }
524
525    #[tokio::test]
526    async fn test_api_service_ping() {
527        let service = ApiServiceImpl;
528        let request = Request::new(ApiRequest {
529            id: "test-id".to_string(),
530            action: "ping".to_string(),
531            payload: None,
532        });
533
534        let response = service.send_request(request).await;
535        assert!(response.is_ok());
536
537        let api_response = response.expect("API response should be ok").into_inner();
538        assert!(api_response.success);
539        assert!(api_response.data.is_some());
540        assert!(api_response.error.is_none());
541
542        // Check that the response data contains "pong"
543        let data_bytes = api_response.data.expect("Data should be present");
544        let data: serde_json::Value =
545            serde_json::from_slice(&data_bytes).expect("Data should be valid JSON");
546        assert_eq!(data["message"], "pong");
547    }
548
549    #[tokio::test]
550    async fn test_api_service_unknown_action() {
551        let service = ApiServiceImpl;
552        let request = Request::new(ApiRequest {
553            id: "test-id".to_string(),
554            action: "unknown-action".to_string(),
555            payload: None,
556        });
557
558        let response = service.send_request(request).await;
559        assert!(response.is_ok());
560
561        let api_response = response.expect("API response should be ok").into_inner();
562        assert!(!api_response.success);
563        assert!(api_response.data.is_none());
564        assert!(api_response.error.is_some());
565
566        let error_message = api_response.error.expect("Error should be present");
567        assert!(error_message.contains("Unknown action"));
568    }
569
570    #[test]
571    fn test_model_download_tracker_new() {
572        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
573        let tracker = ModelDownloadTracker::new();
574
575        // Test that we can get status for a non-existent model
576        let status = tracker.get_status("non-existent-model");
577        assert!(status.is_none());
578    }
579
580    #[test]
581    fn test_model_download_tracker_set_and_get_status() {
582        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
583        let tracker = ModelDownloadTracker::new();
584
585        // Use a unique model name based on current time to avoid conflicts
586        let timestamp = std::time::SystemTime::now()
587            .duration_since(std::time::UNIX_EPOCH)
588            .expect("Time went backwards")
589            .as_nanos();
590        let model_name = format!("test-model-{timestamp}");
591        let provider = ModelProvider::HuggingFace;
592
593        // Initially should return None
594        assert!(tracker.get_status(&model_name).is_none());
595
596        // Set status
597        tracker.set_status(model_name.clone(), ModelStatus::DOWNLOADING, provider);
598
599        // Should now return the status
600        let status = tracker.get_status(&model_name);
601        assert!(status.is_some());
602        assert_eq!(
603            status.expect("Status should be present"),
604            ModelStatus::DOWNLOADING
605        );
606
607        // Cleanup
608        tracker.delete_status(&model_name);
609    }
610
611    #[test]
612    fn test_model_download_tracker_delete_status() {
613        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
614        let tracker = ModelDownloadTracker::new();
615        let timestamp = std::time::SystemTime::now()
616            .duration_since(std::time::UNIX_EPOCH)
617            .expect("Time went backwards")
618            .as_nanos();
619        let model_name = format!("test-delete-model-{timestamp}");
620        let provider = ModelProvider::HuggingFace;
621
622        // Set status
623        tracker.set_status(model_name.clone(), ModelStatus::DOWNLOADED, provider);
624        assert!(tracker.get_status(&model_name).is_some());
625
626        // Delete status
627        tracker.delete_status(&model_name);
628        assert!(tracker.get_status(&model_name).is_none());
629    }
630
631    #[tokio::test]
632    async fn test_model_service_already_downloaded() {
633        let service = ModelServiceImpl;
634        let timestamp = std::time::SystemTime::now()
635            .duration_since(std::time::UNIX_EPOCH)
636            .expect("Time went backwards")
637            .as_nanos();
638        let model_name = format!("test-already-downloaded-model-{timestamp}");
639
640        // Pre-populate the model as downloaded
641        MODEL_TRACKER.set_status(
642            model_name.clone(),
643            ModelStatus::DOWNLOADED,
644            ModelProvider::HuggingFace,
645        );
646
647        let request = Request::new(ModelDownloadRequest {
648            model_name: model_name.clone(),
649            provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
650            ignore_weights: false,
651        });
652
653        let response = service.ensure_model_downloaded(request).await;
654        assert!(response.is_ok());
655
656        let mut stream = response.expect("Response should be ok").into_inner();
657
658        // Should get at least one update indicating it's already downloaded
659        let update = stream.next().await;
660        assert!(update.is_some());
661
662        let update = update.expect("Update should be present");
663        assert!(update.is_ok());
664
665        let status_update = update.expect("Status update should be ok");
666        assert_eq!(status_update.model_name, model_name);
667        assert_eq!(
668            status_update.status,
669            modelexpress_common::grpc::model::ModelStatus::Downloaded as i32
670        );
671
672        // Cleanup
673        MODEL_TRACKER.delete_status(&model_name);
674    }
675
676    #[test]
677    fn test_model_download_tracker_set_status_and_notify() {
678        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
679        let tracker = ModelDownloadTracker::new();
680        let model_name = "test-notify-model".to_string();
681        let provider = ModelProvider::HuggingFace;
682
683        // Test set_status_and_notify doesn't panic
684        tracker.set_status_and_notify(
685            model_name.clone(),
686            ModelStatus::DOWNLOADED,
687            provider,
688            Some("Download completed".to_string()),
689        );
690
691        // Verify status was set
692        let status = tracker.get_status(&model_name);
693        assert!(status.is_some());
694        assert_eq!(
695            status.expect("Status should be present"),
696            ModelStatus::DOWNLOADED
697        );
698    }
699
700    #[test]
701    fn test_waiting_channels_management() {
702        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
703        let tracker = ModelDownloadTracker::new();
704        let model_name = "test-channels-model";
705
706        let (tx, _rx) = tokio::sync::mpsc::channel(4);
707
708        // Add a waiting channel
709        tracker.add_waiting_channel(model_name, tx);
710
711        // Verify the channel was added by checking internal state
712        let waiting_count = {
713            let waiting = match tracker.waiting_channels.lock() {
714                Ok(guard) => guard,
715                Err(poisoned) => poisoned.into_inner(),
716            };
717            waiting.get(model_name).map_or(0, std::vec::Vec::len)
718        };
719        assert_eq!(waiting_count, 1);
720
721        // Clean up by setting final status
722        tracker.set_status_and_notify(
723            model_name.to_string(),
724            ModelStatus::DOWNLOADED,
725            ModelProvider::HuggingFace,
726            None,
727        );
728
729        // Channels should be cleared for final statuses
730        let waiting_count_after = {
731            let waiting = match tracker.waiting_channels.lock() {
732                Ok(guard) => guard,
733                Err(poisoned) => poisoned.into_inner(),
734            };
735            waiting.get(model_name).map_or(0, std::vec::Vec::len)
736        };
737        assert_eq!(waiting_count_after, 0);
738    }
739
740    #[tokio::test]
741    async fn test_model_service_stream_closes_properly() {
742        let service = ModelServiceImpl;
743        let model_name = "test-stream-model";
744
745        let request = Request::new(ModelDownloadRequest {
746            model_name: model_name.to_string(),
747            provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
748            ignore_weights: false,
749        });
750
751        let response = service.ensure_model_downloaded(request).await;
752        assert!(response.is_ok());
753
754        let mut stream = response.expect("Response should be ok").into_inner();
755
756        // Read a few updates (may include initial status and progress)
757        let mut update_count = 0;
758        while let Some(update) = stream.next().await {
759            assert!(update.is_ok());
760            update_count += 1;
761
762            // Prevent infinite loop in case of issues
763            if update_count > 10 {
764                break;
765            }
766        }
767
768        assert!(update_count > 0);
769
770        // Cleanup
771        MODEL_TRACKER.delete_status(model_name);
772    }
773
774    #[tokio::test]
775    async fn test_concurrent_model_download_no_race_condition() {
776        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
777        let tracker = ModelDownloadTracker::new();
778        let model_name = "test-concurrent-model";
779        let provider = ModelProvider::HuggingFace;
780
781        // Test that the compare-and-swap mechanism works
782        // First attempt should claim the model
783        let status1 = tracker
784            .database
785            .try_claim_for_download(model_name, provider)
786            .expect("Failed to claim for download 1");
787        assert_eq!(status1, ModelStatus::DOWNLOADING);
788
789        // Second attempt should see it's already claimed
790        let status2 = tracker
791            .database
792            .try_claim_for_download(model_name, provider)
793            .expect("Failed to claim for download 2");
794        assert_eq!(status2, ModelStatus::DOWNLOADING);
795
796        // Verify only one record exists
797        let record = tracker
798            .database
799            .get_model_record(model_name)
800            .expect("Failed to get model record")
801            .expect("Record should exist");
802        assert_eq!(record.status, ModelStatus::DOWNLOADING);
803
804        // Cleanup
805        tracker.delete_status(model_name);
806    }
807}