use std::time::Duration;
use tokio::sync::mpsc;
use tracing::Instrument as _;
use crate::store::SqliteStore;
use crate::store::retrieval_failures::RetrievalFailureRecord;
const QUERY_TEXT_MAX_CHARS: usize = 512;
const ERROR_CONTEXT_MAX_CHARS: usize = 256;
const CLEANUP_FLUSH_INTERVAL: u32 = 500;
pub struct RetrievalFailureLogger {
tx: Option<mpsc::Sender<RetrievalFailureRecord>>,
handle: Option<tokio::task::JoinHandle<()>>,
}
impl RetrievalFailureLogger {
#[must_use]
pub fn new(
sqlite: SqliteStore,
channel_capacity: usize,
batch_size: usize,
flush_interval: Duration,
retention_days: u32,
) -> Self {
let (tx, rx) = mpsc::channel(channel_capacity);
let handle = tokio::spawn(writer_task(
sqlite,
rx,
batch_size,
flush_interval,
retention_days,
));
Self {
tx: Some(tx),
handle: Some(handle),
}
}
pub fn log(&self, mut record: RetrievalFailureRecord) {
let _span = tracing::debug_span!("memory.retrieval_failure.log").entered();
if record.query_text.chars().count() > QUERY_TEXT_MAX_CHARS {
record.query_text = record
.query_text
.chars()
.take(QUERY_TEXT_MAX_CHARS)
.collect();
}
if let Some(ref mut ctx) = record.error_context
&& ctx.chars().count() > ERROR_CONTEXT_MAX_CHARS
{
*ctx = ctx.chars().take(ERROR_CONTEXT_MAX_CHARS).collect();
}
if let Some(tx) = &self.tx
&& tx.try_send(record).is_err()
{
tracing::debug!("retrieval_failure_logger: channel full, dropping record");
}
}
pub async fn shutdown(mut self) {
drop(self.tx.take());
if let Some(handle) = self.handle.take() {
let _ = handle.await;
}
}
}
impl Drop for RetrievalFailureLogger {
fn drop(&mut self) {
if let Some(handle) = &self.handle {
handle.abort();
}
}
}
async fn writer_task(
sqlite: SqliteStore,
mut rx: mpsc::Receiver<RetrievalFailureRecord>,
batch_size: usize,
flush_interval: Duration,
retention_days: u32,
) {
let mut batch: Vec<RetrievalFailureRecord> = Vec::with_capacity(batch_size);
let mut flush_counter: u32 = 0;
loop {
let deadline = tokio::time::sleep(flush_interval);
tokio::pin!(deadline);
loop {
tokio::select! {
biased;
msg = rx.recv() => {
if let Some(record) = msg {
batch.push(record);
if batch.len() >= batch_size {
break;
}
} else {
flush_batch(&sqlite, &mut batch, &mut flush_counter, retention_days).await;
return;
}
}
() = &mut deadline => break,
}
}
flush_batch(&sqlite, &mut batch, &mut flush_counter, retention_days).await;
}
}
async fn flush_batch(
sqlite: &SqliteStore,
batch: &mut Vec<RetrievalFailureRecord>,
flush_counter: &mut u32,
retention_days: u32,
) {
if batch.is_empty() {
return;
}
let count = batch.len();
tracing::debug!(count, "retrieval_failure_logger: flushing batch");
let span = tracing::info_span!("memory.retrieval_failure.flush", count);
let result = sqlite
.record_retrieval_failures_batch(batch)
.instrument(span)
.await;
if let Err(e) = result {
tracing::warn!("retrieval_failure_logger: batch write failed: {e:#}");
}
batch.clear();
*flush_counter = flush_counter.wrapping_add(1);
if (*flush_counter).is_multiple_of(CLEANUP_FLUSH_INTERVAL)
&& let Err(e) = sqlite.purge_old_retrieval_failures(retention_days).await
{
tracing::debug!("retrieval_failure_logger: cleanup failed: {e:#}");
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
use crate::store::SqliteStore;
use crate::store::retrieval_failures::{RetrievalFailureRecord, RetrievalFailureType};
fn no_hit_record() -> RetrievalFailureRecord {
RetrievalFailureRecord {
conversation_id: None,
turn_index: 0,
failure_type: RetrievalFailureType::NoHit,
retrieval_strategy: "semantic".into(),
query_text: "hello world".into(),
query_len: 11,
top_score: None,
confidence_threshold: None,
result_count: 0,
latency_ms: 5,
edge_types: None,
error_context: None,
}
}
fn low_confidence_record(score: f32, threshold: f32) -> RetrievalFailureRecord {
RetrievalFailureRecord {
conversation_id: None,
turn_index: 0,
failure_type: RetrievalFailureType::LowConfidence,
retrieval_strategy: "semantic".into(),
query_text: "low confidence query".into(),
query_len: 20,
top_score: Some(score),
confidence_threshold: Some(threshold),
result_count: 3,
latency_ms: 10,
edge_types: None,
error_context: None,
}
}
#[tokio::test]
async fn no_hit_failure_is_persisted() {
let sqlite = SqliteStore::new(":memory:").await.unwrap();
let logger =
RetrievalFailureLogger::new(sqlite.clone(), 256, 16, Duration::from_millis(10), 90);
logger.log(no_hit_record());
logger.shutdown().await;
let rows: Vec<(String,)> = sqlx::query_as(
"SELECT failure_type FROM memory_retrieval_failures WHERE failure_type = 'no_hit'",
)
.fetch_all(sqlite.pool())
.await
.unwrap();
assert_eq!(rows.len(), 1, "no_hit record must be persisted");
}
#[tokio::test]
async fn low_confidence_failure_is_persisted() {
let sqlite = SqliteStore::new(":memory:").await.unwrap();
let logger =
RetrievalFailureLogger::new(sqlite.clone(), 256, 16, Duration::from_millis(10), 90);
logger.log(low_confidence_record(0.3, 0.7));
logger.shutdown().await;
let rows: Vec<(String, f32, f32)> = sqlx::query_as(
"SELECT failure_type, top_score, confidence_threshold \
FROM memory_retrieval_failures WHERE failure_type = 'low_confidence'",
)
.fetch_all(sqlite.pool())
.await
.unwrap();
assert_eq!(rows.len(), 1, "low_confidence record must be persisted");
let (_, top_score, threshold) = &rows[0];
assert!((*top_score - 0.3_f32).abs() < 1e-5, "top_score must match");
assert!(
(*threshold - 0.7_f32).abs() < 1e-5,
"confidence_threshold must match"
);
}
#[tokio::test]
async fn log_does_not_block_when_channel_is_full() {
let sqlite = SqliteStore::new(":memory:").await.unwrap();
let logger = RetrievalFailureLogger::new(sqlite.clone(), 1, 16, Duration::from_mins(1), 90);
logger.log(no_hit_record());
let start = std::time::Instant::now();
logger.log(no_hit_record());
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_millis(100),
"log() must be non-blocking even when channel is full, elapsed={elapsed:?}"
);
logger.shutdown().await;
}
#[tokio::test]
async fn query_text_truncated_to_512_chars() {
let sqlite = SqliteStore::new(":memory:").await.unwrap();
let logger =
RetrievalFailureLogger::new(sqlite.clone(), 256, 16, Duration::from_millis(10), 90);
let long_query = "x".repeat(1000);
let mut record = no_hit_record();
record.query_text = long_query;
record.query_len = 1000;
logger.log(record);
logger.shutdown().await;
let rows: Vec<(String,)> =
sqlx::query_as("SELECT query_text FROM memory_retrieval_failures")
.fetch_all(sqlite.pool())
.await
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(
rows[0].0.chars().count(),
512,
"query_text must be truncated to 512 chars"
);
}
#[tokio::test]
async fn logger_disabled_when_option_is_none() {
let sqlite = SqliteStore::new(":memory:").await.unwrap();
let logger: Option<RetrievalFailureLogger> = None;
if let Some(l) = &logger {
l.log(no_hit_record());
}
let rows: Vec<(i64,)> = sqlx::query_as("SELECT COUNT(*) FROM memory_retrieval_failures")
.fetch_all(sqlite.pool())
.await
.unwrap();
assert_eq!(
rows[0].0, 0,
"no records must be written when logger is None"
);
}
#[tokio::test]
async fn multiple_records_batch_flushed() {
let sqlite = SqliteStore::new(":memory:").await.unwrap();
let logger =
RetrievalFailureLogger::new(sqlite.clone(), 256, 16, Duration::from_millis(10), 90);
for _ in 0..5 {
logger.log(no_hit_record());
}
logger.log(low_confidence_record(0.2, 0.8));
logger.shutdown().await;
let rows: Vec<(i64,)> = sqlx::query_as("SELECT COUNT(*) FROM memory_retrieval_failures")
.fetch_all(sqlite.pool())
.await
.unwrap();
assert_eq!(rows[0].0, 6, "all 6 records must be persisted in batch");
}
}