use crate::{
api::models::{
deployments::{
ComponentEndpointSummary, ComponentModelSummary, DeployedModelResponse, ModelComponentResponse, ModelMetrics, ModelProbeStatus,
ModelType,
},
inference_endpoints::InferenceEndpointResponse,
},
db::{
handlers::{Groups, InferenceEndpoints, Repository, analytics::get_model_metrics},
models::{deployments::DeploymentComponentDBResponse, groups::GroupDBResponse},
},
errors::{Error, Result},
types::{DeploymentId, GroupId, InferenceEndpointId},
};
use chrono::{DateTime, Utc};
use sqlx::PgPool;
use std::collections::HashMap;
use uuid::Uuid;
pub struct DeployedModelEnricher<'a> {
pub db: &'a PgPool,
pub include_groups: bool,
pub include_metrics: bool,
pub include_status: bool,
pub include_pricing: bool,
pub include_endpoints: bool,
pub include_components: bool,
pub can_read_pricing: bool,
pub can_read_rate_limits: bool,
pub can_read_users: bool,
pub can_read_composite_info: bool,
}
type ProbeStatusTuple = (Option<Uuid>, bool, Option<i32>, Option<DateTime<Utc>>, Option<bool>, Option<f64>);
impl<'a> DeployedModelEnricher<'a> {
#[tracing::instrument(skip(self, models), fields(count = models.len()))]
pub async fn enrich_many(&self, models: Vec<DeployedModelResponse>) -> Result<Vec<DeployedModelResponse>> {
if models.is_empty() {
return Ok(vec![]);
}
let model_ids: Vec<DeploymentId> = models.iter().map(|m| m.id).collect();
let model_aliases: Vec<String> = models.iter().map(|m| m.alias.clone()).collect();
let (groups_result, status_map, metrics_map, endpoints_map, pricing_tariffs_map, components_map) = tokio::join!(
async {
if self.include_groups {
let mut groups_conn = self.db.acquire().await.map_err(|e| Error::Database(e.into())).ok()?;
let mut groups_repo = Groups::new(&mut groups_conn);
let model_groups_map = groups_repo.get_deployments_groups_bulk(&model_ids).await.ok()?;
let all_group_ids: Vec<GroupId> = model_groups_map
.values()
.flatten()
.copied()
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let groups_map = groups_repo.get_bulk(all_group_ids).await.ok()?;
Some((model_groups_map, groups_map))
} else {
None
}
},
async {
if self.include_status {
use crate::probes::db::ProbeManager;
ProbeManager::get_deployment_statuses(self.db, &model_ids).await.ok()
} else {
None
}
},
async {
if self.include_metrics {
match get_model_metrics(self.db, model_aliases).await {
Ok(map) => Some(map),
Err(e) => {
tracing::warn!("Failed to fetch bulk metrics: {:?}", e);
None
}
}
} else {
None
}
},
async {
if self.include_endpoints {
let mut endpoints_conn = self.db.acquire().await.map_err(|e| Error::Database(e.into())).ok()?;
let mut endpoints_repo = InferenceEndpoints::new(&mut endpoints_conn);
let endpoint_ids: Vec<InferenceEndpointId> = models
.iter()
.filter_map(|m| m.hosted_on)
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let endpoints_db = endpoints_repo.get_bulk(endpoint_ids).await.ok()?;
let endpoints_map: HashMap<InferenceEndpointId, InferenceEndpointResponse> =
endpoints_db.into_iter().map(|(id, endpoint)| (id, endpoint.into())).collect();
Some(endpoints_map)
} else {
None
}
},
async {
if self.include_pricing {
use crate::{api::models::tariffs::TariffResponse, db::handlers::Tariffs};
let mut tariffs_map: HashMap<DeploymentId, Vec<TariffResponse>> = HashMap::new();
for model_id in &model_ids {
let mut tariffs_conn = self.db.acquire().await.map_err(|e| Error::Database(e.into())).ok()?;
let mut tariffs_repo = Tariffs::new(&mut tariffs_conn);
if let Ok(tariffs) = tariffs_repo.list_current_by_model(*model_id).await {
tariffs_map.insert(*model_id, tariffs.into_iter().map(TariffResponse::from).collect());
}
}
Some(tariffs_map)
} else {
None
}
},
async {
if self.include_components && self.can_read_composite_info {
use crate::db::handlers::Deployments;
let composite_ids: Vec<DeploymentId> = models.iter().filter(|m| m.is_composite == Some(true)).map(|m| m.id).collect();
if composite_ids.is_empty() {
return Some(HashMap::new());
}
let mut conn = self.db.acquire().await.map_err(|e| Error::Database(e.into())).ok()?;
let mut repo = Deployments::new(&mut conn);
repo.get_components_bulk(composite_ids).await.ok()
} else {
None
}
}
);
let (model_groups_map, groups_map) = match groups_result {
Some((model_groups_map, groups_map)) => (Some(model_groups_map), Some(groups_map)),
None => (None, None),
};
let mut enriched_models = Vec::with_capacity(models.len());
for mut model_response in models {
if self.include_groups {
model_response = Self::apply_groups(model_response, &model_groups_map, &groups_map);
}
if self.include_metrics {
model_response = Self::apply_metrics(model_response, &metrics_map);
}
if self.include_status {
model_response = Self::apply_status(model_response, &status_map);
}
if self.include_endpoints {
model_response = Self::apply_endpoint(model_response, &endpoints_map);
}
if self.include_pricing {
model_response = Self::apply_tariffs(model_response, &pricing_tariffs_map);
}
if self.include_components && self.can_read_composite_info {
model_response = Self::apply_components(model_response, &components_map);
}
if !self.can_read_rate_limits {
model_response = model_response.mask_rate_limiting();
model_response = model_response.mask_capacity();
}
if !self.can_read_users {
model_response = model_response.mask_created_by();
}
if !self.can_read_composite_info {
model_response = model_response.mask_composite_fields();
}
if !self.can_read_composite_info {
model_response = model_response.mask_response_config();
}
if !self.can_read_pricing {
model_response = model_response.filter_disabled_batch_tariffs();
}
enriched_models.push(model_response);
}
Ok(enriched_models)
}
#[tracing::instrument(skip(self, model))]
pub async fn enrich_one(&self, model: DeployedModelResponse) -> Result<DeployedModelResponse> {
let enriched = self.enrich_many(vec![model]).await?;
enriched.into_iter().next().ok_or_else(|| Error::BadRequest {
message: "No model returned from enrichment".to_string(),
})
}
fn apply_groups(
mut model: DeployedModelResponse,
model_groups_map: &Option<HashMap<DeploymentId, Vec<GroupId>>>,
groups_map: &Option<HashMap<GroupId, GroupDBResponse>>,
) -> DeployedModelResponse {
if let (Some(model_groups_map), Some(groups_map)) = (model_groups_map, groups_map) {
if let Some(group_ids) = model_groups_map.get(&model.id) {
let model_groups: Vec<_> = group_ids
.iter()
.filter_map(|group_id| groups_map.get(group_id))
.cloned()
.map(|group| group.into())
.collect();
model = model.with_groups(model_groups);
} else {
model = model.with_groups(vec![]);
}
}
model
}
fn apply_metrics(mut model: DeployedModelResponse, metrics_map: &Option<HashMap<String, ModelMetrics>>) -> DeployedModelResponse {
if let Some(metrics_map) = metrics_map
&& let Some(metrics) = metrics_map.get(&model.alias)
{
model = model.with_metrics(metrics.clone());
}
model
}
fn apply_status(
mut model: DeployedModelResponse,
status_map: &Option<HashMap<DeploymentId, ProbeStatusTuple>>,
) -> DeployedModelResponse {
if let Some(statuses) = status_map {
if let Some((probe_id, active, interval_seconds, last_check, last_success, uptime_percentage)) = statuses.get(&model.id) {
let status = ModelProbeStatus {
probe_id: *probe_id,
active: *active,
interval_seconds: *interval_seconds,
last_check: *last_check,
last_success: *last_success,
uptime_percentage: *uptime_percentage,
};
model = model.with_status(status);
} else {
let status = ModelProbeStatus {
probe_id: None,
active: false,
interval_seconds: None,
last_check: None,
last_success: None,
uptime_percentage: None,
};
model = model.with_status(status);
}
}
model
}
fn apply_endpoint(
mut model: DeployedModelResponse,
endpoints_map: &Option<HashMap<InferenceEndpointId, InferenceEndpointResponse>>,
) -> DeployedModelResponse {
if let Some(endpoints_map) = endpoints_map
&& let Some(hosted_on) = model.hosted_on
&& let Some(endpoint) = endpoints_map.get(&hosted_on)
{
model = model.with_endpoint(endpoint.clone());
}
model
}
fn apply_tariffs(
mut model: DeployedModelResponse,
tariffs_map: &Option<HashMap<DeploymentId, Vec<crate::api::models::tariffs::TariffResponse>>>,
) -> DeployedModelResponse {
if let Some(tariffs_map) = tariffs_map
&& let Some(tariffs) = tariffs_map.get(&model.id)
{
model = model.with_tariffs(tariffs.clone());
}
model
}
fn apply_components(
mut model: DeployedModelResponse,
components_map: &Option<HashMap<DeploymentId, Vec<DeploymentComponentDBResponse>>>,
) -> DeployedModelResponse {
if let Some(components_map) = components_map
&& let Some(components) = components_map.get(&model.id)
{
let component_responses: Vec<ModelComponentResponse> =
components.iter().map(|c| Self::db_component_to_response(c.clone())).collect();
model = model.with_components(component_responses);
}
model
}
fn db_component_to_response(c: DeploymentComponentDBResponse) -> ModelComponentResponse {
ModelComponentResponse {
weight: c.weight,
enabled: c.enabled,
sort_order: c.sort_order,
created_at: c.created_at,
model: ComponentModelSummary {
id: c.deployed_model_id,
alias: c.model_alias,
model_name: c.model_name,
description: c.model_description,
model_type: c.model_type.and_then(|s| match s.as_str() {
"CHAT" => Some(ModelType::Chat),
"EMBEDDINGS" => Some(ModelType::Embeddings),
"RERANKER" => Some(ModelType::Reranker),
_ => None,
}),
endpoint: c.endpoint_id.map(|id| ComponentEndpointSummary {
id,
name: c.endpoint_name.unwrap_or_default(),
}),
trusted: c.model_trusted,
open_responses_adapter: c.model_open_responses_adapter,
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
api::models::deployments::ModelMetrics,
db::models::{api_keys::ApiKeyPurpose, groups::GroupDBResponse},
};
use chrono::Utc;
use std::collections::HashMap;
use uuid::Uuid;
fn create_test_model() -> DeployedModelResponse {
DeployedModelResponse {
id: Uuid::new_v4(),
model_name: "test-model".to_string(),
alias: "test-alias".to_string(),
display_name: None,
description: None,
model_type: None,
capabilities: None,
created_by: Some(Uuid::new_v4()),
hosted_on: Some(Uuid::new_v4()),
created_at: Utc::now(),
updated_at: Utc::now(),
requests_per_second: Some(100.0),
burst_size: Some(200),
capacity: None,
batch_capacity: None,
throughput: None,
groups: None,
metrics: None,
status: None,
provider_pricing: None,
endpoint: None,
tariffs: None,
is_composite: Some(false),
lb_strategy: None,
fallback: None,
components: None,
sanitize_responses: None,
trusted: None,
open_responses_adapter: None,
traffic_routing_rules: None,
allowed_batch_completion_windows: None,
metadata: None,
}
}
#[test]
fn test_apply_groups_with_data() {
let model = create_test_model();
let model_id = model.id;
let group_id: GroupId = Uuid::new_v4();
let mut model_groups_map = HashMap::new();
model_groups_map.insert(model_id, vec![group_id]);
let mut groups_map = HashMap::new();
groups_map.insert(
group_id,
GroupDBResponse {
id: group_id,
name: "Test Group".to_string(),
description: Some("Test description".to_string()),
created_by: Uuid::new_v4(),
created_at: Utc::now(),
updated_at: Utc::now(),
source: "native".to_string(),
},
);
let result = DeployedModelEnricher::apply_groups(model, &Some(model_groups_map), &Some(groups_map));
assert!(result.groups.is_some());
let groups = result.groups.unwrap();
assert_eq!(groups.len(), 1);
assert_eq!(groups[0].name, "Test Group");
}
#[test]
fn test_apply_groups_empty() {
let model = create_test_model();
let model_groups_map = HashMap::new();
let groups_map = HashMap::new();
let result = DeployedModelEnricher::apply_groups(model, &Some(model_groups_map), &Some(groups_map));
assert!(result.groups.is_some());
assert_eq!(result.groups.unwrap().len(), 0);
}
#[test]
fn test_apply_groups_not_requested() {
let model = create_test_model();
let result = DeployedModelEnricher::apply_groups(model, &None, &None);
assert!(result.groups.is_none());
}
#[test]
fn test_apply_metrics_with_data() {
let model = create_test_model();
let alias = model.alias.clone();
let mut metrics_map = HashMap::new();
metrics_map.insert(
alias.clone(),
ModelMetrics {
avg_latency_ms: Some(123.45),
total_requests: 100,
total_input_tokens: 1000,
total_output_tokens: 2000,
last_active_at: Some(Utc::now()),
time_series: None,
},
);
let result = DeployedModelEnricher::apply_metrics(model, &Some(metrics_map));
assert!(result.metrics.is_some());
let metrics = result.metrics.unwrap();
assert_eq!(metrics.total_requests, 100);
assert_eq!(metrics.total_input_tokens, 1000);
assert_eq!(metrics.avg_latency_ms, Some(123.45));
}
#[test]
fn test_apply_metrics_no_data() {
let model = create_test_model();
let metrics_map = HashMap::new();
let result = DeployedModelEnricher::apply_metrics(model, &Some(metrics_map));
assert!(result.metrics.is_none());
}
#[test]
fn test_apply_status_with_data() {
let model = create_test_model();
let model_id = model.id;
let probe_id = Uuid::new_v4();
let last_check = Utc::now();
let mut status_map = HashMap::new();
status_map.insert(model_id, (Some(probe_id), true, Some(60), Some(last_check), Some(true), Some(99.5)));
let result = DeployedModelEnricher::apply_status(model, &Some(status_map));
assert!(result.status.is_some());
let status = result.status.unwrap();
assert_eq!(status.probe_id, Some(probe_id));
assert!(status.active);
assert_eq!(status.interval_seconds, Some(60));
assert_eq!(status.uptime_percentage, Some(99.5));
}
#[test]
fn test_apply_status_no_probe() {
let model = create_test_model();
let status_map = HashMap::new();
let result = DeployedModelEnricher::apply_status(model, &Some(status_map));
assert!(result.status.is_some());
let status = result.status.unwrap();
assert_eq!(status.probe_id, None);
assert!(!status.active);
assert_eq!(status.interval_seconds, None);
}
#[test]
fn test_mask_rate_limiting() {
let mut model = create_test_model();
model.requests_per_second = Some(100.0);
model.burst_size = Some(200);
model.capacity = Some(50);
let masked = model.mask_rate_limiting();
assert_eq!(masked.requests_per_second, None);
assert_eq!(masked.burst_size, None);
assert_eq!(masked.capacity, Some(50));
}
#[test]
fn test_mask_response_config() {
let mut model = create_test_model();
model.sanitize_responses = Some(true);
model.trusted = Some(false);
model.open_responses_adapter = Some(true);
let masked = model.mask_response_config();
assert_eq!(masked.sanitize_responses, None);
assert_eq!(masked.trusted, None);
assert_eq!(masked.open_responses_adapter, None);
}
#[test]
fn test_apply_endpoint_with_data() {
let model = create_test_model();
let endpoint_id = model.hosted_on.expect("test model should have hosted_on");
let mut endpoints_map = HashMap::new();
endpoints_map.insert(
endpoint_id,
InferenceEndpointResponse {
id: endpoint_id,
name: "Test Endpoint".to_string(),
description: Some("Test endpoint description".to_string()),
url: "https://api.example.com".to_string(),
model_filter: None,
requires_api_key: true,
auth_header_name: "Authorization".to_string(),
auth_header_prefix: "Bearer ".to_string(),
created_by: Uuid::new_v4(),
created_at: Utc::now(),
updated_at: Utc::now(),
},
);
let result = DeployedModelEnricher::apply_endpoint(model, &Some(endpoints_map));
assert!(result.endpoint.is_some());
let endpoint = result.endpoint.unwrap();
assert_eq!(endpoint.name, "Test Endpoint");
assert_eq!(endpoint.url, "https://api.example.com");
}
#[test]
fn test_apply_endpoint_no_data() {
let model = create_test_model();
let endpoints_map = HashMap::new();
let result = DeployedModelEnricher::apply_endpoint(model, &Some(endpoints_map));
assert!(result.endpoint.is_none());
}
#[test]
fn test_apply_endpoint_not_requested() {
let model = create_test_model();
let result = DeployedModelEnricher::apply_endpoint(model, &None);
assert!(result.endpoint.is_none());
}
#[test]
fn test_apply_tariffs_with_data() {
use crate::api::models::tariffs::TariffResponse;
use rust_decimal::Decimal;
use std::str::FromStr;
let model = create_test_model();
let model_id = model.id;
let mut tariffs_map = HashMap::new();
tariffs_map.insert(
model_id,
vec![TariffResponse {
id: Uuid::new_v4(),
deployed_model_id: model_id,
name: "Standard Tariff".to_string(),
input_price_per_token: Decimal::from_str("0.001").unwrap(),
output_price_per_token: Decimal::from_str("0.002").unwrap(),
api_key_purpose: None,
completion_window: None,
valid_from: Utc::now(),
valid_until: None,
is_active: true,
}],
);
let result = DeployedModelEnricher::apply_tariffs(model, &Some(tariffs_map));
assert!(result.tariffs.is_some());
let tariffs = result.tariffs.unwrap();
assert_eq!(tariffs.len(), 1);
assert_eq!(tariffs[0].name, "Standard Tariff");
assert_eq!(tariffs[0].input_price_per_token, Decimal::from_str("0.001").unwrap());
assert_eq!(tariffs[0].output_price_per_token, Decimal::from_str("0.002").unwrap());
}
#[test]
fn test_apply_tariffs_no_data() {
let model = create_test_model();
let tariffs_map = HashMap::new();
let result = DeployedModelEnricher::apply_tariffs(model, &Some(tariffs_map));
assert!(result.tariffs.is_none());
}
#[test]
fn test_apply_tariffs_not_requested() {
let model = create_test_model();
let result = DeployedModelEnricher::apply_tariffs(model, &None);
assert!(result.tariffs.is_none());
}
fn create_test_tariff(purpose: Option<ApiKeyPurpose>, window: Option<&str>) -> crate::api::models::tariffs::TariffResponse {
use rust_decimal::Decimal;
use std::str::FromStr;
crate::api::models::tariffs::TariffResponse {
id: Uuid::new_v4(),
deployed_model_id: Uuid::new_v4(),
name: format!("Tariff {:?} {:?}", purpose, window.unwrap_or("none")),
input_price_per_token: Decimal::from_str("0.001").unwrap(),
output_price_per_token: Decimal::from_str("0.002").unwrap(),
api_key_purpose: purpose,
completion_window: window.map(String::from),
valid_from: Utc::now(),
valid_until: None,
is_active: true,
}
}
#[test]
fn test_filter_disabled_batch_tariffs_with_allowed_windows() {
let mut model = create_test_model();
model.allowed_batch_completion_windows = Some(vec!["24h".to_string()]);
model.tariffs = Some(vec![
create_test_tariff(Some(ApiKeyPurpose::Batch), Some("24h")),
create_test_tariff(Some(ApiKeyPurpose::Batch), Some("1h")),
create_test_tariff(Some(ApiKeyPurpose::Realtime), None),
]);
let result = model.filter_disabled_batch_tariffs();
let tariffs = result.tariffs.unwrap();
assert_eq!(tariffs.len(), 2);
assert_eq!(tariffs[0].api_key_purpose, Some(ApiKeyPurpose::Batch));
assert_eq!(tariffs[0].completion_window.as_deref(), Some("24h"));
assert_eq!(tariffs[1].api_key_purpose, Some(ApiKeyPurpose::Realtime));
}
#[test]
fn test_filter_disabled_batch_tariffs_none_allowed() {
let mut model = create_test_model();
model.allowed_batch_completion_windows = None;
model.tariffs = Some(vec![
create_test_tariff(Some(ApiKeyPurpose::Batch), Some("24h")),
create_test_tariff(Some(ApiKeyPurpose::Batch), Some("1h")),
]);
let result = model.filter_disabled_batch_tariffs();
assert_eq!(result.tariffs.unwrap().len(), 2);
}
#[test]
fn test_filter_disabled_batch_tariffs_empty_allowed() {
let mut model = create_test_model();
model.allowed_batch_completion_windows = Some(vec![]);
model.tariffs = Some(vec![
create_test_tariff(Some(ApiKeyPurpose::Batch), Some("24h")),
create_test_tariff(Some(ApiKeyPurpose::Batch), Some("1h")),
create_test_tariff(Some(ApiKeyPurpose::Realtime), None),
]);
let result = model.filter_disabled_batch_tariffs();
let tariffs = result.tariffs.unwrap();
assert_eq!(tariffs.len(), 1);
assert_eq!(tariffs[0].api_key_purpose, Some(ApiKeyPurpose::Realtime));
}
#[test]
fn test_filter_disabled_batch_tariffs_generic_fallback_kept() {
let mut model = create_test_model();
model.allowed_batch_completion_windows = Some(vec!["24h".to_string()]);
model.tariffs = Some(vec![
create_test_tariff(Some(ApiKeyPurpose::Batch), Some("24h")),
create_test_tariff(Some(ApiKeyPurpose::Batch), None), create_test_tariff(Some(ApiKeyPurpose::Batch), Some("1h")),
]);
let result = model.filter_disabled_batch_tariffs();
let tariffs = result.tariffs.unwrap();
assert_eq!(tariffs.len(), 2);
assert_eq!(tariffs[0].completion_window.as_deref(), Some("24h"));
assert_eq!(tariffs[1].completion_window, None);
}
#[test]
fn test_filter_disabled_batch_tariffs_generic_fallback_removed_when_empty() {
let mut model = create_test_model();
model.allowed_batch_completion_windows = Some(vec![]);
model.tariffs = Some(vec![
create_test_tariff(Some(ApiKeyPurpose::Batch), None), create_test_tariff(Some(ApiKeyPurpose::Realtime), None),
]);
let result = model.filter_disabled_batch_tariffs();
let tariffs = result.tariffs.unwrap();
assert_eq!(tariffs.len(), 1);
assert_eq!(tariffs[0].api_key_purpose, Some(ApiKeyPurpose::Realtime));
}
#[test]
fn test_filter_disabled_batch_tariffs_realtime_unaffected() {
let mut model = create_test_model();
model.allowed_batch_completion_windows = Some(vec!["24h".to_string()]);
model.tariffs = Some(vec![
create_test_tariff(Some(ApiKeyPurpose::Realtime), None),
create_test_tariff(None, None),
]);
let result = model.filter_disabled_batch_tariffs();
assert_eq!(result.tariffs.unwrap().len(), 2);
}
}