use crate::{
db::{
errors::Result,
models::tariffs::{ModelTariff, TariffCreateDBRequest, TariffDBResponse},
},
types::DeploymentId,
};
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use sqlx::PgConnection;
use tracing::instrument;
use uuid::Uuid;
pub struct Tariffs<'c> {
db: &'c mut PgConnection,
}
impl<'c> Tariffs<'c> {
pub fn new(db: &'c mut PgConnection) -> Self {
Self { db }
}
#[instrument(skip(self, request), fields(deployed_model_id = %request.deployed_model_id, name = %request.name), err)]
pub async fn create(&mut self, request: &TariffCreateDBRequest) -> Result<TariffDBResponse> {
let purpose_str = request.api_key_purpose.as_ref().map(|p| match p {
crate::db::models::api_keys::ApiKeyPurpose::Realtime => "realtime",
crate::db::models::api_keys::ApiKeyPurpose::Batch => "batch",
crate::db::models::api_keys::ApiKeyPurpose::Playground => "playground",
crate::db::models::api_keys::ApiKeyPurpose::Platform => "platform",
});
let tariff = sqlx::query_as!(
ModelTariff,
r#"
INSERT INTO model_tariffs (
deployed_model_id, name, input_price_per_token, output_price_per_token,
api_key_purpose, completion_window, valid_from
)
VALUES ($1, $2, $3, $4, $5, $6, COALESCE($7, NOW()))
RETURNING id, deployed_model_id, name, input_price_per_token, output_price_per_token,
valid_from, valid_until, api_key_purpose as "api_key_purpose: _", completion_window
"#,
request.deployed_model_id,
request.name,
request.input_price_per_token,
request.output_price_per_token,
purpose_str,
request.completion_window,
request.valid_from,
)
.fetch_one(&mut *self.db)
.await?;
Ok(tariff)
}
#[instrument(skip(self), err)]
pub async fn get_by_id(&mut self, id: Uuid) -> Result<Option<TariffDBResponse>> {
let tariff = sqlx::query_as!(
ModelTariff,
r#"
SELECT id, deployed_model_id, name, input_price_per_token, output_price_per_token,
valid_from, valid_until, api_key_purpose as "api_key_purpose: _", completion_window
FROM model_tariffs
WHERE id = $1
"#,
id
)
.fetch_optional(&mut *self.db)
.await?;
Ok(tariff)
}
#[instrument(skip(self), err)]
pub async fn list_current_by_model(&mut self, deployed_model_id: DeploymentId) -> Result<Vec<TariffDBResponse>> {
let tariffs = sqlx::query_as!(
ModelTariff,
r#"
SELECT id, deployed_model_id, name, input_price_per_token, output_price_per_token,
valid_from, valid_until, api_key_purpose as "api_key_purpose: _", completion_window
FROM model_tariffs
WHERE deployed_model_id = $1 AND valid_until IS NULL
ORDER BY api_key_purpose ASC NULLS LAST, completion_window ASC NULLS LAST, name ASC
"#,
deployed_model_id
)
.fetch_all(&mut *self.db)
.await?;
Ok(tariffs)
}
#[instrument(skip(self), err)]
pub async fn list_all_by_model(&mut self, deployed_model_id: DeploymentId) -> Result<Vec<TariffDBResponse>> {
let tariffs = sqlx::query_as!(
ModelTariff,
r#"
SELECT id, deployed_model_id, name, input_price_per_token, output_price_per_token,
valid_from, valid_until, api_key_purpose as "api_key_purpose: _", completion_window
FROM model_tariffs
WHERE deployed_model_id = $1
ORDER BY valid_from DESC, api_key_purpose ASC NULLS LAST, completion_window ASC NULLS LAST, name ASC
"#,
deployed_model_id
)
.fetch_all(&mut *self.db)
.await?;
Ok(tariffs)
}
#[instrument(skip(self), err)]
pub async fn get_pricing_at_timestamp_with_fallback(
&mut self,
deployed_model_id: DeploymentId,
preferred_purpose: Option<&crate::db::models::api_keys::ApiKeyPurpose>,
fallback_purpose: &crate::db::models::api_keys::ApiKeyPurpose,
timestamp: DateTime<Utc>,
completion_window: Option<&str>,
) -> Result<Option<(Decimal, Decimal)>> {
if let Some(preferred) = preferred_purpose
&& let Some(pricing) = self
.get_pricing_at_timestamp(deployed_model_id, preferred, timestamp, completion_window)
.await?
{
return Ok(Some(pricing));
}
self.get_pricing_at_timestamp(deployed_model_id, fallback_purpose, timestamp, None)
.await
}
#[instrument(skip(self), err)]
pub async fn get_pricing_at_timestamp(
&mut self,
deployed_model_id: DeploymentId,
api_key_purpose: &crate::db::models::api_keys::ApiKeyPurpose,
timestamp: DateTime<Utc>,
completion_window: Option<&str>,
) -> Result<Option<(Decimal, Decimal)>> {
let purpose_str = match api_key_purpose {
crate::db::models::api_keys::ApiKeyPurpose::Realtime => "realtime",
crate::db::models::api_keys::ApiKeyPurpose::Batch => "batch",
crate::db::models::api_keys::ApiKeyPurpose::Playground => "playground",
crate::db::models::api_keys::ApiKeyPurpose::Platform => "platform",
};
let current_tariff = sqlx::query!(
r#"
SELECT input_price_per_token, output_price_per_token, valid_from
FROM model_tariffs
WHERE deployed_model_id = $1
AND api_key_purpose = $2
AND valid_until IS NULL
AND ($3::VARCHAR IS NULL OR completion_window = $3 OR api_key_purpose != 'batch')
LIMIT 1
"#,
deployed_model_id,
purpose_str,
completion_window
)
.fetch_optional(&mut *self.db)
.await?;
if let Some(current) = current_tariff
&& timestamp >= current.valid_from
{
return Ok(Some((current.input_price_per_token, current.output_price_per_token)));
}
let result = sqlx::query!(
r#"
SELECT input_price_per_token, output_price_per_token
FROM model_tariffs
WHERE deployed_model_id = $1
AND api_key_purpose = $2
AND valid_from <= $3
AND (valid_until IS NULL OR valid_until > $3)
AND ($4::VARCHAR IS NULL OR completion_window = $4 OR api_key_purpose != 'batch')
ORDER BY valid_from DESC
LIMIT 1
"#,
deployed_model_id,
purpose_str,
timestamp,
completion_window
)
.fetch_optional(&mut *self.db)
.await?;
Ok(result.map(|r| (r.input_price_per_token, r.output_price_per_token)))
}
#[instrument(skip(self), fields(count = ids.len()), err)]
pub async fn close_tariffs_batch(&mut self, ids: &[Uuid]) -> Result<u64> {
if ids.is_empty() {
return Ok(0);
}
let result = sqlx::query!("UPDATE model_tariffs SET valid_until = NOW() WHERE id = ANY($1)", ids)
.execute(&mut *self.db)
.await?;
Ok(result.rows_affected())
}
#[instrument(skip(self), err)]
pub async fn delete(&mut self, id: Uuid) -> Result<bool> {
let result = sqlx::query!(
r#"
DELETE FROM model_tariffs
WHERE id = $1
"#,
id
)
.execute(&mut *self.db)
.await?;
Ok(result.rows_affected() > 0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::models::api_keys::ApiKeyPurpose;
use crate::types::DeploymentId;
use rust_decimal::Decimal;
use sqlx::PgPool;
use std::str::FromStr;
#[sqlx::test]
async fn test_multiple_batch_tariffs_per_sla(pool: PgPool) {
let base_url = url::Url::parse("http://localhost:8080").unwrap();
let sources = vec![crate::config::ModelSource {
name: "test".to_string(),
url: base_url.clone(),
api_key: None,
sync_interval: std::time::Duration::from_secs(3600),
default_models: None,
}];
crate::seed_database(&sources, &pool).await.unwrap();
let user = crate::test::utils::create_test_user(&pool, crate::api::models::users::Role::StandardUser).await;
let test_endpoint_id = crate::test::utils::get_test_endpoint_id(&pool).await;
let deployment_id = DeploymentId::new_v4();
let mut tx = pool.begin().await.unwrap();
sqlx::query!(
"INSERT INTO deployed_models (id, model_name, alias, hosted_on, created_by) VALUES ($1, 'test-model', 'test-alias', $2, $3)",
deployment_id,
test_endpoint_id,
user.id
)
.execute(&mut *tx)
.await
.unwrap();
let mut tariffs = Tariffs::new(&mut tx);
let tariff_24h = TariffCreateDBRequest {
deployed_model_id: deployment_id,
name: "Batch 24h".to_string(),
input_price_per_token: Decimal::from_str("0.001").unwrap(),
output_price_per_token: Decimal::from_str("0.002").unwrap(),
api_key_purpose: Some(ApiKeyPurpose::Batch),
completion_window: Some("24h".to_string()),
valid_from: None,
};
let created_24h = tariffs.create(&tariff_24h).await.unwrap();
assert_eq!(created_24h.completion_window, Some("24h".to_string()));
let tariff_1h = TariffCreateDBRequest {
deployed_model_id: deployment_id,
name: "Batch 1h".to_string(),
input_price_per_token: Decimal::from_str("0.002").unwrap(),
output_price_per_token: Decimal::from_str("0.004").unwrap(),
api_key_purpose: Some(ApiKeyPurpose::Batch),
completion_window: Some("1h".to_string()),
valid_from: None,
};
let created_1h = tariffs.create(&tariff_1h).await.unwrap();
assert_eq!(created_1h.completion_window, Some("1h".to_string()));
let current_tariffs = tariffs.list_current_by_model(deployment_id).await.unwrap();
assert_eq!(current_tariffs.len(), 2);
let tariff_24h_found = current_tariffs
.iter()
.find(|t| t.completion_window == Some("24h".to_string()))
.unwrap();
assert_eq!(tariff_24h_found.name, "Batch 24h");
let tariff_1h_found = current_tariffs
.iter()
.find(|t| t.completion_window == Some("1h".to_string()))
.unwrap();
assert_eq!(tariff_1h_found.name, "Batch 1h");
}
#[sqlx::test]
async fn test_duplicate_batch_tariff_same_sla_rejected(pool: PgPool) {
let base_url = url::Url::parse("http://localhost:8080").unwrap();
let sources = vec![crate::config::ModelSource {
name: "test".to_string(),
url: base_url.clone(),
api_key: None,
sync_interval: std::time::Duration::from_secs(3600),
default_models: None,
}];
crate::seed_database(&sources, &pool).await.unwrap();
let user = crate::test::utils::create_test_user(&pool, crate::api::models::users::Role::StandardUser).await;
let test_endpoint_id = crate::test::utils::get_test_endpoint_id(&pool).await;
let deployment_id = DeploymentId::new_v4();
let mut tx = pool.begin().await.unwrap();
sqlx::query!(
"INSERT INTO deployed_models (id, model_name, alias, hosted_on, created_by) VALUES ($1, 'test-model', 'test-alias', $2, $3)",
deployment_id,
test_endpoint_id,
user.id
)
.execute(&mut *tx)
.await
.unwrap();
let mut tariffs = Tariffs::new(&mut tx);
let tariff_24h = TariffCreateDBRequest {
deployed_model_id: deployment_id,
name: "Batch 24h".to_string(),
input_price_per_token: Decimal::from_str("0.001").unwrap(),
output_price_per_token: Decimal::from_str("0.002").unwrap(),
api_key_purpose: Some(ApiKeyPurpose::Batch),
completion_window: Some("24h".to_string()),
valid_from: None,
};
tariffs.create(&tariff_24h).await.unwrap();
let duplicate_tariff = TariffCreateDBRequest {
deployed_model_id: deployment_id,
name: "Batch 24h Duplicate".to_string(),
input_price_per_token: Decimal::from_str("0.003").unwrap(),
output_price_per_token: Decimal::from_str("0.006").unwrap(),
api_key_purpose: Some(ApiKeyPurpose::Batch),
completion_window: Some("24h".to_string()),
valid_from: None,
};
let result = tariffs.create(&duplicate_tariff).await;
assert!(result.is_err(), "Should not allow duplicate batch tariff with same SLA");
}
#[sqlx::test]
async fn test_single_realtime_tariff_still_enforced(pool: PgPool) {
let base_url = url::Url::parse("http://localhost:8080").unwrap();
let sources = vec![crate::config::ModelSource {
name: "test".to_string(),
url: base_url.clone(),
api_key: None,
sync_interval: std::time::Duration::from_secs(3600),
default_models: None,
}];
crate::seed_database(&sources, &pool).await.unwrap();
let user = crate::test::utils::create_test_user(&pool, crate::api::models::users::Role::StandardUser).await;
let test_endpoint_id = crate::test::utils::get_test_endpoint_id(&pool).await;
let deployment_id = DeploymentId::new_v4();
let mut tx = pool.begin().await.unwrap();
sqlx::query!(
"INSERT INTO deployed_models (id, model_name, alias, hosted_on, created_by) VALUES ($1, 'test-model', 'test-alias', $2, $3)",
deployment_id,
test_endpoint_id,
user.id
)
.execute(&mut *tx)
.await
.unwrap();
let mut tariffs = Tariffs::new(&mut tx);
let realtime_tariff = TariffCreateDBRequest {
deployed_model_id: deployment_id,
name: "Realtime".to_string(),
input_price_per_token: Decimal::from_str("0.001").unwrap(),
output_price_per_token: Decimal::from_str("0.002").unwrap(),
api_key_purpose: Some(ApiKeyPurpose::Realtime),
completion_window: None,
valid_from: None,
};
tariffs.create(&realtime_tariff).await.unwrap();
let duplicate_realtime = TariffCreateDBRequest {
deployed_model_id: deployment_id,
name: "Realtime 2".to_string(),
input_price_per_token: Decimal::from_str("0.003").unwrap(),
output_price_per_token: Decimal::from_str("0.006").unwrap(),
api_key_purpose: Some(ApiKeyPurpose::Realtime),
completion_window: None,
valid_from: None,
};
let result = tariffs.create(&duplicate_realtime).await;
assert!(result.is_err(), "Should still enforce single realtime tariff per model");
}
#[sqlx::test]
async fn test_batch_tariff_without_completion_window_rejected(pool: PgPool) {
let base_url = url::Url::parse("http://localhost:8080").unwrap();
let sources = vec![crate::config::ModelSource {
name: "test".to_string(),
url: base_url.clone(),
api_key: None,
sync_interval: std::time::Duration::from_secs(3600),
default_models: None,
}];
crate::seed_database(&sources, &pool).await.unwrap();
let user = crate::test::utils::create_test_user(&pool, crate::api::models::users::Role::StandardUser).await;
let test_endpoint_id = crate::test::utils::get_test_endpoint_id(&pool).await;
let deployment_id = DeploymentId::new_v4();
let mut tx = pool.begin().await.unwrap();
sqlx::query!(
"INSERT INTO deployed_models (id, model_name, alias, hosted_on, created_by) VALUES ($1, 'test-model', 'test-alias', $2, $3)",
deployment_id,
test_endpoint_id,
user.id
)
.execute(&mut *tx)
.await
.unwrap();
let mut tariffs = Tariffs::new(&mut tx);
let batch_without_sla = TariffCreateDBRequest {
deployed_model_id: deployment_id,
name: "Batch No SLA".to_string(),
input_price_per_token: Decimal::from_str("0.001").unwrap(),
output_price_per_token: Decimal::from_str("0.002").unwrap(),
api_key_purpose: Some(ApiKeyPurpose::Batch),
completion_window: None, valid_from: None,
};
let result = tariffs.create(&batch_without_sla).await;
assert!(result.is_err(), "Should not allow batch tariff without completion_window");
if let Err(e) = result {
let error_msg = format!("{:?}", e);
assert!(
error_msg.contains("batch_tariffs_must_have_completion_window") || error_msg.contains("constraint"),
"Error should be due to CHECK constraint violation, got: {}",
error_msg
);
}
}
}