use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
use async_trait::async_trait;
use bytesize::ByteSize;
use concerto_core::GpuId;
use crate::monitor::{GpuMonitor, GpuSnapshot};
#[derive(Debug, Clone, Default)]
pub struct MockGpuMonitor {
snapshots: Arc<RwLock<Vec<GpuSnapshot>>>,
}
impl MockGpuMonitor {
pub fn new(snapshots: Vec<GpuSnapshot>) -> Self {
Self {
snapshots: Arc::new(RwLock::new(snapshots)),
}
}
pub fn with_healthy_gpus(count: usize, memory_per_gpu_gb: u64) -> Self {
let snapshots = (0..count)
.map(|i| GpuSnapshot {
id: GpuId(i),
memory_total: ByteSize::gb(memory_per_gpu_gb),
memory_used: ByteSize::b(0),
temperature_celsius: 40,
utilisation_percent: 0,
ecc_errors_uncorrected: 0,
})
.collect();
Self::new(snapshots)
}
pub async fn set_memory_used(&self, gpu_id: GpuId, bytes: ByteSize) {
self.update(gpu_id, |snap| snap.memory_used = bytes);
}
pub async fn set_temperature(&self, gpu_id: GpuId, celsius: u32) {
self.update(gpu_id, |snap| snap.temperature_celsius = celsius);
}
pub async fn inject_ecc_error(&self, gpu_id: GpuId) {
self.update(gpu_id, |snap| {
snap.ecc_errors_uncorrected = snap.ecc_errors_uncorrected.saturating_add(1);
});
}
pub async fn remove_gpu(&self, gpu_id: GpuId) {
if let Some(mut guard) = self.write_guard("remove_gpu") {
guard.retain(|s| s.id != gpu_id);
}
}
fn update(&self, gpu_id: GpuId, f: impl FnOnce(&mut GpuSnapshot)) {
let Some(mut guard) = self.write_guard("update") else {
return;
};
if let Some(snap) = guard.iter_mut().find(|s| s.id == gpu_id) {
f(snap);
}
}
fn read_guard(&self, op: &'static str) -> Option<RwLockReadGuard<'_, Vec<GpuSnapshot>>> {
match self.snapshots.read() {
Ok(guard) => Some(guard),
Err(_) => {
tracing::error!(op, "MockGpuMonitor lock poisoned");
None
}
}
}
fn write_guard(&self, op: &'static str) -> Option<RwLockWriteGuard<'_, Vec<GpuSnapshot>>> {
match self.snapshots.write() {
Ok(guard) => Some(guard),
Err(_) => {
tracing::error!(op, "MockGpuMonitor lock poisoned");
None
}
}
}
}
#[async_trait]
impl GpuMonitor for MockGpuMonitor {
fn gpu_count(&self) -> usize {
self.read_guard("gpu_count").map_or(0, |g| g.len())
}
async fn snapshot(&self) -> Vec<GpuSnapshot> {
self.read_guard("snapshot")
.map_or_else(Vec::new, |g| g.clone())
}
}