use crate::config::CachePricingConfig;
use crate::db::errors::Result;
use crate::types::DeploymentId;
use rust_decimal::Decimal;
use sqlx::{Connection, PgConnection};
use tracing::instrument;
#[derive(Debug, Default, Clone)]
pub struct CacheTariffOverrides {
pub write_multiplier_5m: Option<Decimal>,
pub write_multiplier_1h: Option<Decimal>,
pub write_multiplier_24h: Option<Decimal>,
pub read_multiplier: Option<Decimal>,
pub min_prefix_tokens: Option<i32>,
}
pub struct CacheTariffs<'c> {
db: &'c mut PgConnection,
}
impl<'c> CacheTariffs<'c> {
pub fn new(db: &'c mut PgConnection) -> Self {
Self { db }
}
#[instrument(skip(self, defaults, overrides), fields(deployed_model_id = %model_id), err)]
pub async fn enable(&mut self, model_id: DeploymentId, defaults: &CachePricingConfig, overrides: CacheTariffOverrides) -> Result<()> {
let mut tx = self.db.begin().await?;
sqlx::query!(
r#"UPDATE model_cache_tariffs SET valid_until = now()
WHERE deployed_model_id = $1
AND valid_from <= now()
AND (valid_until IS NULL OR valid_until > now())"#,
model_id,
)
.execute(&mut *tx)
.await?;
sqlx::query!(
r#"INSERT INTO model_cache_tariffs
(deployed_model_id, write_multiplier_5m, write_multiplier_1h, write_multiplier_24h,
read_multiplier, min_prefix_tokens)
VALUES ($1, $2, $3, $4, $5, $6)"#,
model_id,
overrides.write_multiplier_5m.unwrap_or(defaults.default_write_multiplier_5m),
overrides.write_multiplier_1h.unwrap_or(defaults.default_write_multiplier_1h),
overrides.write_multiplier_24h.unwrap_or(defaults.default_write_multiplier_24h),
overrides.read_multiplier.unwrap_or(defaults.default_read_multiplier),
overrides.min_prefix_tokens.unwrap_or(defaults.default_min_prefix_tokens),
)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(())
}
#[instrument(skip(self), fields(deployed_model_id = %model_id), err)]
pub async fn disable(&mut self, model_id: DeploymentId) -> Result<bool> {
let res = sqlx::query!(
r#"UPDATE model_cache_tariffs SET valid_until = now()
WHERE deployed_model_id = $1
AND valid_from <= now()
AND (valid_until IS NULL OR valid_until > now())"#,
model_id,
)
.execute(&mut *self.db)
.await?;
Ok(res.rows_affected() > 0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prompt_cache::ModelConfigResolver;
use crate::test::utils::{create_test_endpoint, create_test_model, create_test_user};
use sqlx::PgPool;
#[sqlx::test]
async fn enable_then_disable_toggles_resolver(pool: PgPool) {
let user = create_test_user(&pool, crate::api::models::users::Role::StandardUser).await;
let endpoint = create_test_endpoint(&pool, "ep", user.id).await;
let id = create_test_model(&pool, "m", "enable-alias", endpoint, user.id).await;
let defaults = CachePricingConfig::default();
let resolver = ModelConfigResolver::new(pool.clone());
assert!(!resolver.resolve("enable-alias").await.unwrap().enabled);
{
let mut conn = pool.acquire().await.unwrap();
CacheTariffs::new(&mut conn)
.enable(id, &defaults, CacheTariffOverrides::default())
.await
.unwrap();
}
let cfg = ModelConfigResolver::new(pool.clone()).resolve("enable-alias").await.unwrap();
assert!(cfg.enabled);
assert_eq!(cfg.min_prefix_tokens, defaults.default_min_prefix_tokens.max(0) as u32);
{
let mut conn = pool.acquire().await.unwrap();
assert!(CacheTariffs::new(&mut conn).disable(id).await.unwrap());
}
assert!(!ModelConfigResolver::new(pool).resolve("enable-alias").await.unwrap().enabled);
}
#[sqlx::test]
async fn enable_twice_supersedes_keeping_history(pool: PgPool) {
let user = create_test_user(&pool, crate::api::models::users::Role::StandardUser).await;
let endpoint = create_test_endpoint(&pool, "ep", user.id).await;
let id = create_test_model(&pool, "m", "reprice-alias", endpoint, user.id).await;
let defaults = CachePricingConfig::default();
let mut conn = pool.acquire().await.unwrap();
let mut repo = CacheTariffs::new(&mut conn);
repo.enable(id, &defaults, CacheTariffOverrides::default()).await.unwrap();
repo.enable(
id,
&defaults,
CacheTariffOverrides {
min_prefix_tokens: Some(2048),
..Default::default()
},
)
.await
.unwrap();
let total = sqlx::query_scalar!("SELECT COUNT(*) FROM model_cache_tariffs WHERE deployed_model_id = $1", id)
.fetch_one(&mut *conn)
.await
.unwrap();
let active = sqlx::query_scalar!(
"SELECT COUNT(*) FROM model_cache_tariffs WHERE deployed_model_id = $1 AND valid_until IS NULL",
id
)
.fetch_one(&mut *conn)
.await
.unwrap();
assert_eq!(total, Some(2), "old version retained for audit");
assert_eq!(active, Some(1), "exactly one active version");
}
#[sqlx::test]
async fn partial_unique_index_rejects_two_active_versions(pool: PgPool) {
let user = create_test_user(&pool, crate::api::models::users::Role::StandardUser).await;
let endpoint = create_test_endpoint(&pool, "ep", user.id).await;
let id = create_test_model(&pool, "m", "dup-active-alias", endpoint, user.id).await;
const INSERT_ACTIVE: &str = "INSERT INTO model_cache_tariffs
(deployed_model_id, write_multiplier_5m, write_multiplier_1h, write_multiplier_24h, min_prefix_tokens)
VALUES ($1, 1.25, 2.0, 2.5, 1024)";
sqlx::query(INSERT_ACTIVE)
.bind(id)
.execute(&pool)
.await
.expect("first active row inserts");
let err = sqlx::query(INSERT_ACTIVE)
.bind(id)
.execute(&pool)
.await
.expect_err("second active row must be rejected");
assert!(
err.as_database_error().is_some_and(|e| e.is_unique_violation()),
"expected a unique violation from idx_model_cache_tariffs_unique_active, got: {err:?}"
);
}
}