use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::context::InstructionExecutionContext;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "benchmark_type", content = "benchmark_name")]
pub enum StatType {
#[serde(rename = "instruction")]
Instruction(String),
#[serde(rename = "transaction")]
Transaction(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InstructionBenchmarkResult {
pub instruction_name: String,
pub cu_estimate: ComputeUnitStats,
pub execution_context: InstructionExecutionContext,
pub generated_at: String,
pub generated_by: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum ComputeUnitLevel {
Min,
Conservative,
Balanced,
Safe,
VeryHigh,
UnsafeMax,
Custom(u64),
Multiplier(f32),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComputeUnitStats {
#[serde(flatten)]
pub stat_type: StatType,
pub min: u64,
pub conservative: u64,
pub balanced: u64,
pub safe: u64,
pub very_high: u64,
pub unsafe_max: u64,
pub sample_size: usize,
}
impl ComputeUnitStats {
pub fn get_cu_for_level(&self, level: ComputeUnitLevel) -> u64 {
match level {
ComputeUnitLevel::Min => self.min,
ComputeUnitLevel::Conservative => self.conservative,
ComputeUnitLevel::Balanced => self.balanced,
ComputeUnitLevel::Safe => self.safe,
ComputeUnitLevel::VeryHigh => self.very_high,
ComputeUnitLevel::UnsafeMax => self.unsafe_max,
ComputeUnitLevel::Custom(cu) => cu,
ComputeUnitLevel::Multiplier(mult) => (self.balanced as f32 * mult) as u64,
}
}
pub fn from_measurements(stat_type: StatType, measurements: &[u64]) -> Self {
let mut sorted = measurements.to_vec();
sorted.sort_unstable();
let len = sorted.len();
let min = sorted[0];
let unsafe_max = sorted[len - 1];
let conservative = sorted[(len - 1) * 25 / 100];
let balanced = sorted[(len - 1) * 50 / 100];
let safe = sorted[(len - 1) * 75 / 100];
let very_high = sorted[(len - 1) * 95 / 100];
Self {
stat_type,
min,
conservative,
balanced,
safe,
very_high,
unsafe_max,
sample_size: len,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComputeUnitDatabase {
pub estimates: HashMap<String, ComputeUnitStats>,
pub generated_at: String, }
impl ComputeUnitDatabase {
pub fn new() -> Self {
Self {
estimates: HashMap::new(),
generated_at: chrono::Utc::now().to_rfc3339(),
}
}
pub fn get_estimate(&self, instruction_type: &str) -> Option<&ComputeUnitStats> {
self.estimates.get(instruction_type)
}
pub fn get_cu_estimate(&self, instruction_type: &str, level: ComputeUnitLevel) -> Option<u64> {
self.get_estimate(instruction_type)
.map(|est| est.get_cu_for_level(level))
}
}
impl Default for ComputeUnitDatabase {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_percentiles_simple_case() {
let measurements: Vec<u64> = (1..=100).collect();
let stats = ComputeUnitStats::from_measurements(
StatType::Instruction("test".to_string()),
&measurements,
);
assert_eq!(stats.min, 1);
assert_eq!(stats.conservative, 25);
assert_eq!(stats.balanced, 50);
assert_eq!(stats.safe, 75);
assert_eq!(stats.very_high, 95);
assert_eq!(stats.unsafe_max, 100);
assert_eq!(stats.sample_size, 100);
}
#[test]
fn test_percentiles_small_dataset() {
let measurements = vec![10, 20, 30, 40];
let stats = ComputeUnitStats::from_measurements(
StatType::Transaction("small_test".to_string()),
&measurements,
);
assert_eq!(stats.min, 10);
assert_eq!(stats.conservative, 10);
assert_eq!(stats.balanced, 20);
assert_eq!(stats.safe, 30);
assert_eq!(stats.very_high, 30);
assert_eq!(stats.unsafe_max, 40);
assert_eq!(stats.sample_size, 4);
}
#[test]
fn test_percentiles_single_value() {
let measurements = vec![42];
let stats = ComputeUnitStats::from_measurements(
StatType::Instruction("single".to_string()),
&measurements,
);
assert_eq!(stats.min, 42);
assert_eq!(stats.conservative, 42);
assert_eq!(stats.balanced, 42);
assert_eq!(stats.safe, 42);
assert_eq!(stats.very_high, 42);
assert_eq!(stats.unsafe_max, 42);
assert_eq!(stats.sample_size, 1);
}
#[test]
fn test_percentiles_duplicate_values() {
let measurements = vec![5, 5, 5, 10, 10, 15, 20, 20, 20, 20];
let stats = ComputeUnitStats::from_measurements(
StatType::Transaction("duplicates".to_string()),
&measurements,
);
assert_eq!(stats.min, 5);
assert_eq!(stats.conservative, 5);
assert_eq!(stats.balanced, 10);
assert_eq!(stats.safe, 20);
assert_eq!(stats.very_high, 20);
assert_eq!(stats.unsafe_max, 20);
assert_eq!(stats.sample_size, 10);
}
#[test]
fn test_percentiles_unsorted_input() {
let measurements = vec![100, 10, 50, 30, 80, 20, 90, 40, 70, 60];
let stats = ComputeUnitStats::from_measurements(
StatType::Instruction("unsorted".to_string()),
&measurements,
);
assert_eq!(stats.min, 10);
assert_eq!(stats.conservative, 30);
assert_eq!(stats.balanced, 50);
assert_eq!(stats.safe, 70);
assert_eq!(stats.very_high, 90);
assert_eq!(stats.unsafe_max, 100);
assert_eq!(stats.sample_size, 10);
}
#[test]
fn test_stat_type_serialization() {
let instruction_stats = ComputeUnitStats::from_measurements(
StatType::Instruction("test_instruction".to_string()),
&[100, 200, 300],
);
let transaction_stats = ComputeUnitStats::from_measurements(
StatType::Transaction("test_transaction".to_string()),
&[100, 200, 300],
);
let instruction_json = serde_json::to_string(&instruction_stats).unwrap();
let transaction_json = serde_json::to_string(&transaction_stats).unwrap();
assert!(instruction_json.contains("\"benchmark_type\":\"instruction\""));
assert!(instruction_json.contains("\"benchmark_name\":\"test_instruction\""));
assert!(transaction_json.contains("\"benchmark_type\":\"transaction\""));
assert!(transaction_json.contains("\"benchmark_name\":\"test_transaction\""));
}
#[test]
fn test_get_cu_for_level() {
let measurements = vec![10, 20, 30, 40, 50];
let stats = ComputeUnitStats::from_measurements(
StatType::Instruction("level_test".to_string()),
&measurements,
);
assert_eq!(stats.get_cu_for_level(ComputeUnitLevel::Min), stats.min);
assert_eq!(
stats.get_cu_for_level(ComputeUnitLevel::Conservative),
stats.conservative
);
assert_eq!(
stats.get_cu_for_level(ComputeUnitLevel::Balanced),
stats.balanced
);
assert_eq!(stats.get_cu_for_level(ComputeUnitLevel::Safe), stats.safe);
assert_eq!(
stats.get_cu_for_level(ComputeUnitLevel::VeryHigh),
stats.very_high
);
assert_eq!(
stats.get_cu_for_level(ComputeUnitLevel::UnsafeMax),
stats.unsafe_max
);
assert_eq!(stats.get_cu_for_level(ComputeUnitLevel::Custom(999)), 999);
let expected_multiplied = (stats.balanced as f32 * 2.0) as u64;
assert_eq!(
stats.get_cu_for_level(ComputeUnitLevel::Multiplier(2.0)),
expected_multiplied
);
}
}