Skip to main content

concerto_core/
state.rs

1use crate::types::*;
2use bytesize::ByteSize;
3use std::collections::HashMap;
4
5/// The complete state of the GPU cluster at a point in time.
6/// This is the input to all routing decisions.
7#[derive(Debug, Clone)]
8pub struct ClusterState {
9    pub gpus: Vec<GpuState>,
10    pub model_registry: HashMap<ModelId, ModelSpec>,
11}
12
13impl ClusterState {
14    pub fn new(gpus: Vec<GpuState>, model_registry: HashMap<ModelId, ModelSpec>) -> Self {
15        Self {
16            gpus,
17            model_registry,
18        }
19    }
20
21    /// Find which GPU (if any) currently has the given model loaded.
22    pub fn find_loaded_model(&self, model_id: &ModelId) -> Option<(GpuId, u16)> {
23        for gpu in &self.gpus {
24            for model in &gpu.loaded_models {
25                if &model.model_id == model_id {
26                    return Some((gpu.id, model.backend_port));
27                }
28            }
29        }
30        None
31    }
32
33    /// Get the spec for a model from the registry.
34    pub fn get_model_spec(&self, model_id: &ModelId) -> Option<&ModelSpec> {
35        self.model_registry.get(model_id)
36    }
37
38    /// Get all healthy GPUs, sorted by available memory (most available first).
39    pub fn healthy_gpus_by_available_memory(&self) -> Vec<&GpuState> {
40        let mut gpus: Vec<&GpuState> = self
41            .gpus
42            .iter()
43            .filter(|g| g.health != GpuHealth::Unhealthy)
44            .collect();
45        gpus.sort_by_key(|g| std::cmp::Reverse(g.memory_available));
46        gpus
47    }
48
49    /// Get all GPUs that could fit a model of the given size (with headroom).
50    pub fn gpus_with_space_for(
51        &self,
52        vram_required: ByteSize,
53        headroom: ByteSize,
54    ) -> Vec<&GpuState> {
55        let total_needed = ByteSize::b(vram_required.as_u64() + headroom.as_u64());
56        self.gpus
57            .iter()
58            .filter(|g| g.health != GpuHealth::Unhealthy && g.memory_available >= total_needed)
59            .collect()
60    }
61}
62
63#[cfg(test)]
64pub mod test_helpers {
65    use super::*;
66    use chrono::{DateTime, Utc};
67
68    /// Builder for creating test GPU states.
69    pub struct GpuStateBuilder {
70        id: GpuId,
71        memory_total: ByteSize,
72        memory_used: ByteSize,
73        temperature: u32,
74        utilisation: u32,
75        health: GpuHealth,
76        loaded_models: Vec<LoadedModel>,
77    }
78
79    impl GpuStateBuilder {
80        pub fn new(id: usize) -> Self {
81            Self {
82                id: GpuId(id),
83                memory_total: ByteSize::gb(24),
84                memory_used: ByteSize::b(0),
85                temperature: 45,
86                utilisation: 0,
87                health: GpuHealth::Healthy,
88                loaded_models: Vec::new(),
89            }
90        }
91
92        pub fn memory_total_gb(mut self, gb: u64) -> Self {
93            self.memory_total = ByteSize::gb(gb);
94            self
95        }
96
97        pub fn temperature(mut self, celsius: u32) -> Self {
98            self.temperature = celsius;
99            self
100        }
101
102        pub fn health(mut self, health: GpuHealth) -> Self {
103            self.health = health;
104            self
105        }
106
107        pub fn with_model(mut self, model_id: &str, vram_gb: u64, port: u16) -> Self {
108            let vram = ByteSize::gb(vram_gb);
109            self.memory_used = ByteSize::b(self.memory_used.as_u64() + vram.as_u64());
110            self.loaded_models.push(LoadedModel {
111                model_id: ModelId(model_id.to_string()),
112                vram_usage: vram,
113                last_request_at: Utc::now(),
114                request_count: 0,
115                backend_port: port,
116            });
117            self
118        }
119
120        /// Add a model with a specific last_request_at for LRU testing.
121        pub fn with_model_last_used(
122            mut self,
123            model_id: &str,
124            vram_gb: u64,
125            port: u16,
126            last_request_at: DateTime<Utc>,
127        ) -> Self {
128            let vram = ByteSize::gb(vram_gb);
129            self.memory_used = ByteSize::b(self.memory_used.as_u64() + vram.as_u64());
130            self.loaded_models.push(LoadedModel {
131                model_id: ModelId(model_id.to_string()),
132                vram_usage: vram,
133                last_request_at,
134                request_count: 0,
135                backend_port: port,
136            });
137            self
138        }
139
140        pub fn build(self) -> GpuState {
141            let memory_available = ByteSize::b(
142                self.memory_total
143                    .as_u64()
144                    .saturating_sub(self.memory_used.as_u64()),
145            );
146            GpuState {
147                id: self.id,
148                memory_total: self.memory_total,
149                memory_used: self.memory_used,
150                memory_available,
151                temperature_celsius: self.temperature,
152                utilisation_percent: self.utilisation,
153                health: self.health,
154                loaded_models: self.loaded_models,
155            }
156        }
157    }
158
159    /// Create a simple model spec for testing.
160    pub fn test_model_spec(id: &str, vram_gb: u64) -> (ModelId, ModelSpec) {
161        let model_id = ModelId(id.to_string());
162        let spec = ModelSpec {
163            id: model_id.clone(),
164            name: id.to_string(),
165            weight_path: format!("/models/{}", id),
166            vram_required: ByteSize::gb(vram_gb),
167            engine: EngineType::Mock,
168            engine_args: vec![],
169            pin: false,
170            max_vram_fraction: None,
171        };
172        (model_id, spec)
173    }
174}