1use crate::registry::backend::ClaimOutcome;
5use crate::registry::state::RegistryManager;
6use modelexpress_common::{
7 cache::{CacheConfig, resolve_model_path},
8 constants, download,
9 grpc::{
10 api::{ApiRequest, ApiResponse, api_service_server::ApiService},
11 health::{HealthRequest, HealthResponse, health_service_server::HealthService},
12 model::{
13 DeleteModelRequest, DeleteModelResponse, FileChunk, ModelDownloadRequest,
14 ModelFileInfo, ModelFileList, ModelFileSelector, ModelFilesRequest,
15 ModelProvider as GrpcModelProvider, ModelStatusUpdate,
16 model_service_server::ModelService,
17 },
18 },
19 models::{ModelProvider, ModelStatus},
20};
21use std::{
22 collections::HashMap,
23 path::{Path, PathBuf},
24 sync::{Arc, Mutex},
25 time::SystemTime,
26};
27use tokio::io::AsyncReadExt;
28use tokio_stream::wrappers::ReceiverStream;
29use tonic::{Request, Response, Status};
30use tracing::{debug, error, info};
31
32static START_TIME: std::sync::OnceLock<SystemTime> = std::sync::OnceLock::new();
33
34fn get_server_cache_dir() -> Option<std::path::PathBuf> {
36 if let Ok(config) = CacheConfig::discover() {
38 Some(config.local_path)
39 } else {
40 std::env::var("HF_HUB_CACHE")
42 .ok()
43 .map(std::path::PathBuf::from)
44 }
45}
46
47async fn model_files_present(
53 cache_dir: Option<std::path::PathBuf>,
54 model_name: &str,
55 provider: ModelProvider,
56) -> bool {
57 let Some(cache_dir) = cache_dir else {
58 return true;
59 };
60 download::get_provider(provider)
61 .get_model_path(model_name, cache_dir)
62 .await
63 .is_ok()
64}
65
66#[derive(Debug, Default)]
68pub struct HealthServiceImpl;
69
70#[tonic::async_trait]
71impl HealthService for HealthServiceImpl {
72 async fn get_health(
73 &self,
74 _request: Request<HealthRequest>,
75 ) -> Result<Response<HealthResponse>, Status> {
76 let start_time = START_TIME.get_or_init(SystemTime::now);
77 let uptime = SystemTime::now()
78 .duration_since(*start_time)
79 .unwrap_or_default()
80 .as_secs();
81
82 let response = HealthResponse {
83 version: env!("CARGO_PKG_VERSION").to_string(),
84 status: "ok".to_string(),
85 uptime,
86 };
87
88 Ok(Response::new(response))
89 }
90}
91
92#[derive(Debug, Default)]
94pub struct ApiServiceImpl;
95
96#[tonic::async_trait]
97impl ApiService for ApiServiceImpl {
98 async fn send_request(
99 &self,
100 request: Request<ApiRequest>,
101 ) -> Result<Response<ApiResponse>, Status> {
102 let api_request = request.into_inner();
103 info!("Received gRPC request: {:?}", api_request);
104
105 if api_request.action.as_str() == "ping" {
107 info!("Processing ping request");
108 let response_data = serde_json::json!({ "message": "pong" });
109 let data_bytes = serde_json::to_vec(&response_data)
110 .map_err(|e| Status::internal(format!("Serialization error: {e}")))?;
111
112 Ok(Response::new(ApiResponse {
113 success: true,
114 data: Some(data_bytes),
115 error: None,
116 }))
117 } else {
118 error!("Unknown action: {}", api_request.action);
119 Ok(Response::new(ApiResponse {
120 success: false,
121 data: None,
122 error: Some(format!("Unknown action: {}", api_request.action)),
123 }))
124 }
125 }
126}
127
128#[derive(Debug, Default)]
130pub struct ModelServiceImpl;
131
132fn collect_model_files(
134 base_path: &Path,
135 current_path: &Path,
136 file_selector: Option<&ModelFileSelector>,
137) -> Vec<(PathBuf, u64)> {
138 let mut files = Vec::new();
139
140 if let Ok(entries) = std::fs::read_dir(current_path) {
141 for entry in entries.flatten() {
142 let path = entry.path();
143 if path.is_file() {
144 if let Ok(metadata) = std::fs::metadata(&path) {
145 if let Ok(relative) = path.strip_prefix(base_path) {
147 let mut is_safe = true;
149 for comp in relative.components() {
150 use std::path::Component;
151 match comp {
152 Component::ParentDir
153 | Component::RootDir
154 | Component::Prefix(_) => {
155 is_safe = false;
156 break;
157 }
158 _ => {}
159 }
160 }
161 if !is_safe {
162 tracing::warn!(
163 "Skipping potentially unsafe file path: {:?} (relative: {:?})",
164 path,
165 relative
166 );
167 } else if file_selector.is_none_or(|selector| {
168 selector
169 .paths
170 .iter()
171 .any(|selector_path| Path::new(selector_path) == relative)
172 }) {
173 files.push((relative.to_path_buf(), metadata.len()));
174 }
175 }
176 }
177 } else if path.is_dir() {
178 files.extend(collect_model_files(base_path, &path, file_selector));
179 }
180 }
181 }
182
183 files
184}
185
186fn ensure_selected_files_exist(
187 files: &[(PathBuf, u64)],
188 file_selector: Option<&ModelFileSelector>,
189) -> Result<(), String> {
190 let Some(selector) = file_selector else {
191 return Ok(());
192 };
193
194 if let Some(missing_path) = selector.paths.iter().find(|selector_path| {
195 !files
196 .iter()
197 .any(|(path, _)| Path::new(selector_path) == path.as_path())
198 }) {
199 Err(format!(
200 "Selected file not found in model directory: {missing_path}"
201 ))
202 } else {
203 Ok(())
204 }
205}
206
207#[tonic::async_trait]
208impl ModelService for ModelServiceImpl {
209 type EnsureModelDownloadedStream = ReceiverStream<Result<ModelStatusUpdate, Status>>;
210 type StreamModelFilesStream = ReceiverStream<Result<FileChunk, Status>>;
211
212 async fn ensure_model_downloaded(
213 &self,
214 request: Request<ModelDownloadRequest>,
215 ) -> Result<Response<Self::EnsureModelDownloadedStream>, Status> {
216 info!("Starting model download stream");
217 let model_request = request.into_inner();
218 let (tx, rx) = tokio::sync::mpsc::channel(4);
219
220 let grpc_provider = GrpcModelProvider::try_from(model_request.provider).map_err(|_| {
222 Status::invalid_argument(format!(
223 "Invalid provider value: {}",
224 model_request.provider
225 ))
226 })?;
227 let provider = ModelProvider::from(grpc_provider);
228 let model_name = download::canonical_model_name(&model_request.model_name, provider)
229 .map_err(|e| Status::invalid_argument(e.to_string()))?;
230 let ignore_weights = model_request.ignore_weights;
231
232 tokio::spawn(async move {
234 let Some(tracker) = model_tracker() else {
235 let _ = tx
236 .send(Err(Status::unavailable(
237 "server startup incomplete: model tracker not initialized",
238 )))
239 .await;
240 return;
241 };
242 let final_status = tracker
249 .ensure_model_downloaded(&model_name, provider, &tx, ignore_weights)
250 .await;
251
252 let final_update = ModelStatusUpdate {
254 model_name: model_name.clone(),
255 status: modelexpress_common::grpc::model::ModelStatus::from(final_status) as i32,
256 message: match final_status {
257 ModelStatus::DOWNLOADED => {
258 Some("Model download completed successfully".to_string())
259 }
260 ModelStatus::ERROR => Some("Model download failed".to_string()),
261 ModelStatus::DOWNLOADING => Some("Download still in progress".to_string()),
262 },
263 provider: grpc_provider as i32,
264 };
265
266 let _ = tx.send(Ok(final_update)).await;
267 });
268
269 Ok(Response::new(ReceiverStream::new(rx)))
270 }
271
272 async fn stream_model_files(
273 &self,
274 request: Request<ModelFilesRequest>,
275 ) -> Result<Response<Self::StreamModelFilesStream>, Status> {
276 let files_request = request.into_inner();
277 let chunk_size = if files_request.chunk_size == 0 {
278 constants::DEFAULT_TRANSFER_CHUNK_SIZE
279 } else {
280 files_request.chunk_size as usize
281 };
282
283 let grpc_provider = GrpcModelProvider::try_from(files_request.provider).map_err(|_| {
285 Status::invalid_argument(format!(
286 "Invalid provider value: {}",
287 files_request.provider
288 ))
289 })?;
290 let provider = ModelProvider::from(grpc_provider);
291 let model_name = download::canonical_model_name(&files_request.model_name, provider)
292 .map_err(|e| Status::invalid_argument(e.to_string()))?;
293 let provider_impl = download::get_provider(provider);
294
295 info!(
296 "Starting file stream for model: {} with chunk size: {} bytes",
297 model_name, chunk_size
298 );
299
300 let cache_dir = get_server_cache_dir()
302 .ok_or_else(|| Status::internal("Server cache directory not configured"))?;
303
304 let model_path = provider_impl
306 .get_model_path(&model_name, cache_dir.clone())
307 .await
308 .map_err(|e| Status::not_found(format!("Model not found: {e}")))?;
309
310 debug!("Model path resolved to: {:?}", model_path);
311
312 let commit_hash = if provider == ModelProvider::HuggingFace {
313 model_path
314 .file_name()
315 .and_then(|name| name.to_str())
316 .map(String::from)
317 } else {
318 None
319 };
320
321 if provider == ModelProvider::HuggingFace && commit_hash.is_none() {
322 return Err(Status::internal(
323 "Resolved Hugging Face model path did not contain a revision",
324 ));
325 }
326
327 let expected_model_path =
328 resolve_model_path(&cache_dir, provider, &model_name, commit_hash.as_deref()).map_err(
329 |e| Status::internal(format!("Failed to resolve expected cache layout: {e}")),
330 )?;
331
332 if model_path != expected_model_path {
333 error!(
334 "Resolved model path '{}' does not match expected cache layout '{}' for model '{}'",
335 model_path.display(),
336 expected_model_path.display(),
337 model_name
338 );
339 return Err(Status::internal(
340 "Resolved model path does not match expected cache layout",
341 ));
342 }
343
344 let files = collect_model_files(
346 &model_path,
347 &model_path,
348 files_request.file_selector.as_ref(),
349 );
350 ensure_selected_files_exist(&files, files_request.file_selector.as_ref())
351 .map_err(Status::not_found)?;
352
353 if files.is_empty() {
354 return Err(Status::not_found("No files found in model directory"));
355 }
356
357 let total_files = files.len();
358 info!(
359 "Found {} files to stream for model {}",
360 total_files, model_name
361 );
362
363 let (tx, rx) = tokio::sync::mpsc::channel(16);
364
365 tokio::spawn(async move {
367 let mut buffer = vec![0u8; chunk_size];
369 let mut is_first_chunk = true;
370
371 for (file_idx, (relative_path, total_size)) in files.iter().enumerate() {
372 let file_path = model_path.join(relative_path);
373 let is_last_file = file_idx == total_files.saturating_sub(1);
374
375 debug!("Streaming file: {:?} ({} bytes)", relative_path, total_size);
376
377 let file = match tokio::fs::File::open(&file_path).await {
379 Ok(f) => f,
380 Err(e) => {
381 error!("Failed to open file {:?}: {}", file_path, e);
382 let _ = tx
383 .send(Err(Status::internal(format!("Failed to open file: {e}"))))
384 .await;
385 return;
386 }
387 };
388
389 let mut reader = tokio::io::BufReader::new(file);
390 let mut offset: u64 = 0;
391
392 if *total_size == 0 {
393 let first_chunk = std::mem::replace(&mut is_first_chunk, false);
394 let chunk = FileChunk {
395 relative_path: relative_path.to_string_lossy().to_string(),
396 data: Vec::new(),
397 offset: 0,
398 total_size: 0,
399 is_last_chunk: true,
400 is_last_file,
401 commit_hash: if first_chunk {
402 commit_hash.clone()
403 } else {
404 None
405 },
406 };
407
408 if tx.send(Ok(chunk)).await.is_err() {
409 debug!("Client disconnected during file stream");
410 return;
411 }
412
413 continue;
414 }
415
416 loop {
417 let bytes_read = match reader.read(&mut buffer).await {
418 Ok(0) => break, Ok(n) => n,
420 Err(e) => {
421 error!("Failed to read file {:?}: {}", file_path, e);
422 let _ = tx
423 .send(Err(Status::internal(format!("Failed to read file: {e}"))))
424 .await;
425 return;
426 }
427 };
428
429 let is_last_chunk = offset.saturating_add(bytes_read as u64) >= *total_size;
430
431 let first_chunk = std::mem::replace(&mut is_first_chunk, false);
432
433 let chunk = FileChunk {
434 relative_path: relative_path.to_string_lossy().to_string(),
435 data: buffer[..bytes_read].to_vec(),
436 offset,
437 total_size: *total_size,
438 is_last_chunk,
439 is_last_file: is_last_file && is_last_chunk,
440 commit_hash: if first_chunk {
441 commit_hash.clone()
442 } else {
443 None
444 },
445 };
446
447 if tx.send(Ok(chunk)).await.is_err() {
448 debug!("Client disconnected during file stream");
449 return;
450 }
451
452 offset = offset.saturating_add(bytes_read as u64);
453 }
454 }
455
456 info!("File streaming completed for model");
457 });
458
459 Ok(Response::new(ReceiverStream::new(rx)))
460 }
461
462 async fn list_model_files(
463 &self,
464 request: Request<ModelFilesRequest>,
465 ) -> Result<Response<ModelFileList>, Status> {
466 let files_request = request.into_inner();
467
468 let grpc_provider = GrpcModelProvider::try_from(files_request.provider).map_err(|_| {
470 Status::invalid_argument(format!(
471 "Invalid provider value: {}",
472 files_request.provider
473 ))
474 })?;
475 let provider = ModelProvider::from(grpc_provider);
476 let model_name = download::canonical_model_name(&files_request.model_name, provider)
477 .map_err(|e| Status::invalid_argument(e.to_string()))?;
478 let provider_impl = download::get_provider(provider);
479
480 info!("Listing files for model: {}", model_name);
481
482 let cache_dir = get_server_cache_dir()
484 .ok_or_else(|| Status::internal("Server cache directory not configured"))?;
485
486 let model_path = provider_impl
488 .get_model_path(&model_name, cache_dir)
489 .await
490 .map_err(|e| Status::not_found(format!("Model not found: {e}")))?;
491
492 let files = collect_model_files(
494 &model_path,
495 &model_path,
496 files_request.file_selector.as_ref(),
497 );
498 ensure_selected_files_exist(&files, files_request.file_selector.as_ref())
499 .map_err(Status::not_found)?;
500
501 let file_infos: Vec<ModelFileInfo> = files
502 .iter()
503 .map(|(path, size)| ModelFileInfo {
504 relative_path: path.to_string_lossy().to_string(),
505 size: *size,
506 })
507 .collect();
508
509 let total_size: u64 = files.iter().map(|(_, size)| size).sum();
510
511 Ok(Response::new(ModelFileList {
512 model_name,
513 files: file_infos,
514 total_size,
515 }))
516 }
517
518 async fn delete_model(
519 &self,
520 request: Request<DeleteModelRequest>,
521 ) -> Result<Response<DeleteModelResponse>, Status> {
522 let delete_request = request.into_inner();
523
524 let grpc_provider = GrpcModelProvider::try_from(delete_request.provider).map_err(|_| {
525 Status::invalid_argument(format!(
526 "Invalid provider value: {}",
527 delete_request.provider
528 ))
529 })?;
530 let provider = ModelProvider::from(grpc_provider);
531 let model_name = download::canonical_model_name(&delete_request.model_name, provider)
532 .map_err(|e| Status::invalid_argument(e.to_string()))?;
533
534 let Some(tracker) = model_tracker() else {
535 return Err(Status::unavailable(
536 "server startup incomplete: model tracker not initialized",
537 ));
538 };
539
540 tracker.delete_status(&model_name).await;
541 info!("Deleted registry record for model '{model_name}'");
542
543 Ok(Response::new(DeleteModelResponse {
544 success: true,
545 message: Some(format!("Model '{model_name}' removed from registry")),
546 }))
547 }
548}
549
550type WaitingChannels =
552 Arc<Mutex<HashMap<String, Vec<tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>>>>>;
553
554#[derive(Clone)]
556pub struct ModelDownloadTracker {
557 registry: Arc<RegistryManager>,
559 waiting_channels: WaitingChannels,
561}
562
563impl ModelDownloadTracker {
564 pub fn new(registry: Arc<RegistryManager>) -> Self {
565 Self {
566 registry,
567 waiting_channels: Arc::new(Mutex::new(HashMap::new())),
568 }
569 }
570
571 pub async fn get_status(&self, model_name: &str) -> Option<ModelStatus> {
574 match self.registry.get_status(model_name).await {
575 Ok(Some(status)) => {
576 if let Err(e) = self.registry.touch_model(model_name).await {
577 error!("Failed to touch model {model_name}: {e}");
578 }
579 Some(status)
580 }
581 Ok(None) => None,
582 Err(e) => {
583 error!("Failed to get model status from registry: {e}");
584 None
585 }
586 }
587 }
588
589 pub async fn set_status_and_notify(
591 &self,
592 model_name: String,
593 status: ModelStatus,
594 provider: ModelProvider,
595 message: Option<String>,
596 ) {
597 if let Err(e) = self
598 .registry
599 .set_status(&model_name, provider, status, message.clone())
600 .await
601 {
602 error!("Failed to update model status in registry: {e}");
603 return;
604 }
605
606 let mut waiting = match self.waiting_channels.lock() {
607 Ok(guard) => guard,
608 Err(poisoned) => {
609 error!("Waiting channels mutex is poisoned, recovering");
610 poisoned.into_inner()
611 }
612 };
613 if let Some(channels) = waiting.get(&model_name) {
614 let update = ModelStatusUpdate {
615 model_name: model_name.clone(),
616 status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
617 message,
618 provider: GrpcModelProvider::from(provider) as i32,
619 };
620 for channel in channels {
621 let _ = channel.try_send(Ok(update.clone()));
622 }
623 if status == ModelStatus::DOWNLOADED || status == ModelStatus::ERROR {
624 waiting.remove(&model_name);
625 }
626 }
627 }
628
629 pub async fn set_status(
631 &self,
632 model_name: String,
633 status: ModelStatus,
634 provider: ModelProvider,
635 ) {
636 self.set_status_and_notify(model_name, status, provider, None)
637 .await;
638 }
639
640 pub fn add_waiting_channel(
642 &self,
643 model_name: &str,
644 tx: tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
645 ) {
646 let mut waiting = match self.waiting_channels.lock() {
647 Ok(guard) => guard,
648 Err(poisoned) => {
649 error!("Waiting channels mutex is poisoned, recovering");
650 poisoned.into_inner()
651 }
652 };
653 waiting.entry(model_name.to_string()).or_default().push(tx);
654 }
655
656 pub async fn delete_status(&self, model_name: &str) {
658 if let Err(e) = self.registry.delete_model(model_name).await {
659 error!("Failed to delete model from registry: {e}");
660 }
661 let mut waiting = match self.waiting_channels.lock() {
662 Ok(guard) => guard,
663 Err(poisoned) => {
664 error!("Waiting channels mutex is poisoned, recovering");
665 poisoned.into_inner()
666 }
667 };
668 waiting.remove(model_name);
669 }
670
671 fn spawn_download_task(
674 &self,
675 model_name: String,
676 provider: ModelProvider,
677 ignore_weights: bool,
678 retry: bool,
679 ) {
680 let tracker = self.clone();
681 tokio::spawn(async move {
682 let cache_dir = get_server_cache_dir();
683 match download::download_model(&model_name, provider, cache_dir, ignore_weights).await {
684 Ok(_path) => {
685 tracker
686 .set_status_and_notify(
687 model_name,
688 ModelStatus::DOWNLOADED,
689 provider,
690 Some("Model download completed successfully".to_string()),
691 )
692 .await;
693 }
694 Err(e) => {
695 if retry {
696 error!("Failed to download model {model_name} on retry: {e}");
697 } else {
698 error!("Failed to download model {model_name}: {e}");
699 }
700 let msg = if retry {
701 format!("Download failed on retry: {e}")
702 } else {
703 format!("Download failed: {e}")
704 };
705 tracker
706 .set_status_and_notify(model_name, ModelStatus::ERROR, provider, Some(msg))
707 .await;
708 }
709 }
710 });
711 }
712
713 pub async fn ensure_model_downloaded(
715 &self,
716 model_name: &str,
717 provider: ModelProvider,
718 tx: &tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
719 ignore_weights: bool,
720 ) -> ModelStatus {
721 const MAX_CLAIM_ATTEMPTS: usize = 2;
730 let mut attempt: usize = 0;
731 let (status, is_owner) = loop {
732 attempt = attempt.saturating_add(1);
733 match self
734 .registry
735 .try_claim_for_download(model_name, provider)
736 .await
737 {
738 Ok(ClaimOutcome::Claimed) => break (ModelStatus::DOWNLOADING, true),
739 Ok(ClaimOutcome::AlreadyExists(existing)) => {
740 if existing == ModelStatus::DOWNLOADED
741 && attempt < MAX_CLAIM_ATTEMPTS
742 && !model_files_present(get_server_cache_dir(), model_name, provider).await
743 {
744 error!(
745 "Registry reports model '{model_name}' as DOWNLOADED but its files \
746 are missing from the cache; clearing the stale record and \
747 re-downloading"
748 );
749 self.delete_status(model_name).await;
750 continue;
751 }
752 break (existing, false);
753 }
754 Err(e) => {
755 error!("Failed to claim model for download: {e}");
756 let error_update = ModelStatusUpdate {
757 model_name: model_name.to_string(),
758 status: modelexpress_common::grpc::model::ModelStatus::from(
759 ModelStatus::ERROR,
760 ) as i32,
761 message: Some("Registry error occurred".to_string()),
762 provider: GrpcModelProvider::from(provider) as i32,
763 };
764 let _ = tx.send(Ok(error_update)).await;
765 return ModelStatus::ERROR;
766 }
767 }
768 };
769
770 let (effective_status, is_retry_owner) = if status == ModelStatus::ERROR {
776 let won = match self
777 .registry
778 .try_reset_error_for_retry(model_name, provider)
779 .await
780 {
781 Ok(won) => won,
782 Err(e) => {
783 error!("Failed to CAS status for retry: {e}");
784 let _ = tx
785 .send(Ok(ModelStatusUpdate {
786 model_name: model_name.to_string(),
787 status: modelexpress_common::grpc::model::ModelStatus::from(
788 ModelStatus::ERROR,
789 ) as i32,
790 message: Some("Registry error occurred during retry".to_string()),
791 provider: GrpcModelProvider::from(provider) as i32,
792 }))
793 .await;
794 return ModelStatus::ERROR;
795 }
796 };
797 (ModelStatus::DOWNLOADING, won)
798 } else {
799 (status, false)
800 };
801
802 let update = ModelStatusUpdate {
803 model_name: model_name.to_string(),
804 status: modelexpress_common::grpc::model::ModelStatus::from(effective_status) as i32,
805 message: match (status, effective_status) {
806 (_, ModelStatus::DOWNLOADED) => Some("Model already downloaded".to_string()),
807 (ModelStatus::ERROR, _) => Some("Previous download failed, retrying".to_string()),
808 (_, ModelStatus::DOWNLOADING) => Some("Model download in progress".to_string()),
809 (_, ModelStatus::ERROR) => Some("Download error".to_string()),
811 },
812 provider: GrpcModelProvider::from(provider) as i32,
813 };
814 let _ = tx.send(Ok(update)).await;
815
816 if effective_status == ModelStatus::DOWNLOADING {
817 self.add_waiting_channel(model_name, tx.clone());
820
821 if is_owner || is_retry_owner {
824 let retry = status == ModelStatus::ERROR;
825 self.spawn_download_task(model_name.to_string(), provider, ignore_weights, retry);
826 }
827
828 loop {
830 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
831 if let Some(current_status) = self.get_status(model_name).await
832 && current_status != ModelStatus::DOWNLOADING
833 {
834 return current_status;
835 }
836 }
837 }
838
839 effective_status
840 }
841}
842
843static MODEL_TRACKER: std::sync::OnceLock<Arc<ModelDownloadTracker>> = std::sync::OnceLock::new();
846
847pub fn init_model_tracker(
850 tracker: Arc<ModelDownloadTracker>,
851) -> Result<Arc<ModelDownloadTracker>, &'static str> {
852 MODEL_TRACKER
853 .set(tracker.clone())
854 .map(|()| tracker)
855 .map_err(|_| "init_model_tracker called more than once")
856}
857
858pub fn model_tracker() -> Option<&'static Arc<ModelDownloadTracker>> {
861 MODEL_TRACKER.get()
862}
863
864#[cfg(test)]
865#[allow(clippy::expect_used)]
866mod tests {
867 use super::*;
868 use modelexpress_common::grpc::{api::ApiRequest, health::HealthRequest};
869 use modelexpress_common::test_support::{EnvVarGuard, acquire_env_mutex};
870 use tempfile::TempDir;
871 use tokio_stream::StreamExt;
872 use tonic::Request;
873
874 #[tokio::test]
875 async fn test_health_service() {
876 let service = HealthServiceImpl;
877 let request = Request::new(HealthRequest {});
878
879 let response = service.get_health(request).await;
880 assert!(response.is_ok());
881
882 let health_response = response.expect("Health response should be ok").into_inner();
883 assert_eq!(health_response.version, env!("CARGO_PKG_VERSION"));
884 assert_eq!(health_response.status, "ok");
885 let _uptime = health_response.uptime;
887 }
888
889 #[tokio::test]
890 async fn test_api_service_ping() {
891 let service = ApiServiceImpl;
892 let request = Request::new(ApiRequest {
893 id: "test-id".to_string(),
894 action: "ping".to_string(),
895 payload: None,
896 });
897
898 let response = service.send_request(request).await;
899 assert!(response.is_ok());
900
901 let api_response = response.expect("API response should be ok").into_inner();
902 assert!(api_response.success);
903 assert!(api_response.data.is_some());
904 assert!(api_response.error.is_none());
905
906 let data_bytes = api_response.data.expect("Data should be present");
908 let data: serde_json::Value =
909 serde_json::from_slice(&data_bytes).expect("Data should be valid JSON");
910 assert_eq!(data["message"], "pong");
911 }
912
913 #[tokio::test]
914 async fn test_api_service_unknown_action() {
915 let service = ApiServiceImpl;
916 let request = Request::new(ApiRequest {
917 id: "test-id".to_string(),
918 action: "unknown-action".to_string(),
919 payload: None,
920 });
921
922 let response = service.send_request(request).await;
923 assert!(response.is_ok());
924
925 let api_response = response.expect("API response should be ok").into_inner();
926 assert!(!api_response.success);
927 assert!(api_response.data.is_none());
928 assert!(api_response.error.is_some());
929
930 let error_message = api_response.error.expect("Error should be present");
931 assert!(error_message.contains("Unknown action"));
932 }
933
934 fn tracker_with_mock(
939 mock: crate::registry::backend::MockRegistryBackend,
940 ) -> ModelDownloadTracker {
941 let registry = Arc::new(RegistryManager::with_backend(Arc::new(mock)));
942 ModelDownloadTracker::new(registry)
943 }
944
945 #[tokio::test]
946 async fn test_tracker_get_status_missing_returns_none() {
947 let mut mock = crate::registry::backend::MockRegistryBackend::new();
948 mock.expect_get_status().once().returning(|_| Ok(None));
949 let tracker = tracker_with_mock(mock);
951 assert!(tracker.get_status("unknown").await.is_none());
952 }
953
954 #[tokio::test]
955 async fn test_tracker_get_status_hit_bumps_last_used_at() {
956 let mut mock = crate::registry::backend::MockRegistryBackend::new();
957 mock.expect_get_status()
958 .once()
959 .returning(|_| Ok(Some(ModelStatus::DOWNLOADED)));
960 mock.expect_touch_model().once().returning(|_| Ok(()));
961 let tracker = tracker_with_mock(mock);
962 assert_eq!(tracker.get_status("m").await, Some(ModelStatus::DOWNLOADED));
963 }
964
965 #[tokio::test]
966 async fn test_tracker_set_status_notifies_waiting_channel() {
967 let mut mock = crate::registry::backend::MockRegistryBackend::new();
968 mock.expect_set_status()
969 .once()
970 .returning(|_, _, _, _| Ok(()));
971 let tracker = tracker_with_mock(mock);
972
973 let (tx, mut rx) = tokio::sync::mpsc::channel(4);
974 tracker.add_waiting_channel("m", tx);
975
976 tracker
977 .set_status_and_notify(
978 "m".to_string(),
979 ModelStatus::DOWNLOADED,
980 ModelProvider::HuggingFace,
981 Some("done".to_string()),
982 )
983 .await;
984
985 let update = rx.recv().await.expect("waiter should receive update");
986 let update = update.expect("notify should send Ok");
987 assert_eq!(update.model_name, "m");
988 assert_eq!(
989 update.status,
990 modelexpress_common::grpc::model::ModelStatus::Downloaded as i32
991 );
992 assert_eq!(update.message.as_deref(), Some("done"));
993
994 let waiters = tracker
996 .waiting_channels
997 .lock()
998 .expect("waiters lock")
999 .get("m")
1000 .map_or(0, std::vec::Vec::len);
1001 assert_eq!(waiters, 0);
1002 }
1003
1004 #[tokio::test]
1005 async fn test_tracker_delete_status_clears_backend_and_waiters() {
1006 let mut mock = crate::registry::backend::MockRegistryBackend::new();
1007 mock.expect_delete_model().once().returning(|_| Ok(()));
1008 let tracker = tracker_with_mock(mock);
1009
1010 let (tx, _rx) = tokio::sync::mpsc::channel(1);
1011 tracker.add_waiting_channel("m", tx);
1012 tracker.delete_status("m").await;
1013
1014 let waiters = tracker
1015 .waiting_channels
1016 .lock()
1017 .expect("waiters lock")
1018 .contains_key("m");
1019 assert!(!waiters);
1020 }
1021
1022 #[tokio::test]
1023 async fn test_tracker_set_status_delegates_without_message() {
1024 let mut mock = crate::registry::backend::MockRegistryBackend::new();
1025 mock.expect_set_status()
1026 .withf(|_, _, status, msg| *status == ModelStatus::DOWNLOADING && msg.is_none())
1027 .once()
1028 .returning(|_, _, _, _| Ok(()));
1029 let tracker = tracker_with_mock(mock);
1030 tracker
1031 .set_status(
1032 "m".to_string(),
1033 ModelStatus::DOWNLOADING,
1034 ModelProvider::HuggingFace,
1035 )
1036 .await;
1037 }
1038
1039 #[tokio::test]
1040 async fn test_tracker_error_status_clears_waiters() {
1041 let mut mock = crate::registry::backend::MockRegistryBackend::new();
1042 mock.expect_set_status()
1043 .once()
1044 .returning(|_, _, _, _| Ok(()));
1045 let tracker = tracker_with_mock(mock);
1046 let (tx, _rx) = tokio::sync::mpsc::channel(1);
1047 tracker.add_waiting_channel("m", tx);
1048 tracker
1049 .set_status_and_notify(
1050 "m".to_string(),
1051 ModelStatus::ERROR,
1052 ModelProvider::HuggingFace,
1053 Some("fail".to_string()),
1054 )
1055 .await;
1056 let waiters = tracker
1057 .waiting_channels
1058 .lock()
1059 .expect("waiters lock")
1060 .get("m")
1061 .map_or(0, std::vec::Vec::len);
1062 assert_eq!(waiters, 0, "ERROR is terminal, waiters must be cleared");
1063 }
1064
1065 #[tokio::test]
1066 async fn test_tracker_downloading_status_keeps_waiters() {
1067 let mut mock = crate::registry::backend::MockRegistryBackend::new();
1068 mock.expect_set_status()
1069 .once()
1070 .returning(|_, _, _, _| Ok(()));
1071 let tracker = tracker_with_mock(mock);
1072 let (tx, _rx) = tokio::sync::mpsc::channel(1);
1073 tracker.add_waiting_channel("m", tx);
1074 tracker
1075 .set_status_and_notify(
1076 "m".to_string(),
1077 ModelStatus::DOWNLOADING,
1078 ModelProvider::HuggingFace,
1079 None,
1080 )
1081 .await;
1082 let waiters = tracker
1083 .waiting_channels
1084 .lock()
1085 .expect("waiters lock")
1086 .get("m")
1087 .map_or(0, std::vec::Vec::len);
1088 assert_eq!(
1089 waiters, 1,
1090 "DOWNLOADING is non-terminal, waiter must remain"
1091 );
1092 }
1093
1094 #[tokio::test]
1095 async fn test_tracker_set_status_swallows_backend_error() {
1096 let mut mock = crate::registry::backend::MockRegistryBackend::new();
1097 mock.expect_set_status()
1098 .once()
1099 .returning(|_, _, _, _| Err("redis down".into()));
1100 let tracker = tracker_with_mock(mock);
1101 let (tx, mut rx) = tokio::sync::mpsc::channel(1);
1102 tracker.add_waiting_channel("m", tx);
1103 tracker
1105 .set_status_and_notify(
1106 "m".to_string(),
1107 ModelStatus::DOWNLOADED,
1108 ModelProvider::HuggingFace,
1109 None,
1110 )
1111 .await;
1112 assert!(
1114 rx.try_recv().is_err(),
1115 "waiter shouldn't receive on backend error"
1116 );
1117 }
1118
1119 #[tokio::test]
1120 async fn test_model_tracker_uninitialized_returns_none() {
1121 assert!(model_tracker().is_none());
1125 }
1126
1127 #[test]
1128 fn test_collect_model_files_empty_dir() {
1129 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1130 let files = collect_model_files(temp_dir.path(), temp_dir.path(), None);
1131 assert!(files.is_empty());
1132 }
1133
1134 #[test]
1135 fn test_collect_model_files_with_files() {
1136 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1137
1138 let file1_path = temp_dir.path().join("config.json");
1140 std::fs::write(&file1_path, r#"{"test": "data"}"#).expect("Failed to write file1");
1141
1142 let file2_path = temp_dir.path().join("model.bin");
1143 std::fs::write(&file2_path, vec![0u8; 100]).expect("Failed to write file2");
1144
1145 let files = collect_model_files(temp_dir.path(), temp_dir.path(), None);
1146
1147 assert_eq!(files.len(), 2);
1148
1149 let total_size: u64 = files.iter().map(|(_, size)| size).sum();
1151 assert!(total_size > 0);
1152
1153 let paths: Vec<_> = files
1155 .iter()
1156 .map(|(p, _)| p.to_string_lossy().to_string())
1157 .collect();
1158 assert!(paths.contains(&"config.json".to_string()));
1159 assert!(paths.contains(&"model.bin".to_string()));
1160 }
1161
1162 #[test]
1163 fn test_collect_model_files_nested() {
1164 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1165
1166 let subdir = temp_dir.path().join("subdir");
1168 std::fs::create_dir(&subdir).expect("Failed to create subdir");
1169
1170 let file1_path = temp_dir.path().join("root_file.txt");
1171 std::fs::write(&file1_path, "root content").expect("Failed to write file1");
1172
1173 let file2_path = subdir.join("nested_file.txt");
1174 std::fs::write(&file2_path, "nested content").expect("Failed to write file2");
1175
1176 let files = collect_model_files(temp_dir.path(), temp_dir.path(), None);
1177
1178 assert_eq!(files.len(), 2);
1179
1180 let paths: Vec<_> = files
1182 .iter()
1183 .map(|(p, _)| p.to_string_lossy().to_string())
1184 .collect();
1185 assert!(paths.iter().any(|p| p.contains("nested_file")));
1186 }
1187
1188 #[test]
1189 fn test_collect_model_files_with_selector_filters_exact_paths() {
1190 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1191 let subdir = temp_dir.path().join("subdir");
1192 std::fs::create_dir(&subdir).expect("Failed to create subdir");
1193 std::fs::write(temp_dir.path().join("config.json"), "{}").expect("Failed to write config");
1194 std::fs::write(temp_dir.path().join("model.bin"), vec![0u8; 100])
1195 .expect("Failed to write model");
1196 std::fs::write(temp_dir.path().join("ignored.txt"), "ignore")
1197 .expect("Failed to write ignored");
1198 std::fs::write(subdir.join("nested.txt"), "nested").expect("Failed to write nested");
1199
1200 let selector = ModelFileSelector {
1201 paths: vec!["config.json".to_string(), "subdir/nested.txt".to_string()],
1202 };
1203 let files = collect_model_files(temp_dir.path(), temp_dir.path(), Some(&selector));
1204
1205 let mut paths: Vec<_> = files
1206 .iter()
1207 .map(|(p, _)| p.to_string_lossy().to_string())
1208 .collect();
1209 paths.sort();
1210 assert_eq!(
1211 paths,
1212 vec!["config.json".to_string(), "subdir/nested.txt".to_string()]
1213 );
1214 }
1215
1216 #[test]
1217 fn test_collect_model_files_with_selector_empty_and_nonmatching_paths() {
1218 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1219 std::fs::write(temp_dir.path().join("config.json"), "{}").expect("Failed to write config");
1220
1221 let empty_selector = ModelFileSelector { paths: vec![] };
1222 assert!(
1223 collect_model_files(temp_dir.path(), temp_dir.path(), Some(&empty_selector)).is_empty()
1224 );
1225
1226 let nonmatching_selector = ModelFileSelector {
1227 paths: vec!["missing.json".to_string(), "../config.json".to_string()],
1228 };
1229 assert!(
1230 collect_model_files(
1231 temp_dir.path(),
1232 temp_dir.path(),
1233 Some(&nonmatching_selector)
1234 )
1235 .is_empty()
1236 );
1237 }
1238
1239 #[tokio::test]
1240 #[allow(clippy::await_holding_lock)]
1241 async fn test_list_model_files_hf_honors_file_selector() {
1242 let env_lock = acquire_env_mutex();
1243 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1244 let _cache_dir_guard = EnvVarGuard::set(
1245 &env_lock,
1246 "MODEL_EXPRESS_CACHE_DIRECTORY",
1247 temp_dir.path().to_str().expect("Expected temp dir path"),
1248 );
1249 let _offline_guard = EnvVarGuard::set(&env_lock, "HF_HUB_OFFLINE", "1");
1250
1251 let model_dir = temp_dir.path().join("models--test--model/snapshots/abc123");
1252 std::fs::create_dir_all(model_dir.join("subdir")).expect("Failed to create model dir");
1253 std::fs::write(model_dir.join("config.json"), br#"{"model":"test"}"#)
1254 .expect("Failed to write config");
1255 std::fs::write(model_dir.join("model.bin"), vec![0u8; 100]).expect("Failed to write model");
1256 std::fs::write(model_dir.join("subdir/nested.txt"), b"nested")
1257 .expect("Failed to write nested");
1258
1259 let service = ModelServiceImpl;
1260 let request = Request::new(ModelFilesRequest {
1261 model_name: "test/model".to_string(),
1262 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1263 chunk_size: 0,
1264 file_selector: Some(ModelFileSelector {
1265 paths: vec!["config.json".to_string(), "subdir/nested.txt".to_string()],
1266 }),
1267 });
1268
1269 let response = service
1270 .list_model_files(request)
1271 .await
1272 .expect("Expected file list")
1273 .into_inner();
1274 let mut paths: Vec<_> = response
1275 .files
1276 .iter()
1277 .map(|file| file.relative_path.clone())
1278 .collect();
1279 paths.sort();
1280
1281 assert_eq!(
1282 paths,
1283 vec!["config.json".to_string(), "subdir/nested.txt".to_string()]
1284 );
1285 assert_eq!(
1286 response.total_size,
1287 br#"{"model":"test"}"#.len() as u64 + b"nested".len() as u64
1288 );
1289 }
1290
1291 #[tokio::test]
1292 #[allow(clippy::await_holding_lock)]
1293 async fn test_model_files_present_reflects_disk_state() {
1294 let env_lock = acquire_env_mutex();
1295 let _offline_guard = EnvVarGuard::set(&env_lock, "HF_HUB_OFFLINE", "1");
1296 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1297 let cache_dir = temp_dir.path().to_path_buf();
1298
1299 assert!(
1301 !model_files_present(
1302 Some(cache_dir.clone()),
1303 "test/model",
1304 ModelProvider::HuggingFace
1305 )
1306 .await
1307 );
1308
1309 let model_dir = cache_dir.join("models--test--model/snapshots/abc123");
1311 std::fs::create_dir_all(&model_dir).expect("Failed to create model dir");
1312 std::fs::write(model_dir.join("config.json"), b"{}").expect("Failed to write config");
1313 assert!(
1314 model_files_present(Some(cache_dir), "test/model", ModelProvider::HuggingFace).await
1315 );
1316 }
1317
1318 #[tokio::test]
1319 async fn test_model_files_present_assumes_present_without_cache_dir() {
1320 assert!(model_files_present(None, "test/model", ModelProvider::HuggingFace).await);
1323 }
1324
1325 #[tokio::test]
1326 #[allow(clippy::await_holding_lock)]
1327 async fn test_stream_model_files_hf_honors_file_selector() {
1328 let env_lock = acquire_env_mutex();
1329 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1330 let _cache_dir_guard = EnvVarGuard::set(
1331 &env_lock,
1332 "MODEL_EXPRESS_CACHE_DIRECTORY",
1333 temp_dir.path().to_str().expect("Expected temp dir path"),
1334 );
1335 let _offline_guard = EnvVarGuard::set(&env_lock, "HF_HUB_OFFLINE", "1");
1336
1337 let model_dir = temp_dir.path().join("models--test--model/snapshots/abc123");
1338 std::fs::create_dir_all(&model_dir).expect("Failed to create model dir");
1339 std::fs::write(model_dir.join("config.json"), br#"{"model":"test"}"#)
1340 .expect("Failed to write config");
1341 std::fs::write(model_dir.join("model.bin"), vec![0u8; 100]).expect("Failed to write model");
1342
1343 let service = ModelServiceImpl;
1344 let request = Request::new(ModelFilesRequest {
1345 model_name: "test/model".to_string(),
1346 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1347 chunk_size: 1024,
1348 file_selector: Some(ModelFileSelector {
1349 paths: vec!["config.json".to_string()],
1350 }),
1351 });
1352
1353 let response = service
1354 .stream_model_files(request)
1355 .await
1356 .expect("Expected stream response");
1357 let chunks: Vec<_> = response
1358 .into_inner()
1359 .map(|chunk| chunk.expect("Expected chunk"))
1360 .collect()
1361 .await;
1362
1363 assert_eq!(chunks.len(), 1);
1364 assert_eq!(chunks[0].relative_path, "config.json");
1365 assert_eq!(chunks[0].commit_hash.as_deref(), Some("abc123"));
1366 assert!(chunks[0].is_last_file);
1367 }
1368
1369 #[tokio::test]
1370 #[allow(clippy::await_holding_lock)]
1371 async fn test_stream_model_files_hf_returns_not_found_for_missing_selector_path() {
1372 let env_lock = acquire_env_mutex();
1373 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1374 let _cache_dir_guard = EnvVarGuard::set(
1375 &env_lock,
1376 "MODEL_EXPRESS_CACHE_DIRECTORY",
1377 temp_dir.path().to_str().expect("Expected temp dir path"),
1378 );
1379 let _offline_guard = EnvVarGuard::set(&env_lock, "HF_HUB_OFFLINE", "1");
1380
1381 let model_dir = temp_dir.path().join("models--test--model/snapshots/abc123");
1382 std::fs::create_dir_all(&model_dir).expect("Failed to create model dir");
1383 std::fs::write(model_dir.join("config.json"), br#"{"model":"test"}"#)
1384 .expect("Failed to write config");
1385
1386 let service = ModelServiceImpl;
1387 let request = Request::new(ModelFilesRequest {
1388 model_name: "test/model".to_string(),
1389 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1390 chunk_size: 1024,
1391 file_selector: Some(ModelFileSelector {
1392 paths: vec!["config.json".to_string(), "missing.json".to_string()],
1393 }),
1394 });
1395
1396 let result = service.stream_model_files(request).await;
1397 let status = result.expect_err("Expected not found");
1398 assert_eq!(status.code(), tonic::Code::NotFound);
1399 assert_eq!(
1400 status.message(),
1401 "Selected file not found in model directory: missing.json"
1402 );
1403 }
1404
1405 #[tokio::test]
1406 async fn test_list_model_files_not_found() {
1407 let service = ModelServiceImpl;
1408
1409 let request = Request::new(ModelFilesRequest {
1410 model_name: "non-existent-model-12345".to_string(),
1411 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1412 chunk_size: 0,
1413 file_selector: None,
1414 });
1415
1416 let result = service.list_model_files(request).await;
1417 assert!(result.is_err());
1418 let status = result.expect_err("Should return error");
1419 assert_eq!(status.code(), tonic::Code::NotFound);
1420 }
1421
1422 #[tokio::test]
1423 async fn test_stream_model_files_not_found() {
1424 let service = ModelServiceImpl;
1425
1426 let request = Request::new(ModelFilesRequest {
1427 model_name: "non-existent-model-12345".to_string(),
1428 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1429 chunk_size: 1024,
1430 file_selector: None,
1431 });
1432
1433 let result = service.stream_model_files(request).await;
1434 assert!(result.is_err());
1435 let status = result.expect_err("Should return error");
1436 assert_eq!(status.code(), tonic::Code::NotFound);
1437 }
1438
1439 #[tokio::test]
1440 async fn test_ensure_model_downloaded_rejects_invalid_provider() {
1441 let service = ModelServiceImpl;
1442
1443 let request = Request::new(ModelDownloadRequest {
1444 model_name: "test/model".to_string(),
1445 provider: 99,
1446 ignore_weights: false,
1447 });
1448
1449 let result = service.ensure_model_downloaded(request).await;
1450 assert!(result.is_err());
1451 let status = result.expect_err("Should return error");
1452 assert_eq!(status.code(), tonic::Code::InvalidArgument);
1453 assert!(status.message().contains("Invalid provider value"));
1454 }
1455
1456 #[tokio::test]
1457 async fn test_list_model_files_rejects_invalid_provider() {
1458 let service = ModelServiceImpl;
1459
1460 let request = Request::new(ModelFilesRequest {
1461 model_name: "test/model".to_string(),
1462 provider: 99,
1463 chunk_size: 0,
1464 file_selector: None,
1465 });
1466
1467 let result = service.list_model_files(request).await;
1468 assert!(result.is_err());
1469 let status = result.expect_err("Should return error");
1470 assert_eq!(status.code(), tonic::Code::InvalidArgument);
1471 assert!(status.message().contains("Invalid provider value"));
1472 }
1473
1474 #[tokio::test]
1475 async fn test_stream_model_files_rejects_invalid_provider() {
1476 let service = ModelServiceImpl;
1477
1478 let request = Request::new(ModelFilesRequest {
1479 model_name: "test/model".to_string(),
1480 provider: 99,
1481 chunk_size: 1024,
1482 file_selector: None,
1483 });
1484
1485 let result = service.stream_model_files(request).await;
1486 assert!(result.is_err());
1487 let status = result.expect_err("Should return error");
1488 assert_eq!(status.code(), tonic::Code::InvalidArgument);
1489 assert!(status.message().contains("Invalid provider value"));
1490 }
1491
1492 #[tokio::test]
1493 #[allow(clippy::await_holding_lock)]
1494 async fn test_stream_model_files_hf_first_chunk_includes_commit_hash() {
1495 let env_lock = acquire_env_mutex();
1496 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1497 let _cache_dir_guard = EnvVarGuard::set(
1498 &env_lock,
1499 "MODEL_EXPRESS_CACHE_DIRECTORY",
1500 temp_dir.path().to_str().expect("Expected temp dir path"),
1501 );
1502 let _offline_guard = EnvVarGuard::set(&env_lock, "HF_HUB_OFFLINE", "1");
1503
1504 let model_dir = temp_dir.path().join("models--test--model/snapshots/abc123");
1505 std::fs::create_dir_all(&model_dir).expect("Failed to create model dir");
1506 std::fs::write(model_dir.join("config.json"), br#"{"model":"test"}"#)
1507 .expect("Failed to write model file");
1508
1509 let service = ModelServiceImpl;
1510 let request = Request::new(ModelFilesRequest {
1511 model_name: "test/model".to_string(),
1512 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1513 chunk_size: 1024,
1514 file_selector: None,
1515 });
1516
1517 let response = service
1518 .stream_model_files(request)
1519 .await
1520 .expect("Expected stream response");
1521 let mut stream = response.into_inner();
1522 let first_chunk = stream
1523 .next()
1524 .await
1525 .expect("Expected stream item")
1526 .expect("Expected first chunk");
1527
1528 assert_eq!(first_chunk.relative_path, "config.json");
1529 assert_eq!(first_chunk.commit_hash.as_deref(), Some("abc123"));
1530 assert!(first_chunk.is_last_chunk);
1531 assert!(first_chunk.is_last_file);
1532 }
1533
1534 #[tokio::test]
1535 #[allow(clippy::await_holding_lock)]
1536 async fn test_stream_model_files_hf_emits_chunk_for_zero_byte_file() {
1537 let env_lock = acquire_env_mutex();
1538 let temp_dir = TempDir::new().expect("Failed to create temp dir");
1539 let _cache_dir_guard = EnvVarGuard::set(
1540 &env_lock,
1541 "MODEL_EXPRESS_CACHE_DIRECTORY",
1542 temp_dir.path().to_str().expect("Expected temp dir path"),
1543 );
1544 let _offline_guard = EnvVarGuard::set(&env_lock, "HF_HUB_OFFLINE", "1");
1545
1546 let model_dir = temp_dir.path().join("models--test--model/snapshots/abc123");
1547 std::fs::create_dir_all(&model_dir).expect("Failed to create model dir");
1548 std::fs::write(model_dir.join("empty.bin"), []).expect("Failed to write empty file");
1549
1550 let service = ModelServiceImpl;
1551 let request = Request::new(ModelFilesRequest {
1552 model_name: "test/model".to_string(),
1553 provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
1554 chunk_size: 1024,
1555 file_selector: None,
1556 });
1557
1558 let response = service
1559 .stream_model_files(request)
1560 .await
1561 .expect("Expected stream response");
1562 let mut stream = response.into_inner();
1563 let first_chunk = stream
1564 .next()
1565 .await
1566 .expect("Expected stream item")
1567 .expect("Expected first chunk");
1568
1569 assert_eq!(first_chunk.relative_path, "empty.bin");
1570 assert_eq!(first_chunk.total_size, 0);
1571 assert_eq!(first_chunk.data.len(), 0);
1572 assert_eq!(first_chunk.offset, 0);
1573 assert_eq!(first_chunk.commit_hash.as_deref(), Some("abc123"));
1574 assert!(first_chunk.is_last_chunk);
1575 assert!(first_chunk.is_last_file);
1576 }
1577}