Skip to main content

concerto_gpu/
mock.rs

1//! An in-memory, configurable [`GpuMonitor`] used for testing and development.
2
3use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
4
5use async_trait::async_trait;
6use bytesize::ByteSize;
7use concerto_core::GpuId;
8
9use crate::monitor::{GpuMonitor, GpuSnapshot};
10
11/// A [`GpuMonitor`] backed by an in-memory `Vec<GpuSnapshot>`.
12///
13/// `MockGpuMonitor` is the workhorse of Concerto's test suite: it lets tests
14/// declare the exact fleet they want to see (including edge cases like
15/// overheating GPUs, GPUs with ECC errors, or GPUs that disappear mid-run) and
16/// then drive the system under test through those scenarios.
17///
18/// Cloning is cheap — the monitor is backed by an `Arc<RwLock<_>>` so clones
19/// share the same underlying state. This lets a test hold a handle for
20/// mutation while also passing the monitor into the system under test.
21///
22/// The lock is a `std::sync::RwLock` (not `tokio::sync::RwLock`) so that
23/// [`GpuMonitor::gpu_count`], which is synchronous, can read the length
24/// without entering the async runtime. Critical sections are always tiny
25/// (a clone, a field write, or a `retain`) so blocking here is harmless.
26#[derive(Debug, Clone, Default)]
27pub struct MockGpuMonitor {
28    snapshots: Arc<RwLock<Vec<GpuSnapshot>>>,
29}
30
31impl MockGpuMonitor {
32    /// Create a mock monitor from an explicit list of snapshots.
33    pub fn new(snapshots: Vec<GpuSnapshot>) -> Self {
34        Self {
35            snapshots: Arc::new(RwLock::new(snapshots)),
36        }
37    }
38
39    /// Create a mock monitor reporting `count` healthy GPUs, each with
40    /// `memory_per_gpu_gb` gigabytes of VRAM, zero utilisation, zero memory
41    /// used, 40 degrees Celsius, and no ECC errors.
42    pub fn with_healthy_gpus(count: usize, memory_per_gpu_gb: u64) -> Self {
43        let snapshots = (0..count)
44            .map(|i| GpuSnapshot {
45                id: GpuId(i),
46                memory_total: ByteSize::gb(memory_per_gpu_gb),
47                memory_used: ByteSize::b(0),
48                temperature_celsius: 40,
49                utilisation_percent: 0,
50                ecc_errors_uncorrected: 0,
51            })
52            .collect();
53        Self::new(snapshots)
54    }
55
56    /// Overwrite the `memory_used` field of the GPU with the given id.
57    ///
58    /// No-op if the GPU is not present (e.g. it has been removed via
59    /// [`MockGpuMonitor::remove_gpu`]).
60    pub async fn set_memory_used(&self, gpu_id: GpuId, bytes: ByteSize) {
61        self.update(gpu_id, |snap| snap.memory_used = bytes);
62    }
63
64    /// Overwrite the `temperature_celsius` field of the GPU with the given id.
65    pub async fn set_temperature(&self, gpu_id: GpuId, celsius: u32) {
66        self.update(gpu_id, |snap| snap.temperature_celsius = celsius);
67    }
68
69    /// Increment the uncorrected ECC error count of the GPU with the given id
70    /// by one.
71    pub async fn inject_ecc_error(&self, gpu_id: GpuId) {
72        self.update(gpu_id, |snap| {
73            snap.ecc_errors_uncorrected = snap.ecc_errors_uncorrected.saturating_add(1);
74        });
75    }
76
77    /// Remove a GPU from the monitor's view, simulating a GPU that has
78    /// dropped off the bus (driver crash, hardware fault, hot-unplug).
79    pub async fn remove_gpu(&self, gpu_id: GpuId) {
80        if let Some(mut guard) = self.write_guard("remove_gpu") {
81            guard.retain(|s| s.id != gpu_id);
82        }
83    }
84
85    /// Apply `f` to the snapshot with the given id, if one exists. No-op
86    /// otherwise — callers use this to drive specific GPUs from tests without
87    /// first checking whether the GPU is still present.
88    fn update(&self, gpu_id: GpuId, f: impl FnOnce(&mut GpuSnapshot)) {
89        let Some(mut guard) = self.write_guard("update") else {
90            return;
91        };
92        if let Some(snap) = guard.iter_mut().find(|s| s.id == gpu_id) {
93            f(snap);
94        }
95    }
96
97    fn read_guard(&self, op: &'static str) -> Option<RwLockReadGuard<'_, Vec<GpuSnapshot>>> {
98        match self.snapshots.read() {
99            Ok(guard) => Some(guard),
100            Err(_) => {
101                tracing::error!(op, "MockGpuMonitor lock poisoned");
102                None
103            }
104        }
105    }
106
107    fn write_guard(&self, op: &'static str) -> Option<RwLockWriteGuard<'_, Vec<GpuSnapshot>>> {
108        match self.snapshots.write() {
109            Ok(guard) => Some(guard),
110            Err(_) => {
111                tracing::error!(op, "MockGpuMonitor lock poisoned");
112                None
113            }
114        }
115    }
116}
117
118#[async_trait]
119impl GpuMonitor for MockGpuMonitor {
120    fn gpu_count(&self) -> usize {
121        self.read_guard("gpu_count").map_or(0, |g| g.len())
122    }
123
124    async fn snapshot(&self) -> Vec<GpuSnapshot> {
125        self.read_guard("snapshot")
126            .map_or_else(Vec::new, |g| g.clone())
127    }
128}