1use crate::database::ModelDatabase;
5use modelexpress_common::{
6 cache::CacheConfig,
7 constants, download,
8 grpc::{
9 api::{ApiRequest, ApiResponse, api_service_server::ApiService},
10 health::{HealthRequest, HealthResponse, health_service_server::HealthService},
11 model::{
12 FileChunk, ModelDownloadRequest, ModelFileInfo, ModelFileList, ModelFilesRequest,
13 ModelStatusUpdate, model_service_server::ModelService,
14 },
15 },
16 models::{ModelProvider, ModelStatus},
17};
18use std::{
19 collections::HashMap,
20 path::{Path, PathBuf},
21 sync::{Arc, Mutex},
22 time::SystemTime,
23};
24use tokio::io::AsyncReadExt;
25use tokio_stream::wrappers::ReceiverStream;
26use tonic::{Request, Response, Status};
27use tracing::{debug, error, info};
28
29static START_TIME: std::sync::OnceLock<SystemTime> = std::sync::OnceLock::new();
30
31fn get_server_cache_dir() -> Option<std::path::PathBuf> {
33 if let Ok(config) = CacheConfig::discover() {
35 Some(config.local_path)
36 } else {
37 std::env::var("HF_HUB_CACHE")
39 .ok()
40 .map(std::path::PathBuf::from)
41 }
42}
43
44fn convert_provider(grpc_provider: i32) -> ModelProvider {
49 match modelexpress_common::grpc::model::ModelProvider::try_from(grpc_provider) {
50 Ok(provider) => provider.into(),
51 Err(_) => {
52 tracing::warn!(
53 "Invalid provider value {}, falling back to HuggingFace",
54 grpc_provider
55 );
56 ModelProvider::HuggingFace
57 }
58 }
59}
60
61#[derive(Debug, Default)]
63pub struct HealthServiceImpl;
64
65#[tonic::async_trait]
66impl HealthService for HealthServiceImpl {
67 async fn get_health(
68 &self,
69 _request: Request<HealthRequest>,
70 ) -> Result<Response<HealthResponse>, Status> {
71 let start_time = START_TIME.get_or_init(SystemTime::now);
72 let uptime = SystemTime::now()
73 .duration_since(*start_time)
74 .unwrap_or_default()
75 .as_secs();
76
77 let response = HealthResponse {
78 version: env!("CARGO_PKG_VERSION").to_string(),
79 status: "ok".to_string(),
80 uptime,
81 };
82
83 Ok(Response::new(response))
84 }
85}
86
87#[derive(Debug, Default)]
89pub struct ApiServiceImpl;
90
91#[tonic::async_trait]
92impl ApiService for ApiServiceImpl {
93 async fn send_request(
94 &self,
95 request: Request<ApiRequest>,
96 ) -> Result<Response<ApiResponse>, Status> {
97 let api_request = request.into_inner();
98 info!("Received gRPC request: {:?}", api_request);
99
100 if api_request.action.as_str() == "ping" {
102 info!("Processing ping request");
103 let response_data = serde_json::json!({ "message": "pong" });
104 let data_bytes = serde_json::to_vec(&response_data)
105 .map_err(|e| Status::internal(format!("Serialization error: {e}")))?;
106
107 Ok(Response::new(ApiResponse {
108 success: true,
109 data: Some(data_bytes),
110 error: None,
111 }))
112 } else {
113 error!("Unknown action: {}", api_request.action);
114 Ok(Response::new(ApiResponse {
115 success: false,
116 data: None,
117 error: Some(format!("Unknown action: {}", api_request.action)),
118 }))
119 }
120 }
121}
122
123#[derive(Debug, Default)]
125pub struct ModelServiceImpl;
126
127fn collect_model_files(base_path: &Path, current_path: &Path) -> Vec<(PathBuf, u64)> {
129 let mut files = Vec::new();
130
131 if let Ok(entries) = std::fs::read_dir(current_path) {
132 for entry in entries.flatten() {
133 let path = entry.path();
134 if path.is_file() {
135 if let Ok(metadata) = std::fs::metadata(&path) {
136 if let Ok(relative) = path.strip_prefix(base_path) {
138 let mut is_safe = true;
140 for comp in relative.components() {
141 use std::path::Component;
142 match comp {
143 Component::ParentDir
144 | Component::RootDir
145 | Component::Prefix(_) => {
146 is_safe = false;
147 break;
148 }
149 _ => {}
150 }
151 }
152 if is_safe {
153 files.push((relative.to_path_buf(), metadata.len()));
154 } else {
155 tracing::warn!(
156 "Skipping potentially unsafe file path: {:?} (relative: {:?})",
157 path,
158 relative
159 );
160 }
161 }
162 }
163 } else if path.is_dir() {
164 files.extend(collect_model_files(base_path, &path));
165 }
166 }
167 }
168
169 files
170}
171
172#[tonic::async_trait]
173impl ModelService for ModelServiceImpl {
174 type EnsureModelDownloadedStream = ReceiverStream<Result<ModelStatusUpdate, Status>>;
175 type StreamModelFilesStream = ReceiverStream<Result<FileChunk, Status>>;
176
177 async fn ensure_model_downloaded(
178 &self,
179 request: Request<ModelDownloadRequest>,
180 ) -> Result<Response<Self::EnsureModelDownloadedStream>, Status> {
181 info!("Starting model download stream");
182 let model_request = request.into_inner();
183 let (tx, rx) = tokio::sync::mpsc::channel(4);
184 let model_name = model_request.model_name.clone();
185
186 let provider: ModelProvider =
188 modelexpress_common::grpc::model::ModelProvider::try_from(model_request.provider)
189 .unwrap_or(modelexpress_common::grpc::model::ModelProvider::HuggingFace)
190 .into();
191 let ignore_weights = model_request.ignore_weights;
192
193 tokio::spawn(async move {
195 if let Some(status) = MODEL_TRACKER.get_status(&model_name) {
197 let update = ModelStatusUpdate {
198 model_name: model_name.clone(),
199 status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
200 message: match status {
201 ModelStatus::DOWNLOADED => Some("Model already downloaded".to_string()),
202 ModelStatus::DOWNLOADING => Some("Model download in progress".to_string()),
203 ModelStatus::ERROR => {
204 Some("Previous download failed - retrying".to_string())
205 }
206 },
207 provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
208 as i32,
209 };
210
211 if tx.send(Ok(update)).await.is_err() {
212 return; }
214
215 if status == ModelStatus::DOWNLOADED {
217 return;
218 }
219 }
220
221 let final_status = MODEL_TRACKER
223 .ensure_model_downloaded(&model_name, provider, &tx, ignore_weights)
224 .await;
225
226 let final_update = ModelStatusUpdate {
228 model_name: model_name.clone(),
229 status: modelexpress_common::grpc::model::ModelStatus::from(final_status) as i32,
230 message: match final_status {
231 ModelStatus::DOWNLOADED => {
232 Some("Model download completed successfully".to_string())
233 }
234 ModelStatus::ERROR => Some("Model download failed".to_string()),
235 ModelStatus::DOWNLOADING => Some("Download still in progress".to_string()),
236 },
237 provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
238 };
239
240 let _ = tx.send(Ok(final_update)).await;
241 });
242
243 Ok(Response::new(ReceiverStream::new(rx)))
244 }
245
246 async fn stream_model_files(
247 &self,
248 request: Request<ModelFilesRequest>,
249 ) -> Result<Response<Self::StreamModelFilesStream>, Status> {
250 let files_request = request.into_inner();
251 let model_name = files_request.model_name.clone();
252 let chunk_size = if files_request.chunk_size == 0 {
253 constants::DEFAULT_TRANSFER_CHUNK_SIZE
254 } else {
255 files_request.chunk_size as usize
256 };
257
258 let provider = convert_provider(files_request.provider);
260
261 info!(
262 "Starting file stream for model: {} with chunk size: {} bytes",
263 model_name, chunk_size
264 );
265
266 let cache_dir = get_server_cache_dir()
268 .ok_or_else(|| Status::internal("Server cache directory not configured"))?;
269
270 let provider_impl = download::get_provider(provider);
272 let model_path = provider_impl
273 .get_model_path(&model_name, cache_dir.clone())
274 .await
275 .map_err(|e| Status::not_found(format!("Model not found: {e}")))?;
276
277 debug!("Model path resolved to: {:?}", model_path);
278
279 let commit_hash = if provider == ModelProvider::HuggingFace {
280 model_path
281 .file_name()
282 .and_then(|name| name.to_str())
283 .map(String::from)
284 } else {
285 None
286 };
287
288 if provider == ModelProvider::HuggingFace && commit_hash.is_none() {
289 return Err(Status::internal(
290 "Resolved Hugging Face model path did not contain a revision",
291 ));
292 }
293
294 let files = collect_model_files(&model_path, &model_path);
296
297 if files.is_empty() {
298 return Err(Status::not_found("No files found in model directory"));
299 }
300
301 let total_files = files.len();
302 info!(
303 "Found {} files to stream for model {}",
304 total_files, model_name
305 );
306
307 let (tx, rx) = tokio::sync::mpsc::channel(16);
308
309 tokio::spawn(async move {
311 let mut buffer = vec![0u8; chunk_size];
313 let mut is_first_chunk = true;
314
315 for (file_idx, (relative_path, total_size)) in files.iter().enumerate() {
316 let file_path = model_path.join(relative_path);
317 let is_last_file = file_idx == total_files.saturating_sub(1);
318
319 debug!("Streaming file: {:?} ({} bytes)", relative_path, total_size);
320
321 let file = match tokio::fs::File::open(&file_path).await {
323 Ok(f) => f,
324 Err(e) => {
325 error!("Failed to open file {:?}: {}", file_path, e);
326 let _ = tx
327 .send(Err(Status::internal(format!("Failed to open file: {e}"))))
328 .await;
329 return;
330 }
331 };
332
333 let mut reader = tokio::io::BufReader::new(file);
334 let mut offset: u64 = 0;
335
336 loop {
337 let bytes_read = match reader.read(&mut buffer).await {
338 Ok(0) => break, Ok(n) => n,
340 Err(e) => {
341 error!("Failed to read file {:?}: {}", file_path, e);
342 let _ = tx
343 .send(Err(Status::internal(format!("Failed to read file: {e}"))))
344 .await;
345 return;
346 }
347 };
348
349 let is_last_chunk = offset.saturating_add(bytes_read as u64) >= *total_size;
350
351 let first_chunk = std::mem::replace(&mut is_first_chunk, false);
352
353 let chunk = FileChunk {
354 relative_path: relative_path.to_string_lossy().to_string(),
355 data: buffer[..bytes_read].to_vec(),
356 offset,
357 total_size: *total_size,
358 is_last_chunk,
359 is_last_file: is_last_file && is_last_chunk,
360 commit_hash: if first_chunk {
361 commit_hash.clone()
362 } else {
363 None
364 },
365 };
366
367 if tx.send(Ok(chunk)).await.is_err() {
368 debug!("Client disconnected during file stream");
369 return;
370 }
371
372 offset = offset.saturating_add(bytes_read as u64);
373 }
374 }
375
376 info!("File streaming completed for model");
377 });
378
379 Ok(Response::new(ReceiverStream::new(rx)))
380 }
381
382 async fn list_model_files(
383 &self,
384 request: Request<ModelFilesRequest>,
385 ) -> Result<Response<ModelFileList>, Status> {
386 let files_request = request.into_inner();
387 let model_name = files_request.model_name.clone();
388
389 let provider = convert_provider(files_request.provider);
391
392 info!("Listing files for model: {}", model_name);
393
394 let cache_dir = get_server_cache_dir()
396 .ok_or_else(|| Status::internal("Server cache directory not configured"))?;
397
398 let provider_impl = download::get_provider(provider);
400 let model_path = provider_impl
401 .get_model_path(&model_name, cache_dir)
402 .await
403 .map_err(|e| Status::not_found(format!("Model not found: {e}")))?;
404
405 let files = collect_model_files(&model_path, &model_path);
407
408 let file_infos: Vec<ModelFileInfo> = files
409 .iter()
410 .map(|(path, size)| ModelFileInfo {
411 relative_path: path.to_string_lossy().to_string(),
412 size: *size,
413 })
414 .collect();
415
416 let total_size: u64 = files.iter().map(|(_, size)| size).sum();
417
418 Ok(Response::new(ModelFileList {
419 model_name,
420 files: file_infos,
421 total_size,
422 }))
423 }
424}
425
426type WaitingChannels =
428 Arc<Mutex<HashMap<String, Vec<tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>>>>>;
429
430#[derive(Debug, Clone)]
432pub struct ModelDownloadTracker {
433 database: ModelDatabase,
435 waiting_channels: WaitingChannels,
437}
438
439impl Default for ModelDownloadTracker {
440 fn default() -> Self {
441 Self::new()
442 }
443}
444
445impl ModelDownloadTracker {
446 #[must_use]
447 pub fn new() -> Self {
448 let database = match ModelDatabase::new("./models.db") {
450 Ok(db) => db,
451 Err(e) => {
452 error!("Critical error: Could not initialize model database at ./models.db: {e}");
453 panic!("Critical error: Could not initialize model database at ./models.db");
454 }
455 };
456
457 Self {
458 database,
459 waiting_channels: Arc::new(Mutex::new(HashMap::new())),
460 }
461 }
462
463 pub fn get_status(&self, model_name: &str) -> Option<ModelStatus> {
466 match self.database.get_status(model_name) {
467 Ok(status) => {
468 if status.is_some() {
470 let _ = self.database.touch_model(model_name);
471 }
472 status
473 }
474 Err(e) => {
475 error!("Failed to get model status from database: {}", e);
476 None
477 }
478 }
479 }
480
481 pub fn set_status_and_notify(
483 &self,
484 model_name: String,
485 status: ModelStatus,
486 provider: ModelProvider,
487 message: Option<String>,
488 ) {
489 if let Err(e) = self
491 .database
492 .set_status(&model_name, provider, status, message.clone())
493 {
494 error!("Failed to update model status in database: {}", e);
495 return;
496 }
497
498 let mut waiting = match self.waiting_channels.lock() {
500 Ok(guard) => guard,
501 Err(poisoned) => {
502 error!("Waiting channels mutex is poisoned, recovering");
503 poisoned.into_inner()
504 }
505 };
506 if let Some(channels) = waiting.get(&model_name) {
507 let update = ModelStatusUpdate {
508 model_name: model_name.clone(),
509 status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
510 message,
511 provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
512 };
513
514 for channel in channels {
515 let _ = channel.try_send(Ok(update.clone()));
516 }
517
518 if status == ModelStatus::DOWNLOADED || status == ModelStatus::ERROR {
520 waiting.remove(&model_name);
521 }
522 }
523 }
524
525 pub fn set_status(&self, model_name: String, status: ModelStatus, provider: ModelProvider) {
527 self.set_status_and_notify(model_name, status, provider, None);
528 }
529
530 pub fn add_waiting_channel(
532 &self,
533 model_name: &str,
534 tx: tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
535 ) {
536 let mut waiting = match self.waiting_channels.lock() {
537 Ok(guard) => guard,
538 Err(poisoned) => {
539 error!("Waiting channels mutex is poisoned, recovering");
540 poisoned.into_inner()
541 }
542 };
543 waiting.entry(model_name.to_string()).or_default().push(tx);
544 }
545
546 pub fn delete_status(&self, model_name: &str) {
549 if let Err(e) = self.database.delete_model(model_name) {
550 error!("Failed to delete model from database: {}", e);
551 }
552 let mut waiting = match self.waiting_channels.lock() {
553 Ok(guard) => guard,
554 Err(poisoned) => {
555 error!("Waiting channels mutex is poisoned, recovering");
556 poisoned.into_inner()
557 }
558 };
559 waiting.remove(model_name);
560 }
561
562 pub async fn ensure_model_downloaded(
564 &self,
565 model_name: &str,
566 provider: ModelProvider,
567 tx: &tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
568 ignore_weights: bool,
569 ) -> ModelStatus {
570 let status = match self.database.try_claim_for_download(model_name, provider) {
572 Ok(status) => status,
573 Err(e) => {
574 error!("Failed to claim model for download: {}", e);
575 let error_update = ModelStatusUpdate {
577 model_name: model_name.to_string(),
578 status: modelexpress_common::grpc::model::ModelStatus::from(ModelStatus::ERROR)
579 as i32,
580 message: Some("Database error occurred".to_string()),
581 provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
582 as i32,
583 };
584 let _ = tx.send(Ok(error_update)).await;
585 return ModelStatus::ERROR;
586 }
587 };
588
589 let update = ModelStatusUpdate {
591 model_name: model_name.to_string(),
592 status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
593 message: match status {
594 ModelStatus::DOWNLOADED => Some("Model already downloaded".to_string()),
595 ModelStatus::DOWNLOADING => Some("Model download in progress".to_string()),
596 ModelStatus::ERROR => Some("Previous download failed - retrying".to_string()),
597 },
598 provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
599 };
600
601 let _ = tx.send(Ok(update)).await;
602
603 if status == ModelStatus::DOWNLOADING {
605 self.add_waiting_channel(model_name, tx.clone());
606
607 let should_start_download = {
611 let waiting = match self.waiting_channels.lock() {
612 Ok(guard) => guard,
613 Err(poisoned) => {
614 error!("Waiting channels mutex is poisoned, recovering");
615 poisoned.into_inner()
616 }
617 };
618 waiting
619 .get(model_name)
620 .is_none_or(|channels| channels.len() <= 1)
621 };
622
623 if should_start_download {
624 let tracker = self.clone();
626 let model_name_owned = model_name.to_string();
627
628 tokio::spawn(async move {
630 let cache_dir = get_server_cache_dir();
631 match download::download_model(
632 &model_name_owned,
633 provider,
634 cache_dir,
635 ignore_weights,
636 )
637 .await
638 {
639 Ok(_path) => {
640 tracker.set_status_and_notify(
642 model_name_owned,
643 ModelStatus::DOWNLOADED,
644 provider,
645 Some("Model download completed successfully".to_string()),
646 );
647 }
648 Err(e) => {
649 error!("Failed to download model {model_name_owned}: {e}");
651 tracker.set_status_and_notify(
652 model_name_owned,
653 ModelStatus::ERROR,
654 provider,
655 Some(format!("Download failed: {e}")),
656 );
657 }
658 }
659 });
660 }
661
662 loop {
664 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
665 if let Some(current_status) = self.get_status(model_name)
666 && current_status != ModelStatus::DOWNLOADING
667 {
668 return current_status;
669 }
670 }
671 } else if status == ModelStatus::ERROR {
672 if let Err(e) = self.database.set_status(
675 model_name,
676 provider,
677 ModelStatus::DOWNLOADING,
678 Some("Retrying download...".to_string()),
679 ) {
680 error!("Failed to reset status for retry: {}", e);
681 return ModelStatus::ERROR;
682 }
683
684 self.add_waiting_channel(model_name, tx.clone());
686
687 let tracker = self.clone();
689 let model_name_owned = model_name.to_string();
690
691 tokio::spawn(async move {
692 let cache_dir = get_server_cache_dir();
693 match download::download_model(
694 &model_name_owned,
695 provider,
696 cache_dir,
697 ignore_weights,
698 )
699 .await
700 {
701 Ok(_path) => {
702 tracker.set_status_and_notify(
704 model_name_owned,
705 ModelStatus::DOWNLOADED,
706 provider,
707 Some("Model download completed successfully".to_string()),
708 );
709 }
710 Err(e) => {
711 error!("Failed to download model {model_name_owned} on retry: {e}");
713 tracker.set_status_and_notify(
714 model_name_owned,
715 ModelStatus::ERROR,
716 provider,
717 Some(format!("Download failed on retry: {e}")),
718 );
719 }
720 }
721 });
722
723 loop {
725 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
726 if let Some(current_status) = self.get_status(model_name)
727 && current_status != ModelStatus::DOWNLOADING
728 {
729 return current_status;
730 }
731 }
732 }
733
734 status
735 }
736}
737
738pub static MODEL_TRACKER: std::sync::LazyLock<ModelDownloadTracker> =
740 std::sync::LazyLock::new(ModelDownloadTracker::new);
741
742#[cfg(test)]
743#[allow(clippy::expect_used)]
744mod tests {
745 use super::*;
746 use modelexpress_common::grpc::{
747 api::ApiRequest, health::HealthRequest, model::ModelDownloadRequest,
748 };
749 use modelexpress_common::test_support::{EnvVarGuard, acquire_env_mutex};
750 use tempfile::TempDir;
751 use tokio_stream::StreamExt;
752 use tonic::Request;
753
754 #[tokio::test]
755 async fn test_health_service() {
756 let service = HealthServiceImpl;
757 let request = Request::new(HealthRequest {});
758
759 let response = service.get_health(request).await;
760 assert!(response.is_ok());
761
762 let health_response = response.expect("Health response should be ok").into_inner();
763 assert_eq!(health_response.version, env!("CARGO_PKG_VERSION"));
764 assert_eq!(health_response.status, "ok");
765 let _uptime = health_response.uptime;
767 }
768
769 #[tokio::test]
770 async fn test_api_service_ping() {
771 let service = ApiServiceImpl;
772 let request = Request::new(ApiRequest {
773 id: "test-id".to_string(),
774 action: "ping".to_string(),
775 payload: None,
776 });
777
778 let response = service.send_request(request).await;
779 assert!(response.is_ok());
780
781 let api_response = response.expect("API response should be ok").into_inner();
782 assert!(api_response.success);
783 assert!(api_response.data.is_some());
784 assert!(api_response.error.is_none());
785
786 let data_bytes = api_response.data.expect("Data should be present");
788 let data: serde_json::Value =
789 serde_json::from_slice(&data_bytes).expect("Data should be valid JSON");
790 assert_eq!(data["message"], "pong");
791 }
792
793 #[tokio::test]
794 async fn test_api_service_unknown_action() {
795 let service = ApiServiceImpl;
796 let request = Request::new(ApiRequest {
797 id: "test-id".to_string(),
798 action: "unknown-action".to_string(),
799 payload: None,
800 });
801
802 let response = service.send_request(request).await;
803 assert!(response.is_ok());
804
805 let api_response = response.expect("API response should be ok").into_inner();
806 assert!(!api_response.success);
807 assert!(api_response.data.is_none());
808 assert!(api_response.error.is_some());
809
810 let error_message = api_response.error.expect("Error should be present");
811 assert!(error_message.contains("Unknown action"));
812 }
813
814 #[test]
815 fn test_model_download_tracker_new() {
816 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
817 let tracker = ModelDownloadTracker::new();
818
819 let status = tracker.get_status("non-existent-model");
821 assert!(status.is_none());
822 }
823
824 #[test]
825 fn test_model_download_tracker_set_and_get_status() {
826 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
827 let tracker = ModelDownloadTracker::new();
828
829 let timestamp = std::time::SystemTime::now()
831 .duration_since(std::time::UNIX_EPOCH)
832 .expect("Time went backwards")
833 .as_nanos();
834 let model_name = format!("test-model-{timestamp}");
835 let provider = ModelProvider::HuggingFace;
836
837 assert!(tracker.get_status(&model_name).is_none());
839
840 tracker.set_status(model_name.clone(), ModelStatus::DOWNLOADING, provider);
842
843 let status = tracker.get_status(&model_name);
845 assert!(status.is_some());
846 assert_eq!(
847 status.expect("Status should be present"),
848 ModelStatus::DOWNLOADING
849 );
850
851 tracker.delete_status(&model_name);
853 }
854
855 #[test]
856 fn test_model_download_tracker_delete_status() {
857 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
858 let tracker = ModelDownloadTracker::new();
859 let timestamp = std::time::SystemTime::now()
860 .duration_since(std::time::UNIX_EPOCH)
861 .expect("Time went backwards")
862 .as_nanos();
863 let model_name = format!("test-delete-model-{timestamp}");
864 let provider = ModelProvider::HuggingFace;
865
866 tracker.set_status(model_name.clone(), ModelStatus::DOWNLOADED, provider);
868 assert!(tracker.get_status(&model_name).is_some());
869
870 tracker.delete_status(&model_name);
872 assert!(tracker.get_status(&model_name).is_none());
873 }
874
875 #[tokio::test]
876 async fn test_model_service_already_downloaded() {
877 let service = ModelServiceImpl;
878 let timestamp = std::time::SystemTime::now()
879 .duration_since(std::time::UNIX_EPOCH)
880 .expect("Time went backwards")
881 .as_nanos();
882 let model_name = format!("test-already-downloaded-model-{timestamp}");
883
884 MODEL_TRACKER.set_status(
886 model_name.clone(),
887 ModelStatus::DOWNLOADED,
888 ModelProvider::HuggingFace,
889 );
890
891 let request = Request::new(ModelDownloadRequest {
892 model_name: model_name.clone(),
893 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
894 ignore_weights: false,
895 });
896
897 let response = service.ensure_model_downloaded(request).await;
898 assert!(response.is_ok());
899
900 let mut stream = response.expect("Response should be ok").into_inner();
901
902 let update = stream.next().await;
904 assert!(update.is_some());
905
906 let update = update.expect("Update should be present");
907 assert!(update.is_ok());
908
909 let status_update = update.expect("Status update should be ok");
910 assert_eq!(status_update.model_name, model_name);
911 assert_eq!(
912 status_update.status,
913 modelexpress_common::grpc::model::ModelStatus::Downloaded as i32
914 );
915
916 MODEL_TRACKER.delete_status(&model_name);
918 }
919
920 #[test]
921 fn test_model_download_tracker_set_status_and_notify() {
922 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
923 let tracker = ModelDownloadTracker::new();
924 let model_name = "test-notify-model".to_string();
925 let provider = ModelProvider::HuggingFace;
926
927 tracker.set_status_and_notify(
929 model_name.clone(),
930 ModelStatus::DOWNLOADED,
931 provider,
932 Some("Download completed".to_string()),
933 );
934
935 let status = tracker.get_status(&model_name);
937 assert!(status.is_some());
938 assert_eq!(
939 status.expect("Status should be present"),
940 ModelStatus::DOWNLOADED
941 );
942 }
943
944 #[test]
945 fn test_waiting_channels_management() {
946 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
947 let tracker = ModelDownloadTracker::new();
948 let model_name = "test-channels-model";
949
950 let (tx, _rx) = tokio::sync::mpsc::channel(4);
951
952 tracker.add_waiting_channel(model_name, tx);
954
955 let waiting_count = {
957 let waiting = match tracker.waiting_channels.lock() {
958 Ok(guard) => guard,
959 Err(poisoned) => poisoned.into_inner(),
960 };
961 waiting.get(model_name).map_or(0, std::vec::Vec::len)
962 };
963 assert_eq!(waiting_count, 1);
964
965 tracker.set_status_and_notify(
967 model_name.to_string(),
968 ModelStatus::DOWNLOADED,
969 ModelProvider::HuggingFace,
970 None,
971 );
972
973 let waiting_count_after = {
975 let waiting = match tracker.waiting_channels.lock() {
976 Ok(guard) => guard,
977 Err(poisoned) => poisoned.into_inner(),
978 };
979 waiting.get(model_name).map_or(0, std::vec::Vec::len)
980 };
981 assert_eq!(waiting_count_after, 0);
982 }
983
984 #[tokio::test]
985 async fn test_model_service_stream_closes_properly() {
986 let service = ModelServiceImpl;
987 let model_name = "test-stream-model";
988
989 let request = Request::new(ModelDownloadRequest {
990 model_name: model_name.to_string(),
991 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
992 ignore_weights: false,
993 });
994
995 let response = service.ensure_model_downloaded(request).await;
996 assert!(response.is_ok());
997
998 let mut stream = response.expect("Response should be ok").into_inner();
999
1000 let mut update_count = 0;
1002 while let Some(update) = stream.next().await {
1003 assert!(update.is_ok());
1004 update_count += 1;
1005
1006 if update_count > 10 {
1008 break;
1009 }
1010 }
1011
1012 assert!(update_count > 0);
1013
1014 MODEL_TRACKER.delete_status(model_name);
1016 }
1017
1018 #[tokio::test]
1019 async fn test_concurrent_model_download_no_race_condition() {
1020 let _temp_dir = TempDir::new().expect("Failed to create temp dir");
1021 let tracker = ModelDownloadTracker::new();
1022 let model_name = "test-concurrent-model";
1023 let provider = ModelProvider::HuggingFace;
1024
1025 let status1 = tracker
1028 .database
1029 .try_claim_for_download(model_name, provider)
1030 .expect("Failed to claim for download 1");
1031 assert_eq!(status1, ModelStatus::DOWNLOADING);
1032
1033 let status2 = tracker
1035 .database
1036 .try_claim_for_download(model_name, provider)
1037 .expect("Failed to claim for download 2");
1038 assert_eq!(status2, ModelStatus::DOWNLOADING);
1039
1040 let record = tracker
1042 .database
1043 .get_model_record(model_name)
1044 .expect("Failed to get model record")
1045 .expect("Record should exist");
1046 assert_eq!(record.status, ModelStatus::DOWNLOADING);
1047
1048 tracker.delete_status(model_name);
1050 }
1051
1052 #[test]
1053 fn test_collect_model_files_empty_dir() {
1054 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1055 let files = collect_model_files(temp_dir.path(), temp_dir.path());
1056 assert!(files.is_empty());
1057 }
1058
1059 #[test]
1060 fn test_collect_model_files_with_files() {
1061 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1062
1063 let file1_path = temp_dir.path().join("config.json");
1065 std::fs::write(&file1_path, r#"{"test": "data"}"#).expect("Failed to write file1");
1066
1067 let file2_path = temp_dir.path().join("model.bin");
1068 std::fs::write(&file2_path, vec![0u8; 100]).expect("Failed to write file2");
1069
1070 let files = collect_model_files(temp_dir.path(), temp_dir.path());
1071
1072 assert_eq!(files.len(), 2);
1073
1074 let total_size: u64 = files.iter().map(|(_, size)| size).sum();
1076 assert!(total_size > 0);
1077
1078 let paths: Vec<_> = files
1080 .iter()
1081 .map(|(p, _)| p.to_string_lossy().to_string())
1082 .collect();
1083 assert!(paths.contains(&"config.json".to_string()));
1084 assert!(paths.contains(&"model.bin".to_string()));
1085 }
1086
1087 #[test]
1088 fn test_collect_model_files_nested() {
1089 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1090
1091 let subdir = temp_dir.path().join("subdir");
1093 std::fs::create_dir(&subdir).expect("Failed to create subdir");
1094
1095 let file1_path = temp_dir.path().join("root_file.txt");
1096 std::fs::write(&file1_path, "root content").expect("Failed to write file1");
1097
1098 let file2_path = subdir.join("nested_file.txt");
1099 std::fs::write(&file2_path, "nested content").expect("Failed to write file2");
1100
1101 let files = collect_model_files(temp_dir.path(), temp_dir.path());
1102
1103 assert_eq!(files.len(), 2);
1104
1105 let paths: Vec<_> = files
1107 .iter()
1108 .map(|(p, _)| p.to_string_lossy().to_string())
1109 .collect();
1110 assert!(paths.iter().any(|p| p.contains("nested_file")));
1111 }
1112
1113 #[tokio::test]
1114 async fn test_list_model_files_not_found() {
1115 let service = ModelServiceImpl;
1116
1117 let request = Request::new(ModelFilesRequest {
1118 model_name: "non-existent-model-12345".to_string(),
1119 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1120 chunk_size: 0,
1121 });
1122
1123 let result = service.list_model_files(request).await;
1124 assert!(result.is_err());
1125 let status = result.expect_err("Should return error");
1126 assert_eq!(status.code(), tonic::Code::NotFound);
1127 }
1128
1129 #[tokio::test]
1130 async fn test_stream_model_files_not_found() {
1131 let service = ModelServiceImpl;
1132
1133 let request = Request::new(ModelFilesRequest {
1134 model_name: "non-existent-model-12345".to_string(),
1135 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1136 chunk_size: 1024,
1137 });
1138
1139 let result = service.stream_model_files(request).await;
1140 assert!(result.is_err());
1141 let status = result.expect_err("Should return error");
1142 assert_eq!(status.code(), tonic::Code::NotFound);
1143 }
1144
1145 #[tokio::test]
1146 #[allow(clippy::await_holding_lock)]
1147 async fn test_stream_model_files_hf_first_chunk_includes_commit_hash() {
1148 let env_lock = acquire_env_mutex();
1149 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1150 let _cache_dir_guard = EnvVarGuard::set(
1151 &env_lock,
1152 "MODEL_EXPRESS_CACHE_DIRECTORY",
1153 temp_dir.path().to_str().expect("Expected temp dir path"),
1154 );
1155 let _offline_guard = EnvVarGuard::set(&env_lock, "HF_HUB_OFFLINE", "1");
1156
1157 let model_dir = temp_dir.path().join("models--test--model/snapshots/abc123");
1158 std::fs::create_dir_all(&model_dir).expect("Failed to create model dir");
1159 std::fs::write(model_dir.join("config.json"), br#"{"model":"test"}"#)
1160 .expect("Failed to write model file");
1161
1162 let service = ModelServiceImpl;
1163 let request = Request::new(ModelFilesRequest {
1164 model_name: "test/model".to_string(),
1165 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1166 chunk_size: 1024,
1167 });
1168
1169 let response = service
1170 .stream_model_files(request)
1171 .await
1172 .expect("Expected stream response");
1173 let mut stream = response.into_inner();
1174 let first_chunk = stream
1175 .next()
1176 .await
1177 .expect("Expected stream item")
1178 .expect("Expected first chunk");
1179
1180 assert_eq!(first_chunk.relative_path, "config.json");
1181 assert_eq!(first_chunk.commit_hash.as_deref(), Some("abc123"));
1182 assert!(first_chunk.is_last_chunk);
1183 assert!(first_chunk.is_last_file);
1184 }
1185}