use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use chrono::Utc;
use super::index::{CacheEntry, CacheIndex, CacheResult, IndexScope, PrefixHash};
use super::model_config::ModelConfigResolver;
use super::parse::parse_chat_completions;
use super::principal::PrincipalResolver;
use super::stats::{CacheStats, PendingWrite};
use super::tokenizer::TokenizerClient;
pub struct ClassifyRequest<'a> {
pub virtual_model: &'a str,
pub body: &'a [u8],
pub api_key: Option<&'a str>,
}
pub struct ClassifyOutcome {
pub stats: CacheStats,
pub pending: PendingWrite,
pub active: bool,
}
impl ClassifyOutcome {
pub(crate) fn inactive() -> Self {
Self {
stats: CacheStats::default(),
pending: PendingWrite::default(),
active: false,
}
}
fn zero_active() -> Self {
Self {
stats: CacheStats::default(),
pending: PendingWrite::default(),
active: true,
}
}
fn active(stats: CacheStats, pending: PendingWrite) -> Self {
Self {
stats,
pending,
active: true,
}
}
}
#[derive(Clone)]
pub struct Classifier {
principal: PrincipalResolver,
model_config: ModelConfigResolver,
tokenizer: TokenizerClient,
index: Arc<dyn CacheIndex>,
versions: moka::future::Cache<String, Option<String>>,
}
impl Classifier {
pub fn new(
principal: PrincipalResolver,
model_config: ModelConfigResolver,
tokenizer: TokenizerClient,
index: Arc<dyn CacheIndex>,
) -> Self {
let versions = moka::future::Cache::builder()
.max_capacity(10_000)
.time_to_live(std::time::Duration::from_secs(300))
.build();
Self {
principal,
model_config,
tokenizer,
index,
versions,
}
}
pub async fn classify(&self, req: ClassifyRequest<'_>) -> CacheResult<ClassifyOutcome> {
let Some(api_key) = req.api_key else {
return Ok(ClassifyOutcome::inactive());
};
let Some(principal_id) = self.principal.resolve(api_key).await? else {
return Ok(ClassifyOutcome::inactive());
};
let cfg = self.model_config.resolve(req.virtual_model).await?;
if !cfg.enabled {
return Ok(ClassifyOutcome::inactive());
}
let parsed = match parse_chat_completions(req.body) {
Ok(p) => p,
Err(e) => {
tracing::debug!(error = %e, virtual_model = req.virtual_model, "cache classify: body not cacheable (unparseable / >4 breakpoints)");
return Ok(ClassifyOutcome::zero_active());
}
};
if parsed.breakpoints.is_empty() {
return Ok(ClassifyOutcome::zero_active()); }
let Some(tokenizer_version) = self.tokenizer_version(req.virtual_model).await? else {
return Ok(ClassifyOutcome::zero_active()); };
let scope = IndexScope {
principal_id,
virtual_model: req.virtual_model.to_string(),
tokenizer_version,
};
let read = self.find_longest_read(&scope, &parsed).await?;
let read_block = read.as_ref().map(|r| r.block); let read_tokens = read.as_ref().map(|r| r.tokens).unwrap_or(0);
let Some(deepest_bp) = parsed.breakpoints.last() else {
return Ok(ClassifyOutcome::zero_active());
};
let deepest = deepest_bp.block_index;
let mut stats = CacheStats {
read: read_tokens as u64,
..Default::default()
};
let mut pending = PendingWrite::default();
if let Some(r) = &read {
pending.refresh = Some((scope.clone(), r.hash.clone(), Utc::now() + r.duration));
}
if read_block == Some(deepest) {
return Ok(ClassifyOutcome::active(stats, pending));
}
let write_start = read_block.map(|b| b + 1).unwrap_or(0);
let segments: Vec<String> = parsed.blocks[write_start..=deepest].iter().map(|b| b.text.clone()).collect();
let tok = match self.tokenizer.tokenize(req.virtual_model, &segments).await {
Ok(tok) => tok,
Err(e) => {
tracing::debug!(error = %e, virtual_model = req.virtual_model, "cache classify: tokenize failed, degrading to no write");
return Ok(ClassifyOutcome::zero_active());
}
};
if tok.cumulative.len() != segments.len() {
tracing::debug!(
segments = segments.len(),
cumulative = tok.cumulative.len(),
virtual_model = req.virtual_model,
"cache classify: tokenizer segment-count mismatch, degrading to no write"
);
return Ok(ClassifyOutcome::zero_active());
}
let cumulative_at = |block: usize| -> u64 { read_tokens as u64 + tok.cumulative[block - write_start] as u64 };
let total_prefix = cumulative_at(deepest);
if total_prefix < cfg.min_prefix_tokens as u64 {
return Ok(ClassifyOutcome::zero_active()); }
let mut prev_boundary = read_tokens as u64;
let now = Utc::now();
let read_block_idx: isize = read_block.map(|b| b as isize).unwrap_or(-1);
for bp in parsed.breakpoints.iter().filter(|bp| bp.block_index as isize > read_block_idx) {
let bp_cumulative = cumulative_at(bp.block_index);
let segment_tokens = bp_cumulative.saturating_sub(prev_boundary);
stats.add_creation(bp.ttl_tier, segment_tokens);
pending.writes.push(CacheEntry {
scope: scope.clone(),
prefix_hash: parsed.cumulative_hashes[bp.block_index].clone(),
cumulative_token_count: bp_cumulative.min(u32::MAX as u64) as u32,
ttl_tier: bp.ttl_tier,
expires_at: now + bp.ttl_tier.duration(),
});
prev_boundary = bp_cumulative;
}
Ok(ClassifyOutcome::active(stats, pending))
}
pub async fn commit(&self, pending: &PendingWrite) -> CacheResult<()> {
for entry in &pending.writes {
self.index.write(entry).await?;
}
if let Some((scope, hash, new_expires_at)) = &pending.refresh {
self.index.refresh(scope, hash, *new_expires_at).await?;
}
Ok(())
}
async fn tokenizer_version(&self, alias: &str) -> CacheResult<Option<String>> {
if let Some(v) = self.versions.get(alias).await {
return Ok(v);
}
let Ok(models) = self.tokenizer.models().await else {
return Ok(None);
};
let mut found = None;
for m in models {
if m.alias == alias {
found = Some(m.tokenizer_version.clone());
}
self.versions.insert(m.alias, Some(m.tokenizer_version)).await;
}
if found.is_none() {
self.versions.insert(alias.to_string(), None).await;
}
Ok(found)
}
async fn find_longest_read(&self, scope: &IndexScope, parsed: &super::parse::ParsedPrompt) -> CacheResult<Option<ReadHit>> {
let mut candidates: Vec<PrefixHash> = Vec::new();
let mut seen: HashSet<PrefixHash> = HashSet::new();
for bp in &parsed.breakpoints {
for h in parsed.read_candidates(bp) {
if seen.insert(h.clone()) {
candidates.push(h);
}
}
}
let matches = self.index.lookup(scope, &candidates).await?;
if matches.is_empty() {
return Ok(None);
}
let hash_to_block: HashMap<&[u8], usize> = parsed
.cumulative_hashes
.iter()
.enumerate()
.map(|(i, h)| (h.as_slice(), i))
.collect();
let mut best: Option<ReadHit> = None;
for m in matches {
if let Some(&block) = hash_to_block.get(m.prefix_hash.as_slice())
&& best.as_ref().is_none_or(|b| block > b.block)
{
best = Some(ReadHit {
block,
tokens: m.cumulative_token_count,
hash: m.prefix_hash,
duration: m.ttl_tier.duration(),
});
}
}
Ok(best)
}
}
struct ReadHit {
block: usize,
tokens: u32,
hash: PrefixHash,
duration: chrono::Duration,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api::models::users::Role;
use crate::prompt_cache::{CacheEntry, IndexScope, PostgresIndex, TtlTier, parse_chat_completions};
use crate::test::utils::{create_test_api_key_for_user, create_test_endpoint, create_test_model, create_test_user};
use sqlx::PgPool;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
const ALIAS: &str = "cache-model";
const TOK_VER: &str = "sha256:v1";
fn body() -> Vec<u8> {
serde_json::json!({
"model": ALIAS,
"messages": [
{"role":"system","content":[
{"type":"text","text":"a long static system prompt","cache_control":{"type":"ephemeral","ttl":"1h"}}
]},
{"role":"user","content":"hello"}
]
})
.to_string()
.into_bytes()
}
fn prefix_hash() -> PrefixHash {
parse_chat_completions(&body()).unwrap().cumulative_hashes[0].clone()
}
struct H {
classifier: Classifier,
secret: String,
principal_id: uuid::Uuid,
pool: PgPool,
_server: MockServer,
}
async fn harness(pool: &PgPool, enabled: bool, tokenize_total: u32, min_prefix: i32) -> H {
let user = create_test_user(pool, Role::StandardUser).await;
let key = create_test_api_key_for_user(pool, user.id).await;
let endpoint = create_test_endpoint(pool, "ep", user.id).await;
let id = create_test_model(pool, "m", ALIAS, endpoint, user.id).await;
if enabled {
sqlx::query!(
r#"INSERT INTO model_cache_tariffs
(deployed_model_id, write_multiplier_5m, write_multiplier_1h, write_multiplier_24h, min_prefix_tokens)
VALUES ($1, 1.25, 2.0, 2.5, $2)"#,
id,
min_prefix
)
.execute(pool)
.await
.unwrap();
}
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/models"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"models": [{"alias": ALIAS, "hf_repo": "org/m", "tokenizer_version": TOK_VER}]
})))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/v1/tokenize"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"virtual_model": ALIAS, "tokenizer_version": TOK_VER,
"segment_counts": [tokenize_total], "cumulative": [tokenize_total], "total": tokenize_total
})))
.mount(&server)
.await;
let classifier = Classifier::new(
PrincipalResolver::new(pool.clone()),
ModelConfigResolver::new(pool.clone()),
TokenizerClient::new(server.uri()),
Arc::new(PostgresIndex::new(pool.clone())),
);
H {
classifier,
secret: key.secret,
principal_id: user.id,
pool: pool.clone(),
_server: server,
}
}
fn req<'a>(secret: &'a str, body: &'a [u8]) -> ClassifyRequest<'a> {
ClassifyRequest {
virtual_model: ALIAS,
body,
api_key: Some(secret),
}
}
#[sqlx::test]
async fn no_prior_entry_is_all_creation(pool: PgPool) {
let h = harness(&pool, true, 1500, 1024).await;
let b = body();
let ClassifyOutcome { stats, pending, active } = h.classifier.classify(req(&h.secret, &b)).await.unwrap();
assert!(active, "enabled model is active");
assert_eq!(stats.read, 0);
assert_eq!(stats.creation_1h, 1500);
assert_eq!(stats.creation_total(), 1500);
assert_eq!(pending.writes.len(), 1);
assert_eq!(pending.writes[0].cumulative_token_count, 1500);
assert_eq!(pending.writes[0].ttl_tier, TtlTier::OneHour);
assert_eq!(pending.writes[0].prefix_hash, prefix_hash());
assert!(pending.refresh.is_none());
}
#[sqlx::test]
async fn read_hit_is_pure_read(pool: PgPool) {
let h = harness(&pool, true, 1500, 1024).await;
let scope = IndexScope {
principal_id: h.principal_id,
virtual_model: ALIAS.to_string(),
tokenizer_version: TOK_VER.to_string(),
};
PostgresIndex::new(h.pool.clone())
.write(&CacheEntry {
scope: scope.clone(),
prefix_hash: prefix_hash(),
cumulative_token_count: 1500,
ttl_tier: TtlTier::OneHour,
expires_at: Utc::now() + chrono::Duration::hours(1),
})
.await
.unwrap();
let b = body();
let ClassifyOutcome { stats, pending, active } = h.classifier.classify(req(&h.secret, &b)).await.unwrap();
assert!(active);
assert_eq!(stats.read, 1500);
assert_eq!(stats.creation_total(), 0, "a full read writes nothing");
assert!(pending.writes.is_empty());
assert!(pending.refresh.is_some(), "a read slides the entry's TTL");
}
#[sqlx::test]
async fn below_floor_is_no_cache(pool: PgPool) {
let h = harness(&pool, true, 500, 1024).await; let b = body();
let out = h.classifier.classify(req(&h.secret, &b)).await.unwrap();
assert!(out.active, "an enabled model stays active even below the floor");
assert!(out.stats.is_zero());
assert!(out.pending.is_empty());
}
#[sqlx::test]
async fn disabled_model_is_inactive(pool: PgPool) {
let h = harness(&pool, false, 1500, 1024).await; let b = body();
let out = h.classifier.classify(req(&h.secret, &b)).await.unwrap();
assert!(!out.active, "a disabled model is inactive → response left untouched");
assert!(out.stats.is_zero());
assert!(out.pending.is_empty());
}
#[sqlx::test]
async fn no_markers_is_zero_active(pool: PgPool) {
let h = harness(&pool, true, 1500, 1024).await;
let b = serde_json::json!({
"model": ALIAS,
"messages": [{"role":"user","content":"hi, no markers here"}]
})
.to_string()
.into_bytes();
let out = h.classifier.classify(req(&h.secret, &b)).await.unwrap();
assert!(out.active, "enabled model with no markers still presents zero cache fields");
assert!(out.stats.is_zero());
assert!(out.pending.is_empty());
}
}