#![allow(dead_code)]
use std::collections::HashMap;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy)]
pub struct TimingSample {
pub duration: Duration,
pub started_at: Instant,
}
impl TimingSample {
#[must_use]
pub fn new(duration: Duration, started_at: Instant) -> Self {
Self {
duration,
started_at,
}
}
#[must_use]
pub fn duration_ms(&self) -> f64 {
self.duration.as_secs_f64() * 1000.0
}
#[must_use]
pub fn duration_us(&self) -> f64 {
self.duration.as_secs_f64() * 1_000_000.0
}
}
#[derive(Debug, Clone)]
pub struct OperationStats {
pub name: String,
pub invocation_count: u64,
pub total_duration: Duration,
pub min_duration: Option<Duration>,
pub max_duration: Option<Duration>,
recent_samples: Vec<TimingSample>,
max_samples: usize,
}
impl OperationStats {
#[must_use]
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
invocation_count: 0,
total_duration: Duration::ZERO,
min_duration: None,
max_duration: None,
recent_samples: Vec::new(),
max_samples: 100,
}
}
#[must_use]
pub fn with_max_samples(name: &str, max_samples: usize) -> Self {
Self {
max_samples: max_samples.max(1),
..Self::new(name)
}
}
pub fn record(&mut self, duration: Duration, started_at: Instant) {
self.invocation_count += 1;
self.total_duration += duration;
self.min_duration = Some(match self.min_duration {
Some(min) => min.min(duration),
None => duration,
});
self.max_duration = Some(match self.max_duration {
Some(max) => max.max(duration),
None => duration,
});
if self.recent_samples.len() >= self.max_samples {
self.recent_samples.remove(0);
}
self.recent_samples
.push(TimingSample::new(duration, started_at));
}
#[allow(clippy::cast_precision_loss)]
#[must_use]
pub fn average_duration(&self) -> Option<Duration> {
if self.invocation_count == 0 {
return None;
}
Some(self.total_duration / self.invocation_count as u32)
}
#[must_use]
pub fn average_ms(&self) -> Option<f64> {
self.average_duration().map(|d| d.as_secs_f64() * 1000.0)
}
#[must_use]
pub fn min_ms(&self) -> Option<f64> {
self.min_duration.map(|d| d.as_secs_f64() * 1000.0)
}
#[must_use]
pub fn max_ms(&self) -> Option<f64> {
self.max_duration.map(|d| d.as_secs_f64() * 1000.0)
}
#[must_use]
pub fn recent_count(&self) -> usize {
self.recent_samples.len()
}
#[must_use]
pub fn recent_median_ms(&self) -> Option<f64> {
if self.recent_samples.is_empty() {
return None;
}
let mut durations: Vec<f64> = self
.recent_samples
.iter()
.map(TimingSample::duration_ms)
.collect();
durations.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mid = durations.len() / 2;
if durations.len() % 2 == 0 && durations.len() >= 2 {
Some((durations[mid - 1] + durations[mid]) / 2.0)
} else {
Some(durations[mid])
}
}
pub fn reset(&mut self) {
self.invocation_count = 0;
self.total_duration = Duration::ZERO;
self.min_duration = None;
self.max_duration = None;
self.recent_samples.clear();
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TransferDirection {
HostToDevice,
DeviceToHost,
DeviceToDevice,
}
impl TransferDirection {
#[must_use]
pub fn label(self) -> &'static str {
match self {
Self::HostToDevice => "H2D",
Self::DeviceToHost => "D2H",
Self::DeviceToDevice => "D2D",
}
}
}
#[derive(Debug, Clone)]
pub struct TransferStats {
pub direction: TransferDirection,
pub total_bytes: u64,
pub transfer_count: u64,
pub total_duration: Duration,
}
impl TransferStats {
#[must_use]
pub fn new(direction: TransferDirection) -> Self {
Self {
direction,
total_bytes: 0,
transfer_count: 0,
total_duration: Duration::ZERO,
}
}
pub fn record(&mut self, bytes: u64, duration: Duration) {
self.total_bytes += bytes;
self.transfer_count += 1;
self.total_duration += duration;
}
#[allow(clippy::cast_precision_loss)]
#[must_use]
pub fn throughput_bps(&self) -> Option<f64> {
let secs = self.total_duration.as_secs_f64();
if secs <= 0.0 {
return None;
}
Some(self.total_bytes as f64 / secs)
}
#[must_use]
pub fn throughput_mbps(&self) -> Option<f64> {
self.throughput_bps().map(|bps| bps / (1024.0 * 1024.0))
}
#[allow(clippy::cast_precision_loss)]
#[must_use]
pub fn average_transfer_size(&self) -> Option<f64> {
if self.transfer_count == 0 {
return None;
}
Some(self.total_bytes as f64 / self.transfer_count as f64)
}
pub fn reset(&mut self) {
self.total_bytes = 0;
self.transfer_count = 0;
self.total_duration = Duration::ZERO;
}
}
#[derive(Debug)]
pub struct AccelStatistics {
operations: HashMap<String, OperationStats>,
transfers: HashMap<TransferDirection, TransferStats>,
pub tasks_submitted: u64,
pub tasks_completed: u64,
pub tasks_failed: u64,
pub started_at: Instant,
}
impl AccelStatistics {
#[must_use]
pub fn new() -> Self {
Self {
operations: HashMap::new(),
transfers: HashMap::new(),
tasks_submitted: 0,
tasks_completed: 0,
tasks_failed: 0,
started_at: Instant::now(),
}
}
pub fn record_operation(&mut self, name: &str, duration: Duration, started_at: Instant) {
self.operations
.entry(name.to_string())
.or_insert_with(|| OperationStats::new(name))
.record(duration, started_at);
}
pub fn record_transfer(
&mut self,
direction: TransferDirection,
bytes: u64,
duration: Duration,
) {
self.transfers
.entry(direction)
.or_insert_with(|| TransferStats::new(direction))
.record(bytes, duration);
}
pub fn record_task_submitted(&mut self) {
self.tasks_submitted += 1;
}
pub fn record_task_completed(&mut self) {
self.tasks_completed += 1;
}
pub fn record_task_failed(&mut self) {
self.tasks_failed += 1;
}
#[must_use]
pub fn uptime(&self) -> Duration {
self.started_at.elapsed()
}
#[allow(clippy::cast_precision_loss)]
#[must_use]
pub fn success_rate(&self) -> Option<f64> {
let total = self.tasks_completed + self.tasks_failed;
if total == 0 {
return None;
}
Some(self.tasks_completed as f64 / total as f64)
}
#[must_use]
pub fn get_operation(&self, name: &str) -> Option<&OperationStats> {
self.operations.get(name)
}
#[must_use]
pub fn get_transfer(&self, direction: TransferDirection) -> Option<&TransferStats> {
self.transfers.get(&direction)
}
#[must_use]
pub fn operation_names(&self) -> Vec<&str> {
self.operations.keys().map(String::as_str).collect()
}
pub fn reset(&mut self) {
self.operations.clear();
self.transfers.clear();
self.tasks_submitted = 0;
self.tasks_completed = 0;
self.tasks_failed = 0;
self.started_at = Instant::now();
}
}
impl Default for AccelStatistics {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_timing_sample() {
let now = Instant::now();
let sample = TimingSample::new(Duration::from_millis(50), now);
assert!((sample.duration_ms() - 50.0).abs() < 0.01);
assert!((sample.duration_us() - 50000.0).abs() < 10.0);
}
#[test]
fn test_operation_stats_empty() {
let stats = OperationStats::new("test");
assert_eq!(stats.invocation_count, 0);
assert!(stats.average_duration().is_none());
assert!(stats.min_ms().is_none());
assert!(stats.max_ms().is_none());
assert!(stats.recent_median_ms().is_none());
}
#[test]
fn test_operation_stats_record() {
let mut stats = OperationStats::new("scale");
let now = Instant::now();
stats.record(Duration::from_millis(10), now);
stats.record(Duration::from_millis(20), now);
stats.record(Duration::from_millis(30), now);
assert_eq!(stats.invocation_count, 3);
let avg = stats.average_ms().unwrap();
assert!((avg - 20.0).abs() < 0.1);
}
#[test]
fn test_operation_stats_min_max() {
let mut stats = OperationStats::new("conv");
let now = Instant::now();
stats.record(Duration::from_millis(5), now);
stats.record(Duration::from_millis(15), now);
stats.record(Duration::from_millis(10), now);
assert!((stats.min_ms().unwrap() - 5.0).abs() < 0.01);
assert!((stats.max_ms().unwrap() - 15.0).abs() < 0.01);
}
#[test]
fn test_operation_stats_median_odd() {
let mut stats = OperationStats::new("test");
let now = Instant::now();
stats.record(Duration::from_millis(10), now);
stats.record(Duration::from_millis(30), now);
stats.record(Duration::from_millis(20), now);
let median = stats.recent_median_ms().unwrap();
assert!((median - 20.0).abs() < 0.01);
}
#[test]
fn test_operation_stats_median_even() {
let mut stats = OperationStats::new("test");
let now = Instant::now();
stats.record(Duration::from_millis(10), now);
stats.record(Duration::from_millis(20), now);
let median = stats.recent_median_ms().unwrap();
assert!((median - 15.0).abs() < 0.01);
}
#[test]
fn test_operation_stats_reset() {
let mut stats = OperationStats::new("op");
let now = Instant::now();
stats.record(Duration::from_millis(1), now);
stats.reset();
assert_eq!(stats.invocation_count, 0);
assert!(stats.average_duration().is_none());
}
#[test]
fn test_operation_stats_sample_eviction() {
let mut stats = OperationStats::with_max_samples("op", 3);
let now = Instant::now();
for i in 0..5 {
stats.record(Duration::from_millis(i * 10), now);
}
assert_eq!(stats.recent_count(), 3);
}
#[test]
fn test_transfer_direction_label() {
assert_eq!(TransferDirection::HostToDevice.label(), "H2D");
assert_eq!(TransferDirection::DeviceToHost.label(), "D2H");
assert_eq!(TransferDirection::DeviceToDevice.label(), "D2D");
}
#[test]
fn test_transfer_stats_record() {
let mut stats = TransferStats::new(TransferDirection::HostToDevice);
stats.record(1024, Duration::from_millis(1));
stats.record(2048, Duration::from_millis(2));
assert_eq!(stats.total_bytes, 3072);
assert_eq!(stats.transfer_count, 2);
}
#[test]
fn test_transfer_stats_throughput() {
let mut stats = TransferStats::new(TransferDirection::DeviceToHost);
stats.record(1_000_000, Duration::from_secs(1));
let bps = stats.throughput_bps().unwrap();
assert!((bps - 1_000_000.0).abs() < 1.0);
let mbps = stats.throughput_mbps().unwrap();
assert!(mbps > 0.0);
}
#[test]
fn test_transfer_stats_empty_throughput() {
let stats = TransferStats::new(TransferDirection::HostToDevice);
assert!(stats.throughput_bps().is_none());
assert!(stats.average_transfer_size().is_none());
}
#[test]
fn test_transfer_stats_reset() {
let mut stats = TransferStats::new(TransferDirection::HostToDevice);
stats.record(100, Duration::from_millis(1));
stats.reset();
assert_eq!(stats.total_bytes, 0);
assert_eq!(stats.transfer_count, 0);
}
#[test]
fn test_accel_statistics_record_operations() {
let mut s = AccelStatistics::new();
let now = Instant::now();
s.record_operation("scale", Duration::from_millis(5), now);
s.record_operation("scale", Duration::from_millis(10), now);
let op = s.get_operation("scale").unwrap();
assert_eq!(op.invocation_count, 2);
}
#[test]
fn test_accel_statistics_tasks() {
let mut s = AccelStatistics::new();
s.record_task_submitted();
s.record_task_submitted();
s.record_task_completed();
s.record_task_failed();
assert_eq!(s.tasks_submitted, 2);
assert_eq!(s.tasks_completed, 1);
assert_eq!(s.tasks_failed, 1);
assert!((s.success_rate().unwrap() - 0.5).abs() < 1e-9);
}
#[test]
fn test_accel_statistics_success_rate_none() {
let s = AccelStatistics::new();
assert!(s.success_rate().is_none());
}
#[test]
fn test_accel_statistics_reset() {
let mut s = AccelStatistics::new();
let now = Instant::now();
s.record_operation("op", Duration::from_millis(1), now);
s.record_task_submitted();
s.reset();
assert!(s.operation_names().is_empty());
assert_eq!(s.tasks_submitted, 0);
}
#[test]
fn test_accel_statistics_operation_names() {
let mut s = AccelStatistics::new();
let now = Instant::now();
s.record_operation("alpha", Duration::from_millis(1), now);
s.record_operation("beta", Duration::from_millis(1), now);
let names = s.operation_names();
assert_eq!(names.len(), 2);
assert!(names.contains(&"alpha"));
assert!(names.contains(&"beta"));
}
}