modelexpress_server/
services.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::database::ModelDatabase;
5use modelexpress_common::{
6    cache::CacheConfig,
7    download,
8    grpc::{
9        api::{ApiRequest, ApiResponse, api_service_server::ApiService},
10        health::{HealthRequest, HealthResponse, health_service_server::HealthService},
11        model::{ModelDownloadRequest, ModelStatusUpdate, model_service_server::ModelService},
12    },
13    models::{ModelProvider, ModelStatus},
14};
15use std::{
16    collections::HashMap,
17    sync::{Arc, Mutex},
18    time::SystemTime,
19};
20use tokio_stream::wrappers::ReceiverStream;
21use tonic::{Request, Response, Status};
22use tracing::{error, info};
23
24static START_TIME: std::sync::OnceLock<SystemTime> = std::sync::OnceLock::new();
25
26/// Get the configured cache directory for model downloads
27fn get_server_cache_dir() -> Option<std::path::PathBuf> {
28    // Try to get cache configuration
29    if let Ok(config) = CacheConfig::discover() {
30        Some(config.local_path)
31    } else {
32        // Fall back to environment variable
33        std::env::var("HF_HUB_CACHE")
34            .ok()
35            .map(std::path::PathBuf::from)
36    }
37}
38
39/// Health service implementation
40#[derive(Debug, Default)]
41pub struct HealthServiceImpl;
42
43#[tonic::async_trait]
44impl HealthService for HealthServiceImpl {
45    async fn get_health(
46        &self,
47        _request: Request<HealthRequest>,
48    ) -> Result<Response<HealthResponse>, Status> {
49        let start_time = START_TIME.get_or_init(SystemTime::now);
50        let uptime = SystemTime::now()
51            .duration_since(*start_time)
52            .unwrap_or_default()
53            .as_secs();
54
55        let response = HealthResponse {
56            version: env!("CARGO_PKG_VERSION").to_string(),
57            status: "ok".to_string(),
58            uptime,
59        };
60
61        Ok(Response::new(response))
62    }
63}
64
65/// API service implementation
66#[derive(Debug, Default)]
67pub struct ApiServiceImpl;
68
69#[tonic::async_trait]
70impl ApiService for ApiServiceImpl {
71    async fn send_request(
72        &self,
73        request: Request<ApiRequest>,
74    ) -> Result<Response<ApiResponse>, Status> {
75        let api_request = request.into_inner();
76        info!("Received gRPC request: {:?}", api_request);
77
78        // Process the request based on the action
79        if api_request.action.as_str() == "ping" {
80            info!("Processing ping request");
81            let response_data = serde_json::json!({ "message": "pong" });
82            let data_bytes = serde_json::to_vec(&response_data)
83                .map_err(|e| Status::internal(format!("Serialization error: {e}")))?;
84
85            Ok(Response::new(ApiResponse {
86                success: true,
87                data: Some(data_bytes),
88                error: None,
89            }))
90        } else {
91            error!("Unknown action: {}", api_request.action);
92            Ok(Response::new(ApiResponse {
93                success: false,
94                data: None,
95                error: Some(format!("Unknown action: {}", api_request.action)),
96            }))
97        }
98    }
99}
100
101/// Model service implementation
102#[derive(Debug, Default)]
103pub struct ModelServiceImpl;
104
105#[tonic::async_trait]
106impl ModelService for ModelServiceImpl {
107    type EnsureModelDownloadedStream = ReceiverStream<Result<ModelStatusUpdate, Status>>;
108
109    async fn ensure_model_downloaded(
110        &self,
111        request: Request<ModelDownloadRequest>,
112    ) -> Result<Response<Self::EnsureModelDownloadedStream>, Status> {
113        info!("Starting model download stream");
114        let model_request = request.into_inner();
115        let (tx, rx) = tokio::sync::mpsc::channel(4);
116        let model_name = model_request.model_name.clone();
117
118        // Convert gRPC provider to our enum
119        let provider: ModelProvider =
120            modelexpress_common::grpc::model::ModelProvider::try_from(model_request.provider)
121                .unwrap_or(modelexpress_common::grpc::model::ModelProvider::HuggingFace)
122                .into();
123        let ignore_weights = model_request.ignore_weights;
124
125        // Spawn a task to handle the streaming download updates
126        tokio::spawn(async move {
127            // Check if the model is already downloaded
128            if let Some(status) = MODEL_TRACKER.get_status(&model_name) {
129                let update = ModelStatusUpdate {
130                    model_name: model_name.clone(),
131                    status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
132                    message: match status {
133                        ModelStatus::DOWNLOADED => Some("Model already downloaded".to_string()),
134                        ModelStatus::DOWNLOADING => Some("Model download in progress".to_string()),
135                        ModelStatus::ERROR => {
136                            Some("Previous download failed - retrying".to_string())
137                        }
138                    },
139                    provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
140                        as i32,
141                };
142
143                if tx.send(Ok(update)).await.is_err() {
144                    return; // Client disconnected
145                }
146
147                // If already downloaded, we're done
148                if status == ModelStatus::DOWNLOADED {
149                    return;
150                }
151            }
152
153            // Start or monitor the download process
154            let final_status = MODEL_TRACKER
155                .ensure_model_downloaded(&model_name, provider, &tx, ignore_weights)
156                .await;
157
158            // Send final status update
159            let final_update = ModelStatusUpdate {
160                model_name: model_name.clone(),
161                status: modelexpress_common::grpc::model::ModelStatus::from(final_status) as i32,
162                message: match final_status {
163                    ModelStatus::DOWNLOADED => {
164                        Some("Model download completed successfully".to_string())
165                    }
166                    ModelStatus::ERROR => Some("Model download failed".to_string()),
167                    ModelStatus::DOWNLOADING => Some("Download still in progress".to_string()),
168                },
169                provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
170            };
171
172            let _ = tx.send(Ok(final_update)).await;
173        });
174
175        Ok(Response::new(ReceiverStream::new(rx)))
176    }
177}
178
179/// Type alias for the complex waiting channels type
180type WaitingChannels =
181    Arc<Mutex<HashMap<String, Vec<tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>>>>>;
182
183/// Tracks the status of model downloads using `SQLite` for persistence
184#[derive(Debug, Clone)]
185pub struct ModelDownloadTracker {
186    /// `SQLite` database for persistent model status tracking
187    database: ModelDatabase,
188    /// Maps model names to list of channels waiting for updates
189    waiting_channels: WaitingChannels,
190}
191
192impl Default for ModelDownloadTracker {
193    fn default() -> Self {
194        Self::new()
195    }
196}
197
198impl ModelDownloadTracker {
199    #[must_use]
200    pub fn new() -> Self {
201        // Initialize database in the current directory
202        let database = match ModelDatabase::new("./models.db") {
203            Ok(db) => db,
204            Err(e) => {
205                error!("Critical error: Could not initialize model database at ./models.db: {e}");
206                panic!("Critical error: Could not initialize model database at ./models.db");
207            }
208        };
209
210        Self {
211            database,
212            waiting_channels: Arc::new(Mutex::new(HashMap::new())),
213        }
214    }
215
216    /// Gets the status of a model from the database
217    /// If the model is not in the database, it returns None
218    pub fn get_status(&self, model_name: &str) -> Option<ModelStatus> {
219        match self.database.get_status(model_name) {
220            Ok(status) => {
221                // Update last_used_at when checking status
222                if status.is_some() {
223                    let _ = self.database.touch_model(model_name);
224                }
225                status
226            }
227            Err(e) => {
228                error!("Failed to get model status from database: {}", e);
229                None
230            }
231        }
232    }
233
234    /// Sets the status of a model and notifies all waiting channels
235    pub fn set_status_and_notify(
236        &self,
237        model_name: String,
238        status: ModelStatus,
239        provider: ModelProvider,
240        message: Option<String>,
241    ) {
242        // Update status in database
243        if let Err(e) = self
244            .database
245            .set_status(&model_name, provider, status, message.clone())
246        {
247            error!("Failed to update model status in database: {}", e);
248            return;
249        }
250
251        // Notify all waiting channels
252        let mut waiting = match self.waiting_channels.lock() {
253            Ok(guard) => guard,
254            Err(poisoned) => {
255                error!("Waiting channels mutex is poisoned, recovering");
256                poisoned.into_inner()
257            }
258        };
259        if let Some(channels) = waiting.get(&model_name) {
260            let update = ModelStatusUpdate {
261                model_name: model_name.clone(),
262                status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
263                message,
264                provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
265            };
266
267            for channel in channels {
268                let _ = channel.try_send(Ok(update.clone()));
269            }
270
271            // If the model is downloaded or errored, remove all waiting channels
272            if status == ModelStatus::DOWNLOADED || status == ModelStatus::ERROR {
273                waiting.remove(&model_name);
274            }
275        }
276    }
277
278    /// Sets the status of a model
279    pub fn set_status(&self, model_name: String, status: ModelStatus, provider: ModelProvider) {
280        self.set_status_and_notify(model_name, status, provider, None);
281    }
282
283    /// Adds a channel to wait for updates on a specific model
284    pub fn add_waiting_channel(
285        &self,
286        model_name: &str,
287        tx: tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
288    ) {
289        let mut waiting = match self.waiting_channels.lock() {
290            Ok(guard) => guard,
291            Err(poisoned) => {
292                error!("Waiting channels mutex is poisoned, recovering");
293                poisoned.into_inner()
294            }
295        };
296        waiting.entry(model_name.to_string()).or_default().push(tx);
297    }
298
299    /// Deletes the status of a model from the database
300    /// This is used when a model is removed from the tracker
301    pub fn delete_status(&self, model_name: &str) {
302        if let Err(e) = self.database.delete_model(model_name) {
303            error!("Failed to delete model from database: {}", e);
304        }
305        let mut waiting = match self.waiting_channels.lock() {
306            Ok(guard) => guard,
307            Err(poisoned) => {
308                error!("Waiting channels mutex is poisoned, recovering");
309                poisoned.into_inner()
310            }
311        };
312        waiting.remove(model_name);
313    }
314
315    /// Initiates a download for a model and streams status updates
316    pub async fn ensure_model_downloaded(
317        &self,
318        model_name: &str,
319        provider: ModelProvider,
320        tx: &tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
321        ignore_weights: bool,
322    ) -> ModelStatus {
323        // Atomically try to claim this model for download using compare-and-swap
324        let status = match self.database.try_claim_for_download(model_name, provider) {
325            Ok(status) => status,
326            Err(e) => {
327                error!("Failed to claim model for download: {}", e);
328                // Send error and return
329                let error_update = ModelStatusUpdate {
330                    model_name: model_name.to_string(),
331                    status: modelexpress_common::grpc::model::ModelStatus::from(ModelStatus::ERROR)
332                        as i32,
333                    message: Some("Database error occurred".to_string()),
334                    provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
335                        as i32,
336                };
337                let _ = tx.send(Ok(error_update)).await;
338                return ModelStatus::ERROR;
339            }
340        };
341
342        // Send current status
343        let update = ModelStatusUpdate {
344            model_name: model_name.to_string(),
345            status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
346            message: match status {
347                ModelStatus::DOWNLOADED => Some("Model already downloaded".to_string()),
348                ModelStatus::DOWNLOADING => Some("Model download in progress".to_string()),
349                ModelStatus::ERROR => Some("Previous download failed - retrying".to_string()),
350            },
351            provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
352        };
353
354        let _ = tx.send(Ok(update)).await;
355
356        // If the model already existed and is downloading, add this channel to wait for updates
357        if status == ModelStatus::DOWNLOADING {
358            self.add_waiting_channel(model_name, tx.clone());
359
360            // Check if we were the ones who just claimed it vs. it was already downloading
361            // If we just claimed it, we need to start the actual download
362            // We can determine this by checking if there are any waiting channels yet
363            let should_start_download = {
364                let waiting = match self.waiting_channels.lock() {
365                    Ok(guard) => guard,
366                    Err(poisoned) => {
367                        error!("Waiting channels mutex is poisoned, recovering");
368                        poisoned.into_inner()
369                    }
370                };
371                waiting
372                    .get(model_name)
373                    .is_none_or(|channels| channels.len() <= 1)
374            };
375
376            if should_start_download {
377                // We claimed the model, so we're responsible for downloading it
378                let tracker = self.clone();
379                let model_name_owned = model_name.to_string();
380
381                // Perform the download in the background
382                tokio::spawn(async move {
383                    let cache_dir = get_server_cache_dir();
384                    match download::download_model(
385                        &model_name_owned,
386                        provider,
387                        cache_dir,
388                        ignore_weights,
389                    )
390                    .await
391                    {
392                        Ok(_path) => {
393                            // Download completed successfully
394                            tracker.set_status_and_notify(
395                                model_name_owned,
396                                ModelStatus::DOWNLOADED,
397                                provider,
398                                Some("Model download completed successfully".to_string()),
399                            );
400                        }
401                        Err(e) => {
402                            // Download failed
403                            error!("Failed to download model {model_name_owned}: {e}");
404                            tracker.set_status_and_notify(
405                                model_name_owned,
406                                ModelStatus::ERROR,
407                                provider,
408                                Some(format!("Download failed: {e}")),
409                            );
410                        }
411                    }
412                });
413            }
414
415            // Wait for completion by monitoring the status
416            loop {
417                tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
418                if let Some(current_status) = self.get_status(model_name)
419                    && current_status != ModelStatus::DOWNLOADING
420                {
421                    return current_status;
422                }
423            }
424        } else if status == ModelStatus::ERROR {
425            // If the model is in ERROR status, try to retry the download
426            // First, reset the status to DOWNLOADING
427            if let Err(e) = self.database.set_status(
428                model_name,
429                provider,
430                ModelStatus::DOWNLOADING,
431                Some("Retrying download...".to_string()),
432            ) {
433                error!("Failed to reset status for retry: {}", e);
434                return ModelStatus::ERROR;
435            }
436
437            // Add this channel to wait for updates
438            self.add_waiting_channel(model_name, tx.clone());
439
440            // Start the download
441            let tracker = self.clone();
442            let model_name_owned = model_name.to_string();
443
444            tokio::spawn(async move {
445                let cache_dir = get_server_cache_dir();
446                match download::download_model(
447                    &model_name_owned,
448                    provider,
449                    cache_dir,
450                    ignore_weights,
451                )
452                .await
453                {
454                    Ok(_path) => {
455                        // Download completed successfully
456                        tracker.set_status_and_notify(
457                            model_name_owned,
458                            ModelStatus::DOWNLOADED,
459                            provider,
460                            Some("Model download completed successfully".to_string()),
461                        );
462                    }
463                    Err(e) => {
464                        // Download failed again
465                        error!("Failed to download model {model_name_owned} on retry: {e}");
466                        tracker.set_status_and_notify(
467                            model_name_owned,
468                            ModelStatus::ERROR,
469                            provider,
470                            Some(format!("Download failed on retry: {e}")),
471                        );
472                    }
473                }
474            });
475
476            // Wait for completion by monitoring the status
477            loop {
478                tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
479                if let Some(current_status) = self.get_status(model_name)
480                    && current_status != ModelStatus::DOWNLOADING
481                {
482                    return current_status;
483                }
484            }
485        }
486
487        status
488    }
489}
490
491/// Global model download tracker
492pub static MODEL_TRACKER: std::sync::LazyLock<ModelDownloadTracker> =
493    std::sync::LazyLock::new(ModelDownloadTracker::new);
494
495#[cfg(test)]
496#[allow(clippy::expect_used)]
497mod tests {
498    use super::*;
499    use modelexpress_common::grpc::{
500        api::ApiRequest, health::HealthRequest, model::ModelDownloadRequest,
501    };
502    use tempfile::TempDir;
503    use tokio_stream::StreamExt;
504    use tonic::Request;
505
506    #[tokio::test]
507    async fn test_health_service() {
508        let service = HealthServiceImpl;
509        let request = Request::new(HealthRequest {});
510
511        let response = service.get_health(request).await;
512        assert!(response.is_ok());
513
514        let health_response = response.expect("Health response should be ok").into_inner();
515        assert_eq!(health_response.version, env!("CARGO_PKG_VERSION"));
516        assert_eq!(health_response.status, "ok");
517        // uptime is u64, always >= 0, so just verify it exists
518        let _uptime = health_response.uptime;
519    }
520
521    #[tokio::test]
522    async fn test_api_service_ping() {
523        let service = ApiServiceImpl;
524        let request = Request::new(ApiRequest {
525            id: "test-id".to_string(),
526            action: "ping".to_string(),
527            payload: None,
528        });
529
530        let response = service.send_request(request).await;
531        assert!(response.is_ok());
532
533        let api_response = response.expect("API response should be ok").into_inner();
534        assert!(api_response.success);
535        assert!(api_response.data.is_some());
536        assert!(api_response.error.is_none());
537
538        // Check that the response data contains "pong"
539        let data_bytes = api_response.data.expect("Data should be present");
540        let data: serde_json::Value =
541            serde_json::from_slice(&data_bytes).expect("Data should be valid JSON");
542        assert_eq!(data["message"], "pong");
543    }
544
545    #[tokio::test]
546    async fn test_api_service_unknown_action() {
547        let service = ApiServiceImpl;
548        let request = Request::new(ApiRequest {
549            id: "test-id".to_string(),
550            action: "unknown-action".to_string(),
551            payload: None,
552        });
553
554        let response = service.send_request(request).await;
555        assert!(response.is_ok());
556
557        let api_response = response.expect("API response should be ok").into_inner();
558        assert!(!api_response.success);
559        assert!(api_response.data.is_none());
560        assert!(api_response.error.is_some());
561
562        let error_message = api_response.error.expect("Error should be present");
563        assert!(error_message.contains("Unknown action"));
564    }
565
566    #[test]
567    fn test_model_download_tracker_new() {
568        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
569        let tracker = ModelDownloadTracker::new();
570
571        // Test that we can get status for a non-existent model
572        let status = tracker.get_status("non-existent-model");
573        assert!(status.is_none());
574    }
575
576    #[test]
577    fn test_model_download_tracker_set_and_get_status() {
578        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
579        let tracker = ModelDownloadTracker::new();
580
581        // Use a unique model name based on current time to avoid conflicts
582        let timestamp = std::time::SystemTime::now()
583            .duration_since(std::time::UNIX_EPOCH)
584            .expect("Time went backwards")
585            .as_nanos();
586        let model_name = format!("test-model-{timestamp}");
587        let provider = ModelProvider::HuggingFace;
588
589        // Initially should return None
590        assert!(tracker.get_status(&model_name).is_none());
591
592        // Set status
593        tracker.set_status(model_name.clone(), ModelStatus::DOWNLOADING, provider);
594
595        // Should now return the status
596        let status = tracker.get_status(&model_name);
597        assert!(status.is_some());
598        assert_eq!(
599            status.expect("Status should be present"),
600            ModelStatus::DOWNLOADING
601        );
602
603        // Cleanup
604        tracker.delete_status(&model_name);
605    }
606
607    #[test]
608    fn test_model_download_tracker_delete_status() {
609        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
610        let tracker = ModelDownloadTracker::new();
611        let timestamp = std::time::SystemTime::now()
612            .duration_since(std::time::UNIX_EPOCH)
613            .expect("Time went backwards")
614            .as_nanos();
615        let model_name = format!("test-delete-model-{timestamp}");
616        let provider = ModelProvider::HuggingFace;
617
618        // Set status
619        tracker.set_status(model_name.clone(), ModelStatus::DOWNLOADED, provider);
620        assert!(tracker.get_status(&model_name).is_some());
621
622        // Delete status
623        tracker.delete_status(&model_name);
624        assert!(tracker.get_status(&model_name).is_none());
625    }
626
627    #[tokio::test]
628    async fn test_model_service_already_downloaded() {
629        let service = ModelServiceImpl;
630        let timestamp = std::time::SystemTime::now()
631            .duration_since(std::time::UNIX_EPOCH)
632            .expect("Time went backwards")
633            .as_nanos();
634        let model_name = format!("test-already-downloaded-model-{timestamp}");
635
636        // Pre-populate the model as downloaded
637        MODEL_TRACKER.set_status(
638            model_name.clone(),
639            ModelStatus::DOWNLOADED,
640            ModelProvider::HuggingFace,
641        );
642
643        let request = Request::new(ModelDownloadRequest {
644            model_name: model_name.clone(),
645            provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
646            ignore_weights: false,
647        });
648
649        let response = service.ensure_model_downloaded(request).await;
650        assert!(response.is_ok());
651
652        let mut stream = response.expect("Response should be ok").into_inner();
653
654        // Should get at least one update indicating it's already downloaded
655        let update = stream.next().await;
656        assert!(update.is_some());
657
658        let update = update.expect("Update should be present");
659        assert!(update.is_ok());
660
661        let status_update = update.expect("Status update should be ok");
662        assert_eq!(status_update.model_name, model_name);
663        assert_eq!(
664            status_update.status,
665            modelexpress_common::grpc::model::ModelStatus::Downloaded as i32
666        );
667
668        // Cleanup
669        MODEL_TRACKER.delete_status(&model_name);
670    }
671
672    #[test]
673    fn test_model_download_tracker_set_status_and_notify() {
674        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
675        let tracker = ModelDownloadTracker::new();
676        let model_name = "test-notify-model".to_string();
677        let provider = ModelProvider::HuggingFace;
678
679        // Test set_status_and_notify doesn't panic
680        tracker.set_status_and_notify(
681            model_name.clone(),
682            ModelStatus::DOWNLOADED,
683            provider,
684            Some("Download completed".to_string()),
685        );
686
687        // Verify status was set
688        let status = tracker.get_status(&model_name);
689        assert!(status.is_some());
690        assert_eq!(
691            status.expect("Status should be present"),
692            ModelStatus::DOWNLOADED
693        );
694    }
695
696    #[test]
697    fn test_waiting_channels_management() {
698        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
699        let tracker = ModelDownloadTracker::new();
700        let model_name = "test-channels-model";
701
702        let (tx, _rx) = tokio::sync::mpsc::channel(4);
703
704        // Add a waiting channel
705        tracker.add_waiting_channel(model_name, tx);
706
707        // Verify the channel was added by checking internal state
708        let waiting_count = {
709            let waiting = match tracker.waiting_channels.lock() {
710                Ok(guard) => guard,
711                Err(poisoned) => poisoned.into_inner(),
712            };
713            waiting.get(model_name).map_or(0, std::vec::Vec::len)
714        };
715        assert_eq!(waiting_count, 1);
716
717        // Clean up by setting final status
718        tracker.set_status_and_notify(
719            model_name.to_string(),
720            ModelStatus::DOWNLOADED,
721            ModelProvider::HuggingFace,
722            None,
723        );
724
725        // Channels should be cleared for final statuses
726        let waiting_count_after = {
727            let waiting = match tracker.waiting_channels.lock() {
728                Ok(guard) => guard,
729                Err(poisoned) => poisoned.into_inner(),
730            };
731            waiting.get(model_name).map_or(0, std::vec::Vec::len)
732        };
733        assert_eq!(waiting_count_after, 0);
734    }
735
736    #[tokio::test]
737    async fn test_model_service_stream_closes_properly() {
738        let service = ModelServiceImpl;
739        let model_name = "test-stream-model";
740
741        let request = Request::new(ModelDownloadRequest {
742            model_name: model_name.to_string(),
743            provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
744            ignore_weights: false,
745        });
746
747        let response = service.ensure_model_downloaded(request).await;
748        assert!(response.is_ok());
749
750        let mut stream = response.expect("Response should be ok").into_inner();
751
752        // Read a few updates (may include initial status and progress)
753        let mut update_count = 0;
754        while let Some(update) = stream.next().await {
755            assert!(update.is_ok());
756            update_count += 1;
757
758            // Prevent infinite loop in case of issues
759            if update_count > 10 {
760                break;
761            }
762        }
763
764        assert!(update_count > 0);
765
766        // Cleanup
767        MODEL_TRACKER.delete_status(model_name);
768    }
769
770    #[tokio::test]
771    async fn test_concurrent_model_download_no_race_condition() {
772        let _temp_dir = TempDir::new().expect("Failed to create temp dir");
773        let tracker = ModelDownloadTracker::new();
774        let model_name = "test-concurrent-model";
775        let provider = ModelProvider::HuggingFace;
776
777        // Test that the compare-and-swap mechanism works
778        // First attempt should claim the model
779        let status1 = tracker
780            .database
781            .try_claim_for_download(model_name, provider)
782            .expect("Failed to claim for download 1");
783        assert_eq!(status1, ModelStatus::DOWNLOADING);
784
785        // Second attempt should see it's already claimed
786        let status2 = tracker
787            .database
788            .try_claim_for_download(model_name, provider)
789            .expect("Failed to claim for download 2");
790        assert_eq!(status2, ModelStatus::DOWNLOADING);
791
792        // Verify only one record exists
793        let record = tracker
794            .database
795            .get_model_record(model_name)
796            .expect("Failed to get model record")
797            .expect("Record should exist");
798        assert_eq!(record.status, ModelStatus::DOWNLOADING);
799
800        // Cleanup
801        tracker.delete_status(model_name);
802    }
803}