use crate::database::ModelDatabase;
use modelexpress_common::{
cache::CacheConfig,
download,
grpc::{
api::{ApiRequest, ApiResponse, api_service_server::ApiService},
health::{HealthRequest, HealthResponse, health_service_server::HealthService},
model::{ModelDownloadRequest, ModelStatusUpdate, model_service_server::ModelService},
},
models::{ModelProvider, ModelStatus},
};
use std::{
collections::HashMap,
sync::{Arc, Mutex},
time::SystemTime,
};
use tokio_stream::wrappers::ReceiverStream;
use tonic::{Request, Response, Status};
use tracing::{error, info};
static START_TIME: std::sync::OnceLock<SystemTime> = std::sync::OnceLock::new();
fn get_server_cache_dir() -> Option<std::path::PathBuf> {
if let Ok(config) = CacheConfig::discover() {
Some(config.local_path)
} else {
std::env::var("HF_HUB_CACHE")
.ok()
.map(std::path::PathBuf::from)
}
}
#[derive(Debug, Default)]
pub struct HealthServiceImpl;
#[tonic::async_trait]
impl HealthService for HealthServiceImpl {
async fn get_health(
&self,
_request: Request<HealthRequest>,
) -> Result<Response<HealthResponse>, Status> {
let start_time = START_TIME.get_or_init(SystemTime::now);
let uptime = SystemTime::now()
.duration_since(*start_time)
.unwrap_or_default()
.as_secs();
let response = HealthResponse {
version: env!("CARGO_PKG_VERSION").to_string(),
status: "ok".to_string(),
uptime,
};
Ok(Response::new(response))
}
}
#[derive(Debug, Default)]
pub struct ApiServiceImpl;
#[tonic::async_trait]
impl ApiService for ApiServiceImpl {
async fn send_request(
&self,
request: Request<ApiRequest>,
) -> Result<Response<ApiResponse>, Status> {
let api_request = request.into_inner();
info!("Received gRPC request: {:?}", api_request);
if api_request.action.as_str() == "ping" {
info!("Processing ping request");
let response_data = serde_json::json!({ "message": "pong" });
let data_bytes = serde_json::to_vec(&response_data)
.map_err(|e| Status::internal(format!("Serialization error: {e}")))?;
Ok(Response::new(ApiResponse {
success: true,
data: Some(data_bytes),
error: None,
}))
} else {
error!("Unknown action: {}", api_request.action);
Ok(Response::new(ApiResponse {
success: false,
data: None,
error: Some(format!("Unknown action: {}", api_request.action)),
}))
}
}
}
#[derive(Debug, Default)]
pub struct ModelServiceImpl;
#[tonic::async_trait]
impl ModelService for ModelServiceImpl {
type EnsureModelDownloadedStream = ReceiverStream<Result<ModelStatusUpdate, Status>>;
async fn ensure_model_downloaded(
&self,
request: Request<ModelDownloadRequest>,
) -> Result<Response<Self::EnsureModelDownloadedStream>, Status> {
info!("Starting model download stream");
let model_request = request.into_inner();
let (tx, rx) = tokio::sync::mpsc::channel(4);
let model_name = model_request.model_name.clone();
let provider: ModelProvider =
modelexpress_common::grpc::model::ModelProvider::try_from(model_request.provider)
.unwrap_or(modelexpress_common::grpc::model::ModelProvider::HuggingFace)
.into();
let ignore_weights = model_request.ignore_weights;
tokio::spawn(async move {
if let Some(status) = MODEL_TRACKER.get_status(&model_name) {
let update = ModelStatusUpdate {
model_name: model_name.clone(),
status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
message: match status {
ModelStatus::DOWNLOADED => Some("Model already downloaded".to_string()),
ModelStatus::DOWNLOADING => Some("Model download in progress".to_string()),
ModelStatus::ERROR => {
Some("Previous download failed - retrying".to_string())
}
},
provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
as i32,
};
if tx.send(Ok(update)).await.is_err() {
return; }
if status == ModelStatus::DOWNLOADED {
return;
}
}
let final_status = MODEL_TRACKER
.ensure_model_downloaded(&model_name, provider, &tx, ignore_weights)
.await;
let final_update = ModelStatusUpdate {
model_name: model_name.clone(),
status: modelexpress_common::grpc::model::ModelStatus::from(final_status) as i32,
message: match final_status {
ModelStatus::DOWNLOADED => {
Some("Model download completed successfully".to_string())
}
ModelStatus::ERROR => Some("Model download failed".to_string()),
ModelStatus::DOWNLOADING => Some("Download still in progress".to_string()),
},
provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
};
let _ = tx.send(Ok(final_update)).await;
});
Ok(Response::new(ReceiverStream::new(rx)))
}
}
type WaitingChannels =
Arc<Mutex<HashMap<String, Vec<tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>>>>>;
#[derive(Debug, Clone)]
pub struct ModelDownloadTracker {
database: ModelDatabase,
waiting_channels: WaitingChannels,
}
impl Default for ModelDownloadTracker {
fn default() -> Self {
Self::new()
}
}
impl ModelDownloadTracker {
#[must_use]
pub fn new() -> Self {
let database = match ModelDatabase::new("./models.db") {
Ok(db) => db,
Err(e) => {
error!("Critical error: Could not initialize model database at ./models.db: {e}");
panic!("Critical error: Could not initialize model database at ./models.db");
}
};
Self {
database,
waiting_channels: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn get_status(&self, model_name: &str) -> Option<ModelStatus> {
match self.database.get_status(model_name) {
Ok(status) => {
if status.is_some() {
let _ = self.database.touch_model(model_name);
}
status
}
Err(e) => {
error!("Failed to get model status from database: {}", e);
None
}
}
}
pub fn set_status_and_notify(
&self,
model_name: String,
status: ModelStatus,
provider: ModelProvider,
message: Option<String>,
) {
if let Err(e) = self
.database
.set_status(&model_name, provider, status, message.clone())
{
error!("Failed to update model status in database: {}", e);
return;
}
let mut waiting = match self.waiting_channels.lock() {
Ok(guard) => guard,
Err(poisoned) => {
error!("Waiting channels mutex is poisoned, recovering");
poisoned.into_inner()
}
};
if let Some(channels) = waiting.get(&model_name) {
let update = ModelStatusUpdate {
model_name: model_name.clone(),
status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
message,
provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
};
for channel in channels {
let _ = channel.try_send(Ok(update.clone()));
}
if status == ModelStatus::DOWNLOADED || status == ModelStatus::ERROR {
waiting.remove(&model_name);
}
}
}
pub fn set_status(&self, model_name: String, status: ModelStatus, provider: ModelProvider) {
self.set_status_and_notify(model_name, status, provider, None);
}
pub fn add_waiting_channel(
&self,
model_name: &str,
tx: tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
) {
let mut waiting = match self.waiting_channels.lock() {
Ok(guard) => guard,
Err(poisoned) => {
error!("Waiting channels mutex is poisoned, recovering");
poisoned.into_inner()
}
};
waiting.entry(model_name.to_string()).or_default().push(tx);
}
pub fn delete_status(&self, model_name: &str) {
if let Err(e) = self.database.delete_model(model_name) {
error!("Failed to delete model from database: {}", e);
}
let mut waiting = match self.waiting_channels.lock() {
Ok(guard) => guard,
Err(poisoned) => {
error!("Waiting channels mutex is poisoned, recovering");
poisoned.into_inner()
}
};
waiting.remove(model_name);
}
pub async fn ensure_model_downloaded(
&self,
model_name: &str,
provider: ModelProvider,
tx: &tokio::sync::mpsc::Sender<Result<ModelStatusUpdate, Status>>,
ignore_weights: bool,
) -> ModelStatus {
let status = match self.database.try_claim_for_download(model_name, provider) {
Ok(status) => status,
Err(e) => {
error!("Failed to claim model for download: {}", e);
let error_update = ModelStatusUpdate {
model_name: model_name.to_string(),
status: modelexpress_common::grpc::model::ModelStatus::from(ModelStatus::ERROR)
as i32,
message: Some("Database error occurred".to_string()),
provider: modelexpress_common::grpc::model::ModelProvider::from(provider)
as i32,
};
let _ = tx.send(Ok(error_update)).await;
return ModelStatus::ERROR;
}
};
let update = ModelStatusUpdate {
model_name: model_name.to_string(),
status: modelexpress_common::grpc::model::ModelStatus::from(status) as i32,
message: match status {
ModelStatus::DOWNLOADED => Some("Model already downloaded".to_string()),
ModelStatus::DOWNLOADING => Some("Model download in progress".to_string()),
ModelStatus::ERROR => Some("Previous download failed - retrying".to_string()),
},
provider: modelexpress_common::grpc::model::ModelProvider::from(provider) as i32,
};
let _ = tx.send(Ok(update)).await;
if status == ModelStatus::DOWNLOADING {
self.add_waiting_channel(model_name, tx.clone());
let should_start_download = {
let waiting = match self.waiting_channels.lock() {
Ok(guard) => guard,
Err(poisoned) => {
error!("Waiting channels mutex is poisoned, recovering");
poisoned.into_inner()
}
};
waiting
.get(model_name)
.is_none_or(|channels| channels.len() <= 1)
};
if should_start_download {
let tracker = self.clone();
let model_name_owned = model_name.to_string();
tokio::spawn(async move {
let cache_dir = get_server_cache_dir();
match download::download_model(
&model_name_owned,
provider,
cache_dir,
ignore_weights,
)
.await
{
Ok(_path) => {
tracker.set_status_and_notify(
model_name_owned,
ModelStatus::DOWNLOADED,
provider,
Some("Model download completed successfully".to_string()),
);
}
Err(e) => {
error!("Failed to download model {model_name_owned}: {e}");
tracker.set_status_and_notify(
model_name_owned,
ModelStatus::ERROR,
provider,
Some(format!("Download failed: {e}")),
);
}
}
});
}
loop {
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
if let Some(current_status) = self.get_status(model_name)
&& current_status != ModelStatus::DOWNLOADING
{
return current_status;
}
}
} else if status == ModelStatus::ERROR {
if let Err(e) = self.database.set_status(
model_name,
provider,
ModelStatus::DOWNLOADING,
Some("Retrying download...".to_string()),
) {
error!("Failed to reset status for retry: {}", e);
return ModelStatus::ERROR;
}
self.add_waiting_channel(model_name, tx.clone());
let tracker = self.clone();
let model_name_owned = model_name.to_string();
tokio::spawn(async move {
let cache_dir = get_server_cache_dir();
match download::download_model(
&model_name_owned,
provider,
cache_dir,
ignore_weights,
)
.await
{
Ok(_path) => {
tracker.set_status_and_notify(
model_name_owned,
ModelStatus::DOWNLOADED,
provider,
Some("Model download completed successfully".to_string()),
);
}
Err(e) => {
error!("Failed to download model {model_name_owned} on retry: {e}");
tracker.set_status_and_notify(
model_name_owned,
ModelStatus::ERROR,
provider,
Some(format!("Download failed on retry: {e}")),
);
}
}
});
loop {
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
if let Some(current_status) = self.get_status(model_name)
&& current_status != ModelStatus::DOWNLOADING
{
return current_status;
}
}
}
status
}
}
pub static MODEL_TRACKER: std::sync::LazyLock<ModelDownloadTracker> =
std::sync::LazyLock::new(ModelDownloadTracker::new);
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use super::*;
use modelexpress_common::grpc::{
api::ApiRequest, health::HealthRequest, model::ModelDownloadRequest,
};
use tempfile::TempDir;
use tokio_stream::StreamExt;
use tonic::Request;
#[tokio::test]
async fn test_health_service() {
let service = HealthServiceImpl;
let request = Request::new(HealthRequest {});
let response = service.get_health(request).await;
assert!(response.is_ok());
let health_response = response.expect("Health response should be ok").into_inner();
assert_eq!(health_response.version, env!("CARGO_PKG_VERSION"));
assert_eq!(health_response.status, "ok");
let _uptime = health_response.uptime;
}
#[tokio::test]
async fn test_api_service_ping() {
let service = ApiServiceImpl;
let request = Request::new(ApiRequest {
id: "test-id".to_string(),
action: "ping".to_string(),
payload: None,
});
let response = service.send_request(request).await;
assert!(response.is_ok());
let api_response = response.expect("API response should be ok").into_inner();
assert!(api_response.success);
assert!(api_response.data.is_some());
assert!(api_response.error.is_none());
let data_bytes = api_response.data.expect("Data should be present");
let data: serde_json::Value =
serde_json::from_slice(&data_bytes).expect("Data should be valid JSON");
assert_eq!(data["message"], "pong");
}
#[tokio::test]
async fn test_api_service_unknown_action() {
let service = ApiServiceImpl;
let request = Request::new(ApiRequest {
id: "test-id".to_string(),
action: "unknown-action".to_string(),
payload: None,
});
let response = service.send_request(request).await;
assert!(response.is_ok());
let api_response = response.expect("API response should be ok").into_inner();
assert!(!api_response.success);
assert!(api_response.data.is_none());
assert!(api_response.error.is_some());
let error_message = api_response.error.expect("Error should be present");
assert!(error_message.contains("Unknown action"));
}
#[test]
fn test_model_download_tracker_new() {
let _temp_dir = TempDir::new().expect("Failed to create temp dir");
let tracker = ModelDownloadTracker::new();
let status = tracker.get_status("non-existent-model");
assert!(status.is_none());
}
#[test]
fn test_model_download_tracker_set_and_get_status() {
let _temp_dir = TempDir::new().expect("Failed to create temp dir");
let tracker = ModelDownloadTracker::new();
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("Time went backwards")
.as_nanos();
let model_name = format!("test-model-{timestamp}");
let provider = ModelProvider::HuggingFace;
assert!(tracker.get_status(&model_name).is_none());
tracker.set_status(model_name.clone(), ModelStatus::DOWNLOADING, provider);
let status = tracker.get_status(&model_name);
assert!(status.is_some());
assert_eq!(
status.expect("Status should be present"),
ModelStatus::DOWNLOADING
);
tracker.delete_status(&model_name);
}
#[test]
fn test_model_download_tracker_delete_status() {
let _temp_dir = TempDir::new().expect("Failed to create temp dir");
let tracker = ModelDownloadTracker::new();
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("Time went backwards")
.as_nanos();
let model_name = format!("test-delete-model-{timestamp}");
let provider = ModelProvider::HuggingFace;
tracker.set_status(model_name.clone(), ModelStatus::DOWNLOADED, provider);
assert!(tracker.get_status(&model_name).is_some());
tracker.delete_status(&model_name);
assert!(tracker.get_status(&model_name).is_none());
}
#[tokio::test]
async fn test_model_service_already_downloaded() {
let service = ModelServiceImpl;
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("Time went backwards")
.as_nanos();
let model_name = format!("test-already-downloaded-model-{timestamp}");
MODEL_TRACKER.set_status(
model_name.clone(),
ModelStatus::DOWNLOADED,
ModelProvider::HuggingFace,
);
let request = Request::new(ModelDownloadRequest {
model_name: model_name.clone(),
provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
ignore_weights: false,
});
let response = service.ensure_model_downloaded(request).await;
assert!(response.is_ok());
let mut stream = response.expect("Response should be ok").into_inner();
let update = stream.next().await;
assert!(update.is_some());
let update = update.expect("Update should be present");
assert!(update.is_ok());
let status_update = update.expect("Status update should be ok");
assert_eq!(status_update.model_name, model_name);
assert_eq!(
status_update.status,
modelexpress_common::grpc::model::ModelStatus::Downloaded as i32
);
MODEL_TRACKER.delete_status(&model_name);
}
#[test]
fn test_model_download_tracker_set_status_and_notify() {
let _temp_dir = TempDir::new().expect("Failed to create temp dir");
let tracker = ModelDownloadTracker::new();
let model_name = "test-notify-model".to_string();
let provider = ModelProvider::HuggingFace;
tracker.set_status_and_notify(
model_name.clone(),
ModelStatus::DOWNLOADED,
provider,
Some("Download completed".to_string()),
);
let status = tracker.get_status(&model_name);
assert!(status.is_some());
assert_eq!(
status.expect("Status should be present"),
ModelStatus::DOWNLOADED
);
}
#[test]
fn test_waiting_channels_management() {
let _temp_dir = TempDir::new().expect("Failed to create temp dir");
let tracker = ModelDownloadTracker::new();
let model_name = "test-channels-model";
let (tx, _rx) = tokio::sync::mpsc::channel(4);
tracker.add_waiting_channel(model_name, tx);
let waiting_count = {
let waiting = match tracker.waiting_channels.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
waiting.get(model_name).map_or(0, std::vec::Vec::len)
};
assert_eq!(waiting_count, 1);
tracker.set_status_and_notify(
model_name.to_string(),
ModelStatus::DOWNLOADED,
ModelProvider::HuggingFace,
None,
);
let waiting_count_after = {
let waiting = match tracker.waiting_channels.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
waiting.get(model_name).map_or(0, std::vec::Vec::len)
};
assert_eq!(waiting_count_after, 0);
}
#[tokio::test]
async fn test_model_service_stream_closes_properly() {
let service = ModelServiceImpl;
let model_name = "test-stream-model";
let request = Request::new(ModelDownloadRequest {
model_name: model_name.to_string(),
provider: modelexpress_common::grpc::model::ModelProvider::HuggingFace as i32,
ignore_weights: false,
});
let response = service.ensure_model_downloaded(request).await;
assert!(response.is_ok());
let mut stream = response.expect("Response should be ok").into_inner();
let mut update_count = 0;
while let Some(update) = stream.next().await {
assert!(update.is_ok());
update_count += 1;
if update_count > 10 {
break;
}
}
assert!(update_count > 0);
MODEL_TRACKER.delete_status(model_name);
}
#[tokio::test]
async fn test_concurrent_model_download_no_race_condition() {
let _temp_dir = TempDir::new().expect("Failed to create temp dir");
let tracker = ModelDownloadTracker::new();
let model_name = "test-concurrent-model";
let provider = ModelProvider::HuggingFace;
let status1 = tracker
.database
.try_claim_for_download(model_name, provider)
.expect("Failed to claim for download 1");
assert_eq!(status1, ModelStatus::DOWNLOADING);
let status2 = tracker
.database
.try_claim_for_download(model_name, provider)
.expect("Failed to claim for download 2");
assert_eq!(status2, ModelStatus::DOWNLOADING);
let record = tracker
.database
.get_model_record(model_name)
.expect("Failed to get model record")
.expect("Record should exist");
assert_eq!(record.status, ModelStatus::DOWNLOADING);
tracker.delete_status(model_name);
}
}