1use crate::types::*;
2use bytesize::ByteSize;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone)]
8pub struct ClusterState {
9 pub gpus: Vec<GpuState>,
10 pub model_registry: HashMap<ModelId, ModelSpec>,
11}
12
13impl ClusterState {
14 pub fn new(gpus: Vec<GpuState>, model_registry: HashMap<ModelId, ModelSpec>) -> Self {
15 Self {
16 gpus,
17 model_registry,
18 }
19 }
20
21 pub fn find_loaded_model(&self, model_id: &ModelId) -> Option<(GpuId, u16)> {
23 for gpu in &self.gpus {
24 for model in &gpu.loaded_models {
25 if &model.model_id == model_id {
26 return Some((gpu.id, model.backend_port));
27 }
28 }
29 }
30 None
31 }
32
33 pub fn get_model_spec(&self, model_id: &ModelId) -> Option<&ModelSpec> {
35 self.model_registry.get(model_id)
36 }
37
38 pub fn healthy_gpus_by_available_memory(&self) -> Vec<&GpuState> {
40 let mut gpus: Vec<&GpuState> = self
41 .gpus
42 .iter()
43 .filter(|g| g.health != GpuHealth::Unhealthy)
44 .collect();
45 gpus.sort_by_key(|g| std::cmp::Reverse(g.memory_available));
46 gpus
47 }
48
49 pub fn gpus_with_space_for(
51 &self,
52 vram_required: ByteSize,
53 headroom: ByteSize,
54 ) -> Vec<&GpuState> {
55 let total_needed = ByteSize::b(vram_required.as_u64() + headroom.as_u64());
56 self.gpus
57 .iter()
58 .filter(|g| g.health != GpuHealth::Unhealthy && g.memory_available >= total_needed)
59 .collect()
60 }
61}
62
63#[cfg(test)]
64pub mod test_helpers {
65 use super::*;
66 use chrono::{DateTime, Utc};
67
68 pub struct GpuStateBuilder {
70 id: GpuId,
71 memory_total: ByteSize,
72 memory_used: ByteSize,
73 temperature: u32,
74 utilisation: u32,
75 health: GpuHealth,
76 loaded_models: Vec<LoadedModel>,
77 }
78
79 impl GpuStateBuilder {
80 pub fn new(id: usize) -> Self {
81 Self {
82 id: GpuId(id),
83 memory_total: ByteSize::gb(24),
84 memory_used: ByteSize::b(0),
85 temperature: 45,
86 utilisation: 0,
87 health: GpuHealth::Healthy,
88 loaded_models: Vec::new(),
89 }
90 }
91
92 pub fn memory_total_gb(mut self, gb: u64) -> Self {
93 self.memory_total = ByteSize::gb(gb);
94 self
95 }
96
97 pub fn temperature(mut self, celsius: u32) -> Self {
98 self.temperature = celsius;
99 self
100 }
101
102 pub fn health(mut self, health: GpuHealth) -> Self {
103 self.health = health;
104 self
105 }
106
107 pub fn with_model(mut self, model_id: &str, vram_gb: u64, port: u16) -> Self {
108 let vram = ByteSize::gb(vram_gb);
109 self.memory_used = ByteSize::b(self.memory_used.as_u64() + vram.as_u64());
110 self.loaded_models.push(LoadedModel {
111 model_id: ModelId(model_id.to_string()),
112 vram_usage: vram,
113 last_request_at: Utc::now(),
114 request_count: 0,
115 backend_port: port,
116 });
117 self
118 }
119
120 pub fn with_model_last_used(
122 mut self,
123 model_id: &str,
124 vram_gb: u64,
125 port: u16,
126 last_request_at: DateTime<Utc>,
127 ) -> Self {
128 let vram = ByteSize::gb(vram_gb);
129 self.memory_used = ByteSize::b(self.memory_used.as_u64() + vram.as_u64());
130 self.loaded_models.push(LoadedModel {
131 model_id: ModelId(model_id.to_string()),
132 vram_usage: vram,
133 last_request_at,
134 request_count: 0,
135 backend_port: port,
136 });
137 self
138 }
139
140 pub fn build(self) -> GpuState {
141 let memory_available = ByteSize::b(
142 self.memory_total
143 .as_u64()
144 .saturating_sub(self.memory_used.as_u64()),
145 );
146 GpuState {
147 id: self.id,
148 memory_total: self.memory_total,
149 memory_used: self.memory_used,
150 memory_available,
151 temperature_celsius: self.temperature,
152 utilisation_percent: self.utilisation,
153 health: self.health,
154 loaded_models: self.loaded_models,
155 }
156 }
157 }
158
159 pub fn test_model_spec(id: &str, vram_gb: u64) -> (ModelId, ModelSpec) {
161 let model_id = ModelId(id.to_string());
162 let spec = ModelSpec {
163 id: model_id.clone(),
164 name: id.to_string(),
165 weight_path: format!("/models/{}", id),
166 vram_required: ByteSize::gb(vram_gb),
167 engine: EngineType::Mock,
168 engine_args: vec![],
169 pin: false,
170 max_vram_fraction: None,
171 };
172 (model_id, spec)
173 }
174}