use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use tdigest::TDigest;
use super::data_models::PredictionMetrics;
mod tdigest_serde {
use serde::{self, Deserialize, Deserializer, Serialize, Serializer};
use serde_json::Value;
use tdigest::TDigest;
fn sanitize_nulls(value: &mut Value) {
match value {
Value::Null => *value = Value::Number(serde_json::Number::from_f64(0.0).unwrap()),
Value::Array(arr) => {
for item in arr.iter_mut() {
sanitize_nulls(item);
}
}
Value::Object(map) => {
for v in map.values_mut() {
sanitize_nulls(v);
}
}
_ => {}
}
}
pub fn serialize<S>(digest: &TDigest, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut value = serde_json::to_value(digest).map_err(serde::ser::Error::custom)?;
sanitize_nulls(&mut value);
value.serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<TDigest, D::Error>
where
D: Deserializer<'de>,
{
let mut value = Value::deserialize(deserializer)?;
sanitize_nulls(&mut value);
serde_json::from_value(value).map_err(serde::de::Error::custom)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RunningStats {
pub count: u64,
pub mean: f64,
pub m2: f64,
#[serde(with = "tdigest_serde")]
pub digest: TDigest,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct NodeAccumulators {
pub remaining_calls: HashMap<u32, RunningStats>,
pub interarrival_ms: HashMap<u32, RunningStats>,
pub output_tokens: HashMap<u32, RunningStats>,
pub sensitivity: HashMap<u32, RunningStats>,
pub all_remaining_calls: RunningStats,
pub all_interarrival_ms: RunningStats,
pub all_output_tokens: RunningStats,
pub all_sensitivity: RunningStats,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AccumulatorState {
pub nodes: HashMap<String, NodeAccumulators>,
}
impl RunningStats {
pub fn new() -> Self {
Self {
count: 0,
mean: 0.0,
m2: 0.0,
digest: TDigest::new_with_size(100),
}
}
pub fn has_samples(&self) -> bool {
self.count > 0
}
pub fn add_sample(&mut self, value: f64) {
self.count += 1;
let delta = value - self.mean;
self.mean += delta / self.count as f64;
let delta2 = value - self.mean;
self.m2 += delta * delta2;
self.digest = self.digest.merge_unsorted(vec![value]);
}
pub fn merge(&mut self, other: &RunningStats) {
if other.count == 0 {
return;
}
let combined_count = self.count + other.count;
let delta = other.mean - self.mean;
self.mean = if combined_count > 0 {
(self.mean * self.count as f64 + other.mean * other.count as f64)
/ combined_count as f64
} else {
0.0
};
self.m2 +=
other.m2 + delta * delta * (self.count * other.count) as f64 / combined_count as f64;
self.count = combined_count;
self.digest = TDigest::merge_digests(vec![self.digest.clone(), other.digest.clone()]);
}
pub fn compute_metrics(&self) -> PredictionMetrics {
if self.count == 0 {
return PredictionMetrics::default();
}
PredictionMetrics {
sample_count: self.count as u32,
mean: self.mean,
p50: self.digest.estimate_quantile(0.50),
p90: self.digest.estimate_quantile(0.90),
p95: self.digest.estimate_quantile(0.95),
}
}
}
impl Default for RunningStats {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
pub(crate) fn nat_exact_percentile(sorted_samples: &[f64], pct: f64) -> f64 {
if sorted_samples.is_empty() {
return 0.0;
}
if sorted_samples.len() == 1 {
return sorted_samples[0];
}
let k = (sorted_samples.len() - 1) as f64 * (pct / 100.0);
let f = k.floor() as usize;
let c = k.ceil() as usize;
if f == c {
return sorted_samples[f];
}
sorted_samples[f] + (sorted_samples[c] - sorted_samples[f]) * (k - f as f64)
}
#[cfg(test)]
#[path = "../../tests/unit/trie/accumulator_tests.rs"]
mod tests;