1use crate::database::ModelDatabase;
5use modelexpress_common::{
6 cache::CacheConfig,
7 download,
8 grpc::{
9 api::{ApiRequest, ApiResponse, api_service_server::ApiService},
10 health::{HealthRequest, HealthResponse, health_service_server::HealthService},
11 model::{ModelDownloadRequest, ModelStatusUpdate, model_service_server::ModelService},
12 },
13 models::{ModelProvider, ModelStatus},
14};
15use std::{
16 collections::HashMap,
17 sync::{Arc, Mutex},
18 time::SystemTime,
19};
20use tokio_stream::wrappers::ReceiverStream;
21use tonic::{Request, Response, Status};
22use tracing::{error, info};
23
24static START_TIME: std::sync::OnceLock<SystemTime> = std::sync::OnceLock::new();
25
26fn get_server_cache_dir() -> Option<std::path::PathBuf> {
28 if let Ok(config) = CacheConfig::discover() {
30 Some(config.local_path)
31 } else {
32 std::env::var("HF_HUB_CACHE")
34 .ok()
35 .map(std::path::PathBuf::from)
36 }
37}
38
39#[derive(Debug, Default)]
41pub struct HealthServiceImpl;
42
43#[tonic::async_trait]
44impl HealthService for HealthServiceImpl {
45 async fn get_health(
46 &self,
47 _request: Request<HealthRequest>,
48 ) -> Result<Response<HealthResponse>, Status> {
49 let start_time = START_TIME.get_or_init(SystemTime::now);
50 let uptime = SystemTime::now()
51 .duration_since(*start_time)
52 .unwrap_or_default()
53 .as_secs();
54
55 let response = HealthResponse {
56 version: env!("CARGO_PKG_VERSION").to_string(),
57 status: "ok".to_string(),
58 uptime,
59 };
60
61 Ok(Response::new(response))
62 }
63}
64
65#[derive(Debug, Default)]
67pub struct ApiServiceImpl;
68
69#[tonic::async_trait]
70impl ApiService for ApiServiceImpl {
71 async fn send_request(
72 &self,
73 request: Request<ApiRequest>,
74 ) -> Result<Response<ApiResponse>, Status> {
75 let api_request = request.into_inner();
76 info!("Received gRPC request: {:?}", api_request);
77
78 if api_request.action.as_str() == "ping" {
80 info!("Processing ping request");
81 let response_data = serde_json::json!({ "message": "pong" });
82 let data_bytes = serde_json::to_vec(&response_data)
83 .map_err(|e| Status::internal(format!("Serialization error: {e}")))?;
84
85 Ok(Response::new(ApiResponse {
86 success: true,
87 data: Some(data_bytes),
88 error: None,
89 }))
90 } else {
91 error!("Unknown action: {}", api_request.action);
92 Ok(Response::new(ApiResponse {
93 success: false,
94 data: None,
95 error: Some(format!("Unknown action: {}", api_request.action)),
96 }))
97 }
98 }
99}
100
101#[derive(Debug, Default)]
103pub struct ModelServiceImpl;
104
105#[tonic::async_trait]
106impl ModelService for ModelServiceImpl {
107 type EnsureModelDownloadedStream = ReceiverStream<Result<ModelStatusUpdate, Status>>;
108
109 async fn ensure_model_downloaded(
110 &self,
111 request: Request<ModelDownloadRequest>,
112 ) -> Result<Response<Self::EnsureModelDownloadedStream>, Status> {
113 let model_request = request.into_inner();
114 info!(
115 "Starting model download stream for: {} from provider: {:?}",
116 model_request.model_name, model_request.provider
117 );
118
119 let (tx, rx) = tokio::sync::mpsc::channel(4);
120 let model_name = model_request.model_name.clone();
121
122 let provider: ModelProvider =
124 modelexpress_common::grpc::model::ModelProvider::try_from(model_request.provider)
125 .unwrap_or(modelexpress_common::grpc::model::ModelProvider::HuggingFace)
126 .into();
127 let ignore_weights = model_request.ignore_weights;
128
129 tokio::spawn(async move {
131 if let Some(status) = MODEL_TRACKER.get_status(&model_name) {
133 let update = ModelStatusUpdate {
134 model_name: model_name.clone(),
135 status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
136 message: match status {
137 ModelStatus::DOWNLOADED => Some("Model already downloaded".to_string()),
138 ModelStatus::DOWNLOADING => Some("Model download in progress".to_string()),
139 ModelStatus::ERROR => {
140 Some("Previous download failed - retrying".to_string())
141 }
142 },
143 provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
144 as i32,
145 };
146
147 if tx.send(Ok(update)).await.is_err() {
148 return; }
150
151 if status == ModelStatus::DOWNLOADED {
153 return;
154 }
155 }
156
157 let final_status = MODEL_TRACKER
159 .ensure_model_downloaded(&model_name, provider, &tx, ignore_weights)
160 .await;
161
162 let final_update = ModelStatusUpdate {
164 model_name: model_name.clone(),
165 status: modelexpress_common::grpc::model::ModelStatus::from(final_status) as i32,
166 message: match final_status {
167 ModelStatus::DOWNLOADED => {
168 Some("Model download completed successfully".to_string())
169 }
170 ModelStatus::ERROR => Some("Model download failed".to_string()),
171 ModelStatus::DOWNLOADING => Some("Download still in progress".to_string()),
172 },
173 provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
174 };
175
176 let _ = tx.send(Ok(final_update)).await;
177 });
178
179 Ok(Response::new(ReceiverStream::new(rx)))
180 }
181}
182
183type WaitingChannels =
185 Arc<Mutex<HashMap<String, Vec<tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>>>>>;
186
187#[derive(Debug, Clone)]
189pub struct ModelDownloadTracker {
190 database: ModelDatabase,
192 waiting_channels: WaitingChannels,
194}
195
196impl Default for ModelDownloadTracker {
197 fn default() -> Self {
198 Self::new()
199 }
200}
201
202impl ModelDownloadTracker {
203 #[must_use]
204 pub fn new() -> Self {
205 let database = match ModelDatabase::new("./models.db") {
207 Ok(db) => db,
208 Err(e) => {
209 error!("Critical error: Could not initialize model database at ./models.db: {e}");
210 panic!("Critical error: Could not initialize model database at ./models.db");
211 }
212 };
213
214 Self {
215 database,
216 waiting_channels: Arc::new(Mutex::new(HashMap::new())),
217 }
218 }
219
220 pub fn get_status(&self, model_name: &str) -> Option<ModelStatus> {
223 match self.database.get_status(model_name) {
224 Ok(status) => {
225 if status.is_some() {
227 let _ = self.database.touch_model(model_name);
228 }
229 status
230 }
231 Err(e) => {
232 error!("Failed to get model status from database: {}", e);
233 None
234 }
235 }
236 }
237
238 pub fn set_status_and_notify(
240 &self,
241 model_name: String,
242 status: ModelStatus,
243 provider: ModelProvider,
244 message: Option<String>,
245 ) {
246 if let Err(e) = self
248 .database
249 .set_status(&model_name, provider, status, message.clone())
250 {
251 error!("Failed to update model status in database: {}", e);
252 return;
253 }
254
255 let mut waiting = match self.waiting_channels.lock() {
257 Ok(guard) => guard,
258 Err(poisoned) => {
259 error!("Waiting channels mutex is poisoned, recovering");
260 poisoned.into_inner()
261 }
262 };
263 if let Some(channels) = waiting.get(&model_name) {
264 let update = ModelStatusUpdate {
265 model_name: model_name.clone(),
266 status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
267 message,
268 provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
269 };
270
271 for channel in channels {
272 let _ = channel.try_send(Ok(update.clone()));
273 }
274
275 if status == ModelStatus::DOWNLOADED || status == ModelStatus::ERROR {
277 waiting.remove(&model_name);
278 }
279 }
280 }
281
282 pub fn set_status(&self, model_name: String, status: ModelStatus, provider: ModelProvider) {
284 self.set_status_and_notify(model_name, status, provider, None);
285 }
286
287 pub fn add_waiting_channel(
289 &self,
290 model_name: &str,
291 tx: tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
292 ) {
293 let mut waiting = match self.waiting_channels.lock() {
294 Ok(guard) => guard,
295 Err(poisoned) => {
296 error!("Waiting channels mutex is poisoned, recovering");
297 poisoned.into_inner()
298 }
299 };
300 waiting.entry(model_name.to_string()).or_default().push(tx);
301 }
302
303 pub fn delete_status(&self, model_name: &str) {
306 if let Err(e) = self.database.delete_model(model_name) {
307 error!("Failed to delete model from database: {}", e);
308 }
309 let mut waiting = match self.waiting_channels.lock() {
310 Ok(guard) => guard,
311 Err(poisoned) => {
312 error!("Waiting channels mutex is poisoned, recovering");
313 poisoned.into_inner()
314 }
315 };
316 waiting.remove(model_name);
317 }
318
319 pub async fn ensure_model_downloaded(
321 &self,
322 model_name: &str,
323 provider: ModelProvider,
324 tx: &tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
325 ignore_weights: bool,
326 ) -> ModelStatus {
327 let status = match self.database.try_claim_for_download(model_name, provider) {
329 Ok(status) => status,
330 Err(e) => {
331 error!("Failed to claim model for download: {}", e);
332 let error_update = ModelStatusUpdate {
334 model_name: model_name.to_string(),
335 status: modelexpress_common::grpc::model::ModelStatus::from(ModelStatus::ERROR)
336 as i32,
337 message: Some("Database error occurred".to_string()),
338 provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
339 as i32,
340 };
341 let _ = tx.send(Ok(error_update)).await;
342 return ModelStatus::ERROR;
343 }
344 };
345
346 let update = ModelStatusUpdate {
348 model_name: model_name.to_string(),
349 status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
350 message: match status {
351 ModelStatus::DOWNLOADED => Some("Model already downloaded".to_string()),
352 ModelStatus::DOWNLOADING => Some("Model download in progress".to_string()),
353 ModelStatus::ERROR => Some("Previous download failed - retrying".to_string()),
354 },
355 provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
356 };
357
358 let _ = tx.send(Ok(update)).await;
359
360 if status == ModelStatus::DOWNLOADING {
362 self.add_waiting_channel(model_name, tx.clone());
363
364 let should_start_download = {
368 let waiting = match self.waiting_channels.lock() {
369 Ok(guard) => guard,
370 Err(poisoned) => {
371 error!("Waiting channels mutex is poisoned, recovering");
372 poisoned.into_inner()
373 }
374 };
375 waiting
376 .get(model_name)
377 .is_none_or(|channels| channels.len() <= 1)
378 };
379
380 if should_start_download {
381 let tracker = self.clone();
383 let model_name_owned = model_name.to_string();
384
385 tokio::spawn(async move {
387 let cache_dir = get_server_cache_dir();
388 match download::download_model(
389 &model_name_owned,
390 provider,
391 cache_dir,
392 ignore_weights,
393 )
394 .await
395 {
396 Ok(_path) => {
397 tracker.set_status_and_notify(
399 model_name_owned,
400 ModelStatus::DOWNLOADED,
401 provider,
402 Some("Model download completed successfully".to_string()),
403 );
404 }
405 Err(e) => {
406 error!("Failed to download model {model_name_owned}: {e}");
408 tracker.set_status_and_notify(
409 model_name_owned,
410 ModelStatus::ERROR,
411 provider,
412 Some(format!("Download failed: {e}")),
413 );
414 }
415 }
416 });
417 }
418
419 loop {
421 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
422 if let Some(current_status) = self.get_status(model_name) {
423 if current_status != ModelStatus::DOWNLOADING {
424 return current_status;
425 }
426 }
427 }
428 } else if status == ModelStatus::ERROR {
429 if let Err(e) = self.database.set_status(
432 model_name,
433 provider,
434 ModelStatus::DOWNLOADING,
435 Some("Retrying download...".to_string()),
436 ) {
437 error!("Failed to reset status for retry: {}", e);
438 return ModelStatus::ERROR;
439 }
440
441 self.add_waiting_channel(model_name, tx.clone());
443
444 let tracker = self.clone();
446 let model_name_owned = model_name.to_string();
447
448 tokio::spawn(async move {
449 let cache_dir = get_server_cache_dir();
450 match download::download_model(
451 &model_name_owned,
452 provider,
453 cache_dir,
454 ignore_weights,
455 )
456 .await
457 {
458 Ok(_path) => {
459 tracker.set_status_and_notify(
461 model_name_owned,
462 ModelStatus::DOWNLOADED,
463 provider,
464 Some("Model download completed successfully".to_string()),
465 );
466 }
467 Err(e) => {
468 error!("Failed to download model {model_name_owned} on retry: {e}");
470 tracker.set_status_and_notify(
471 model_name_owned,
472 ModelStatus::ERROR,
473 provider,
474 Some(format!("Download failed on retry: {e}")),
475 );
476 }
477 }
478 });
479
480 loop {
482 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
483 if let Some(current_status) = self.get_status(model_name) {
484 if current_status != ModelStatus::DOWNLOADING {
485 return current_status;
486 }
487 }
488 }
489 }
490
491 status
492 }
493}
494
495pub static MODEL_TRACKER: std::sync::LazyLock<ModelDownloadTracker> =
497 std::sync::LazyLock::new(ModelDownloadTracker::new);
498
499#[cfg(test)]
500#[allow(clippy::expect_used)]
501mod tests {
502 use super::*;
503 use modelexpress_common::grpc::{
504 api::ApiRequest, health::HealthRequest, model::ModelDownloadRequest,
505 };
506 use tempfile::TempDir;
507 use tokio_stream::StreamExt;
508 use tonic::Request;
509
510 #[tokio::test]
511 async fn test_health_service() {
512 let service = HealthServiceImpl;
513 let request = Request::new(HealthRequest {});
514
515 let response = service.get_health(request).await;
516 assert!(response.is_ok());
517
518 let health_response = response.expect("Health response should be ok").into_inner();
519 assert_eq!(health_response.version, env!("CARGO_PKG_VERSION"));
520 assert_eq!(health_response.status, "ok");
521 let _uptime = health_response.uptime;
523 }
524
525 #[tokio::test]
526 async fn test_api_service_ping() {
527 let service = ApiServiceImpl;
528 let request = Request::new(ApiRequest {
529 id: "test-id".to_string(),
530 action: "ping".to_string(),
531 payload: None,
532 });
533
534 let response = service.send_request(request).await;
535 assert!(response.is_ok());
536
537 let api_response = response.expect("API response should be ok").into_inner();
538 assert!(api_response.success);
539 assert!(api_response.data.is_some());
540 assert!(api_response.error.is_none());
541
542 let data_bytes = api_response.data.expect("Data should be present");
544 let data: serde_json::Value =
545 serde_json::from_slice(&data_bytes).expect("Data should be valid JSON");
546 assert_eq!(data["message"], "pong");
547 }
548
549 #[tokio::test]
550 async fn test_api_service_unknown_action() {
551 let service = ApiServiceImpl;
552 let request = Request::new(ApiRequest {
553 id: "test-id".to_string(),
554 action: "unknown-action".to_string(),
555 payload: None,
556 });
557
558 let response = service.send_request(request).await;
559 assert!(response.is_ok());
560
561 let api_response = response.expect("API response should be ok").into_inner();
562 assert!(!api_response.success);
563 assert!(api_response.data.is_none());
564 assert!(api_response.error.is_some());
565
566 let error_message = api_response.error.expect("Error should be present");
567 assert!(error_message.contains("Unknown action"));
568 }
569
570 #[test]
571 fn test_model_download_tracker_new() {
572 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
573 let tracker = ModelDownloadTracker::new();
574
575 let status = tracker.get_status("non-existent-model");
577 assert!(status.is_none());
578 }
579
580 #[test]
581 fn test_model_download_tracker_set_and_get_status() {
582 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
583 let tracker = ModelDownloadTracker::new();
584
585 let timestamp = std::time::SystemTime::now()
587 .duration_since(std::time::UNIX_EPOCH)
588 .expect("Time went backwards")
589 .as_nanos();
590 let model_name = format!("test-model-{timestamp}");
591 let provider = ModelProvider::HuggingFace;
592
593 assert!(tracker.get_status(&model_name).is_none());
595
596 tracker.set_status(model_name.clone(), ModelStatus::DOWNLOADING, provider);
598
599 let status = tracker.get_status(&model_name);
601 assert!(status.is_some());
602 assert_eq!(
603 status.expect("Status should be present"),
604 ModelStatus::DOWNLOADING
605 );
606
607 tracker.delete_status(&model_name);
609 }
610
611 #[test]
612 fn test_model_download_tracker_delete_status() {
613 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
614 let tracker = ModelDownloadTracker::new();
615 let timestamp = std::time::SystemTime::now()
616 .duration_since(std::time::UNIX_EPOCH)
617 .expect("Time went backwards")
618 .as_nanos();
619 let model_name = format!("test-delete-model-{timestamp}");
620 let provider = ModelProvider::HuggingFace;
621
622 tracker.set_status(model_name.clone(), ModelStatus::DOWNLOADED, provider);
624 assert!(tracker.get_status(&model_name).is_some());
625
626 tracker.delete_status(&model_name);
628 assert!(tracker.get_status(&model_name).is_none());
629 }
630
631 #[tokio::test]
632 async fn test_model_service_already_downloaded() {
633 let service = ModelServiceImpl;
634 let timestamp = std::time::SystemTime::now()
635 .duration_since(std::time::UNIX_EPOCH)
636 .expect("Time went backwards")
637 .as_nanos();
638 let model_name = format!("test-already-downloaded-model-{timestamp}");
639
640 MODEL_TRACKER.set_status(
642 model_name.clone(),
643 ModelStatus::DOWNLOADED,
644 ModelProvider::HuggingFace,
645 );
646
647 let request = Request::new(ModelDownloadRequest {
648 model_name: model_name.clone(),
649 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
650 ignore_weights: false,
651 });
652
653 let response = service.ensure_model_downloaded(request).await;
654 assert!(response.is_ok());
655
656 let mut stream = response.expect("Response should be ok").into_inner();
657
658 let update = stream.next().await;
660 assert!(update.is_some());
661
662 let update = update.expect("Update should be present");
663 assert!(update.is_ok());
664
665 let status_update = update.expect("Status update should be ok");
666 assert_eq!(status_update.model_name, model_name);
667 assert_eq!(
668 status_update.status,
669 modelexpress_common::grpc::model::ModelStatus::Downloaded as i32
670 );
671
672 MODEL_TRACKER.delete_status(&model_name);
674 }
675
676 #[test]
677 fn test_model_download_tracker_set_status_and_notify() {
678 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
679 let tracker = ModelDownloadTracker::new();
680 let model_name = "test-notify-model".to_string();
681 let provider = ModelProvider::HuggingFace;
682
683 tracker.set_status_and_notify(
685 model_name.clone(),
686 ModelStatus::DOWNLOADED,
687 provider,
688 Some("Download completed".to_string()),
689 );
690
691 let status = tracker.get_status(&model_name);
693 assert!(status.is_some());
694 assert_eq!(
695 status.expect("Status should be present"),
696 ModelStatus::DOWNLOADED
697 );
698 }
699
700 #[test]
701 fn test_waiting_channels_management() {
702 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
703 let tracker = ModelDownloadTracker::new();
704 let model_name = "test-channels-model";
705
706 let (tx, _rx) = tokio::sync::mpsc::channel(4);
707
708 tracker.add_waiting_channel(model_name, tx);
710
711 let waiting_count = {
713 let waiting = match tracker.waiting_channels.lock() {
714 Ok(guard) => guard,
715 Err(poisoned) => poisoned.into_inner(),
716 };
717 waiting.get(model_name).map_or(0, std::vec::Vec::len)
718 };
719 assert_eq!(waiting_count, 1);
720
721 tracker.set_status_and_notify(
723 model_name.to_string(),
724 ModelStatus::DOWNLOADED,
725 ModelProvider::HuggingFace,
726 None,
727 );
728
729 let waiting_count_after = {
731 let waiting = match tracker.waiting_channels.lock() {
732 Ok(guard) => guard,
733 Err(poisoned) => poisoned.into_inner(),
734 };
735 waiting.get(model_name).map_or(0, std::vec::Vec::len)
736 };
737 assert_eq!(waiting_count_after, 0);
738 }
739
740 #[tokio::test]
741 async fn test_model_service_stream_closes_properly() {
742 let service = ModelServiceImpl;
743 let model_name = "test-stream-model";
744
745 let request = Request::new(ModelDownloadRequest {
746 model_name: model_name.to_string(),
747 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
748 ignore_weights: false,
749 });
750
751 let response = service.ensure_model_downloaded(request).await;
752 assert!(response.is_ok());
753
754 let mut stream = response.expect("Response should be ok").into_inner();
755
756 let mut update_count = 0;
758 while let Some(update) = stream.next().await {
759 assert!(update.is_ok());
760 update_count += 1;
761
762 if update_count > 10 {
764 break;
765 }
766 }
767
768 assert!(update_count > 0);
769
770 MODEL_TRACKER.delete_status(model_name);
772 }
773
774 #[tokio::test]
775 async fn test_concurrent_model_download_no_race_condition() {
776 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
777 let tracker = ModelDownloadTracker::new();
778 let model_name = "test-concurrent-model";
779 let provider = ModelProvider::HuggingFace;
780
781 let status1 = tracker
784 .database
785 .try_claim_for_download(model_name, provider)
786 .expect("Failed to claim for download 1");
787 assert_eq!(status1, ModelStatus::DOWNLOADING);
788
789 let status2 = tracker
791 .database
792 .try_claim_for_download(model_name, provider)
793 .expect("Failed to claim for download 2");
794 assert_eq!(status2, ModelStatus::DOWNLOADING);
795
796 let record = tracker
798 .database
799 .get_model_record(model_name)
800 .expect("Failed to get model record")
801 .expect("Record should exist");
802 assert_eq!(record.status, ModelStatus::DOWNLOADING);
803
804 tracker.delete_status(model_name);
806 }
807}