use std::time::{Duration, Instant};
use tacet_core::adaptive::AdaptiveState as CoreAdaptiveState;
use tacet_core::adaptive::{CalibrationSnapshot, Posterior};
use tacet_core::statistics::StatsSnapshot;
pub struct AdaptiveState {
core: CoreAdaptiveState,
start_time: Instant,
}
impl AdaptiveState {
pub fn new() -> Self {
Self {
core: CoreAdaptiveState::new(),
start_time: Instant::now(),
}
}
pub fn with_capacity(expected_samples: usize) -> Self {
Self {
core: CoreAdaptiveState::with_capacity(expected_samples),
start_time: Instant::now(),
}
}
pub fn core(&self) -> &CoreAdaptiveState {
&self.core
}
pub fn core_mut(&mut self) -> &mut CoreAdaptiveState {
&mut self.core
}
pub fn n_total(&self) -> usize {
self.core.n_total()
}
pub fn elapsed(&self) -> Duration {
self.start_time.elapsed()
}
pub fn reset_start_time(&mut self) {
self.start_time = Instant::now();
}
pub fn batch_count(&self) -> usize {
self.core.batch_count
}
pub fn baseline_samples(&self) -> &[u64] {
&self.core.baseline_samples
}
pub fn sample_samples(&self) -> &[u64] {
&self.core.sample_samples
}
pub fn add_batch(&mut self, baseline: Vec<u64>, sample: Vec<u64>) {
self.core.add_batch(baseline, sample);
}
pub fn add_batch_with_conversion(
&mut self,
baseline: Vec<u64>,
sample: Vec<u64>,
ns_per_tick: f64,
) {
self.core
.add_batch_with_conversion(baseline, sample, ns_per_tick);
}
pub fn update_kl(&mut self, kl: f64) {
self.core.update_kl(kl);
}
pub fn recent_kl_sum(&self) -> f64 {
self.core.recent_kl_sum()
}
pub fn has_kl_history(&self) -> bool {
self.core.has_kl_history()
}
pub fn update_posterior(&mut self, new_posterior: Posterior) -> f64 {
self.core.update_posterior(new_posterior)
}
pub fn current_posterior(&self) -> Option<&Posterior> {
self.core.current_posterior()
}
pub fn baseline_ns(&self, ns_per_tick: f64) -> Vec<f64> {
self.core.baseline_ns(ns_per_tick)
}
pub fn sample_ns(&self, ns_per_tick: f64) -> Vec<f64> {
self.core.sample_ns(ns_per_tick)
}
pub fn baseline_stats(&self) -> Option<StatsSnapshot> {
self.core.baseline_stats()
}
pub fn sample_stats(&self) -> Option<StatsSnapshot> {
self.core.sample_stats()
}
pub fn get_stats_snapshot(&self) -> Option<CalibrationSnapshot> {
self.core.get_stats_snapshot()
}
pub fn has_stats_tracking(&self) -> bool {
self.core.has_stats_tracking()
}
pub fn reset(&mut self) {
self.core.reset();
self.start_time = Instant::now();
}
}
impl Default for AdaptiveState {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Matrix9, Vector9};
#[test]
fn test_adaptive_state_new() {
let state = AdaptiveState::new();
assert_eq!(state.n_total(), 0);
assert_eq!(state.batch_count(), 0);
assert!(state.current_posterior().is_none());
assert!(!state.has_kl_history());
}
#[test]
fn test_add_batch() {
let mut state = AdaptiveState::new();
state.add_batch(vec![100, 101, 102], vec![200, 201, 202]);
assert_eq!(state.n_total(), 3);
assert_eq!(state.batch_count(), 1);
assert_eq!(state.baseline_samples(), &[100, 101, 102]);
assert_eq!(state.sample_samples(), &[200, 201, 202]);
}
#[test]
fn test_elapsed() {
let state = AdaptiveState::new();
std::thread::sleep(Duration::from_millis(10));
let elapsed = state.elapsed();
assert!(elapsed >= Duration::from_millis(10));
}
#[test]
fn test_kl_history() {
let mut state = AdaptiveState::new();
for i in 0..5 {
state.update_kl(0.1 * (i + 1) as f64);
}
assert!(state.has_kl_history());
assert!((state.recent_kl_sum() - 1.5).abs() < 1e-10);
state.update_kl(1.0);
assert!((state.recent_kl_sum() - 2.4).abs() < 1e-10);
}
#[test]
fn test_posterior_update() {
let mut state = AdaptiveState::new();
let posterior = Posterior::new(
Vector9::zeros(), Matrix9::identity(), Vec::new(), 0.75, 100.0, 1000, );
let kl = state.update_posterior(posterior.clone());
assert_eq!(kl, 0.0);
assert!(state.current_posterior().is_some());
}
#[test]
fn test_add_batch_with_conversion() {
let mut state = AdaptiveState::new();
state.add_batch_with_conversion(vec![100, 110, 120], vec![200, 210, 220], 2.0);
assert_eq!(state.n_total(), 3);
assert_eq!(state.batch_count(), 1);
assert!(state.has_stats_tracking());
}
#[test]
fn test_online_stats_tracking() {
let mut state = AdaptiveState::new();
let baseline: Vec<u64> = (0..100).map(|i| 1000 + (i % 10)).collect();
let sample: Vec<u64> = (0..100).map(|i| 1100 + (i % 10)).collect();
state.add_batch_with_conversion(baseline, sample, 1.0);
let baseline_stats = state.baseline_stats().expect("Should have baseline stats");
assert_eq!(baseline_stats.count, 100);
assert!(
(baseline_stats.mean - 1004.5).abs() < 1.0,
"Baseline mean {} should be near 1004.5",
baseline_stats.mean
);
let sample_stats = state.sample_stats().expect("Should have sample stats");
assert_eq!(sample_stats.count, 100);
assert!(
(sample_stats.mean - 1104.5).abs() < 1.0,
"Sample mean {} should be near 1104.5",
sample_stats.mean
);
}
#[test]
fn test_reset() {
let mut state = AdaptiveState::new();
state.add_batch_with_conversion(vec![100, 110], vec![200, 210], 1.0);
state.update_kl(0.5);
let posterior = Posterior::new(
Vector9::zeros(), Matrix9::identity(), Vec::new(), 0.75, 100.0, 100, );
state.update_posterior(posterior);
assert!(state.n_total() > 0);
state.reset();
assert_eq!(state.n_total(), 0);
assert_eq!(state.batch_count(), 0);
assert!(state.current_posterior().is_none());
assert!(!state.has_kl_history());
assert!(!state.has_stats_tracking());
}
#[test]
fn test_stats_not_tracked_without_conversion() {
let mut state = AdaptiveState::new();
state.add_batch(vec![100, 110, 120], vec![200, 210, 220]);
assert!(!state.has_stats_tracking());
assert!(state.baseline_stats().is_none());
assert!(state.sample_stats().is_none());
}
}