use std::time::Duration;
use moka::future::Cache;
use sqlx::PgPool;
use super::index::CacheResult;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ModelCacheConfig {
pub enabled: bool,
pub min_prefix_tokens: u32,
}
impl ModelCacheConfig {
pub const DISABLED: Self = Self {
enabled: false,
min_prefix_tokens: u32::MAX,
};
}
#[derive(Clone)]
pub struct ModelConfigResolver {
pool: PgPool,
cache: Cache<String, ModelCacheConfig>,
}
impl ModelConfigResolver {
pub fn new(pool: PgPool) -> Self {
let cache = Cache::builder()
.max_capacity(10_000)
.time_to_live(Duration::from_secs(60))
.build();
Self { pool, cache }
}
pub async fn resolve(&self, virtual_model: &str) -> CacheResult<ModelCacheConfig> {
if let Some(c) = self.cache.get(virtual_model).await {
return Ok(c);
}
let row = sqlx::query!(
r#"
SELECT MIN(mct.min_prefix_tokens) AS min_prefix
FROM deployed_models dm
JOIN model_cache_tariffs mct
ON mct.deployed_model_id = dm.id
AND mct.valid_from <= now()
AND (mct.valid_until IS NULL OR mct.valid_until > now())
WHERE dm.alias = $1 AND dm.deleted = false
"#,
virtual_model,
)
.fetch_one(&self.pool)
.await?;
let config = match row.min_prefix {
Some(m) => ModelCacheConfig {
enabled: true,
min_prefix_tokens: m.max(0) as u32,
},
None => ModelCacheConfig::DISABLED,
};
self.cache.insert(virtual_model.to_string(), config).await;
Ok(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test::utils::{create_test_endpoint, create_test_model, create_test_user};
async fn add_tariff(pool: &PgPool, model_id: uuid::Uuid, min_prefix: i32, expired: bool) {
sqlx::query!(
r#"INSERT INTO model_cache_tariffs
(deployed_model_id, write_multiplier_5m, write_multiplier_1h, write_multiplier_24h, min_prefix_tokens, valid_until)
VALUES ($1, 1.25, 2.0, 2.5, $2, CASE WHEN $3 THEN now() - interval '1 hour' ELSE NULL END)"#,
model_id,
min_prefix,
expired,
)
.execute(pool)
.await
.unwrap();
}
#[sqlx::test]
async fn disabled_without_tariff_and_unknown_model(pool: PgPool) {
let user = create_test_user(&pool, crate::api::models::users::Role::StandardUser).await;
let endpoint = create_test_endpoint(&pool, "ep", user.id).await;
let _ = create_test_model(&pool, "m1", "alias-default", endpoint, user.id).await;
let r = ModelConfigResolver::new(pool);
assert!(!r.resolve("alias-default").await.unwrap().enabled);
assert_eq!(r.resolve("nope").await.unwrap(), ModelCacheConfig::DISABLED);
}
#[sqlx::test]
async fn active_tariff_enables_with_its_floor(pool: PgPool) {
let user = create_test_user(&pool, crate::api::models::users::Role::StandardUser).await;
let endpoint = create_test_endpoint(&pool, "ep", user.id).await;
let id = create_test_model(&pool, "m2", "alias-on", endpoint, user.id).await;
add_tariff(&pool, id, 2048, false).await;
let cfg = ModelConfigResolver::new(pool).resolve("alias-on").await.unwrap();
assert!(cfg.enabled);
assert_eq!(cfg.min_prefix_tokens, 2048);
}
#[sqlx::test]
async fn expired_tariff_is_disabled(pool: PgPool) {
let user = create_test_user(&pool, crate::api::models::users::Role::StandardUser).await;
let endpoint = create_test_endpoint(&pool, "ep", user.id).await;
let id = create_test_model(&pool, "m3", "alias-expired", endpoint, user.id).await;
add_tariff(&pool, id, 1024, true).await;
let cfg = ModelConfigResolver::new(pool).resolve("alias-expired").await.unwrap();
assert!(!cfg.enabled, "an expired tariff version no longer enables caching");
}
}