#[cfg(feature = "alloc")]
use alloc::{string::String, vec::Vec};
use hashbrown::HashMap;
use serde::{Deserialize, Serialize};
pub const DEFAULT_MAX_LOG_SIZE: usize = 10_000;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingLogEntry {
pub query_id: u64,
pub source_id: String,
pub confidence: f32,
pub ml_used: bool,
pub timestamp_ms: u64,
pub outcome: Option<RoutingOutcome>,
#[serde(default)]
pub feature_vector: Option<Vec<f32>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingOutcome {
pub success: bool,
pub latency_ms: u32,
pub result_count: u32,
pub reward: f32,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SourceLogStats {
pub total_routed: u64,
pub successes: u64,
latency_sum: f64,
reward_sum: f64,
pub outcomes_received: u64,
}
impl SourceLogStats {
pub fn record(&mut self, outcome: &RoutingOutcome) {
self.outcomes_received += 1;
if outcome.success {
self.successes += 1;
}
self.latency_sum += f64::from(outcome.latency_ms);
self.reward_sum += f64::from(outcome.reward);
}
#[must_use]
pub fn avg_latency_ms(&self) -> f64 {
if self.outcomes_received == 0 {
0.0
} else {
self.latency_sum / self.outcomes_received as f64
}
}
#[must_use]
pub fn avg_reward(&self) -> f64 {
if self.outcomes_received == 0 {
0.5 } else {
self.reward_sum / self.outcomes_received as f64
}
}
#[must_use]
pub fn success_rate(&self) -> f64 {
if self.outcomes_received == 0 {
1.0 } else {
self.successes as f64 / self.outcomes_received as f64
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryLog {
entries: Vec<RoutingLogEntry>,
max_size: usize,
source_stats: HashMap<String, SourceLogStats>,
pub total_recorded: u64,
}
impl QueryLog {
#[must_use]
pub fn new() -> Self {
Self::with_capacity(DEFAULT_MAX_LOG_SIZE)
}
#[must_use]
pub fn with_capacity(max_size: usize) -> Self {
Self {
entries: Vec::with_capacity(max_size.min(1024)),
max_size,
source_stats: HashMap::new(),
total_recorded: 0,
}
}
pub fn record_routing(
&mut self,
query_id: u64,
source_id: impl Into<String>,
confidence: f32,
ml_used: bool,
feature_vector: Option<Vec<f32>>,
) {
let source_id = source_id.into();
self.source_stats
.entry(source_id.clone())
.or_default()
.total_routed += 1;
self.total_recorded += 1;
let entry = RoutingLogEntry {
query_id,
source_id,
confidence,
ml_used,
timestamp_ms: get_time_ms(),
outcome: None,
feature_vector,
};
if self.entries.len() >= self.max_size {
self.entries.remove(0);
}
self.entries.push(entry);
}
#[must_use]
pub fn find_entry_features(&self, query_id: u64, source_id: &str) -> Option<Vec<f32>> {
self.entries
.iter()
.rev()
.find(|e| e.query_id == query_id && e.source_id == source_id)
.and_then(|e| e.feature_vector.clone())
}
pub fn record_outcome(
&mut self,
query_id: u64,
source_id: &str,
success: bool,
latency_ms: u32,
result_count: u32,
reward: f32,
) {
let outcome = RoutingOutcome {
success,
latency_ms,
result_count,
reward,
};
if let Some(stats) = self.source_stats.get_mut(source_id) {
stats.record(&outcome);
}
if let Some(entry) = self
.entries
.iter_mut()
.rev()
.find(|e| e.query_id == query_id && e.source_id == source_id && e.outcome.is_none())
{
entry.outcome = Some(outcome);
}
}
#[must_use]
pub fn source_stats(&self, source_id: &str) -> Option<&SourceLogStats> {
self.source_stats.get(source_id)
}
#[must_use]
pub fn routing_score(&self, source_id: &str) -> Option<f32> {
let stats = self.source_stats.get(source_id)?;
if stats.outcomes_received == 0 {
return None;
}
let score = stats.success_rate() * 0.6 + stats.avg_reward() * 0.4;
Some(score as f32)
}
#[must_use]
pub fn best_source(&self) -> Option<&str> {
self.source_stats
.iter()
.filter(|(_, s)| s.outcomes_received > 0)
.max_by(|(_, a), (_, b)| {
let sa = a.success_rate() * 0.6 + a.avg_reward() * 0.4;
let sb = b.success_rate() * 0.6 + b.avg_reward() * 0.4;
sa.partial_cmp(&sb).unwrap_or(core::cmp::Ordering::Equal)
})
.map(|(id, _)| id.as_str())
}
#[must_use]
pub fn ranked_sources(&self) -> Vec<(String, f32)> {
let mut ranked: Vec<_> = self
.source_stats
.iter()
.filter(|(_, s)| s.outcomes_received > 0)
.map(|(id, s)| {
let score = s.success_rate() * 0.6 + s.avg_reward() * 0.4;
(id.clone(), score as f32)
})
.collect();
ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
ranked
}
#[must_use]
pub fn recent_entries(&self, limit: usize) -> &[RoutingLogEntry] {
let len = self.entries.len();
if len <= limit {
&self.entries
} else {
&self.entries[len - limit..]
}
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn clear_entries(&mut self) {
self.entries.clear();
}
pub fn clear_all(&mut self) {
self.entries.clear();
self.source_stats.clear();
self.total_recorded = 0;
}
pub fn evict_older_than(&mut self, max_age_ms: u64) -> usize {
let now = get_time_ms();
let before = self.entries.len();
self.entries
.retain(|e| now.saturating_sub(e.timestamp_ms) <= max_age_ms);
before - self.entries.len()
}
#[must_use]
pub fn tracked_source_count(&self) -> usize {
self.source_stats.len()
}
#[must_use]
pub fn combined_reliability(&self, source_id: &str, current_rate: f32) -> f32 {
match self.routing_score(source_id) {
Some(log_score) => {
0.4 * current_rate + 0.6 * log_score
}
None => current_rate,
}
}
}
impl Default for QueryLog {
fn default() -> Self {
Self::new()
}
}
fn get_time_ms() -> u64 {
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
{
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
#[cfg(any(not(feature = "std"), target_arch = "wasm32"))]
{
0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(all(not(feature = "std"), feature = "alloc"))]
use alloc::{format, vec};
#[test]
fn test_query_log_basic() {
let mut log = QueryLog::new();
assert!(log.is_empty());
log.record_routing(1, "src1", 0.9, false, None);
log.record_routing(1, "src2", 0.7, false, None);
assert_eq!(log.len(), 2);
}
#[test]
fn test_record_outcome() {
let mut log = QueryLog::new();
log.record_routing(42, "src1", 0.8, false, None);
log.record_outcome(42, "src1", true, 100, 50, 0.9);
let stats = log.source_stats("src1").unwrap();
assert_eq!(stats.outcomes_received, 1);
assert_eq!(stats.successes, 1);
}
#[test]
fn test_routing_score() {
let mut log = QueryLog::new();
for _ in 0..5 {
log.record_routing(1, "good", 0.9, false, None);
log.record_outcome(1, "good", true, 100, 50, 0.9);
}
for _ in 0..5 {
log.record_routing(2, "bad", 0.5, false, None);
log.record_outcome(2, "bad", false, 5000, 0, 0.0);
}
let good_score = log.routing_score("good").unwrap();
let bad_score = log.routing_score("bad").unwrap();
assert!(good_score > bad_score);
}
#[test]
fn test_best_source() {
let mut log = QueryLog::new();
log.record_routing(1, "fast", 0.9, false, None);
log.record_outcome(1, "fast", true, 50, 100, 1.0);
log.record_routing(2, "slow", 0.5, false, None);
log.record_outcome(2, "slow", true, 3000, 10, 0.4);
assert_eq!(log.best_source(), Some("fast"));
}
#[test]
fn test_ranked_sources() {
let mut log = QueryLog::new();
log.record_routing(1, "a", 0.9, false, None);
log.record_outcome(1, "a", true, 100, 50, 0.9);
log.record_routing(2, "b", 0.5, false, None);
log.record_outcome(2, "b", false, 2000, 0, 0.1);
let ranked = log.ranked_sources();
assert_eq!(ranked.len(), 2);
assert_eq!(ranked[0].0, "a");
}
#[test]
fn test_eviction() {
let mut log = QueryLog::with_capacity(3);
for i in 0..5u64 {
log.record_routing(i, format!("src{}", i), 0.5, false, None);
}
assert_eq!(log.len(), 3); }
#[test]
fn test_combined_reliability() {
let mut log = QueryLog::new();
let score = log.combined_reliability("unknown", 0.8);
assert!((score - 0.8).abs() < 0.001);
log.record_routing(1, "known", 0.9, false, None);
log.record_outcome(1, "known", true, 100, 50, 0.9);
let score = log.combined_reliability("known", 0.8);
assert!(score > 0.7); }
#[test]
fn test_clear() {
let mut log = QueryLog::new();
log.record_routing(1, "src1", 0.8, false, None);
log.record_outcome(1, "src1", true, 100, 50, 0.9);
log.clear_entries();
assert!(log.is_empty());
assert!(log.source_stats("src1").is_some());
log.clear_all();
assert!(log.source_stats("src1").is_none());
}
#[test]
fn test_feature_vector_storage_and_retrieval() {
let mut log = QueryLog::new();
let fv = vec![0.1_f32, 0.2, 0.3, 0.4];
log.record_routing(99, "src1", 0.8, true, Some(fv.clone()));
let retrieved = log.find_entry_features(99, "src1");
assert_eq!(retrieved, Some(fv));
}
#[test]
fn test_find_entry_features_none_when_not_ml() {
let mut log = QueryLog::new();
log.record_routing(100, "src1", 0.8, false, None);
let retrieved = log.find_entry_features(100, "src1");
assert!(retrieved.is_none());
}
}