use crate::error::{Result, RuvLLMError};
use crate::types::{ErrorInfo, ModelSize, QualityMetrics};
use chrono::{DateTime, Utc};
use parking_lot::Mutex;
use ruvector_core::types::DbOptions;
use ruvector_core::{AgenticDB, SearchQuery, VectorEntry};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use uuid::Uuid;
#[cfg(feature = "async-runtime")]
use tokio::sync::{oneshot, Notify};
#[cfg(feature = "async-runtime")]
use tokio::time::{interval, Duration};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct LatencyBreakdown {
pub embedding_ms: f32,
pub retrieval_ms: f32,
pub routing_ms: f32,
pub attention_ms: f32,
pub generation_ms: f32,
pub total_ms: f32,
}
impl LatencyBreakdown {
pub fn new() -> Self {
Self::default()
}
pub fn compute_total(&mut self) {
self.total_ms = self.embedding_ms
+ self.retrieval_ms
+ self.routing_ms
+ self.attention_ms
+ self.generation_ms;
}
pub fn exceeds_threshold(&self, threshold_ms: f32) -> bool {
self.total_ms > threshold_ms
}
pub fn slowest_component(&self) -> (&'static str, f32) {
let components = [
("embedding", self.embedding_ms),
("retrieval", self.retrieval_ms),
("routing", self.routing_ms),
("attention", self.attention_ms),
("generation", self.generation_ms),
];
components
.into_iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(("unknown", 0.0))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingDecision {
pub model: ModelSize,
pub context_size: usize,
pub temperature: f32,
pub top_p: f32,
pub confidence: f32,
pub model_probs: [f32; 4],
}
impl Default for RoutingDecision {
fn default() -> Self {
Self {
model: ModelSize::Small,
context_size: 0,
temperature: 0.7,
top_p: 0.9,
confidence: 0.5,
model_probs: [0.25, 0.25, 0.25, 0.25],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WitnessEntry {
pub request_id: Uuid,
pub session_id: String,
pub query_embedding: Vec<f32>,
pub routing_decision: RoutingDecision,
pub model_used: ModelSize,
pub quality_score: f32,
pub latency: LatencyBreakdown,
pub context_doc_ids: Vec<Uuid>,
pub response_embedding: Vec<f32>,
pub timestamp: DateTime<Utc>,
pub error: Option<ErrorInfo>,
pub quality_metrics: Option<QualityMetrics>,
pub tags: Vec<String>,
}
impl WitnessEntry {
pub fn new(
session_id: String,
query_embedding: Vec<f32>,
routing_decision: RoutingDecision,
) -> Self {
Self {
request_id: Uuid::new_v4(),
session_id,
query_embedding,
routing_decision: routing_decision.clone(),
model_used: routing_decision.model,
quality_score: 0.0,
latency: LatencyBreakdown::default(),
context_doc_ids: Vec::new(),
response_embedding: Vec::new(),
timestamp: Utc::now(),
error: None,
quality_metrics: None,
tags: Vec::new(),
}
}
pub fn with_quality(mut self, score: f32) -> Self {
self.quality_score = score;
self
}
pub fn with_latency(mut self, latency: LatencyBreakdown) -> Self {
self.latency = latency;
self
}
pub fn with_error(mut self, error: ErrorInfo) -> Self {
self.error = Some(error);
self
}
pub fn is_success(&self) -> bool {
self.error.is_none()
}
pub fn meets_quality_threshold(&self, threshold: f32) -> bool {
self.quality_score >= threshold
}
}
#[derive(Debug, Clone)]
pub struct AsyncWriteConfig {
pub max_batch_size: usize,
pub max_wait_ms: u64,
pub max_queue_depth: usize,
pub fsync_critical: bool,
pub flush_interval_ms: u64,
}
impl Default for AsyncWriteConfig {
fn default() -> Self {
Self {
max_batch_size: 100,
max_wait_ms: 1000,
max_queue_depth: 10000,
fsync_critical: false,
flush_interval_ms: 1000,
}
}
}
struct WritebackQueue {
entries: Vec<WitnessEntry>,
config: AsyncWriteConfig,
last_flush: DateTime<Utc>,
dropped_count: usize,
}
impl WritebackQueue {
fn new(config: AsyncWriteConfig) -> Self {
Self {
entries: Vec::with_capacity(config.max_batch_size),
config,
last_flush: Utc::now(),
dropped_count: 0,
}
}
fn should_flush(&self) -> bool {
if self.entries.len() >= self.config.max_batch_size {
return true;
}
let elapsed = (Utc::now() - self.last_flush).num_milliseconds() as u64;
elapsed >= self.config.max_wait_ms && !self.entries.is_empty()
}
fn push(&mut self, entry: WitnessEntry) -> bool {
if self.entries.len() >= self.config.max_queue_depth {
self.dropped_count += 1;
return false;
}
self.entries.push(entry);
true
}
fn drain(&mut self) -> Vec<WitnessEntry> {
self.last_flush = Utc::now();
std::mem::take(&mut self.entries)
}
fn len(&self) -> usize {
self.entries.len()
}
fn is_empty(&self) -> bool {
self.entries.is_empty()
}
fn dropped_count(&self) -> usize {
self.dropped_count
}
}
pub struct WitnessLog {
db: AgenticDB,
embedding_dim: usize,
writeback_queue: Arc<Mutex<WritebackQueue>>,
total_entries: AtomicUsize,
success_count: AtomicUsize,
error_count: AtomicUsize,
async_config: AsyncWriteConfig,
storage_path: String,
background_running: Arc<AtomicBool>,
#[cfg(feature = "async-runtime")]
flush_notify: Arc<Notify>,
#[cfg(feature = "async-runtime")]
shutdown_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
}
impl WitnessLog {
pub fn new(storage_path: &str, embedding_dim: usize) -> Result<Self> {
Self::with_config(storage_path, embedding_dim, AsyncWriteConfig::default())
}
pub fn with_config(
storage_path: &str,
embedding_dim: usize,
async_config: AsyncWriteConfig,
) -> Result<Self> {
let mut options = DbOptions::default();
options.storage_path = storage_path.to_string();
options.dimensions = embedding_dim;
let db = AgenticDB::new(options).map_err(|e| RuvLLMError::Storage(e.to_string()))?;
Ok(Self {
db,
embedding_dim,
writeback_queue: Arc::new(Mutex::new(WritebackQueue::new(async_config.clone()))),
total_entries: AtomicUsize::new(0),
success_count: AtomicUsize::new(0),
error_count: AtomicUsize::new(0),
async_config,
storage_path: storage_path.to_string(),
background_running: Arc::new(AtomicBool::new(false)),
#[cfg(feature = "async-runtime")]
flush_notify: Arc::new(Notify::new()),
#[cfg(feature = "async-runtime")]
shutdown_tx: Arc::new(Mutex::new(None)),
})
}
pub fn record(&self, entry: WitnessEntry) -> Result<()> {
self.total_entries.fetch_add(1, Ordering::SeqCst);
if entry.is_success() {
self.success_count.fetch_add(1, Ordering::SeqCst);
} else {
self.error_count.fetch_add(1, Ordering::SeqCst);
}
let mut queue = self.writeback_queue.lock();
if !queue.push(entry) {
return Err(RuvLLMError::OutOfMemory(
"Witness log queue full, entry dropped due to backpressure".to_string(),
));
}
if !self.background_running.load(Ordering::SeqCst) && queue.should_flush() {
let entries = queue.drain();
drop(queue); self.flush_entries(entries)?;
}
#[cfg(feature = "async-runtime")]
if self.background_running.load(Ordering::SeqCst) {
self.flush_notify.notify_one();
}
Ok(())
}
pub fn record_critical(&self, entry: WitnessEntry) -> Result<()> {
self.total_entries.fetch_add(1, Ordering::SeqCst);
if entry.is_success() {
self.success_count.fetch_add(1, Ordering::SeqCst);
} else {
self.error_count.fetch_add(1, Ordering::SeqCst);
}
self.flush_entries(vec![entry])?;
if self.async_config.fsync_critical {
self.fsync()?;
}
Ok(())
}
fn fsync(&self) -> Result<()> {
#[cfg(feature = "async-runtime")]
{
use std::fs::OpenOptions;
if let Ok(file) = OpenOptions::new().read(true).open(&self.storage_path) {
let _ = file.sync_all();
}
}
Ok(())
}
fn flush_entries(&self, entries: Vec<WitnessEntry>) -> Result<()> {
for entry in entries {
let mut metadata = HashMap::new();
metadata.insert(
"request_id".to_string(),
serde_json::json!(entry.request_id.to_string()),
);
metadata.insert(
"session_id".to_string(),
serde_json::json!(entry.session_id),
);
metadata.insert(
"model_used".to_string(),
serde_json::to_value(&entry.model_used).unwrap_or_default(),
);
metadata.insert(
"quality_score".to_string(),
serde_json::json!(entry.quality_score),
);
metadata.insert(
"routing_decision".to_string(),
serde_json::to_value(&entry.routing_decision).unwrap_or_default(),
);
metadata.insert(
"latency".to_string(),
serde_json::to_value(&entry.latency).unwrap_or_default(),
);
metadata.insert(
"timestamp".to_string(),
serde_json::json!(entry.timestamp.to_rfc3339()),
);
metadata.insert(
"is_success".to_string(),
serde_json::json!(entry.is_success()),
);
metadata.insert("tags".to_string(), serde_json::json!(entry.tags));
if let Some(error) = &entry.error {
metadata.insert(
"error".to_string(),
serde_json::to_value(error).unwrap_or_default(),
);
}
if let Some(qm) = &entry.quality_metrics {
metadata.insert(
"quality_metrics".to_string(),
serde_json::to_value(qm).unwrap_or_default(),
);
}
let vector_entry = VectorEntry {
id: Some(entry.request_id.to_string()),
vector: entry.query_embedding,
metadata: Some(metadata),
};
self.db
.insert(vector_entry)
.map_err(|e| RuvLLMError::Storage(e.to_string()))?;
}
Ok(())
}
pub fn flush(&self) -> Result<()> {
let mut queue = self.writeback_queue.lock();
if !queue.entries.is_empty() {
let entries = queue.drain();
drop(queue);
self.flush_entries(entries)?;
}
Ok(())
}
pub fn search(&self, query_embedding: &[f32], limit: usize) -> Result<Vec<WitnessEntry>> {
let query = SearchQuery {
vector: query_embedding.to_vec(),
k: limit,
filter: None,
ef_search: None,
};
let results = self
.db
.search(query)
.map_err(|e| RuvLLMError::Storage(e.to_string()))?;
let mut entries = Vec::with_capacity(results.len());
for result in results {
if let Some(metadata) = &result.metadata {
if let Some(entry) = self.entry_from_metadata(&result.id, query_embedding, metadata)
{
entries.push(entry);
}
}
}
Ok(entries)
}
pub fn stats(&self) -> WitnessLogStats {
let total = self.total_entries.load(Ordering::SeqCst);
let success = self.success_count.load(Ordering::SeqCst);
let errors = self.error_count.load(Ordering::SeqCst);
let queue = self.writeback_queue.lock();
WitnessLogStats {
total_entries: total,
success_count: success,
error_count: errors,
success_rate: if total > 0 {
success as f32 / total as f32
} else {
0.0
},
pending_writes: queue.len(),
dropped_entries: queue.dropped_count(),
background_running: self.background_running.load(Ordering::SeqCst),
}
}
pub fn async_config(&self) -> &AsyncWriteConfig {
&self.async_config
}
pub fn has_dropped_entries(&self) -> bool {
self.writeback_queue.lock().dropped_count() > 0
}
fn entry_from_metadata(
&self,
_id: &str,
embedding: &[f32],
metadata: &HashMap<String, serde_json::Value>,
) -> Option<WitnessEntry> {
let request_id = metadata
.get("request_id")
.and_then(|v| v.as_str())
.and_then(|s| Uuid::parse_str(s).ok())?;
let session_id = metadata
.get("session_id")
.and_then(|v| v.as_str())?
.to_string();
let model_used: ModelSize = metadata
.get("model_used")
.and_then(|v| serde_json::from_value(v.clone()).ok())
.unwrap_or_default();
let quality_score = metadata
.get("quality_score")
.and_then(|v| v.as_f64())
.unwrap_or(0.0) as f32;
let routing_decision: RoutingDecision = metadata
.get("routing_decision")
.and_then(|v| serde_json::from_value(v.clone()).ok())
.unwrap_or_default();
let latency: LatencyBreakdown = metadata
.get("latency")
.and_then(|v| serde_json::from_value(v.clone()).ok())
.unwrap_or_default();
let timestamp = metadata
.get("timestamp")
.and_then(|v| v.as_str())
.and_then(|s| DateTime::parse_from_rfc3339(s).ok())
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(Utc::now);
let error: Option<ErrorInfo> = metadata
.get("error")
.and_then(|v| serde_json::from_value(v.clone()).ok());
let quality_metrics: Option<QualityMetrics> = metadata
.get("quality_metrics")
.and_then(|v| serde_json::from_value(v.clone()).ok());
let tags: Vec<String> = metadata
.get("tags")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
Some(WitnessEntry {
request_id,
session_id,
query_embedding: embedding.to_vec(),
routing_decision,
model_used,
quality_score,
latency,
context_doc_ids: Vec::new(),
response_embedding: Vec::new(),
timestamp,
error,
quality_metrics,
tags,
})
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct WitnessLogStats {
pub total_entries: usize,
pub success_count: usize,
pub error_count: usize,
pub success_rate: f32,
pub pending_writes: usize,
pub dropped_entries: usize,
pub background_running: bool,
}
#[cfg(feature = "async-runtime")]
impl WitnessLog {
pub fn start_background_flush(self: &Arc<Self>) {
if self.background_running.swap(true, Ordering::SeqCst) {
return;
}
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
*self.shutdown_tx.lock() = Some(shutdown_tx);
let log = Arc::clone(self);
let flush_interval = Duration::from_millis(self.async_config.flush_interval_ms);
tokio::spawn(async move {
let mut ticker = interval(flush_interval);
loop {
tokio::select! {
_ = ticker.tick() => {
log.flush_if_needed_internal();
}
_ = log.flush_notify.notified() => {
log.flush_if_needed_internal();
}
_ = &mut shutdown_rx => {
if let Err(e) = log.flush() {
tracing::error!("Error during final witness log flush: {}", e);
}
log.background_running.store(false, Ordering::SeqCst);
break;
}
}
}
});
}
pub async fn stop_background_flush(&self) {
if !self.background_running.load(Ordering::SeqCst) {
return;
}
if let Some(tx) = self.shutdown_tx.lock().take() {
let _ = tx.send(());
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
pub async fn record_async(&self, entry: WitnessEntry) -> Result<()> {
self.record(entry)
}
pub async fn flush_async(&self) -> Result<()> {
let queue = Arc::clone(&self.writeback_queue);
let entries = {
let mut q = queue.lock();
if q.is_empty() {
return Ok(());
}
q.drain()
};
self.flush_entries(entries)
}
fn flush_if_needed_internal(&self) {
let entries = {
let mut queue = self.writeback_queue.lock();
if queue.should_flush() {
queue.drain()
} else {
return;
}
};
if let Err(e) = self.flush_entries(entries) {
tracing::error!("Background witness log flush failed: {}", e);
}
}
pub async fn record_batch(&self, entries: Vec<WitnessEntry>) -> Result<usize> {
let mut accepted = 0;
for entry in entries {
self.total_entries.fetch_add(1, Ordering::SeqCst);
if entry.is_success() {
self.success_count.fetch_add(1, Ordering::SeqCst);
} else {
self.error_count.fetch_add(1, Ordering::SeqCst);
}
let mut queue = self.writeback_queue.lock();
if queue.push(entry) {
accepted += 1;
}
}
self.flush_notify.notify_one();
Ok(accepted)
}
pub fn stats_async(&self) -> WitnessLogStats {
let total = self.total_entries.load(Ordering::SeqCst);
let success = self.success_count.load(Ordering::SeqCst);
let errors = self.error_count.load(Ordering::SeqCst);
let queue = self.writeback_queue.lock();
WitnessLogStats {
total_entries: total,
success_count: success,
error_count: errors,
success_rate: if total > 0 {
success as f32 / total as f32
} else {
0.0
},
pending_writes: queue.len(),
dropped_entries: queue.dropped_count(),
background_running: self.background_running.load(Ordering::SeqCst),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_latency_breakdown() {
let mut latency = LatencyBreakdown {
embedding_ms: 10.0,
retrieval_ms: 5.0,
routing_ms: 2.0,
attention_ms: 50.0,
generation_ms: 100.0,
total_ms: 0.0,
};
latency.compute_total();
assert_eq!(latency.total_ms, 167.0);
let (name, _) = latency.slowest_component();
assert_eq!(name, "generation");
}
#[test]
fn test_witness_entry() {
let entry = WitnessEntry::new(
"session-1".to_string(),
vec![0.1; 768],
RoutingDecision::default(),
);
assert!(entry.is_success());
assert!(!entry.meets_quality_threshold(0.5));
let entry = entry.with_quality(0.8);
assert!(entry.meets_quality_threshold(0.5));
}
#[test]
fn test_routing_decision() {
let decision = RoutingDecision::default();
assert_eq!(decision.model, ModelSize::Small);
assert_eq!(decision.temperature, 0.7);
}
#[test]
fn test_async_write_config_default() {
let config = AsyncWriteConfig::default();
assert_eq!(config.max_batch_size, 100);
assert_eq!(config.max_wait_ms, 1000);
assert_eq!(config.max_queue_depth, 10000);
assert!(!config.fsync_critical);
assert_eq!(config.flush_interval_ms, 1000);
}
#[test]
fn test_writeback_queue_batching() {
let config = AsyncWriteConfig {
max_batch_size: 5,
max_wait_ms: 1000,
max_queue_depth: 100,
fsync_critical: false,
flush_interval_ms: 1000,
};
let mut queue = WritebackQueue::new(config);
assert!(!queue.should_flush());
assert!(queue.is_empty());
for i in 0..4 {
let entry = WitnessEntry::new(
format!("session-{}", i),
vec![0.1; 768],
RoutingDecision::default(),
);
assert!(queue.push(entry));
}
assert_eq!(queue.len(), 4);
assert!(!queue.should_flush());
let entry = WitnessEntry::new(
"session-4".to_string(),
vec![0.1; 768],
RoutingDecision::default(),
);
assert!(queue.push(entry));
assert!(queue.should_flush());
let entries = queue.drain();
assert_eq!(entries.len(), 5);
assert!(queue.is_empty());
}
#[test]
fn test_writeback_queue_backpressure() {
let config = AsyncWriteConfig {
max_batch_size: 5,
max_wait_ms: 1000,
max_queue_depth: 10, fsync_critical: false,
flush_interval_ms: 1000,
};
let mut queue = WritebackQueue::new(config);
for i in 0..10 {
let entry = WitnessEntry::new(
format!("session-{}", i),
vec![0.1; 768],
RoutingDecision::default(),
);
assert!(queue.push(entry), "Entry {} should be accepted", i);
}
let entry = WitnessEntry::new(
"session-overflow".to_string(),
vec![0.1; 768],
RoutingDecision::default(),
);
assert!(
!queue.push(entry),
"Entry should be dropped due to backpressure"
);
assert_eq!(queue.dropped_count(), 1);
let entry2 = WitnessEntry::new(
"session-overflow-2".to_string(),
vec![0.1; 768],
RoutingDecision::default(),
);
assert!(!queue.push(entry2));
assert_eq!(queue.dropped_count(), 2);
}
#[test]
fn test_witness_log_stats() {
let config = AsyncWriteConfig {
max_batch_size: 100,
max_wait_ms: 1000,
max_queue_depth: 5, fsync_critical: false,
flush_interval_ms: 1000,
};
let temp_dir = tempfile::tempdir().unwrap();
let storage_path = temp_dir.path().join("witness_test");
let log = WitnessLog::with_config(storage_path.to_str().unwrap(), 64, config).unwrap();
for i in 0..3 {
let entry = WitnessEntry::new(
format!("session-{}", i),
vec![0.1; 64],
RoutingDecision::default(),
);
log.record(entry).unwrap();
}
let stats = log.stats();
assert_eq!(stats.total_entries, 3);
assert_eq!(stats.success_count, 3);
assert_eq!(stats.error_count, 0);
assert!(!stats.background_running);
}
#[cfg(feature = "async-runtime")]
mod async_tests {
use super::*;
use std::sync::Arc;
#[tokio::test]
async fn test_background_flush_task() {
let config = AsyncWriteConfig {
max_batch_size: 5,
max_wait_ms: 100, max_queue_depth: 1000,
fsync_critical: false,
flush_interval_ms: 50, };
let temp_dir = tempfile::tempdir().unwrap();
let storage_path = temp_dir.path().join("async_witness_test");
let log = Arc::new(
WitnessLog::with_config(storage_path.to_str().unwrap(), 64, config).unwrap(),
);
log.start_background_flush();
let stats = log.stats_async();
assert!(stats.background_running);
for i in 0..10 {
let entry = WitnessEntry::new(
format!("async-session-{}", i),
vec![0.1; 64],
RoutingDecision::default(),
);
log.record_async(entry).await.unwrap();
}
tokio::time::sleep(Duration::from_millis(200)).await;
let stats = log.stats_async();
assert!(stats.pending_writes < 10);
log.stop_background_flush().await;
let stats = log.stats_async();
assert!(!stats.background_running);
}
#[tokio::test]
async fn test_record_batch() {
let temp_dir = tempfile::tempdir().unwrap();
let storage_path = temp_dir.path().join("batch_witness_test");
let log = Arc::new(WitnessLog::new(storage_path.to_str().unwrap(), 64).unwrap());
log.start_background_flush();
let entries: Vec<_> = (0..50)
.map(|i| {
WitnessEntry::new(
format!("batch-session-{}", i),
vec![0.1; 64],
RoutingDecision::default(),
)
})
.collect();
let accepted = log.record_batch(entries).await.unwrap();
assert_eq!(accepted, 50);
let stats = log.stats_async();
assert_eq!(stats.total_entries, 50);
log.stop_background_flush().await;
}
#[tokio::test]
async fn test_flush_async() {
let temp_dir = tempfile::tempdir().unwrap();
let storage_path = temp_dir.path().join("flush_async_test");
let log = WitnessLog::new(storage_path.to_str().unwrap(), 64).unwrap();
for i in 0..5 {
let entry = WitnessEntry::new(
format!("flush-session-{}", i),
vec![0.1; 64],
RoutingDecision::default(),
);
log.record(entry).unwrap();
}
log.flush_async().await.unwrap();
let stats = log.stats();
assert_eq!(stats.pending_writes, 0);
}
}
}