1use crate::database::ModelDatabase;
5use modelexpress_common::{
6 cache::CacheConfig,
7 download,
8 grpc::{
9 api::{ApiRequest, ApiResponse, api_service_server::ApiService},
10 health::{HealthRequest, HealthResponse, health_service_server::HealthService},
11 model::{ModelDownloadRequest, ModelStatusUpdate, model_service_server::ModelService},
12 },
13 models::{ModelProvider, ModelStatus},
14};
15use std::{
16 collections::HashMap,
17 sync::{Arc, Mutex},
18 time::SystemTime,
19};
20use tokio_stream::wrappers::ReceiverStream;
21use tonic::{Request, Response, Status};
22use tracing::{error, info};
23
24static START_TIME: std::sync::OnceLock<SystemTime> = std::sync::OnceLock::new();
25
26fn get_server_cache_dir() -> Option<std::path::PathBuf> {
28 if let Ok(config) = CacheConfig::discover() {
30 Some(config.local_path)
31 } else {
32 std::env::var("HF_HUB_CACHE")
34 .ok()
35 .map(std::path::PathBuf::from)
36 }
37}
38
39#[derive(Debug, Default)]
41pub struct HealthServiceImpl;
42
43#[tonic::async_trait]
44impl HealthService for HealthServiceImpl {
45 async fn get_health(
46 &self,
47 _request: Request<HealthRequest>,
48 ) -> Result<Response<HealthResponse>, Status> {
49 let start_time = START_TIME.get_or_init(SystemTime::now);
50 let uptime = SystemTime::now()
51 .duration_since(*start_time)
52 .unwrap_or_default()
53 .as_secs();
54
55 let response = HealthResponse {
56 version: env!("CARGO_PKG_VERSION").to_string(),
57 status: "ok".to_string(),
58 uptime,
59 };
60
61 Ok(Response::new(response))
62 }
63}
64
65#[derive(Debug, Default)]
67pub struct ApiServiceImpl;
68
69#[tonic::async_trait]
70impl ApiService for ApiServiceImpl {
71 async fn send_request(
72 &self,
73 request: Request<ApiRequest>,
74 ) -> Result<Response<ApiResponse>, Status> {
75 let api_request = request.into_inner();
76 info!("Received gRPC request: {:?}", api_request);
77
78 if api_request.action.as_str() == "ping" {
80 info!("Processing ping request");
81 let response_data = serde_json::json!({ "message": "pong" });
82 let data_bytes = serde_json::to_vec(&response_data)
83 .map_err(|e| Status::internal(format!("Serialization error: {e}")))?;
84
85 Ok(Response::new(ApiResponse {
86 success: true,
87 data: Some(data_bytes),
88 error: None,
89 }))
90 } else {
91 error!("Unknown action: {}", api_request.action);
92 Ok(Response::new(ApiResponse {
93 success: false,
94 data: None,
95 error: Some(format!("Unknown action: {}", api_request.action)),
96 }))
97 }
98 }
99}
100
101#[derive(Debug, Default)]
103pub struct ModelServiceImpl;
104
105#[tonic::async_trait]
106impl ModelService for ModelServiceImpl {
107 type EnsureModelDownloadedStream = ReceiverStream<Result<ModelStatusUpdate, Status>>;
108
109 async fn ensure_model_downloaded(
110 &self,
111 request: Request<ModelDownloadRequest>,
112 ) -> Result<Response<Self::EnsureModelDownloadedStream>, Status> {
113 info!("Starting model download stream");
114 let model_request = request.into_inner();
115 let (tx, rx) = tokio::sync::mpsc::channel(4);
116 let model_name = model_request.model_name.clone();
117
118 let provider: ModelProvider =
120 modelexpress_common::grpc::model::ModelProvider::try_from(model_request.provider)
121 .unwrap_or(modelexpress_common::grpc::model::ModelProvider::HuggingFace)
122 .into();
123 let ignore_weights = model_request.ignore_weights;
124
125 tokio::spawn(async move {
127 if let Some(status) = MODEL_TRACKER.get_status(&model_name) {
129 let update = ModelStatusUpdate {
130 model_name: model_name.clone(),
131 status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
132 message: match status {
133 ModelStatus::DOWNLOADED => Some("Model already downloaded".to_string()),
134 ModelStatus::DOWNLOADING => Some("Model download in progress".to_string()),
135 ModelStatus::ERROR => {
136 Some("Previous download failed - retrying".to_string())
137 }
138 },
139 provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
140 as i32,
141 };
142
143 if tx.send(Ok(update)).await.is_err() {
144 return; }
146
147 if status == ModelStatus::DOWNLOADED {
149 return;
150 }
151 }
152
153 let final_status = MODEL_TRACKER
155 .ensure_model_downloaded(&model_name, provider, &tx, ignore_weights)
156 .await;
157
158 let final_update = ModelStatusUpdate {
160 model_name: model_name.clone(),
161 status: modelexpress_common::grpc::model::ModelStatus::from(final_status) as i32,
162 message: match final_status {
163 ModelStatus::DOWNLOADED => {
164 Some("Model download completed successfully".to_string())
165 }
166 ModelStatus::ERROR => Some("Model download failed".to_string()),
167 ModelStatus::DOWNLOADING => Some("Download still in progress".to_string()),
168 },
169 provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
170 };
171
172 let _ = tx.send(Ok(final_update)).await;
173 });
174
175 Ok(Response::new(ReceiverStream::new(rx)))
176 }
177}
178
179type WaitingChannels =
181 Arc<Mutex<HashMap<String, Vec<tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>>>>>;
182
183#[derive(Debug, Clone)]
185pub struct ModelDownloadTracker {
186 database: ModelDatabase,
188 waiting_channels: WaitingChannels,
190}
191
192impl Default for ModelDownloadTracker {
193 fn default() -> Self {
194 Self::new()
195 }
196}
197
198impl ModelDownloadTracker {
199 #[must_use]
200 pub fn new() -> Self {
201 let database = match ModelDatabase::new("./models.db") {
203 Ok(db) => db,
204 Err(e) => {
205 error!("Critical error: Could not initialize model database at ./models.db: {e}");
206 panic!("Critical error: Could not initialize model database at ./models.db");
207 }
208 };
209
210 Self {
211 database,
212 waiting_channels: Arc::new(Mutex::new(HashMap::new())),
213 }
214 }
215
216 pub fn get_status(&self, model_name: &str) -> Option<ModelStatus> {
219 match self.database.get_status(model_name) {
220 Ok(status) => {
221 if status.is_some() {
223 let _ = self.database.touch_model(model_name);
224 }
225 status
226 }
227 Err(e) => {
228 error!("Failed to get model status from database: {}", e);
229 None
230 }
231 }
232 }
233
234 pub fn set_status_and_notify(
236 &self,
237 model_name: String,
238 status: ModelStatus,
239 provider: ModelProvider,
240 message: Option<String>,
241 ) {
242 if let Err(e) = self
244 .database
245 .set_status(&model_name, provider, status, message.clone())
246 {
247 error!("Failed to update model status in database: {}", e);
248 return;
249 }
250
251 let mut waiting = match self.waiting_channels.lock() {
253 Ok(guard) => guard,
254 Err(poisoned) => {
255 error!("Waiting channels mutex is poisoned, recovering");
256 poisoned.into_inner()
257 }
258 };
259 if let Some(channels) = waiting.get(&model_name) {
260 let update = ModelStatusUpdate {
261 model_name: model_name.clone(),
262 status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
263 message,
264 provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
265 };
266
267 for channel in channels {
268 let _ = channel.try_send(Ok(update.clone()));
269 }
270
271 if status == ModelStatus::DOWNLOADED || status == ModelStatus::ERROR {
273 waiting.remove(&model_name);
274 }
275 }
276 }
277
278 pub fn set_status(&self, model_name: String, status: ModelStatus, provider: ModelProvider) {
280 self.set_status_and_notify(model_name, status, provider, None);
281 }
282
283 pub fn add_waiting_channel(
285 &self,
286 model_name: &str,
287 tx: tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
288 ) {
289 let mut waiting = match self.waiting_channels.lock() {
290 Ok(guard) => guard,
291 Err(poisoned) => {
292 error!("Waiting channels mutex is poisoned, recovering");
293 poisoned.into_inner()
294 }
295 };
296 waiting.entry(model_name.to_string()).or_default().push(tx);
297 }
298
299 pub fn delete_status(&self, model_name: &str) {
302 if let Err(e) = self.database.delete_model(model_name) {
303 error!("Failed to delete model from database: {}", e);
304 }
305 let mut waiting = match self.waiting_channels.lock() {
306 Ok(guard) => guard,
307 Err(poisoned) => {
308 error!("Waiting channels mutex is poisoned, recovering");
309 poisoned.into_inner()
310 }
311 };
312 waiting.remove(model_name);
313 }
314
315 pub async fn ensure_model_downloaded(
317 &self,
318 model_name: &str,
319 provider: ModelProvider,
320 tx: &tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
321 ignore_weights: bool,
322 ) -> ModelStatus {
323 let status = match self.database.try_claim_for_download(model_name, provider) {
325 Ok(status) => status,
326 Err(e) => {
327 error!("Failed to claim model for download: {}", e);
328 let error_update = ModelStatusUpdate {
330 model_name: model_name.to_string(),
331 status: modelexpress_common::grpc::model::ModelStatus::from(ModelStatus::ERROR)
332 as i32,
333 message: Some("Database error occurred".to_string()),
334 provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
335 as i32,
336 };
337 let _ = tx.send(Ok(error_update)).await;
338 return ModelStatus::ERROR;
339 }
340 };
341
342 let update = ModelStatusUpdate {
344 model_name: model_name.to_string(),
345 status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
346 message: match status {
347 ModelStatus::DOWNLOADED => Some("Model already downloaded".to_string()),
348 ModelStatus::DOWNLOADING => Some("Model download in progress".to_string()),
349 ModelStatus::ERROR => Some("Previous download failed - retrying".to_string()),
350 },
351 provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
352 };
353
354 let _ = tx.send(Ok(update)).await;
355
356 if status == ModelStatus::DOWNLOADING {
358 self.add_waiting_channel(model_name, tx.clone());
359
360 let should_start_download = {
364 let waiting = match self.waiting_channels.lock() {
365 Ok(guard) => guard,
366 Err(poisoned) => {
367 error!("Waiting channels mutex is poisoned, recovering");
368 poisoned.into_inner()
369 }
370 };
371 waiting
372 .get(model_name)
373 .is_none_or(|channels| channels.len() <= 1)
374 };
375
376 if should_start_download {
377 let tracker = self.clone();
379 let model_name_owned = model_name.to_string();
380
381 tokio::spawn(async move {
383 let cache_dir = get_server_cache_dir();
384 match download::download_model(
385 &model_name_owned,
386 provider,
387 cache_dir,
388 ignore_weights,
389 )
390 .await
391 {
392 Ok(_path) => {
393 tracker.set_status_and_notify(
395 model_name_owned,
396 ModelStatus::DOWNLOADED,
397 provider,
398 Some("Model download completed successfully".to_string()),
399 );
400 }
401 Err(e) => {
402 error!("Failed to download model {model_name_owned}: {e}");
404 tracker.set_status_and_notify(
405 model_name_owned,
406 ModelStatus::ERROR,
407 provider,
408 Some(format!("Download failed: {e}")),
409 );
410 }
411 }
412 });
413 }
414
415 loop {
417 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
418 if let Some(current_status) = self.get_status(model_name)
419 && current_status != ModelStatus::DOWNLOADING
420 {
421 return current_status;
422 }
423 }
424 } else if status == ModelStatus::ERROR {
425 if let Err(e) = self.database.set_status(
428 model_name,
429 provider,
430 ModelStatus::DOWNLOADING,
431 Some("Retrying download...".to_string()),
432 ) {
433 error!("Failed to reset status for retry: {}", e);
434 return ModelStatus::ERROR;
435 }
436
437 self.add_waiting_channel(model_name, tx.clone());
439
440 let tracker = self.clone();
442 let model_name_owned = model_name.to_string();
443
444 tokio::spawn(async move {
445 let cache_dir = get_server_cache_dir();
446 match download::download_model(
447 &model_name_owned,
448 provider,
449 cache_dir,
450 ignore_weights,
451 )
452 .await
453 {
454 Ok(_path) => {
455 tracker.set_status_and_notify(
457 model_name_owned,
458 ModelStatus::DOWNLOADED,
459 provider,
460 Some("Model download completed successfully".to_string()),
461 );
462 }
463 Err(e) => {
464 error!("Failed to download model {model_name_owned} on retry: {e}");
466 tracker.set_status_and_notify(
467 model_name_owned,
468 ModelStatus::ERROR,
469 provider,
470 Some(format!("Download failed on retry: {e}")),
471 );
472 }
473 }
474 });
475
476 loop {
478 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
479 if let Some(current_status) = self.get_status(model_name)
480 && current_status != ModelStatus::DOWNLOADING
481 {
482 return current_status;
483 }
484 }
485 }
486
487 status
488 }
489}
490
491pub static MODEL_TRACKER: std::sync::LazyLock<ModelDownloadTracker> =
493 std::sync::LazyLock::new(ModelDownloadTracker::new);
494
495#[cfg(test)]
496#[allow(clippy::expect_used)]
497mod tests {
498 use super::*;
499 use modelexpress_common::grpc::{
500 api::ApiRequest, health::HealthRequest, model::ModelDownloadRequest,
501 };
502 use tempfile::TempDir;
503 use tokio_stream::StreamExt;
504 use tonic::Request;
505
506 #[tokio::test]
507 async fn test_health_service() {
508 let service = HealthServiceImpl;
509 let request = Request::new(HealthRequest {});
510
511 let response = service.get_health(request).await;
512 assert!(response.is_ok());
513
514 let health_response = response.expect("Health response should be ok").into_inner();
515 assert_eq!(health_response.version, env!("CARGO_PKG_VERSION"));
516 assert_eq!(health_response.status, "ok");
517 let _uptime = health_response.uptime;
519 }
520
521 #[tokio::test]
522 async fn test_api_service_ping() {
523 let service = ApiServiceImpl;
524 let request = Request::new(ApiRequest {
525 id: "test-id".to_string(),
526 action: "ping".to_string(),
527 payload: None,
528 });
529
530 let response = service.send_request(request).await;
531 assert!(response.is_ok());
532
533 let api_response = response.expect("API response should be ok").into_inner();
534 assert!(api_response.success);
535 assert!(api_response.data.is_some());
536 assert!(api_response.error.is_none());
537
538 let data_bytes = api_response.data.expect("Data should be present");
540 let data: serde_json::Value =
541 serde_json::from_slice(&data_bytes).expect("Data should be valid JSON");
542 assert_eq!(data["message"], "pong");
543 }
544
545 #[tokio::test]
546 async fn test_api_service_unknown_action() {
547 let service = ApiServiceImpl;
548 let request = Request::new(ApiRequest {
549 id: "test-id".to_string(),
550 action: "unknown-action".to_string(),
551 payload: None,
552 });
553
554 let response = service.send_request(request).await;
555 assert!(response.is_ok());
556
557 let api_response = response.expect("API response should be ok").into_inner();
558 assert!(!api_response.success);
559 assert!(api_response.data.is_none());
560 assert!(api_response.error.is_some());
561
562 let error_message = api_response.error.expect("Error should be present");
563 assert!(error_message.contains("Unknown action"));
564 }
565
566 #[test]
567 fn test_model_download_tracker_new() {
568 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
569 let tracker = ModelDownloadTracker::new();
570
571 let status = tracker.get_status("non-existent-model");
573 assert!(status.is_none());
574 }
575
576 #[test]
577 fn test_model_download_tracker_set_and_get_status() {
578 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
579 let tracker = ModelDownloadTracker::new();
580
581 let timestamp = std::time::SystemTime::now()
583 .duration_since(std::time::UNIX_EPOCH)
584 .expect("Time went backwards")
585 .as_nanos();
586 let model_name = format!("test-model-{timestamp}");
587 let provider = ModelProvider::HuggingFace;
588
589 assert!(tracker.get_status(&model_name).is_none());
591
592 tracker.set_status(model_name.clone(), ModelStatus::DOWNLOADING, provider);
594
595 let status = tracker.get_status(&model_name);
597 assert!(status.is_some());
598 assert_eq!(
599 status.expect("Status should be present"),
600 ModelStatus::DOWNLOADING
601 );
602
603 tracker.delete_status(&model_name);
605 }
606
607 #[test]
608 fn test_model_download_tracker_delete_status() {
609 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
610 let tracker = ModelDownloadTracker::new();
611 let timestamp = std::time::SystemTime::now()
612 .duration_since(std::time::UNIX_EPOCH)
613 .expect("Time went backwards")
614 .as_nanos();
615 let model_name = format!("test-delete-model-{timestamp}");
616 let provider = ModelProvider::HuggingFace;
617
618 tracker.set_status(model_name.clone(), ModelStatus::DOWNLOADED, provider);
620 assert!(tracker.get_status(&model_name).is_some());
621
622 tracker.delete_status(&model_name);
624 assert!(tracker.get_status(&model_name).is_none());
625 }
626
627 #[tokio::test]
628 async fn test_model_service_already_downloaded() {
629 let service = ModelServiceImpl;
630 let timestamp = std::time::SystemTime::now()
631 .duration_since(std::time::UNIX_EPOCH)
632 .expect("Time went backwards")
633 .as_nanos();
634 let model_name = format!("test-already-downloaded-model-{timestamp}");
635
636 MODEL_TRACKER.set_status(
638 model_name.clone(),
639 ModelStatus::DOWNLOADED,
640 ModelProvider::HuggingFace,
641 );
642
643 let request = Request::new(ModelDownloadRequest {
644 model_name: model_name.clone(),
645 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
646 ignore_weights: false,
647 });
648
649 let response = service.ensure_model_downloaded(request).await;
650 assert!(response.is_ok());
651
652 let mut stream = response.expect("Response should be ok").into_inner();
653
654 let update = stream.next().await;
656 assert!(update.is_some());
657
658 let update = update.expect("Update should be present");
659 assert!(update.is_ok());
660
661 let status_update = update.expect("Status update should be ok");
662 assert_eq!(status_update.model_name, model_name);
663 assert_eq!(
664 status_update.status,
665 modelexpress_common::grpc::model::ModelStatus::Downloaded as i32
666 );
667
668 MODEL_TRACKER.delete_status(&model_name);
670 }
671
672 #[test]
673 fn test_model_download_tracker_set_status_and_notify() {
674 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
675 let tracker = ModelDownloadTracker::new();
676 let model_name = "test-notify-model".to_string();
677 let provider = ModelProvider::HuggingFace;
678
679 tracker.set_status_and_notify(
681 model_name.clone(),
682 ModelStatus::DOWNLOADED,
683 provider,
684 Some("Download completed".to_string()),
685 );
686
687 let status = tracker.get_status(&model_name);
689 assert!(status.is_some());
690 assert_eq!(
691 status.expect("Status should be present"),
692 ModelStatus::DOWNLOADED
693 );
694 }
695
696 #[test]
697 fn test_waiting_channels_management() {
698 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
699 let tracker = ModelDownloadTracker::new();
700 let model_name = "test-channels-model";
701
702 let (tx, _rx) = tokio::sync::mpsc::channel(4);
703
704 tracker.add_waiting_channel(model_name, tx);
706
707 let waiting_count = {
709 let waiting = match tracker.waiting_channels.lock() {
710 Ok(guard) => guard,
711 Err(poisoned) => poisoned.into_inner(),
712 };
713 waiting.get(model_name).map_or(0, std::vec::Vec::len)
714 };
715 assert_eq!(waiting_count, 1);
716
717 tracker.set_status_and_notify(
719 model_name.to_string(),
720 ModelStatus::DOWNLOADED,
721 ModelProvider::HuggingFace,
722 None,
723 );
724
725 let waiting_count_after = {
727 let waiting = match tracker.waiting_channels.lock() {
728 Ok(guard) => guard,
729 Err(poisoned) => poisoned.into_inner(),
730 };
731 waiting.get(model_name).map_or(0, std::vec::Vec::len)
732 };
733 assert_eq!(waiting_count_after, 0);
734 }
735
736 #[tokio::test]
737 async fn test_model_service_stream_closes_properly() {
738 let service = ModelServiceImpl;
739 let model_name = "test-stream-model";
740
741 let request = Request::new(ModelDownloadRequest {
742 model_name: model_name.to_string(),
743 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
744 ignore_weights: false,
745 });
746
747 let response = service.ensure_model_downloaded(request).await;
748 assert!(response.is_ok());
749
750 let mut stream = response.expect("Response should be ok").into_inner();
751
752 let mut update_count = 0;
754 while let Some(update) = stream.next().await {
755 assert!(update.is_ok());
756 update_count += 1;
757
758 if update_count > 10 {
760 break;
761 }
762 }
763
764 assert!(update_count > 0);
765
766 MODEL_TRACKER.delete_status(model_name);
768 }
769
770 #[tokio::test]
771 async fn test_concurrent_model_download_no_race_condition() {
772 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
773 let tracker = ModelDownloadTracker::new();
774 let model_name = "test-concurrent-model";
775 let provider = ModelProvider::HuggingFace;
776
777 let status1 = tracker
780 .database
781 .try_claim_for_download(model_name, provider)
782 .expect("Failed to claim for download 1");
783 assert_eq!(status1, ModelStatus::DOWNLOADING);
784
785 let status2 = tracker
787 .database
788 .try_claim_for_download(model_name, provider)
789 .expect("Failed to claim for download 2");
790 assert_eq!(status2, ModelStatus::DOWNLOADING);
791
792 let record = tracker
794 .database
795 .get_model_record(model_name)
796 .expect("Failed to get model record")
797 .expect("Record should exist");
798 assert_eq!(record.status, ModelStatus::DOWNLOADING);
799
800 tracker.delete_status(model_name);
802 }
803}