1use crate::database::ModelDatabase;
5use modelexpress_common::{
6 cache::CacheConfig,
7 constants, download,
8 grpc::{
9 api::{ApiRequest, ApiResponse, api_service_server::ApiService},
10 health::{HealthRequest, HealthResponse, health_service_server::HealthService},
11 model::{
12 FileChunk, ModelDownloadRequest, ModelFileInfo, ModelFileList, ModelFilesRequest,
13 ModelStatusUpdate, model_service_server::ModelService,
14 },
15 },
16 models::{ModelProvider, ModelStatus},
17};
18use std::{
19 collections::HashMap,
20 path::{Path, PathBuf},
21 sync::{Arc, Mutex},
22 time::SystemTime,
23};
24use tokio::io::AsyncReadExt;
25use tokio_stream::wrappers::ReceiverStream;
26use tonic::{Request, Response, Status};
27use tracing::{debug, error, info};
28
29static START_TIME: std::sync::OnceLock<SystemTime> = std::sync::OnceLock::new();
30
31fn get_server_cache_dir() -> Option<std::path::PathBuf> {
33 if let Ok(config) = CacheConfig::discover() {
35 Some(config.local_path)
36 } else {
37 std::env::var("HF_HUB_CACHE")
39 .ok()
40 .map(std::path::PathBuf::from)
41 }
42}
43
44fn convert_provider(grpc_provider: i32) -> ModelProvider {
49 match modelexpress_common::grpc::model::ModelProvider::try_from(grpc_provider) {
50 Ok(provider) => provider.into(),
51 Err(_) => {
52 tracing::warn!(
53 "Invalid provider value {}, falling back to HuggingFace",
54 grpc_provider
55 );
56 ModelProvider::HuggingFace
57 }
58 }
59}
60
61#[derive(Debug, Default)]
63pub struct HealthServiceImpl;
64
65#[tonic::async_trait]
66impl HealthService for HealthServiceImpl {
67 async fn get_health(
68 &self,
69 _request: Request<HealthRequest>,
70 ) -> Result<Response<HealthResponse>, Status> {
71 let start_time = START_TIME.get_or_init(SystemTime::now);
72 let uptime = SystemTime::now()
73 .duration_since(*start_time)
74 .unwrap_or_default()
75 .as_secs();
76
77 let response = HealthResponse {
78 version: env!("CARGO_PKG_VERSION").to_string(),
79 status: "ok".to_string(),
80 uptime,
81 };
82
83 Ok(Response::new(response))
84 }
85}
86
87#[derive(Debug, Default)]
89pub struct ApiServiceImpl;
90
91#[tonic::async_trait]
92impl ApiService for ApiServiceImpl {
93 async fn send_request(
94 &self,
95 request: Request<ApiRequest>,
96 ) -> Result<Response<ApiResponse>, Status> {
97 let api_request = request.into_inner();
98 info!("Received gRPC request: {:?}", api_request);
99
100 if api_request.action.as_str() == "ping" {
102 info!("Processing ping request");
103 let response_data = serde_json::json!({ "message": "pong" });
104 let data_bytes = serde_json::to_vec(&response_data)
105 .map_err(|e| Status::internal(format!("Serialization error: {e}")))?;
106
107 Ok(Response::new(ApiResponse {
108 success: true,
109 data: Some(data_bytes),
110 error: None,
111 }))
112 } else {
113 error!("Unknown action: {}", api_request.action);
114 Ok(Response::new(ApiResponse {
115 success: false,
116 data: None,
117 error: Some(format!("Unknown action: {}", api_request.action)),
118 }))
119 }
120 }
121}
122
123#[derive(Debug, Default)]
125pub struct ModelServiceImpl;
126
127fn collect_model_files(base_path: &Path, current_path: &Path) -> Vec<(PathBuf, u64)> {
129 let mut files = Vec::new();
130
131 if let Ok(entries) = std::fs::read_dir(current_path) {
132 for entry in entries.flatten() {
133 let path = entry.path();
134 if path.is_file() {
135 if let Ok(metadata) = std::fs::metadata(&path) {
136 if let Ok(relative) = path.strip_prefix(base_path) {
138 let mut is_safe = true;
140 for comp in relative.components() {
141 use std::path::Component;
142 match comp {
143 Component::ParentDir
144 | Component::RootDir
145 | Component::Prefix(_) => {
146 is_safe = false;
147 break;
148 }
149 _ => {}
150 }
151 }
152 if is_safe {
153 files.push((relative.to_path_buf(), metadata.len()));
154 } else {
155 tracing::warn!(
156 "Skipping potentially unsafe file path: {:?} (relative: {:?})",
157 path,
158 relative
159 );
160 }
161 }
162 }
163 } else if path.is_dir() {
164 files.extend(collect_model_files(base_path, &path));
165 }
166 }
167 }
168
169 files
170}
171
172#[tonic::async_trait]
173impl ModelService for ModelServiceImpl {
174 type EnsureModelDownloadedStream = ReceiverStream<Result<ModelStatusUpdate, Status>>;
175 type StreamModelFilesStream = ReceiverStream<Result<FileChunk, Status>>;
176
177 async fn ensure_model_downloaded(
178 &self,
179 request: Request<ModelDownloadRequest>,
180 ) -> Result<Response<Self::EnsureModelDownloadedStream>, Status> {
181 info!("Starting model download stream");
182 let model_request = request.into_inner();
183 let (tx, rx) = tokio::sync::mpsc::channel(4);
184 let model_name = model_request.model_name.clone();
185
186 let provider: ModelProvider =
188 modelexpress_common::grpc::model::ModelProvider::try_from(model_request.provider)
189 .unwrap_or(modelexpress_common::grpc::model::ModelProvider::HuggingFace)
190 .into();
191 let ignore_weights = model_request.ignore_weights;
192
193 tokio::spawn(async move {
195 if let Some(status) = MODEL_TRACKER.get_status(&model_name) {
197 let update = ModelStatusUpdate {
198 model_name: model_name.clone(),
199 status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
200 message: match status {
201 ModelStatus::DOWNLOADED => Some("Model already downloaded".to_string()),
202 ModelStatus::DOWNLOADING => Some("Model download in progress".to_string()),
203 ModelStatus::ERROR => {
204 Some("Previous download failed - retrying".to_string())
205 }
206 },
207 provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
208 as i32,
209 };
210
211 if tx.send(Ok(update)).await.is_err() {
212 return; }
214
215 if status == ModelStatus::DOWNLOADED {
217 return;
218 }
219 }
220
221 let final_status = MODEL_TRACKER
223 .ensure_model_downloaded(&model_name, provider, &tx, ignore_weights)
224 .await;
225
226 let final_update = ModelStatusUpdate {
228 model_name: model_name.clone(),
229 status: modelexpress_common::grpc::model::ModelStatus::from(final_status) as i32,
230 message: match final_status {
231 ModelStatus::DOWNLOADED => {
232 Some("Model download completed successfully".to_string())
233 }
234 ModelStatus::ERROR => Some("Model download failed".to_string()),
235 ModelStatus::DOWNLOADING => Some("Download still in progress".to_string()),
236 },
237 provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
238 };
239
240 let _ = tx.send(Ok(final_update)).await;
241 });
242
243 Ok(Response::new(ReceiverStream::new(rx)))
244 }
245
246 async fn stream_model_files(
247 &self,
248 request: Request<ModelFilesRequest>,
249 ) -> Result<Response<Self::StreamModelFilesStream>, Status> {
250 let files_request = request.into_inner();
251 let model_name = files_request.model_name.clone();
252 let chunk_size = if files_request.chunk_size == 0 {
253 constants::DEFAULT_TRANSFER_CHUNK_SIZE
254 } else {
255 files_request.chunk_size as usize
256 };
257
258 let provider = convert_provider(files_request.provider);
260
261 info!(
262 "Starting file stream for model: {} with chunk size: {} bytes",
263 model_name, chunk_size
264 );
265
266 let cache_dir = get_server_cache_dir()
268 .ok_or_else(|| Status::internal("Server cache directory not configured"))?;
269
270 let provider_impl = download::get_provider(provider);
272 let model_path = provider_impl
273 .get_model_path(&model_name, cache_dir)
274 .await
275 .map_err(|e| Status::not_found(format!("Model not found: {e}")))?;
276
277 debug!("Model path resolved to: {:?}", model_path);
278
279 let files = collect_model_files(&model_path, &model_path);
281
282 if files.is_empty() {
283 return Err(Status::not_found("No files found in model directory"));
284 }
285
286 let total_files = files.len();
287 info!(
288 "Found {} files to stream for model {}",
289 total_files, model_name
290 );
291
292 let commit_hash = model_path
294 .file_name()
295 .and_then(|name| name.to_str())
296 .map(String::from);
297
298 let (tx, rx) = tokio::sync::mpsc::channel(16);
299
300 tokio::spawn(async move {
302 let mut buffer = vec![0u8; chunk_size];
304 let mut is_first_chunk = true;
305
306 for (file_idx, (relative_path, total_size)) in files.iter().enumerate() {
307 let file_path = model_path.join(relative_path);
308 let is_last_file = file_idx == total_files.saturating_sub(1);
309
310 debug!("Streaming file: {:?} ({} bytes)", relative_path, total_size);
311
312 let file = match tokio::fs::File::open(&file_path).await {
314 Ok(f) => f,
315 Err(e) => {
316 error!("Failed to open file {:?}: {}", file_path, e);
317 let _ = tx
318 .send(Err(Status::internal(format!("Failed to open file: {e}"))))
319 .await;
320 return;
321 }
322 };
323
324 let mut reader = tokio::io::BufReader::new(file);
325 let mut offset: u64 = 0;
326
327 loop {
328 let bytes_read = match reader.read(&mut buffer).await {
329 Ok(0) => break, Ok(n) => n,
331 Err(e) => {
332 error!("Failed to read file {:?}: {}", file_path, e);
333 let _ = tx
334 .send(Err(Status::internal(format!("Failed to read file: {e}"))))
335 .await;
336 return;
337 }
338 };
339
340 let is_last_chunk = offset.saturating_add(bytes_read as u64) >= *total_size;
341
342 let chunk = FileChunk {
343 relative_path: relative_path.to_string_lossy().to_string(),
344 data: buffer[..bytes_read].to_vec(),
345 offset,
346 total_size: *total_size,
347 is_last_chunk,
348 is_last_file: is_last_file && is_last_chunk,
349 commit_hash: if is_first_chunk {
350 is_first_chunk = false;
351 commit_hash.clone()
352 } else {
353 None
354 },
355 };
356
357 if tx.send(Ok(chunk)).await.is_err() {
358 debug!("Client disconnected during file stream");
359 return;
360 }
361
362 offset = offset.saturating_add(bytes_read as u64);
363 }
364 }
365
366 info!("File streaming completed for model");
367 });
368
369 Ok(Response::new(ReceiverStream::new(rx)))
370 }
371
372 async fn list_model_files(
373 &self,
374 request: Request<ModelFilesRequest>,
375 ) -> Result<Response<ModelFileList>, Status> {
376 let files_request = request.into_inner();
377 let model_name = files_request.model_name.clone();
378
379 let provider = convert_provider(files_request.provider);
381
382 info!("Listing files for model: {}", model_name);
383
384 let cache_dir = get_server_cache_dir()
386 .ok_or_else(|| Status::internal("Server cache directory not configured"))?;
387
388 let provider_impl = download::get_provider(provider);
390 let model_path = provider_impl
391 .get_model_path(&model_name, cache_dir)
392 .await
393 .map_err(|e| Status::not_found(format!("Model not found: {e}")))?;
394
395 let files = collect_model_files(&model_path, &model_path);
397
398 let file_infos: Vec<ModelFileInfo> = files
399 .iter()
400 .map(|(path, size)| ModelFileInfo {
401 relative_path: path.to_string_lossy().to_string(),
402 size: *size,
403 })
404 .collect();
405
406 let total_size: u64 = files.iter().map(|(_, size)| size).sum();
407
408 Ok(Response::new(ModelFileList {
409 model_name,
410 files: file_infos,
411 total_size,
412 }))
413 }
414}
415
416type WaitingChannels =
418 Arc<Mutex<HashMap<String, Vec<tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>>>>>;
419
420#[derive(Debug, Clone)]
422pub struct ModelDownloadTracker {
423 database: ModelDatabase,
425 waiting_channels: WaitingChannels,
427}
428
429impl Default for ModelDownloadTracker {
430 fn default() -> Self {
431 Self::new()
432 }
433}
434
435impl ModelDownloadTracker {
436 #[must_use]
437 pub fn new() -> Self {
438 let database = match ModelDatabase::new("./models.db") {
440 Ok(db) => db,
441 Err(e) => {
442 error!("Critical error: Could not initialize model database at ./models.db: {e}");
443 panic!("Critical error: Could not initialize model database at ./models.db");
444 }
445 };
446
447 Self {
448 database,
449 waiting_channels: Arc::new(Mutex::new(HashMap::new())),
450 }
451 }
452
453 pub fn get_status(&self, model_name: &str) -> Option<ModelStatus> {
456 match self.database.get_status(model_name) {
457 Ok(status) => {
458 if status.is_some() {
460 let _ = self.database.touch_model(model_name);
461 }
462 status
463 }
464 Err(e) => {
465 error!("Failed to get model status from database: {}", e);
466 None
467 }
468 }
469 }
470
471 pub fn set_status_and_notify(
473 &self,
474 model_name: String,
475 status: ModelStatus,
476 provider: ModelProvider,
477 message: Option<String>,
478 ) {
479 if let Err(e) = self
481 .database
482 .set_status(&model_name, provider, status, message.clone())
483 {
484 error!("Failed to update model status in database: {}", e);
485 return;
486 }
487
488 let mut waiting = match self.waiting_channels.lock() {
490 Ok(guard) => guard,
491 Err(poisoned) => {
492 error!("Waiting channels mutex is poisoned, recovering");
493 poisoned.into_inner()
494 }
495 };
496 if let Some(channels) = waiting.get(&model_name) {
497 let update = ModelStatusUpdate {
498 model_name: model_name.clone(),
499 status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
500 message,
501 provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
502 };
503
504 for channel in channels {
505 let _ = channel.try_send(Ok(update.clone()));
506 }
507
508 if status == ModelStatus::DOWNLOADED || status == ModelStatus::ERROR {
510 waiting.remove(&model_name);
511 }
512 }
513 }
514
515 pub fn set_status(&self, model_name: String, status: ModelStatus, provider: ModelProvider) {
517 self.set_status_and_notify(model_name, status, provider, None);
518 }
519
520 pub fn add_waiting_channel(
522 &self,
523 model_name: &str,
524 tx: tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
525 ) {
526 let mut waiting = match self.waiting_channels.lock() {
527 Ok(guard) => guard,
528 Err(poisoned) => {
529 error!("Waiting channels mutex is poisoned, recovering");
530 poisoned.into_inner()
531 }
532 };
533 waiting.entry(model_name.to_string()).or_default().push(tx);
534 }
535
536 pub fn delete_status(&self, model_name: &str) {
539 if let Err(e) = self.database.delete_model(model_name) {
540 error!("Failed to delete model from database: {}", e);
541 }
542 let mut waiting = match self.waiting_channels.lock() {
543 Ok(guard) => guard,
544 Err(poisoned) => {
545 error!("Waiting channels mutex is poisoned, recovering");
546 poisoned.into_inner()
547 }
548 };
549 waiting.remove(model_name);
550 }
551
552 pub async fn ensure_model_downloaded(
554 &self,
555 model_name: &str,
556 provider: ModelProvider,
557 tx: &tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
558 ignore_weights: bool,
559 ) -> ModelStatus {
560 let status = match self.database.try_claim_for_download(model_name, provider) {
562 Ok(status) => status,
563 Err(e) => {
564 error!("Failed to claim model for download: {}", e);
565 let error_update = ModelStatusUpdate {
567 model_name: model_name.to_string(),
568 status: modelexpress_common::grpc::model::ModelStatus::from(ModelStatus::ERROR)
569 as i32,
570 message: Some("Database error occurred".to_string()),
571 provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
572 as i32,
573 };
574 let _ = tx.send(Ok(error_update)).await;
575 return ModelStatus::ERROR;
576 }
577 };
578
579 let update = ModelStatusUpdate {
581 model_name: model_name.to_string(),
582 status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
583 message: match status {
584 ModelStatus::DOWNLOADED => Some("Model already downloaded".to_string()),
585 ModelStatus::DOWNLOADING => Some("Model download in progress".to_string()),
586 ModelStatus::ERROR => Some("Previous download failed - retrying".to_string()),
587 },
588 provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
589 };
590
591 let _ = tx.send(Ok(update)).await;
592
593 if status == ModelStatus::DOWNLOADING {
595 self.add_waiting_channel(model_name, tx.clone());
596
597 let should_start_download = {
601 let waiting = match self.waiting_channels.lock() {
602 Ok(guard) => guard,
603 Err(poisoned) => {
604 error!("Waiting channels mutex is poisoned, recovering");
605 poisoned.into_inner()
606 }
607 };
608 waiting
609 .get(model_name)
610 .is_none_or(|channels| channels.len() <= 1)
611 };
612
613 if should_start_download {
614 let tracker = self.clone();
616 let model_name_owned = model_name.to_string();
617
618 tokio::spawn(async move {
620 let cache_dir = get_server_cache_dir();
621 match download::download_model(
622 &model_name_owned,
623 provider,
624 cache_dir,
625 ignore_weights,
626 )
627 .await
628 {
629 Ok(_path) => {
630 tracker.set_status_and_notify(
632 model_name_owned,
633 ModelStatus::DOWNLOADED,
634 provider,
635 Some("Model download completed successfully".to_string()),
636 );
637 }
638 Err(e) => {
639 error!("Failed to download model {model_name_owned}: {e}");
641 tracker.set_status_and_notify(
642 model_name_owned,
643 ModelStatus::ERROR,
644 provider,
645 Some(format!("Download failed: {e}")),
646 );
647 }
648 }
649 });
650 }
651
652 loop {
654 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
655 if let Some(current_status) = self.get_status(model_name)
656 && current_status != ModelStatus::DOWNLOADING
657 {
658 return current_status;
659 }
660 }
661 } else if status == ModelStatus::ERROR {
662 if let Err(e) = self.database.set_status(
665 model_name,
666 provider,
667 ModelStatus::DOWNLOADING,
668 Some("Retrying download...".to_string()),
669 ) {
670 error!("Failed to reset status for retry: {}", e);
671 return ModelStatus::ERROR;
672 }
673
674 self.add_waiting_channel(model_name, tx.clone());
676
677 let tracker = self.clone();
679 let model_name_owned = model_name.to_string();
680
681 tokio::spawn(async move {
682 let cache_dir = get_server_cache_dir();
683 match download::download_model(
684 &model_name_owned,
685 provider,
686 cache_dir,
687 ignore_weights,
688 )
689 .await
690 {
691 Ok(_path) => {
692 tracker.set_status_and_notify(
694 model_name_owned,
695 ModelStatus::DOWNLOADED,
696 provider,
697 Some("Model download completed successfully".to_string()),
698 );
699 }
700 Err(e) => {
701 error!("Failed to download model {model_name_owned} on retry: {e}");
703 tracker.set_status_and_notify(
704 model_name_owned,
705 ModelStatus::ERROR,
706 provider,
707 Some(format!("Download failed on retry: {e}")),
708 );
709 }
710 }
711 });
712
713 loop {
715 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
716 if let Some(current_status) = self.get_status(model_name)
717 && current_status != ModelStatus::DOWNLOADING
718 {
719 return current_status;
720 }
721 }
722 }
723
724 status
725 }
726}
727
728pub static MODEL_TRACKER: std::sync::LazyLock<ModelDownloadTracker> =
730 std::sync::LazyLock::new(ModelDownloadTracker::new);
731
732#[cfg(test)]
733#[allow(clippy::expect_used)]
734mod tests {
735 use super::*;
736 use modelexpress_common::grpc::{
737 api::ApiRequest, health::HealthRequest, model::ModelDownloadRequest,
738 };
739 use tempfile::TempDir;
740 use tokio_stream::StreamExt;
741 use tonic::Request;
742
743 #[tokio::test]
744 async fn test_health_service() {
745 let service = HealthServiceImpl;
746 let request = Request::new(HealthRequest {});
747
748 let response = service.get_health(request).await;
749 assert!(response.is_ok());
750
751 let health_response = response.expect("Health response should be ok").into_inner();
752 assert_eq!(health_response.version, env!("CARGO_PKG_VERSION"));
753 assert_eq!(health_response.status, "ok");
754 let _uptime = health_response.uptime;
756 }
757
758 #[tokio::test]
759 async fn test_api_service_ping() {
760 let service = ApiServiceImpl;
761 let request = Request::new(ApiRequest {
762 id: "test-id".to_string(),
763 action: "ping".to_string(),
764 payload: None,
765 });
766
767 let response = service.send_request(request).await;
768 assert!(response.is_ok());
769
770 let api_response = response.expect("API response should be ok").into_inner();
771 assert!(api_response.success);
772 assert!(api_response.data.is_some());
773 assert!(api_response.error.is_none());
774
775 let data_bytes = api_response.data.expect("Data should be present");
777 let data: serde_json::Value =
778 serde_json::from_slice(&data_bytes).expect("Data should be valid JSON");
779 assert_eq!(data["message"], "pong");
780 }
781
782 #[tokio::test]
783 async fn test_api_service_unknown_action() {
784 let service = ApiServiceImpl;
785 let request = Request::new(ApiRequest {
786 id: "test-id".to_string(),
787 action: "unknown-action".to_string(),
788 payload: None,
789 });
790
791 let response = service.send_request(request).await;
792 assert!(response.is_ok());
793
794 let api_response = response.expect("API response should be ok").into_inner();
795 assert!(!api_response.success);
796 assert!(api_response.data.is_none());
797 assert!(api_response.error.is_some());
798
799 let error_message = api_response.error.expect("Error should be present");
800 assert!(error_message.contains("Unknown action"));
801 }
802
803 #[test]
804 fn test_model_download_tracker_new() {
805 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
806 let tracker = ModelDownloadTracker::new();
807
808 let status = tracker.get_status("non-existent-model");
810 assert!(status.is_none());
811 }
812
813 #[test]
814 fn test_model_download_tracker_set_and_get_status() {
815 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
816 let tracker = ModelDownloadTracker::new();
817
818 let timestamp = std::time::SystemTime::now()
820 .duration_since(std::time::UNIX_EPOCH)
821 .expect("Time went backwards")
822 .as_nanos();
823 let model_name = format!("test-model-{timestamp}");
824 let provider = ModelProvider::HuggingFace;
825
826 assert!(tracker.get_status(&model_name).is_none());
828
829 tracker.set_status(model_name.clone(), ModelStatus::DOWNLOADING, provider);
831
832 let status = tracker.get_status(&model_name);
834 assert!(status.is_some());
835 assert_eq!(
836 status.expect("Status should be present"),
837 ModelStatus::DOWNLOADING
838 );
839
840 tracker.delete_status(&model_name);
842 }
843
844 #[test]
845 fn test_model_download_tracker_delete_status() {
846 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
847 let tracker = ModelDownloadTracker::new();
848 let timestamp = std::time::SystemTime::now()
849 .duration_since(std::time::UNIX_EPOCH)
850 .expect("Time went backwards")
851 .as_nanos();
852 let model_name = format!("test-delete-model-{timestamp}");
853 let provider = ModelProvider::HuggingFace;
854
855 tracker.set_status(model_name.clone(), ModelStatus::DOWNLOADED, provider);
857 assert!(tracker.get_status(&model_name).is_some());
858
859 tracker.delete_status(&model_name);
861 assert!(tracker.get_status(&model_name).is_none());
862 }
863
864 #[tokio::test]
865 async fn test_model_service_already_downloaded() {
866 let service = ModelServiceImpl;
867 let timestamp = std::time::SystemTime::now()
868 .duration_since(std::time::UNIX_EPOCH)
869 .expect("Time went backwards")
870 .as_nanos();
871 let model_name = format!("test-already-downloaded-model-{timestamp}");
872
873 MODEL_TRACKER.set_status(
875 model_name.clone(),
876 ModelStatus::DOWNLOADED,
877 ModelProvider::HuggingFace,
878 );
879
880 let request = Request::new(ModelDownloadRequest {
881 model_name: model_name.clone(),
882 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
883 ignore_weights: false,
884 });
885
886 let response = service.ensure_model_downloaded(request).await;
887 assert!(response.is_ok());
888
889 let mut stream = response.expect("Response should be ok").into_inner();
890
891 let update = stream.next().await;
893 assert!(update.is_some());
894
895 let update = update.expect("Update should be present");
896 assert!(update.is_ok());
897
898 let status_update = update.expect("Status update should be ok");
899 assert_eq!(status_update.model_name, model_name);
900 assert_eq!(
901 status_update.status,
902 modelexpress_common::grpc::model::ModelStatus::Downloaded as i32
903 );
904
905 MODEL_TRACKER.delete_status(&model_name);
907 }
908
909 #[test]
910 fn test_model_download_tracker_set_status_and_notify() {
911 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
912 let tracker = ModelDownloadTracker::new();
913 let model_name = "test-notify-model".to_string();
914 let provider = ModelProvider::HuggingFace;
915
916 tracker.set_status_and_notify(
918 model_name.clone(),
919 ModelStatus::DOWNLOADED,
920 provider,
921 Some("Download completed".to_string()),
922 );
923
924 let status = tracker.get_status(&model_name);
926 assert!(status.is_some());
927 assert_eq!(
928 status.expect("Status should be present"),
929 ModelStatus::DOWNLOADED
930 );
931 }
932
933 #[test]
934 fn test_waiting_channels_management() {
935 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
936 let tracker = ModelDownloadTracker::new();
937 let model_name = "test-channels-model";
938
939 let (tx, _rx) = tokio::sync::mpsc::channel(4);
940
941 tracker.add_waiting_channel(model_name, tx);
943
944 let waiting_count = {
946 let waiting = match tracker.waiting_channels.lock() {
947 Ok(guard) => guard,
948 Err(poisoned) => poisoned.into_inner(),
949 };
950 waiting.get(model_name).map_or(0, std::vec::Vec::len)
951 };
952 assert_eq!(waiting_count, 1);
953
954 tracker.set_status_and_notify(
956 model_name.to_string(),
957 ModelStatus::DOWNLOADED,
958 ModelProvider::HuggingFace,
959 None,
960 );
961
962 let waiting_count_after = {
964 let waiting = match tracker.waiting_channels.lock() {
965 Ok(guard) => guard,
966 Err(poisoned) => poisoned.into_inner(),
967 };
968 waiting.get(model_name).map_or(0, std::vec::Vec::len)
969 };
970 assert_eq!(waiting_count_after, 0);
971 }
972
973 #[tokio::test]
974 async fn test_model_service_stream_closes_properly() {
975 let service = ModelServiceImpl;
976 let model_name = "test-stream-model";
977
978 let request = Request::new(ModelDownloadRequest {
979 model_name: model_name.to_string(),
980 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
981 ignore_weights: false,
982 });
983
984 let response = service.ensure_model_downloaded(request).await;
985 assert!(response.is_ok());
986
987 let mut stream = response.expect("Response should be ok").into_inner();
988
989 let mut update_count = 0;
991 while let Some(update) = stream.next().await {
992 assert!(update.is_ok());
993 update_count += 1;
994
995 if update_count > 10 {
997 break;
998 }
999 }
1000
1001 assert!(update_count > 0);
1002
1003 MODEL_TRACKER.delete_status(model_name);
1005 }
1006
1007 #[tokio::test]
1008 async fn test_concurrent_model_download_no_race_condition() {
1009 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
1010 let tracker = ModelDownloadTracker::new();
1011 let model_name = "test-concurrent-model";
1012 let provider = ModelProvider::HuggingFace;
1013
1014 let status1 = tracker
1017 .database
1018 .try_claim_for_download(model_name, provider)
1019 .expect("Failed to claim for download 1");
1020 assert_eq!(status1, ModelStatus::DOWNLOADING);
1021
1022 let status2 = tracker
1024 .database
1025 .try_claim_for_download(model_name, provider)
1026 .expect("Failed to claim for download 2");
1027 assert_eq!(status2, ModelStatus::DOWNLOADING);
1028
1029 let record = tracker
1031 .database
1032 .get_model_record(model_name)
1033 .expect("Failed to get model record")
1034 .expect("Record should exist");
1035 assert_eq!(record.status, ModelStatus::DOWNLOADING);
1036
1037 tracker.delete_status(model_name);
1039 }
1040
1041 #[test]
1042 fn test_collect_model_files_empty_dir() {
1043 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1044 let files = collect_model_files(temp_dir.path(), temp_dir.path());
1045 assert!(files.is_empty());
1046 }
1047
1048 #[test]
1049 fn test_collect_model_files_with_files() {
1050 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1051
1052 let file1_path = temp_dir.path().join("config.json");
1054 std::fs::write(&file1_path, r#"{"test": "data"}"#).expect("Failed to write file1");
1055
1056 let file2_path = temp_dir.path().join("model.bin");
1057 std::fs::write(&file2_path, vec![0u8; 100]).expect("Failed to write file2");
1058
1059 let files = collect_model_files(temp_dir.path(), temp_dir.path());
1060
1061 assert_eq!(files.len(), 2);
1062
1063 let total_size: u64 = files.iter().map(|(_, size)| size).sum();
1065 assert!(total_size > 0);
1066
1067 let paths: Vec<_> = files
1069 .iter()
1070 .map(|(p, _)| p.to_string_lossy().to_string())
1071 .collect();
1072 assert!(paths.contains(&"config.json".to_string()));
1073 assert!(paths.contains(&"model.bin".to_string()));
1074 }
1075
1076 #[test]
1077 fn test_collect_model_files_nested() {
1078 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1079
1080 let subdir = temp_dir.path().join("subdir");
1082 std::fs::create_dir(&subdir).expect("Failed to create subdir");
1083
1084 let file1_path = temp_dir.path().join("root_file.txt");
1085 std::fs::write(&file1_path, "root content").expect("Failed to write file1");
1086
1087 let file2_path = subdir.join("nested_file.txt");
1088 std::fs::write(&file2_path, "nested content").expect("Failed to write file2");
1089
1090 let files = collect_model_files(temp_dir.path(), temp_dir.path());
1091
1092 assert_eq!(files.len(), 2);
1093
1094 let paths: Vec<_> = files
1096 .iter()
1097 .map(|(p, _)| p.to_string_lossy().to_string())
1098 .collect();
1099 assert!(paths.iter().any(|p| p.contains("nested_file")));
1100 }
1101
1102 #[tokio::test]
1103 async fn test_list_model_files_not_found() {
1104 let service = ModelServiceImpl;
1105
1106 let request = Request::new(ModelFilesRequest {
1107 model_name: "non-existent-model-12345".to_string(),
1108 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1109 chunk_size: 0,
1110 });
1111
1112 let result = service.list_model_files(request).await;
1113 assert!(result.is_err());
1114 let status = result.expect_err("Should return error");
1115 assert_eq!(status.code(), tonic::Code::NotFound);
1116 }
1117
1118 #[tokio::test]
1119 async fn test_stream_model_files_not_found() {
1120 let service = ModelServiceImpl;
1121
1122 let request = Request::new(ModelFilesRequest {
1123 model_name: "non-existent-model-12345".to_string(),
1124 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1125 chunk_size: 1024,
1126 });
1127
1128 let result = service.stream_model_files(request).await;
1129 assert!(result.is_err());
1130 let status = result.expect_err("Should return error");
1131 assert_eq!(status.code(), tonic::Code::NotFound);
1132 }
1133}