use std::{
cmp::{Ordering, Reverse},
collections::{BinaryHeap, HashMap, HashSet},
sync::Arc,
};
use ordered_float::OrderedFloat;
use polars::{
df,
frame::DataFrame,
prelude::{DataType, Field, PlSmallStr, Schema, SchemaRef},
};
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
use strum::{Display, EnumCount, EnumIter, EnumString, IntoEnumIterator, IntoStaticStr};
use crate::{
error::{ChapatyError, ChapatyResult, DataError, IoError, SystemError},
report::{
io::{Report, ReportName, ToSchema, generate_dynamic_base_name},
portfolio_performance::PortfolioPerformanceCol,
},
sorted_vec_map::SortedVecMap,
};
const METRIC_COUNT: usize = PortfolioPerformanceCol::COUNT;
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize,
EnumString,
Display,
PartialOrd,
Ord,
EnumIter,
EnumCount,
IntoStaticStr,
)]
#[strum(serialize_all = "snake_case")]
pub enum LeaderboardCol {
PortfolioPerformanceMetric,
Rank,
Value,
AgentUid,
AgentParameterization,
}
impl From<LeaderboardCol> for PlSmallStr {
fn from(value: LeaderboardCol) -> Self {
value.as_str().into()
}
}
impl LeaderboardCol {
pub fn name(&self) -> PlSmallStr {
(*self).into()
}
pub fn as_str(&self) -> &'static str {
self.into()
}
}
impl ToSchema for Leaderboard {
fn to_schema() -> SchemaRef {
let fields = LeaderboardCol::iter()
.map(|col| {
let dtype = match col {
LeaderboardCol::Rank => DataType::UInt32,
LeaderboardCol::AgentUid => DataType::UInt64,
LeaderboardCol::AgentParameterization
| LeaderboardCol::PortfolioPerformanceMetric => DataType::String,
LeaderboardCol::Value => DataType::Float64,
};
Field::new(col.into(), dtype)
})
.collect::<Vec<_>>();
Arc::new(Schema::from_iter(fields))
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Leaderboard {
df: DataFrame,
}
impl ReportName for Leaderboard {
fn base_name(&self) -> String {
generate_dynamic_base_name(&self.df, "leaderboard")
}
}
impl Report for Leaderboard {
fn as_df(&self) -> &DataFrame {
&self.df
}
fn as_df_mut(&mut self) -> &mut DataFrame {
&mut self.df
}
}
#[derive(Clone, Debug)]
pub(crate) struct AgentLeaderboard<T> {
pub top_per_metric:
SortedVecMap<PortfolioPerformanceCol, BinaryHeap<Reverse<LeaderboardEntry>>>,
pub k: usize,
pub agent_data: HashMap<u64, T>,
}
impl<T> TryFrom<AgentLeaderboard<T>> for Leaderboard
where
T: Serialize,
{
type Error = ChapatyError;
fn try_from(value: AgentLeaderboard<T>) -> Result<Self, Self::Error> {
Ok(Self {
df: value.leaderboard_soa()?.try_into()?,
})
}
}
impl<T> AgentLeaderboard<T> {
pub(crate) fn new(k: usize) -> Self {
Self {
top_per_metric: PortfolioPerformanceCol::iter()
.map(|metric| (metric, BinaryHeap::with_capacity(k)))
.collect(),
k,
agent_data: HashMap::with_capacity(k * METRIC_COUNT),
}
}
pub(crate) fn update(&mut self, new_entries: &[LeaderboardEntry], agent: T) {
if new_entries.is_empty() {
return;
}
let mut is_global_winner = false;
let mut potentially_evicted = SmallVec::<[u64; METRIC_COUNT]>::new();
for entry in new_entries {
match self.process_entry(entry) {
HeapAction::Rejected => {}
HeapAction::Added => {
is_global_winner = true;
}
HeapAction::Swapped(evicted_uid) => {
is_global_winner = true;
if !potentially_evicted.contains(&evicted_uid) {
potentially_evicted.push(evicted_uid);
}
}
}
}
if is_global_winner {
self.agent_data
.entry(new_entries[0].agent_uid)
.or_insert(agent);
}
self.garbage_collect(new_entries[0].agent_uid, &potentially_evicted);
}
pub(crate) fn merge(mut self, other: Self) -> Self {
for (metric, other_heap) in other.top_per_metric {
let heap = self.top_per_metric.get_mut(&metric).expect(
"Critical Logic Error: Leaderboard metric missing. This should be unreachable.",
);
heap.extend(other_heap);
while heap.len() > self.k {
heap.pop();
}
}
self.agent_data.extend(other.agent_data);
let surviving_uids = self.top_per_metric.values().fold(
HashSet::with_capacity(self.k * METRIC_COUNT),
|mut set, heap| {
set.extend(heap.iter().map(|entry| entry.0.agent_uid));
set
},
);
self.agent_data
.retain(|uid, _| surviving_uids.contains(uid));
self
}
}
impl<T> AgentLeaderboard<T> {
fn is_agent_tracked(&self, uid: u64) -> bool {
self.top_per_metric
.values()
.any(|heap| heap.iter().any(|entry| entry.0.agent_uid == uid))
}
fn process_entry(&mut self, entry: &LeaderboardEntry) -> HeapAction {
let heap = self
.top_per_metric
.get_mut(&entry.metric())
.expect("Metric missing");
if heap.len() < self.k {
heap.push(Reverse(*entry));
return HeapAction::Added;
}
let qualifies = heap.peek().is_none_or(|Reverse(worst)| entry > worst);
if qualifies && let Some(Reverse(evicted)) = heap.pop() {
heap.push(Reverse(*entry));
return HeapAction::Swapped(evicted.agent_uid);
}
HeapAction::Rejected
}
fn garbage_collect(&mut self, safe_uid: u64, candidates: &[u64]) {
for &uid in candidates {
if uid != safe_uid && !self.is_agent_tracked(uid) {
self.agent_data.remove(&uid);
}
}
}
}
impl<T> AgentLeaderboard<T>
where
T: Serialize,
{
fn leaderboard_soa(self) -> ChapatyResult<LeaderboardSoA> {
let capacity = self.top_per_metric.len() * self.k;
let mut metric_col = Vec::with_capacity(capacity);
let mut rank_col = Vec::with_capacity(capacity);
let mut agent_uid_col = Vec::with_capacity(capacity);
let mut value_col = Vec::with_capacity(capacity);
let mut agent_parameterization_col = Vec::with_capacity(capacity);
for (metric, heap) in self.top_per_metric {
let top_k = heap.into_sorted_vec().into_iter().map(|rev| rev.0);
for (i, entry) in top_k.enumerate() {
let uid = entry.agent_uid();
let agent = self.agent_data.get(&uid).ok_or_else(|| {
SystemError::MissingField(format!("Agent UID {} missing from cache", uid))
})?;
let param_str = serde_json::to_string(agent).map_err(IoError::Json)?;
metric_col.push(metric.to_string());
rank_col.push((i + 1) as u32);
agent_uid_col.push(uid);
value_col.push(entry.denormalized_reward());
agent_parameterization_col.push(param_str);
}
}
Ok(LeaderboardSoA {
portofolio_performance_metric: metric_col,
rank: rank_col,
value: value_col,
agent_uid: agent_uid_col,
agent_parameterization: agent_parameterization_col,
})
}
}
enum HeapAction {
Added,
Swapped(u64),
Rejected,
}
struct LeaderboardSoA {
portofolio_performance_metric: Vec<String>,
rank: Vec<u32>,
value: Vec<f64>,
agent_uid: Vec<u64>,
agent_parameterization: Vec<String>,
}
impl TryFrom<LeaderboardSoA> for DataFrame {
type Error = ChapatyError;
fn try_from(value: LeaderboardSoA) -> Result<Self, Self::Error> {
df!(
LeaderboardCol::PortfolioPerformanceMetric.to_string() => value.portofolio_performance_metric,
LeaderboardCol::Rank.to_string() => value.rank,
LeaderboardCol::AgentUid.to_string() => value.agent_uid,
LeaderboardCol::Value.to_string() => value.value,
LeaderboardCol::AgentParameterization.to_string() => value.agent_parameterization,
)
.map_err(|e| DataError::DataFrame(e.to_string()).into())
}
}
#[derive(Copy, Clone, Debug)]
pub struct LeaderboardEntry {
pub agent_uid: u64,
pub metric: PortfolioPerformanceCol,
pub reward: OrderedFloat<f64>,
}
impl LeaderboardEntry {
pub fn agent_uid(&self) -> u64 {
self.agent_uid
}
pub fn metric(&self) -> PortfolioPerformanceCol {
self.metric
}
pub fn normalized_reward(&self) -> f64 {
self.reward.0
}
pub fn denormalized_reward(&self) -> f64 {
self.metric.from_heap_score(self.normalized_reward())
}
}
impl PartialEq for LeaderboardEntry {
fn eq(&self, other: &Self) -> bool {
self.reward == other.reward
}
}
impl Eq for LeaderboardEntry {}
impl PartialOrd for LeaderboardEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for LeaderboardEntry {
fn cmp(&self, other: &Self) -> Ordering {
self.reward.cmp(&other.reward)
}
}
#[cfg(test)]
mod tests {
use polars::prelude::{IntoLazy, col, lit};
use serde::{Deserialize, Serialize};
use super::*;
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct TestAgent {
id: u64,
}
impl TestAgent {
fn new(id: u64) -> Self {
Self { id }
}
}
fn make_entry(
agent_uid: u64,
metric: PortfolioPerformanceCol,
reward: f64,
) -> LeaderboardEntry {
LeaderboardEntry {
agent_uid,
metric,
reward: OrderedFloat(reward),
}
}
#[test]
fn test_leaderboard_initializes_all_metrics() {
let k = 5;
let leaderboard = AgentLeaderboard::<TestAgent>::new(k);
for metric in PortfolioPerformanceCol::iter() {
assert!(
leaderboard.top_per_metric.contains_key(&metric),
"Leaderboard missing initialization for metric: {:?}",
metric
);
}
assert_eq!(
leaderboard.top_per_metric.len(),
METRIC_COUNT,
"Leaderboard map length does not match Enum variant count"
);
for (_, heap) in leaderboard.top_per_metric.iter() {
assert!(heap.is_empty(), "New heaps must be empty");
assert!(heap.capacity() >= k, "Heap capacity should be at least k");
}
}
#[test]
fn test_update_with_empty_entries_does_not_panic() {
let k = 3;
let mut board = AgentLeaderboard::<TestAgent>::new(k);
let metric = PortfolioPerformanceCol::TradeSharpeRatio;
let valid_entry = make_entry(1, metric, 10.0);
board.update(&[valid_entry], TestAgent::new(1));
board.update(&[], TestAgent::new(2));
let heap = board.top_per_metric.get(&metric).unwrap();
assert_eq!(heap.len(), 1, "Heap should only contain the valid agent");
let uids = heap.iter().map(|r| r.0.agent_uid).collect::<Vec<_>>();
assert!(
!uids.contains(&2),
"Zero-entry agent should not be in the heap"
);
assert!(
!board.agent_data.contains_key(&2),
"Zero-entry agent should not be cached"
);
}
#[test]
fn test_update_fills_capacity() {
let k = 3;
let mut board = AgentLeaderboard::<TestAgent>::new(k);
let metric = PortfolioPerformanceCol::TradeSharpeRatio;
for i in 0..k {
let uid = i as u64;
let entry = make_entry(uid, metric, (i + 1) as f64 * 10.0);
let agent = TestAgent::new(uid);
board.update(&[entry], agent);
}
let heap = board.top_per_metric.get(&metric).unwrap();
assert_eq!(heap.len(), k, "Heap should be exactly at capacity K");
assert_eq!(board.agent_data.len(), k, "All K agents should be cached");
}
#[test]
fn test_update_respects_ordering_and_eviction() {
let k = 3;
let mut board = AgentLeaderboard::<TestAgent>::new(k);
let metric = PortfolioPerformanceCol::TradeSharpeRatio;
for i in 1..=k {
let uid = i as u64;
let entry = make_entry(uid, metric, i as f64 * 10.0);
board.update(&[entry], TestAgent::new(uid));
}
let worse_entry = make_entry(100, metric, 5.0);
board.update(&[worse_entry], TestAgent::new(100));
let heap = board.top_per_metric.get(&metric).unwrap();
assert_eq!(heap.len(), k, "Heap size should remain unchanged");
let uids = heap.iter().map(|r| r.0.agent_uid).collect::<Vec<_>>();
assert!(
!uids.contains(&100),
"Worse agent should not be in the heap"
);
assert!(
!board.agent_data.contains_key(&100),
"Worse agent should not be cached"
);
let better_entry = make_entry(200, metric, 100.0);
board.update(&[better_entry], TestAgent::new(200));
let heap = board.top_per_metric.get(&metric).unwrap();
assert_eq!(heap.len(), k, "Heap size should remain at K");
let uids = heap.iter().map(|r| r.0.agent_uid).collect::<Vec<_>>();
assert!(uids.contains(&200), "Better agent should be in the heap");
assert!(
!uids.contains(&1),
"Worst agent should be evicted after better insertion"
);
assert!(
!board.agent_data.contains_key(&1),
"Evicted agent should be garbage collected"
);
assert!(
board.agent_data.contains_key(&200),
"New winner should be cached"
);
}
#[test]
fn test_update_rejects_ties() {
let k = 2;
let mut board = AgentLeaderboard::<TestAgent>::new(k);
let metric = PortfolioPerformanceCol::TradeSharpeRatio;
board.update(&[make_entry(1, metric, 10.0)], TestAgent::new(1));
board.update(&[make_entry(2, metric, 20.0)], TestAgent::new(2));
let tie_entry = make_entry(999, metric, 10.0);
board.update(&[tie_entry], TestAgent::new(999));
let heap = board.top_per_metric.get(&metric).unwrap();
let uids = heap.iter().map(|r| r.0.agent_uid).collect::<Vec<_>>();
assert!(!uids.contains(&999), "Tied agent should not be in the heap");
assert_eq!(heap.len(), k, "Heap size should remain unchanged");
assert!(
!board.agent_data.contains_key(&999),
"Tied agent should not be cached"
);
}
#[test]
fn test_merge_correctness() {
let k = 3;
let metric = PortfolioPerformanceCol::TradeSharpeRatio;
let mut board_a = AgentLeaderboard::<TestAgent>::new(k);
for i in 1..=3u64 {
let entry = make_entry(i, metric, i as f64 * 10.0);
board_a.update(&[entry], TestAgent::new(i));
}
let mut board_b = AgentLeaderboard::<TestAgent>::new(k);
for i in 1..=3u64 {
let uid = 100 + i;
let entry = make_entry(uid, metric, (i as f64 * 10.0) + 15.0);
board_b.update(&[entry], TestAgent::new(uid));
}
let merged = board_a.merge(board_b);
let heap = merged.top_per_metric.get(&metric).unwrap();
assert_eq!(heap.len(), k, "Merged heap should have exactly K entries");
let uids = heap.iter().map(|r| r.0.agent_uid).collect::<Vec<_>>();
assert!(uids.contains(&3), "Agent 3 (reward=30) should survive");
assert!(uids.contains(&102), "Agent 102 (reward=35) should survive");
assert!(uids.contains(&103), "Agent 103 (reward=45) should survive");
assert!(!uids.contains(&1), "Agent 1 (reward=10) should be evicted");
assert!(!uids.contains(&2), "Agent 2 (reward=20) should be evicted");
assert!(
!uids.contains(&101),
"Agent 101 (reward=25) should be evicted"
);
}
#[test]
fn test_merge_garbage_collection() {
let k = 2;
let metric = PortfolioPerformanceCol::TradeSharpeRatio;
let mut board_a = AgentLeaderboard::<TestAgent>::new(k);
board_a.update(&[make_entry(1, metric, 10.0)], TestAgent::new(1));
board_a.update(&[make_entry(2, metric, 20.0)], TestAgent::new(2));
let mut board_b = AgentLeaderboard::<TestAgent>::new(k);
board_b.update(&[make_entry(101, metric, 100.0)], TestAgent::new(101));
board_b.update(&[make_entry(102, metric, 200.0)], TestAgent::new(102));
let merged = board_a.merge(board_b);
let heap = merged.top_per_metric.get(&metric).unwrap();
let surviving_uids = heap.iter().map(|r| r.0.agent_uid).collect::<Vec<_>>();
assert!(surviving_uids.contains(&101));
assert!(surviving_uids.contains(&102));
assert!(!surviving_uids.contains(&1));
assert!(!surviving_uids.contains(&2));
assert!(
!merged.agent_data.contains_key(&1),
"Agent 1's data should be garbage collected"
);
assert!(
!merged.agent_data.contains_key(&2),
"Agent 2's data should be garbage collected"
);
assert!(
merged.agent_data.contains_key(&101),
"Agent 101's data should survive"
);
assert!(
merged.agent_data.contains_key(&102),
"Agent 102's data should survive"
);
assert_eq!(
merged.agent_data.len(),
k,
"Only K agent data entries should remain"
);
}
#[test]
fn test_into_dataframe_schema_and_sorting() {
let k = 3;
let metric = PortfolioPerformanceCol::TradeSharpeRatio;
let mut board = AgentLeaderboard::<TestAgent>::new(k);
let entries = [
(42u64, 50.0),
(17u64, 100.0), (99u64, 75.0),
];
for (uid, reward) in entries {
board.update(&[make_entry(uid, metric, reward)], TestAgent::new(uid));
}
let leaderboard: Leaderboard = board.try_into().expect("Conversion should succeed");
let df = leaderboard.as_df();
let metric_str = metric.to_string();
let filtered = df
.clone()
.lazy()
.filter(col(LeaderboardCol::PortfolioPerformanceMetric).eq(lit(metric_str)))
.collect()
.unwrap();
assert_eq!(filtered.height(), k, "Should have K rows for the metric");
let ranks = filtered
.column(LeaderboardCol::Rank.as_str())
.unwrap()
.u32()
.unwrap()
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(ranks, vec![1, 2, 3], "Ranks should be sequential 1, 2, 3");
let values = filtered
.column(LeaderboardCol::Value.as_str())
.unwrap()
.f64()
.unwrap()
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(
values[0], 100.0,
"Rank 1 should have the highest reward value"
);
assert!(
values[0] > values[1] && values[1] > values[2],
"Values should be in descending order"
);
let uids = filtered
.column(LeaderboardCol::AgentUid.as_str())
.unwrap()
.u64()
.unwrap()
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(uids[0], 17, "Rank 1 should be agent 17");
assert_eq!(uids[1], 99, "Rank 2 should be agent 99");
assert_eq!(uids[2], 42, "Rank 3 should be agent 42");
let params = filtered
.column(LeaderboardCol::AgentParameterization.as_str())
.unwrap()
.str()
.unwrap()
.into_no_null_iter()
.collect::<Vec<_>>();
for (i, param_str) in params.iter().enumerate() {
let parsed: TestAgent = serde_json::from_str(param_str).expect("Should be valid JSON");
assert_eq!(parsed.id, uids[i], "Serialized agent id should match UID");
}
}
#[test]
fn test_soa_generation_missing_data() {
let k = 2;
let metric = PortfolioPerformanceCol::TradeSharpeRatio;
let mut board = AgentLeaderboard::<TestAgent>::new(k);
board.update(&[make_entry(1, metric, 10.0)], TestAgent::new(1));
board.update(&[make_entry(2, metric, 20.0)], TestAgent::new(2));
board.agent_data.remove(&2);
let result = board.try_into() as ChapatyResult<Leaderboard>;
assert!(result.is_err(), "Should fail when agent data is missing");
let err = result.unwrap_err();
match err {
ChapatyError::System(SystemError::MissingField(msg)) => {
assert!(
msg.contains("2"),
"Error should mention the missing UID: {}",
msg
);
}
other => panic!("Expected SystemError::MissingField, got: {:?}", other),
}
}
#[test]
fn test_update_with_multiple_metrics() {
let k = 2;
let mut board = AgentLeaderboard::<TestAgent>::new(k);
let metric_a = PortfolioPerformanceCol::TradeSharpeRatio;
let metric_b = PortfolioPerformanceCol::NetProfit;
let entries = vec![
make_entry(1, metric_a, 10.0),
make_entry(1, metric_b, 1000.0),
];
board.update(&entries, TestAgent::new(1));
let heap_a = board.top_per_metric.get(&metric_a).unwrap();
let heap_b = board.top_per_metric.get(&metric_b).unwrap();
assert_eq!(heap_a.len(), 1);
assert_eq!(heap_b.len(), 1);
assert_eq!(heap_a.peek().unwrap().0.agent_uid, 1);
assert_eq!(heap_b.peek().unwrap().0.agent_uid, 1);
assert_eq!(board.agent_data.len(), 1);
assert!(board.agent_data.contains_key(&1));
}
#[test]
fn test_empty_leaderboard_to_dataframe() {
let k = 3;
let board = AgentLeaderboard::<TestAgent>::new(k);
let leaderboard: Leaderboard = board.try_into().expect("Empty conversion should succeed");
assert_eq!(leaderboard.as_df().height(), 0, "Empty board = empty df");
}
#[test]
fn test_update_garbage_collects_during_eviction() {
let k = 1;
let mut board = AgentLeaderboard::<TestAgent>::new(k);
let metric = PortfolioPerformanceCol::TradeSharpeRatio;
board.update(&[make_entry(1, metric, 10.0)], TestAgent::new(1));
assert!(board.agent_data.contains_key(&1));
board.update(&[make_entry(2, metric, 20.0)], TestAgent::new(2));
assert!(
!board.agent_data.contains_key(&1),
"Evicted agent should be garbage collected during update"
);
assert!(
board.agent_data.contains_key(&2),
"New winner should be cached"
);
assert_eq!(board.agent_data.len(), 1, "Only one agent should be cached");
}
#[test]
fn test_garbage_collection_protects_multi_metric_winners() {
let k = 1;
let mut board = AgentLeaderboard::<TestAgent>::new(k);
let metric_a = PortfolioPerformanceCol::TradeSharpeRatio;
let metric_b = PortfolioPerformanceCol::NetProfit;
let entries = vec![
make_entry(1, metric_a, 10.0),
make_entry(1, metric_b, 100.0),
];
board.update(&entries, TestAgent::new(1));
board.update(&[make_entry(2, metric_a, 20.0)], TestAgent::new(2));
let heap_a = board.top_per_metric.get(&metric_a).unwrap();
assert_eq!(heap_a.peek().unwrap().0.agent_uid, 2);
let heap_b = board.top_per_metric.get(&metric_b).unwrap();
assert_eq!(heap_b.peek().unwrap().0.agent_uid, 1);
assert!(
board.agent_data.contains_key(&1),
"Agent 1 should survive because it leads in metric B"
);
assert!(board.agent_data.contains_key(&2), "Agent 2 should be added");
assert_eq!(board.agent_data.len(), 2);
}
#[test]
fn test_garbage_collection_removes_agent_lost_all_titles() {
let k = 1;
let mut board = AgentLeaderboard::<TestAgent>::new(k);
let metric_a = PortfolioPerformanceCol::TradeSharpeRatio;
let metric_b = PortfolioPerformanceCol::NetProfit;
let entries = vec![
make_entry(1, metric_a, 10.0),
make_entry(1, metric_b, 100.0),
];
board.update(&entries, TestAgent::new(1));
board.update(&[make_entry(2, metric_a, 20.0)], TestAgent::new(2));
assert!(board.agent_data.contains_key(&1));
board.update(&[make_entry(3, metric_b, 200.0)], TestAgent::new(3));
assert!(
!board.agent_data.contains_key(&1),
"Agent 1 lost all metrics and should be GC'd"
);
assert!(board.agent_data.contains_key(&2));
assert!(board.agent_data.contains_key(&3));
}
}