use sha2::{Digest, Sha256};
use super::index::{TierPolicy, TtlTier};
pub const MAX_BREAKPOINTS: usize = 4;
pub const WALK_BACK: usize = 20;
#[derive(Debug, thiserror::Error)]
pub enum ParseError {
#[error("request body is not valid JSON: {0}")]
Json(#[from] serde_json::Error),
#[error("too many cache_control breakpoints (max {MAX_BREAKPOINTS})")]
TooManyBreakpoints,
#[error("invalid cache_control ttl: {0:?}")]
InvalidTtl(String),
#[error("unsupported cache_control type: {0:?} (only \"ephemeral\")")]
UnsupportedType(String),
#[error("cache_control ttl tier '{}' is not currently available", .0.as_str())]
DisabledTier(TtlTier),
#[error("cache_control must be an object with a string \"type\": \"ephemeral\" (and an optional string \"ttl\")")]
MalformedCacheControl,
}
#[derive(Debug, Clone)]
pub struct Block {
pub role: String,
pub text: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Breakpoint {
pub block_index: usize,
pub ttl_tier: TtlTier,
}
#[derive(Debug, Clone)]
pub struct ParsedPrompt {
pub blocks: Vec<Block>,
pub cumulative_hashes: Vec<Vec<u8>>,
pub breakpoints: Vec<Breakpoint>,
}
impl ParsedPrompt {
pub fn read_candidates(&self, bp: &Breakpoint) -> Vec<Vec<u8>> {
let i = bp.block_index;
let start = i.saturating_sub(WALK_BACK - 1);
(start..=i).rev().map(|j| self.cumulative_hashes[j].clone()).collect()
}
}
pub fn parse_chat_completions(body: &[u8], policy: &TierPolicy) -> Result<ParsedPrompt, ParseError> {
let v: serde_json::Value = serde_json::from_slice(body)?;
let mut blocks: Vec<Block> = Vec::new();
let mut breakpoints: Vec<Breakpoint> = Vec::new();
let mut cumulative_hashes: Vec<Vec<u8>> = Vec::new();
let mut hasher = Sha256::new();
if let Some(messages) = v.get("messages").and_then(|m| m.as_array()) {
for msg in messages {
let role = msg.get("role").and_then(|r| r.as_str()).unwrap_or("").to_string();
match msg.get("content") {
Some(serde_json::Value::String(s)) => {
let canonical = canonical_block_bytes(&role, &serde_json::json!({ "type": "text", "text": s }));
hasher.update(&canonical);
cumulative_hashes.push(hasher.clone().finalize().to_vec());
blocks.push(Block {
role: role.clone(),
text: s.clone(),
});
}
Some(serde_json::Value::Array(arr)) => {
for block in arr {
let ttl = match block.get("cache_control") {
Some(cc) if !cc.is_null() => Some(parse_ttl(cc, policy)?),
_ => None,
};
let stripped = strip_cache_control(block);
let text = stripped.get("text").and_then(|t| t.as_str()).unwrap_or("").to_string();
let canonical = canonical_block_bytes(&role, &stripped);
hasher.update(&canonical);
cumulative_hashes.push(hasher.clone().finalize().to_vec());
let block_index = blocks.len();
blocks.push(Block { role: role.clone(), text });
if let Some(ttl_tier) = ttl {
breakpoints.push(Breakpoint { block_index, ttl_tier });
}
}
}
_ => {}
}
}
}
if breakpoints.len() > MAX_BREAKPOINTS {
return Err(ParseError::TooManyBreakpoints);
}
Ok(ParsedPrompt {
blocks,
cumulative_hashes,
breakpoints,
})
}
fn parse_ttl(cache_control: &serde_json::Value, policy: &TierPolicy) -> Result<TtlTier, ParseError> {
use serde_json::Value;
if !cache_control.is_object() {
return Err(ParseError::MalformedCacheControl);
}
match cache_control.get("type") {
Some(Value::String(t)) if t == "ephemeral" => {}
Some(Value::String(t)) => return Err(ParseError::UnsupportedType(t.clone())),
_ => return Err(ParseError::MalformedCacheControl),
}
let tier = match cache_control.get("ttl") {
None => policy.default_ttl(),
Some(Value::String(ttl)) => TtlTier::parse(ttl).ok_or_else(|| ParseError::InvalidTtl(ttl.clone()))?,
Some(_) => return Err(ParseError::MalformedCacheControl),
};
if !policy.is_enabled(tier) {
return Err(ParseError::DisabledTier(tier));
}
Ok(tier)
}
pub fn validate_markers(body: &serde_json::Value, policy: &TierPolicy) -> Result<(), ParseError> {
let mut breakpoints = 0usize;
if let Some(messages) = body.get("messages").and_then(|m| m.as_array()) {
for msg in messages {
if let Some(arr) = msg.get("content").and_then(|c| c.as_array()) {
for block in arr {
match block.get("cache_control") {
Some(cc) if !cc.is_null() => {
parse_ttl(cc, policy)?;
breakpoints += 1;
if breakpoints > MAX_BREAKPOINTS {
return Err(ParseError::TooManyBreakpoints);
}
}
_ => {}
}
}
}
}
}
Ok(())
}
fn strip_cache_control(block: &serde_json::Value) -> serde_json::Value {
let mut b = block.clone();
if let Some(obj) = b.as_object_mut() {
obj.remove("cache_control");
}
b
}
fn canonical_block_bytes(role: &str, stripped_block: &serde_json::Value) -> Vec<u8> {
let mut out = Vec::new();
out.extend_from_slice(role.as_bytes());
out.push(0x00);
out.extend_from_slice(&serde_json::to_vec(stripped_block).unwrap_or_default());
out
}
#[cfg(test)]
mod tests {
use super::*;
fn all_tiers() -> TierPolicy {
TierPolicy::from_config(&["5m".to_string(), "1h".to_string(), "24h".to_string()], "5m")
}
fn parse(body: serde_json::Value) -> ParsedPrompt {
parse_chat_completions(body.to_string().as_bytes(), &all_tiers()).unwrap()
}
#[test]
fn no_markers_no_breakpoints() {
let p = parse(serde_json::json!({
"model": "m",
"messages": [
{"role": "system", "content": "you are helpful"},
{"role": "user", "content": "hi"}
]
}));
assert_eq!(p.blocks.len(), 2);
assert_eq!(p.cumulative_hashes.len(), 2);
assert!(p.breakpoints.is_empty());
assert_ne!(p.cumulative_hashes[0], p.cumulative_hashes[1]);
}
#[test]
fn single_marker_with_default_ttl() {
let p = parse(serde_json::json!({
"messages": [
{"role": "system", "content": [
{"type": "text", "text": "long ctx", "cache_control": {"type": "ephemeral"}}
]},
{"role": "user", "content": "q"}
]
}));
assert_eq!(p.breakpoints.len(), 1);
assert_eq!(p.breakpoints[0].block_index, 0);
assert_eq!(p.breakpoints[0].ttl_tier, TtlTier::FiveMinutes);
}
#[test]
fn ttl_tiers_parse() {
for (ttl, tier) in [
("5m", TtlTier::FiveMinutes),
("1h", TtlTier::OneHour),
("24h", TtlTier::TwentyFourHours),
] {
let p = parse(serde_json::json!({
"messages": [{"role": "system", "content": [
{"type": "text", "text": "x", "cache_control": {"type": "ephemeral", "ttl": ttl}}
]}]
}));
assert_eq!(p.breakpoints[0].ttl_tier, tier);
}
}
#[test]
fn invalid_ttl_errors() {
let err = parse_chat_completions(
serde_json::json!({
"messages": [{"role": "system", "content": [
{"type": "text", "text": "x", "cache_control": {"type": "ephemeral", "ttl": "2h"}}
]}]
})
.to_string()
.as_bytes(),
&all_tiers(),
)
.unwrap_err();
assert!(matches!(err, ParseError::InvalidTtl(t) if t == "2h"));
}
#[test]
fn unsupported_cache_control_type_errors() {
let err = parse_chat_completions(
serde_json::json!({
"messages": [{"role": "system", "content": [
{"type": "text", "text": "x", "cache_control": {"type": "persistent"}}
]}]
})
.to_string()
.as_bytes(),
&all_tiers(),
)
.unwrap_err();
assert!(matches!(err, ParseError::UnsupportedType(t) if t == "persistent"));
}
#[test]
fn more_than_four_breakpoints_errors() {
let blocks: Vec<_> = (0..5)
.map(|i| serde_json::json!({"type": "text", "text": format!("b{i}"), "cache_control": {"type": "ephemeral"}}))
.collect();
let err = parse_chat_completions(
serde_json::json!({ "messages": [{"role": "user", "content": blocks}] })
.to_string()
.as_bytes(),
&all_tiers(),
)
.unwrap_err();
assert!(matches!(err, ParseError::TooManyBreakpoints));
}
#[test]
fn disabled_tier_rejected_by_validate_and_parse() {
let policy = TierPolicy::from_config(&["5m".to_string(), "1h".to_string()], "5m");
let body = serde_json::json!({
"messages": [{"role": "system", "content": [
{"type": "text", "text": "x", "cache_control": {"type": "ephemeral", "ttl": "24h"}}
]}]
});
let err = validate_markers(&body, &policy).unwrap_err();
assert!(matches!(err, ParseError::DisabledTier(TtlTier::TwentyFourHours)));
let err2 = parse_chat_completions(body.to_string().as_bytes(), &policy).unwrap_err();
assert!(matches!(err2, ParseError::DisabledTier(TtlTier::TwentyFourHours)));
}
#[test]
fn validate_markers_default_ttl_honours_policy() {
let policy = TierPolicy::from_config(&["1h".to_string()], "1h");
let body = serde_json::json!({
"messages": [{"role": "system", "content": [
{"type": "text", "text": "x", "cache_control": {"type": "ephemeral"}}
]}]
});
assert!(validate_markers(&body, &policy).is_ok());
let p = parse_chat_completions(body.to_string().as_bytes(), &policy).unwrap();
assert_eq!(p.breakpoints[0].ttl_tier, TtlTier::OneHour);
}
#[test]
fn validate_markers_ok_and_counts_breakpoints() {
let ok = serde_json::json!({
"messages": [{"role": "user", "content": [
{"type": "text", "text": "a", "cache_control": {"type": "ephemeral", "ttl": "1h"}},
{"type": "text", "text": "q"}
]}]
});
assert!(validate_markers(&ok, &all_tiers()).is_ok());
let blocks: Vec<_> = (0..5)
.map(|i| serde_json::json!({"type": "text", "text": format!("b{i}"), "cache_control": {"type": "ephemeral"}}))
.collect();
let too_many = serde_json::json!({ "messages": [{"role": "user", "content": blocks}] });
assert!(matches!(
validate_markers(&too_many, &all_tiers()).unwrap_err(),
ParseError::TooManyBreakpoints
));
}
#[test]
fn non_object_cache_control_is_malformed() {
let body = serde_json::json!({
"messages": [{"role": "system", "content": [
{"type": "text", "text": "x", "cache_control": "persistent"}
]}]
});
assert!(matches!(
validate_markers(&body, &all_tiers()).unwrap_err(),
ParseError::MalformedCacheControl
));
assert!(matches!(
parse_chat_completions(body.to_string().as_bytes(), &all_tiers()).unwrap_err(),
ParseError::MalformedCacheControl
));
}
#[test]
fn non_string_type_or_ttl_is_malformed() {
let bad_ttl = serde_json::json!({
"messages": [{"role": "system", "content": [
{"type": "text", "text": "x", "cache_control": {"type": "ephemeral", "ttl": 123}}
]}]
});
assert!(matches!(
validate_markers(&bad_ttl, &all_tiers()).unwrap_err(),
ParseError::MalformedCacheControl
));
let bad_type = serde_json::json!({
"messages": [{"role": "system", "content": [
{"type": "text", "text": "x", "cache_control": {"type": true}}
]}]
});
assert!(matches!(
validate_markers(&bad_type, &all_tiers()).unwrap_err(),
ParseError::MalformedCacheControl
));
}
#[test]
fn missing_type_is_malformed() {
let no_type = serde_json::json!({
"messages": [{"role": "system", "content": [
{"type": "text", "text": "x", "cache_control": {"ttl": "1h"}}
]}]
});
assert!(matches!(
validate_markers(&no_type, &all_tiers()).unwrap_err(),
ParseError::MalformedCacheControl
));
let empty = serde_json::json!({
"messages": [{"role": "system", "content": [
{"type": "text", "text": "x", "cache_control": {}}
]}]
});
assert!(matches!(
parse_chat_completions(empty.to_string().as_bytes(), &all_tiers()).unwrap_err(),
ParseError::MalformedCacheControl
));
}
#[test]
fn null_cache_control_is_no_marker() {
let body = serde_json::json!({
"messages": [{"role": "system", "content": [
{"type": "text", "text": "x", "cache_control": null}
]}]
});
assert!(validate_markers(&body, &all_tiers()).is_ok());
let p = parse_chat_completions(body.to_string().as_bytes(), &all_tiers()).unwrap();
assert!(p.breakpoints.is_empty());
}
#[test]
fn cache_control_excluded_from_hash() {
let marked = parse(serde_json::json!({
"messages": [{"role": "system", "content": [
{"type": "text", "text": "shared prefix", "cache_control": {"type": "ephemeral", "ttl": "1h"}}
]}]
}));
let unmarked = parse(serde_json::json!({
"messages": [{"role": "system", "content": [
{"type": "text", "text": "shared prefix"}
]}]
}));
assert_eq!(marked.cumulative_hashes[0], unmarked.cumulative_hashes[0]);
let other = parse(serde_json::json!({
"messages": [{"role": "system", "content": [{"type": "text", "text": "different"}]}]
}));
assert_ne!(marked.cumulative_hashes[0], other.cumulative_hashes[0]);
}
#[test]
fn walk_back_candidates_longest_first_bounded() {
let blocks: Vec<_> = (0..25)
.map(|i| {
let mut b = serde_json::json!({"type": "text", "text": format!("b{i}")});
if i == 24 {
b["cache_control"] = serde_json::json!({"type": "ephemeral"});
}
b
})
.collect();
let p = parse(serde_json::json!({ "messages": [{"role": "user", "content": blocks}] }));
let cands = p.read_candidates(&p.breakpoints[0]);
assert_eq!(cands.len(), WALK_BACK);
assert_eq!(cands[0], p.cumulative_hashes[24]); assert_eq!(cands[1], p.cumulative_hashes[23]);
assert_eq!(cands[WALK_BACK - 1], p.cumulative_hashes[5]); }
#[test]
fn deterministic() {
let body = serde_json::json!({
"messages": [{"role": "system", "content": [{"type": "text", "text": "abc"}]}]
});
assert_eq!(parse(body.clone()).cumulative_hashes, parse(body).cumulative_hashes);
}
}