use crate::api::models::deployments::{DeployedModelCreate, DeployedModelUpdate};
use crate::types::{DeploymentId, InferenceEndpointId, UserId};
use bon::Builder;
use chrono::{DateTime, NaiveDate, Utc};
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use serde_with::rust::double_option;
use utoipa::ToSchema;
use uuid::Uuid;
const MODE_PER_TOKEN: &str = "per_token";
const MODE_HOURLY: &str = "hourly";
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, ToSchema)]
#[serde(tag = "mode", rename_all = "snake_case")]
pub enum ProviderPricing {
PerToken {
#[schema(value_type = Option<String>)]
input_price_per_token: Option<Decimal>,
#[schema(value_type = Option<String>)]
output_price_per_token: Option<Decimal>,
},
Hourly {
#[schema(value_type = String)]
rate: Decimal,
#[schema(value_type = String)]
input_token_cost_ratio: Decimal,
},
}
#[derive(Debug, Clone, Default)]
pub struct ProviderPricingFields {
pub mode: Option<String>,
pub input_price_per_token: Option<Decimal>,
pub output_price_per_token: Option<Decimal>,
pub hourly_rate: Option<Decimal>,
pub input_token_cost_ratio: Option<Decimal>,
}
impl ProviderPricing {
pub fn from_flat_fields(fields: ProviderPricingFields) -> Option<Self> {
match fields.mode.as_deref() {
Some(MODE_HOURLY) => match (fields.hourly_rate, fields.input_token_cost_ratio) {
(Some(rate), Some(input_token_cost_ratio)) => Some(ProviderPricing::Hourly {
rate,
input_token_cost_ratio,
}),
_ => None,
},
Some(MODE_PER_TOKEN) => Some(ProviderPricing::PerToken {
input_price_per_token: fields.input_price_per_token,
output_price_per_token: fields.output_price_per_token,
}),
_ => None,
}
}
pub fn to_flat_fields(&self) -> ProviderPricingFields {
match self {
ProviderPricing::PerToken {
input_price_per_token,
output_price_per_token,
} => ProviderPricingFields {
mode: Some(MODE_PER_TOKEN.to_string()),
input_price_per_token: *input_price_per_token,
output_price_per_token: *output_price_per_token,
hourly_rate: None,
input_token_cost_ratio: None,
},
ProviderPricing::Hourly {
rate,
input_token_cost_ratio,
} => ProviderPricingFields {
mode: Some(MODE_HOURLY.to_string()),
input_price_per_token: None,
output_price_per_token: None,
hourly_rate: Some(*rate),
input_token_cost_ratio: Some(*input_token_cost_ratio),
},
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, ToSchema)]
#[serde(tag = "mode", rename_all = "snake_case")]
pub enum ProviderPricingUpdate {
#[default]
NoChange,
PerToken {
#[serde(default, skip_serializing_if = "Option::is_none", with = "double_option")]
#[schema(value_type = Option<Option<String>>)]
input_price_per_token: Option<Option<Decimal>>,
#[serde(default, skip_serializing_if = "Option::is_none", with = "double_option")]
#[schema(value_type = Option<Option<String>>)]
output_price_per_token: Option<Option<Decimal>>,
},
Hourly {
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(value_type = Option<String>)]
rate: Option<Decimal>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(value_type = Option<String>)]
input_token_cost_ratio: Option<Decimal>,
},
}
#[derive(Debug, Clone, Default)]
pub struct ProviderPricingUpdateParams {
pub should_update_mode: bool,
pub mode: Option<String>,
pub should_update_input: bool,
pub input: Option<Decimal>,
pub should_update_output: bool,
pub output: Option<Decimal>,
pub should_update_hourly: bool,
pub hourly: Option<Decimal>,
pub should_update_ratio: bool,
pub ratio: Option<Decimal>,
}
impl ProviderPricingUpdate {
pub fn to_update_params(&self) -> ProviderPricingUpdateParams {
match self {
ProviderPricingUpdate::NoChange => ProviderPricingUpdateParams::default(),
ProviderPricingUpdate::PerToken {
input_price_per_token,
output_price_per_token,
} => ProviderPricingUpdateParams {
should_update_mode: true,
mode: Some(MODE_PER_TOKEN.to_string()),
should_update_input: input_price_per_token.is_some(),
input: input_price_per_token.flatten(),
should_update_output: output_price_per_token.is_some(),
output: output_price_per_token.flatten(),
..Default::default()
},
ProviderPricingUpdate::Hourly {
rate,
input_token_cost_ratio,
} => ProviderPricingUpdateParams {
should_update_mode: true,
mode: Some(MODE_HOURLY.to_string()),
should_update_hourly: rate.is_some(),
hourly: *rate,
should_update_ratio: input_token_cost_ratio.is_some(),
ratio: *input_token_cost_ratio,
..Default::default()
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, ToSchema)]
#[serde(rename_all = "UPPERCASE")]
pub enum ModelType {
Chat,
Embeddings,
Reranker,
}
impl ModelType {
pub fn detect_from_name(model_name: &str) -> Self {
let name_lower = model_name.to_lowercase();
let reranker_patterns = [
"rerank",
"reranker",
"cross-encoder",
"bge-reranker",
"mixedbread-reranker",
"mxbai-rerank",
];
let embedding_patterns = [
"embed",
"embedding",
"ada", "text-embedding",
"sentence-transformer",
"all-minilm",
"bge-",
"e5-",
];
if reranker_patterns.iter().any(|pattern| name_lower.contains(pattern)) {
return Self::Reranker;
}
if embedding_patterns.iter().any(|pattern| name_lower.contains(pattern)) {
return Self::Embeddings;
}
Self::Chat
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, ToSchema)]
#[serde(rename_all = "snake_case")]
pub enum ModelDisplayCategory {
Generation,
Embedding,
Ocr,
}
pub const MODEL_CATALOG_METADATA_MAX_BYTES: usize = 16_384;
pub const MODEL_CATALOG_METADATA_MAX_EXTRA_KEYS: usize = 50;
#[derive(Debug, Clone, Default, Serialize, Deserialize, ToSchema, PartialEq)]
pub struct ModelCatalogMetadata {
#[serde(skip_serializing_if = "Option::is_none")]
pub provider: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub display_category: Option<ModelDisplayCategory>,
#[serde(skip_serializing_if = "Option::is_none")]
pub intelligence_index: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context_window: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub released_at: Option<NaiveDate>,
#[serde(skip_serializing_if = "Option::is_none")]
pub quantization: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub attribution: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(value_type = Option<Object>)]
pub extra: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, ToSchema)]
#[serde(rename_all = "lowercase")]
pub enum ModelStatus {
Active,
Inactive,
}
impl ModelStatus {
pub fn to_db_string(&self) -> &'static str {
match self {
ModelStatus::Active => "active",
ModelStatus::Inactive => "inactive",
}
}
pub fn from_db_string(s: &str) -> ModelStatus {
match s {
"active" => ModelStatus::Active,
"inactive" => ModelStatus::Inactive,
_ => ModelStatus::Active, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize, ToSchema)]
#[serde(rename_all = "snake_case")]
pub enum LoadBalancingStrategy {
#[default]
WeightedRandom,
Priority,
}
impl LoadBalancingStrategy {
pub fn as_str(&self) -> &'static str {
match self {
Self::WeightedRandom => "weighted_random",
Self::Priority => "priority",
}
}
pub fn try_parse(s: &str) -> Option<Self> {
match s {
"weighted_random" => Some(Self::WeightedRandom),
"priority" => Some(Self::Priority),
_ => None,
}
}
}
fn default_true() -> bool {
true
}
fn default_fallback_status_codes() -> Vec<i32> {
vec![429, 500, 502, 503, 504]
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, ToSchema)]
pub struct FallbackConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_true")]
pub on_rate_limit: bool,
#[serde(default = "default_fallback_status_codes")]
pub on_status: Vec<i32>,
#[serde(default)]
pub with_replacement: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_attempts: Option<i32>,
}
impl FallbackConfig {
pub fn new() -> Self {
Self {
enabled: true,
on_rate_limit: true,
on_status: default_fallback_status_codes(),
with_replacement: false,
max_attempts: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct DeploymentComponent {
#[schema(value_type = String, format = "uuid")]
pub deployed_model_id: DeploymentId,
pub weight: i32,
pub enabled: bool,
#[serde(default)]
pub sort_order: i32,
}
#[derive(Debug, Clone)]
pub struct DeploymentComponentCreateDBRequest {
pub composite_model_id: DeploymentId,
pub deployed_model_id: DeploymentId,
pub weight: i32,
pub enabled: bool,
pub sort_order: i32,
}
#[derive(Debug, Clone)]
pub struct DeploymentComponentDBResponse {
pub id: uuid::Uuid,
pub composite_model_id: DeploymentId,
pub deployed_model_id: DeploymentId,
pub weight: i32,
pub enabled: bool,
pub sort_order: i32,
pub created_at: DateTime<Utc>,
pub model_alias: String,
pub model_name: String,
pub model_description: Option<String>,
pub model_type: Option<String>,
pub endpoint_id: Option<InferenceEndpointId>,
pub endpoint_name: Option<String>,
pub model_trusted: bool,
pub model_open_responses_adapter: bool,
}
#[derive(Debug, Clone, Builder)]
pub struct DeploymentCreateDBRequest {
pub created_by: UserId,
pub model_name: String,
pub alias: String,
pub display_name: Option<String>,
pub description: Option<String>,
pub model_type: Option<ModelType>,
pub capabilities: Option<Vec<String>>,
pub hosted_on: Option<InferenceEndpointId>,
pub requests_per_second: Option<f32>,
pub burst_size: Option<i32>,
pub capacity: Option<i32>,
pub batch_capacity: Option<i32>,
pub throughput: Option<f32>,
pub provider_pricing: Option<ProviderPricing>,
#[builder(default)]
pub is_composite: bool,
pub lb_strategy: Option<LoadBalancingStrategy>,
pub fallback_enabled: Option<bool>,
pub fallback_on_rate_limit: Option<bool>,
pub fallback_on_status: Option<Vec<i32>>,
pub fallback_with_replacement: Option<bool>,
pub fallback_max_attempts: Option<i32>,
#[builder(default = false)]
pub sanitize_responses: bool,
#[builder(default = false)]
pub trusted: bool,
#[builder(default = true)]
pub open_responses_adapter: bool,
pub allowed_batch_completion_windows: Option<Vec<String>>,
pub metadata: Option<ModelCatalogMetadata>,
}
impl DeploymentCreateDBRequest {
pub fn from_api_create(created_by: UserId, create: DeployedModelCreate) -> Self {
match create {
DeployedModelCreate::Standard(standard) => Self::builder()
.created_by(created_by)
.model_name(standard.model_name.clone())
.alias(standard.alias.unwrap_or(standard.model_name))
.maybe_display_name(standard.display_name)
.maybe_description(standard.description)
.maybe_model_type(standard.model_type)
.maybe_capabilities(standard.capabilities)
.hosted_on(standard.hosted_on)
.maybe_requests_per_second(standard.requests_per_second)
.maybe_burst_size(standard.burst_size)
.maybe_capacity(standard.capacity)
.maybe_batch_capacity(standard.batch_capacity)
.maybe_throughput(standard.throughput)
.maybe_provider_pricing(standard.provider_pricing)
.is_composite(false)
.sanitize_responses(standard.sanitize_responses.unwrap_or(false))
.trusted(standard.trusted.unwrap_or(false))
.open_responses_adapter(standard.open_responses_adapter.unwrap_or(true))
.maybe_allowed_batch_completion_windows(standard.allowed_batch_completion_windows)
.maybe_metadata(standard.metadata)
.build(),
DeployedModelCreate::Composite(composite) => Self::builder()
.created_by(created_by)
.model_name(composite.model_name.clone())
.alias(composite.alias.unwrap_or(composite.model_name))
.maybe_display_name(composite.display_name)
.maybe_description(composite.description)
.maybe_model_type(composite.model_type)
.maybe_capabilities(composite.capabilities)
.maybe_requests_per_second(composite.requests_per_second)
.maybe_burst_size(composite.burst_size)
.maybe_capacity(composite.capacity)
.maybe_batch_capacity(composite.batch_capacity)
.maybe_throughput(composite.throughput)
.is_composite(true)
.lb_strategy(composite.lb_strategy)
.fallback_enabled(composite.fallback_enabled)
.fallback_on_rate_limit(composite.fallback_on_rate_limit)
.fallback_on_status(composite.fallback_on_status)
.fallback_with_replacement(composite.fallback_with_replacement)
.maybe_fallback_max_attempts(composite.fallback_max_attempts)
.sanitize_responses(composite.sanitize_responses)
.trusted(composite.trusted.unwrap_or(false))
.open_responses_adapter(composite.open_responses_adapter.unwrap_or(true))
.maybe_allowed_batch_completion_windows(composite.allowed_batch_completion_windows)
.maybe_metadata(composite.metadata)
.build(),
}
}
}
#[derive(Debug, Clone, Builder)]
pub struct DeploymentUpdateDBRequest {
pub model_name: Option<String>,
pub alias: Option<String>,
pub display_name: Option<String>,
pub description: Option<Option<String>>,
pub model_type: Option<Option<ModelType>>,
pub capabilities: Option<Option<Vec<String>>>,
pub status: Option<ModelStatus>,
pub last_sync: Option<Option<DateTime<Utc>>>,
pub deleted: Option<bool>,
pub requests_per_second: Option<Option<f32>>,
pub burst_size: Option<Option<i32>>,
pub capacity: Option<Option<i32>>,
pub batch_capacity: Option<Option<i32>>,
pub throughput: Option<Option<f32>>,
pub provider_pricing: Option<ProviderPricingUpdate>,
pub lb_strategy: Option<LoadBalancingStrategy>,
pub fallback_enabled: Option<bool>,
pub fallback_on_rate_limit: Option<bool>,
pub fallback_on_status: Option<Vec<i32>>,
pub fallback_with_replacement: Option<bool>,
pub fallback_max_attempts: Option<Option<i32>>,
pub sanitize_responses: Option<bool>,
pub trusted: Option<bool>,
pub open_responses_adapter: Option<bool>,
pub allowed_batch_completion_windows: Option<Option<Vec<String>>>,
pub metadata: Option<ModelCatalogMetadata>,
}
impl From<DeployedModelUpdate> for DeploymentUpdateDBRequest {
fn from(update: DeployedModelUpdate) -> Self {
Self::builder()
.maybe_alias(update.alias)
.maybe_display_name(update.display_name)
.maybe_description(update.description)
.maybe_model_type(update.model_type)
.maybe_capabilities(update.capabilities)
.maybe_requests_per_second(update.requests_per_second)
.maybe_burst_size(update.burst_size)
.maybe_capacity(update.capacity)
.maybe_batch_capacity(update.batch_capacity)
.maybe_throughput(update.throughput)
.maybe_provider_pricing(update.provider_pricing)
.maybe_lb_strategy(update.lb_strategy)
.maybe_fallback_enabled(update.fallback_enabled)
.maybe_fallback_on_rate_limit(update.fallback_on_rate_limit)
.maybe_fallback_on_status(update.fallback_on_status)
.maybe_fallback_with_replacement(update.fallback_with_replacement)
.maybe_fallback_max_attempts(update.fallback_max_attempts)
.maybe_sanitize_responses(update.sanitize_responses)
.maybe_trusted(update.trusted)
.maybe_open_responses_adapter(update.open_responses_adapter)
.maybe_allowed_batch_completion_windows(update.allowed_batch_completion_windows)
.maybe_metadata(update.metadata)
.build()
}
}
impl DeploymentUpdateDBRequest {
pub fn status_update(status: Option<ModelStatus>, last_sync: DateTime<Utc>) -> Self {
Self::builder().maybe_status(status).last_sync(Some(last_sync)).build()
}
pub fn visibility_update(deleted: bool) -> Self {
Self::builder().deleted(deleted).build()
}
pub fn alias_update(new_alias: String) -> Self {
Self::builder().maybe_alias(Some(new_alias)).build()
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct DeploymentDBResponse {
pub id: DeploymentId,
pub model_name: String,
pub alias: String,
pub display_name: Option<String>,
pub description: Option<String>,
pub model_type: Option<ModelType>,
pub capabilities: Option<Vec<String>>,
pub created_by: UserId,
pub hosted_on: Option<InferenceEndpointId>,
pub status: ModelStatus,
pub last_sync: Option<DateTime<Utc>>,
pub deleted: bool,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub requests_per_second: Option<f32>,
pub burst_size: Option<i32>,
pub capacity: Option<i32>,
pub batch_capacity: Option<i32>,
pub throughput: Option<f32>,
pub provider_pricing: Option<ProviderPricing>,
pub is_composite: bool,
pub lb_strategy: LoadBalancingStrategy,
pub fallback_enabled: bool,
pub fallback_on_rate_limit: bool,
pub fallback_on_status: Vec<i32>,
pub fallback_with_replacement: bool,
pub fallback_max_attempts: Option<i32>,
pub sanitize_responses: bool,
pub trusted: bool,
pub open_responses_adapter: bool,
pub allowed_batch_completion_windows: Option<Vec<String>>,
pub metadata: serde_json::Value,
}
#[derive(Debug, Clone)]
pub enum TrafficRuleAction {
Deny,
Redirect(DeploymentId),
}
#[derive(Debug, Clone)]
pub struct TrafficRuleDBRow {
pub id: Uuid,
pub deployed_model_id: DeploymentId,
pub api_key_purpose: String,
pub action: String,
pub redirect_target_id: Option<DeploymentId>,
pub redirect_target_alias: Option<String>,
pub created_at: DateTime<Utc>,
}