use std::borrow::Cow;
use std::sync::LazyLock;
#[allow(unused_imports)]
use zeph_db::sql;
use regex::Regex;
use crate::error::MemoryError;
use crate::store::SqliteStore;
use crate::types::ConversationId;
static SECRET_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(
r#"(?:sk-|sk_live_|sk_test_|AKIA|ghp_|gho_|-----BEGIN|xoxb-|xoxp-|AIza|ya29\.|glpat-|hf_|npm_|dckr_pat_)[^\s"'`,;\{\}\[\]]*"#,
)
.expect("secret regex")
});
static PATH_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r#"(?:/home/|/Users/|/root/|/tmp/|/var/)[^\s"'`,;\{\}\[\]]*"#).expect("path regex")
});
static BEARER_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"(?i)(Authorization:\s*Bearer\s+)\S+").expect("bearer regex"));
static JWT_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"eyJ[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]*").expect("jwt regex")
});
pub(crate) fn redact_sensitive(text: &str) -> Cow<'_, str> {
let s0: Cow<'_, str> = SECRET_RE.replace_all(text, "[REDACTED]");
let s1: Cow<'_, str> = match PATH_RE.replace_all(s0.as_ref(), "[PATH]") {
Cow::Borrowed(_) => s0,
Cow::Owned(o) => Cow::Owned(o),
};
let s2: Cow<'_, str> = match BEARER_RE.replace_all(s1.as_ref(), "${1}[REDACTED]") {
Cow::Borrowed(_) => s1,
Cow::Owned(o) => Cow::Owned(o),
};
match JWT_RE.replace_all(s2.as_ref(), "[REDACTED_JWT]") {
Cow::Borrowed(_) => s2,
Cow::Owned(o) => Cow::Owned(o),
}
}
#[derive(Debug, Clone)]
pub struct CompressionFailurePair {
pub id: i64,
pub conversation_id: ConversationId,
pub compressed_context: String,
pub failure_reason: String,
pub category: String,
pub created_at: String,
}
const MAX_FIELD_CHARS: usize = 4096;
fn truncate_field(s: &str) -> &str {
let mut idx = MAX_FIELD_CHARS;
while idx > 0 && !s.is_char_boundary(idx) {
idx -= 1;
}
&s[..idx.min(s.len())]
}
impl SqliteStore {
pub async fn load_compression_guidelines(
&self,
conversation_id: Option<ConversationId>,
) -> Result<(i64, String), MemoryError> {
let row = zeph_db::query_as::<_, (i64, String)>(sql!(
"SELECT version, guidelines FROM compression_guidelines \
WHERE conversation_id = ? OR conversation_id IS NULL \
ORDER BY CASE WHEN conversation_id IS NOT NULL THEN 0 ELSE 1 END, \
version DESC \
LIMIT 1"
))
.bind(conversation_id.map(|c| c.0))
.fetch_optional(&self.pool)
.await?;
Ok(row.unwrap_or((0, String::new())))
}
pub async fn load_compression_guidelines_meta(
&self,
conversation_id: Option<ConversationId>,
) -> Result<(i64, String), MemoryError> {
let row = zeph_db::query_as::<_, (i64, String)>(sql!(
"SELECT version, created_at FROM compression_guidelines \
WHERE conversation_id = ? OR conversation_id IS NULL \
ORDER BY CASE WHEN conversation_id IS NOT NULL THEN 0 ELSE 1 END, \
version DESC \
LIMIT 1"
))
.bind(conversation_id.map(|c| c.0)) .fetch_optional(&self.pool)
.await?;
Ok(row.unwrap_or((0, String::new())))
}
pub async fn save_compression_guidelines(
&self,
guidelines: &str,
token_count: i64,
conversation_id: Option<ConversationId>,
) -> Result<i64, MemoryError> {
let new_version: i64 = zeph_db::query_scalar(
sql!("INSERT INTO compression_guidelines (version, guidelines, token_count, conversation_id) \
SELECT COALESCE(MAX(version), 0) + 1, ?, ?, ? \
FROM compression_guidelines \
RETURNING version"),
)
.bind(guidelines)
.bind(token_count)
.bind(conversation_id.map(|c| c.0))
.fetch_one(&self.pool)
.await?;
Ok(new_version)
}
pub async fn log_compression_failure(
&self,
conversation_id: ConversationId,
compressed_context: &str,
failure_reason: &str,
category: &str,
) -> Result<i64, MemoryError> {
let ctx = redact_sensitive(compressed_context);
let ctx = truncate_field(&ctx);
let reason = redact_sensitive(failure_reason);
let reason = truncate_field(&reason);
let id = zeph_db::query_scalar(sql!(
"INSERT INTO compression_failure_pairs \
(conversation_id, compressed_context, failure_reason, category) \
VALUES (?, ?, ?, ?) RETURNING id"
))
.bind(conversation_id.0)
.bind(ctx)
.bind(reason)
.bind(category)
.fetch_one(&self.pool)
.await?;
Ok(id)
}
pub async fn get_unused_failure_pairs(
&self,
limit: usize,
) -> Result<Vec<CompressionFailurePair>, MemoryError> {
let limit = i64::try_from(limit).unwrap_or(i64::MAX);
let rows = zeph_db::query_as::<_, (i64, i64, String, String, String, String)>(sql!(
"SELECT id, conversation_id, compressed_context, failure_reason, category, created_at \
FROM compression_failure_pairs \
WHERE used_in_update = 0 \
ORDER BY created_at ASC \
LIMIT ?"
))
.bind(limit)
.fetch_all(&self.pool)
.await?;
Ok(rows
.into_iter()
.map(
|(id, cid, ctx, reason, category, created_at)| CompressionFailurePair {
id,
conversation_id: ConversationId(cid),
compressed_context: ctx,
failure_reason: reason,
category,
created_at,
},
)
.collect())
}
pub async fn get_unused_failure_pairs_by_category(
&self,
category: &str,
limit: usize,
) -> Result<Vec<CompressionFailurePair>, MemoryError> {
let limit = i64::try_from(limit).unwrap_or(i64::MAX);
let rows = zeph_db::query_as::<_, (i64, i64, String, String, String, String)>(sql!(
"SELECT id, conversation_id, compressed_context, failure_reason, category, created_at \
FROM compression_failure_pairs \
WHERE used_in_update = 0 AND category = ? \
ORDER BY created_at ASC \
LIMIT ?"
))
.bind(category)
.bind(limit)
.fetch_all(&self.pool)
.await?;
Ok(rows
.into_iter()
.map(
|(id, cid, ctx, reason, cat, created_at)| CompressionFailurePair {
id,
conversation_id: ConversationId(cid),
compressed_context: ctx,
failure_reason: reason,
category: cat,
created_at,
},
)
.collect())
}
pub async fn count_unused_failure_pairs_by_category(
&self,
category: &str,
) -> Result<i64, MemoryError> {
let count = zeph_db::query_scalar(sql!(
"SELECT COUNT(*) FROM compression_failure_pairs \
WHERE used_in_update = 0 AND category = ?"
))
.bind(category)
.fetch_one(&self.pool)
.await?;
Ok(count)
}
pub async fn load_compression_guidelines_by_category(
&self,
category: &str,
conversation_id: Option<ConversationId>,
) -> Result<(i64, String), MemoryError> {
let row = zeph_db::query_as::<_, (i64, String)>(sql!(
"SELECT version, guidelines FROM compression_guidelines \
WHERE category = ? \
AND (conversation_id = ? OR conversation_id IS NULL) \
ORDER BY CASE WHEN conversation_id IS NOT NULL THEN 0 ELSE 1 END, \
version DESC \
LIMIT 1"
))
.bind(category)
.bind(conversation_id.map(|c| c.0))
.fetch_optional(&self.pool)
.await?;
Ok(row.unwrap_or((0, String::new())))
}
pub async fn save_compression_guidelines_with_category(
&self,
guidelines: &str,
token_count: i64,
category: &str,
conversation_id: Option<ConversationId>,
) -> Result<i64, MemoryError> {
let new_version: i64 = zeph_db::query_scalar(sql!(
"INSERT INTO compression_guidelines \
(version, category, guidelines, token_count, conversation_id) \
SELECT COALESCE(MAX(version), 0) + 1, ?, ?, ?, ? \
FROM compression_guidelines \
RETURNING version"
))
.bind(category)
.bind(guidelines)
.bind(token_count)
.bind(conversation_id.map(|c| c.0))
.fetch_one(&self.pool)
.await?;
Ok(new_version)
}
pub async fn mark_failure_pairs_used(&self, ids: &[i64]) -> Result<(), MemoryError> {
if ids.is_empty() {
return Ok(());
}
let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
let query = format!(
"UPDATE compression_failure_pairs SET used_in_update = 1 WHERE id IN ({placeholders})"
);
let mut q = zeph_db::query(&query);
for id in ids {
q = q.bind(id);
}
q.execute(&self.pool).await?;
Ok(())
}
pub async fn count_unused_failure_pairs(&self) -> Result<i64, MemoryError> {
let count = zeph_db::query_scalar(sql!(
"SELECT COUNT(*) FROM compression_failure_pairs WHERE used_in_update = 0"
))
.fetch_one(&self.pool)
.await?;
Ok(count)
}
pub async fn cleanup_old_failure_pairs(&self, keep_recent: usize) -> Result<(), MemoryError> {
zeph_db::query(sql!(
"DELETE FROM compression_failure_pairs WHERE used_in_update = 1"
))
.execute(&self.pool)
.await?;
let keep = i64::try_from(keep_recent).unwrap_or(i64::MAX);
zeph_db::query(sql!(
"DELETE FROM compression_failure_pairs \
WHERE used_in_update = 0 \
AND id NOT IN ( \
SELECT id FROM compression_failure_pairs \
WHERE used_in_update = 0 \
ORDER BY created_at DESC \
LIMIT ? \
)"
))
.bind(keep)
.execute(&self.pool)
.await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
async fn make_store() -> SqliteStore {
SqliteStore::with_pool_size(":memory:", 1)
.await
.expect("in-memory SqliteStore")
}
#[tokio::test]
async fn load_guidelines_meta_returns_defaults_when_empty() {
let store = make_store().await;
let (version, created_at) = store.load_compression_guidelines_meta(None).await.unwrap();
assert_eq!(version, 0);
assert!(created_at.is_empty());
}
#[tokio::test]
async fn load_guidelines_meta_returns_version_and_created_at() {
let store = make_store().await;
store
.save_compression_guidelines("keep file paths", 4, None)
.await
.unwrap();
let (version, created_at) = store.load_compression_guidelines_meta(None).await.unwrap();
assert_eq!(version, 1);
assert!(!created_at.is_empty(), "created_at should be populated");
}
#[tokio::test]
async fn load_guidelines_returns_defaults_when_empty() {
let store = make_store().await;
let (version, text) = store.load_compression_guidelines(None).await.unwrap();
assert_eq!(version, 0);
assert!(text.is_empty());
}
#[tokio::test]
async fn save_and_load_guidelines() {
let store = make_store().await;
let v1 = store
.save_compression_guidelines("always preserve file paths", 4, None)
.await
.unwrap();
assert_eq!(v1, 1);
let v2 = store
.save_compression_guidelines(
"always preserve file paths\nalways preserve errors",
8,
None,
)
.await
.unwrap();
assert_eq!(v2, 2);
let (v, text) = store.load_compression_guidelines(None).await.unwrap();
assert_eq!(v, 2);
assert!(text.contains("errors"));
}
#[tokio::test]
async fn load_guidelines_prefers_conversation_specific() {
let store = make_store().await;
let cid = ConversationId(store.create_conversation().await.unwrap().0);
store
.save_compression_guidelines("global rule", 2, None)
.await
.unwrap();
store
.save_compression_guidelines("conversation rule", 2, Some(cid))
.await
.unwrap();
let (_, text) = store.load_compression_guidelines(Some(cid)).await.unwrap();
assert_eq!(text, "conversation rule");
}
#[tokio::test]
async fn load_guidelines_falls_back_to_global() {
let store = make_store().await;
let cid = ConversationId(store.create_conversation().await.unwrap().0);
store
.save_compression_guidelines("global rule", 2, None)
.await
.unwrap();
let (_, text) = store.load_compression_guidelines(Some(cid)).await.unwrap();
assert_eq!(text, "global rule");
}
#[tokio::test]
async fn load_guidelines_none_returns_global_only() {
let store = make_store().await;
let cid = ConversationId(store.create_conversation().await.unwrap().0);
store
.save_compression_guidelines("conversation rule", 2, Some(cid))
.await
.unwrap();
let (version, text) = store.load_compression_guidelines(None).await.unwrap();
assert_eq!(version, 0);
assert!(text.is_empty());
}
#[tokio::test]
async fn load_guidelines_scope_isolation() {
let store = make_store().await;
let cid_a = ConversationId(store.create_conversation().await.unwrap().0);
let cid_b = ConversationId(store.create_conversation().await.unwrap().0);
store
.save_compression_guidelines("Use bullet points", 1, None)
.await
.unwrap();
store
.save_compression_guidelines("Be concise", 2, Some(cid_a))
.await
.unwrap();
let (_, text_b) = store
.load_compression_guidelines(Some(cid_b))
.await
.unwrap();
assert_eq!(
text_b, "Use bullet points",
"conversation B must see global guideline"
);
let (_, text_a) = store
.load_compression_guidelines(Some(cid_a))
.await
.unwrap();
assert_eq!(
text_a, "Be concise",
"conversation A must prefer its own guideline over global"
);
let (_, text_global) = store.load_compression_guidelines(None).await.unwrap();
assert_eq!(
text_global, "Use bullet points",
"None scope must see only the global guideline"
);
}
#[tokio::test]
async fn save_with_nonexistent_conversation_id_fails() {
let store = make_store().await;
let nonexistent = ConversationId(99999);
let result = store
.save_compression_guidelines("rule", 1, Some(nonexistent))
.await;
assert!(
result.is_err(),
"FK violation expected for nonexistent conversation_id"
);
}
#[tokio::test]
async fn cascade_delete_removes_conversation_guidelines() {
let store = make_store().await;
let cid = ConversationId(store.create_conversation().await.unwrap().0);
store
.save_compression_guidelines("rule", 1, Some(cid))
.await
.unwrap();
zeph_db::query(sql!("DELETE FROM conversations WHERE id = ?"))
.bind(cid.0)
.execute(store.pool())
.await
.unwrap();
let (version, text) = store.load_compression_guidelines(Some(cid)).await.unwrap();
assert_eq!(version, 0);
assert!(text.is_empty());
}
#[tokio::test]
async fn log_and_count_failure_pairs() {
let store = make_store().await;
let cid = ConversationId(store.create_conversation().await.unwrap().0);
store
.log_compression_failure(cid, "compressed ctx", "i don't recall that", "unknown")
.await
.unwrap();
let count = store.count_unused_failure_pairs().await.unwrap();
assert_eq!(count, 1);
}
#[tokio::test]
async fn get_unused_pairs_sorted_oldest_first() {
let store = make_store().await;
let cid = ConversationId(store.create_conversation().await.unwrap().0);
store
.log_compression_failure(cid, "ctx A", "reason A", "unknown")
.await
.unwrap();
store
.log_compression_failure(cid, "ctx B", "reason B", "unknown")
.await
.unwrap();
let pairs = store.get_unused_failure_pairs(10).await.unwrap();
assert_eq!(pairs.len(), 2);
assert_eq!(pairs[0].compressed_context, "ctx A");
}
#[tokio::test]
async fn mark_pairs_used_reduces_count() {
let store = make_store().await;
let cid = ConversationId(store.create_conversation().await.unwrap().0);
let id = store
.log_compression_failure(cid, "ctx", "reason", "unknown")
.await
.unwrap();
store.mark_failure_pairs_used(&[id]).await.unwrap();
let count = store.count_unused_failure_pairs().await.unwrap();
assert_eq!(count, 0);
}
#[tokio::test]
async fn cleanup_deletes_used_and_trims_unused() {
let store = make_store().await;
let cid = ConversationId(store.create_conversation().await.unwrap().0);
let id1 = store
.log_compression_failure(cid, "ctx1", "r1", "tool_output")
.await
.unwrap();
store
.log_compression_failure(cid, "ctx2", "r2", "tool_output")
.await
.unwrap();
store
.log_compression_failure(cid, "ctx3", "r3", "unknown")
.await
.unwrap();
store.mark_failure_pairs_used(&[id1]).await.unwrap();
store.cleanup_old_failure_pairs(1).await.unwrap();
let count = store.count_unused_failure_pairs().await.unwrap();
assert_eq!(count, 1, "only 1 unused pair should remain");
}
#[test]
fn redact_sensitive_api_key_is_redacted() {
let result = redact_sensitive("token sk-abc123def456 used for auth");
assert!(result.contains("[REDACTED]"), "API key must be redacted");
assert!(
!result.contains("sk-abc123"),
"original key must not appear"
);
}
#[test]
fn redact_sensitive_plain_text_borrows() {
let text = "safe text, no secrets here";
let result = redact_sensitive(text);
assert!(
matches!(result, Cow::Borrowed(_)),
"plain text must return Cow::Borrowed (zero-alloc)"
);
}
#[test]
fn redact_sensitive_filesystem_path_is_redacted() {
let result = redact_sensitive("config loaded from /Users/dev/project/config.toml");
assert!(
result.contains("[PATH]"),
"filesystem path must be redacted"
);
assert!(
!result.contains("/Users/dev/"),
"original path must not appear"
);
}
#[test]
fn redact_sensitive_combined_secret_and_path() {
let result = redact_sensitive("key sk-abc at /home/user/file");
assert!(result.contains("[REDACTED]"), "secret must be redacted");
assert!(result.contains("[PATH]"), "path must be redacted");
}
#[tokio::test]
async fn log_compression_failure_redacts_secrets() {
let store = make_store().await;
let cid = ConversationId(store.create_conversation().await.unwrap().0);
store
.log_compression_failure(
cid,
"token sk-abc123def456 used for auth",
"context lost",
"unknown",
)
.await
.unwrap();
let pairs = store.get_unused_failure_pairs(10).await.unwrap();
assert_eq!(pairs.len(), 1);
assert!(
pairs[0].compressed_context.contains("[REDACTED]"),
"stored context must have redacted secret"
);
assert!(
!pairs[0].compressed_context.contains("sk-abc123"),
"stored context must not contain raw secret"
);
}
#[tokio::test]
async fn log_compression_failure_redacts_paths() {
let store = make_store().await;
let cid = ConversationId(store.create_conversation().await.unwrap().0);
store
.log_compression_failure(
cid,
"/Users/dev/project/config.toml was loaded",
"lost",
"unknown",
)
.await
.unwrap();
let pairs = store.get_unused_failure_pairs(10).await.unwrap();
assert!(
pairs[0].compressed_context.contains("[PATH]"),
"stored context must have redacted path"
);
assert!(
!pairs[0].compressed_context.contains("/Users/dev/"),
"stored context must not contain raw path"
);
}
#[tokio::test]
async fn log_compression_failure_reason_also_redacted() {
let store = make_store().await;
let cid = ConversationId(store.create_conversation().await.unwrap().0);
store
.log_compression_failure(
cid,
"some context",
"secret ghp_abc123xyz was leaked",
"unknown",
)
.await
.unwrap();
let pairs = store.get_unused_failure_pairs(10).await.unwrap();
assert!(
pairs[0].failure_reason.contains("[REDACTED]"),
"failure_reason must also be redacted"
);
assert!(
!pairs[0].failure_reason.contains("ghp_abc123xyz"),
"raw secret must not appear in failure_reason"
);
}
#[tokio::test]
async fn truncate_field_respects_char_boundary() {
let s = "а".repeat(5000); let truncated = truncate_field(&s);
assert!(truncated.len() <= MAX_FIELD_CHARS);
assert!(s.is_char_boundary(truncated.len()));
}
#[tokio::test]
async fn unique_constraint_prevents_duplicate_version() {
let store = make_store().await;
store
.save_compression_guidelines("first", 1, None)
.await
.unwrap();
let result = zeph_db::query(
sql!("INSERT INTO compression_guidelines (version, guidelines, token_count) VALUES (1, 'dup', 0)"),
)
.execute(store.pool())
.await;
assert!(
result.is_err(),
"duplicate version insert should violate UNIQUE constraint"
);
}
#[test]
fn redact_sensitive_bearer_token_is_redacted() {
let result =
redact_sensitive("Authorization: Bearer eyJhbGciOiJSUzI1NiJ9.payload.signature");
assert!(
result.contains("[REDACTED]"),
"Bearer token must be redacted: {result}"
);
assert!(
!result.contains("eyJhbGciOiJSUzI1NiJ9"),
"raw JWT header must not appear: {result}"
);
assert!(
result.contains("Authorization:"),
"header name must be preserved: {result}"
);
}
#[test]
fn redact_sensitive_bearer_token_case_insensitive() {
let result =
redact_sensitive("authorization: bearer eyJhbGciOiJSUzI1NiJ9.payload.signature");
assert!(
result.contains("[REDACTED]"),
"Bearer header match must be case-insensitive: {result}"
);
}
#[test]
fn redact_sensitive_standalone_jwt_is_redacted() {
let jwt = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyMTIzIn0.SflKxwRJSMeKKF2";
let input = format!("token value: {jwt} was found in logs");
let result = redact_sensitive(&input);
assert!(
result.contains("[REDACTED_JWT]"),
"standalone JWT must be replaced with [REDACTED_JWT]: {result}"
);
assert!(
!result.contains("eyJhbGci"),
"raw JWT must not appear: {result}"
);
}
#[test]
fn redact_sensitive_mixed_content_all_redacted() {
let input =
"key sk-abc123 at /home/user/f with Authorization: Bearer eyJhbG.pay.sig and eyJx.b.c";
let result = redact_sensitive(input);
assert!(result.contains("[REDACTED]"), "API key must be redacted");
assert!(result.contains("[PATH]"), "path must be redacted");
assert!(!result.contains("sk-abc123"), "raw API key must not appear");
assert!(!result.contains("eyJhbG"), "raw JWT must not appear");
}
#[test]
fn redact_sensitive_partial_jwt_not_redacted() {
let input = "eyJhbGciOiJSUzI1NiJ9.onlytwoparts";
let result = redact_sensitive(input);
assert!(
!result.contains("[REDACTED_JWT]"),
"two-part eyJ string must not be treated as JWT: {result}"
);
assert!(
matches!(result, Cow::Borrowed(_)),
"no-match input must return Cow::Borrowed: {result}"
);
}
#[test]
fn redact_sensitive_alg_none_jwt_empty_signature_redacted() {
let input =
"token: eyJhbGciOiJub25lIn0.eyJzdWIiOiJ1c2VyIn0. was submitted without signature";
let result = redact_sensitive(input);
assert!(
result.contains("[REDACTED_JWT]"),
"alg=none JWT with empty signature must be redacted: {result}"
);
assert!(
!result.contains("eyJhbGciOiJub25lIn0"),
"raw alg=none JWT header must not appear: {result}"
);
}
#[tokio::test]
async fn get_unused_pairs_by_category_filters_correctly() {
let store = make_store().await;
let cid = ConversationId(store.create_conversation().await.unwrap().0);
store
.log_compression_failure(cid, "tool ctx", "lost tool output", "tool_output")
.await
.unwrap();
store
.log_compression_failure(cid, "user ctx", "lost user context", "user_context")
.await
.unwrap();
let tool_pairs = store
.get_unused_failure_pairs_by_category("tool_output", 10)
.await
.unwrap();
assert_eq!(tool_pairs.len(), 1);
assert_eq!(tool_pairs[0].category, "tool_output");
assert_eq!(tool_pairs[0].compressed_context, "tool ctx");
let user_pairs = store
.get_unused_failure_pairs_by_category("user_context", 10)
.await
.unwrap();
assert_eq!(user_pairs.len(), 1);
assert_eq!(user_pairs[0].category, "user_context");
let unknown_pairs = store
.get_unused_failure_pairs_by_category("assistant_reasoning", 10)
.await
.unwrap();
assert!(unknown_pairs.is_empty());
}
#[tokio::test]
async fn count_unused_pairs_by_category_returns_correct_count() {
let store = make_store().await;
let cid = ConversationId(store.create_conversation().await.unwrap().0);
store
.log_compression_failure(cid, "ctx A", "reason", "tool_output")
.await
.unwrap();
store
.log_compression_failure(cid, "ctx B", "reason", "tool_output")
.await
.unwrap();
store
.log_compression_failure(cid, "ctx C", "reason", "user_context")
.await
.unwrap();
let tool_count = store
.count_unused_failure_pairs_by_category("tool_output")
.await
.unwrap();
assert_eq!(tool_count, 2);
let user_count = store
.count_unused_failure_pairs_by_category("user_context")
.await
.unwrap();
assert_eq!(user_count, 1);
let unknown_count = store
.count_unused_failure_pairs_by_category("assistant_reasoning")
.await
.unwrap();
assert_eq!(unknown_count, 0);
}
#[tokio::test]
async fn save_and_load_guidelines_by_category() {
let store = make_store().await;
store
.save_compression_guidelines_with_category(
"preserve tool names",
3,
"tool_output",
None,
)
.await
.unwrap();
store
.save_compression_guidelines_with_category("keep user intent", 3, "user_context", None)
.await
.unwrap();
let (_, tool_text) = store
.load_compression_guidelines_by_category("tool_output", None)
.await
.unwrap();
assert_eq!(tool_text, "preserve tool names");
let (_, user_text) = store
.load_compression_guidelines_by_category("user_context", None)
.await
.unwrap();
assert_eq!(user_text, "keep user intent");
}
#[tokio::test]
async fn load_guidelines_by_category_returns_defaults_when_empty() {
let store = make_store().await;
let (version, text) = store
.load_compression_guidelines_by_category("tool_output", None)
.await
.unwrap();
assert_eq!(version, 0, "version must be 0 when no entries exist");
assert!(text.is_empty(), "text must be empty when no entries exist");
}
#[tokio::test]
async fn concurrent_saves_produce_unique_versions() {
use std::collections::HashSet;
use std::sync::Arc;
let dir = tempfile::tempdir().expect("tempdir");
let db_path = dir.path().join("test.db");
let store = Arc::new(
SqliteStore::with_pool_size(db_path.to_str().expect("utf8 path"), 4)
.await
.expect("file-backed SqliteStore"),
);
let tasks: Vec<_> = (0..8_i64)
.map(|i| {
let s = Arc::clone(&store);
tokio::spawn(async move {
s.save_compression_guidelines(&format!("guideline {i}"), i, None)
.await
.expect("concurrent save must succeed")
})
})
.collect();
let mut versions = HashSet::new();
for task in tasks {
let v = task.await.expect("task must not panic");
assert!(versions.insert(v), "version {v} appeared more than once");
}
assert_eq!(
versions.len(),
8,
"all 8 saves must produce distinct versions"
);
}
}