1use crate::expert_pool::gpu_expert_budget_from_vram;
41use crate::weight_registry::WeightRegistry;
42use rlx_ir::Graph;
43use rlx_opt::memory::plan_memory;
44
45#[derive(Debug, Clone)]
46pub struct MemoryEstimate {
47 pub activation_bytes: usize,
50 pub weight_bytes: usize,
52 pub input_bytes: usize,
56}
57
58impl MemoryEstimate {
59 pub fn peak_bytes(&self) -> usize {
60 self.activation_bytes + self.weight_bytes + self.input_bytes
61 }
62
63 pub fn fits_in(&self, budget_bytes: usize) -> Result<(), MemoryDeficit> {
67 let peak = self.peak_bytes();
68 if peak <= budget_bytes {
69 Ok(())
70 } else {
71 Err(MemoryDeficit {
72 budget_bytes,
73 peak_bytes: peak,
74 activation_bytes: self.activation_bytes,
75 weight_bytes: self.weight_bytes,
76 input_bytes: self.input_bytes,
77 })
78 }
79 }
80}
81
82#[derive(Debug, Clone)]
83pub struct MemoryDeficit {
84 pub budget_bytes: usize,
85 pub peak_bytes: usize,
86 pub activation_bytes: usize,
87 pub weight_bytes: usize,
88 pub input_bytes: usize,
89}
90
91impl std::fmt::Display for MemoryDeficit {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 let over = self.peak_bytes - self.budget_bytes;
94 write!(
95 f,
96 "estimated peak {peak_mb:.1} MiB exceeds budget {budget_mb:.1} MiB by {over_mb:.1} MiB \
97 (activation {act_mb:.1}, weights {w_mb:.1}, inputs {in_mb:.1})",
98 peak_mb = self.peak_bytes as f64 / 1024.0 / 1024.0,
99 budget_mb = self.budget_bytes as f64 / 1024.0 / 1024.0,
100 over_mb = over as f64 / 1024.0 / 1024.0,
101 act_mb = self.activation_bytes as f64 / 1024.0 / 1024.0,
102 w_mb = self.weight_bytes as f64 / 1024.0 / 1024.0,
103 in_mb = self.input_bytes as f64 / 1024.0 / 1024.0,
104 )
105 }
106}
107
108impl std::error::Error for MemoryDeficit {}
109
110#[derive(Debug, Clone)]
116pub struct MoeOffloadEstimate {
117 pub expert_param_bytes: usize,
119 pub num_moe_layers: usize,
120 pub num_experts: usize,
121 pub gpu_expert_budget_per_layer: usize,
123 pub all_expert_weight_bytes: usize,
125 pub resident_expert_weight_bytes: usize,
127}
128
129impl MoeOffloadEstimate {
130 pub fn peak_with_offload(&self, base: &MemoryEstimate) -> usize {
132 base.activation_bytes
133 + base.input_bytes
134 + (base.weight_bytes - self.all_expert_weight_bytes)
135 + self.resident_expert_weight_bytes
136 }
137}
138
139pub fn estimate_moe_offload(
141 expert_param_bytes: usize,
142 num_moe_layers: usize,
143 num_experts: usize,
144 max_gpu_experts_per_layer: usize,
145 memory_budget_bytes: usize,
146 reserve_fraction: f32,
147) -> MoeOffloadEstimate {
148 let reserve_bytes = (memory_budget_bytes as f64 * reserve_fraction as f64) as usize;
149 let gpu_budget = gpu_expert_budget_from_vram(
150 memory_budget_bytes,
151 reserve_bytes,
152 expert_param_bytes,
153 num_moe_layers,
154 max_gpu_experts_per_layer,
155 num_experts,
156 );
157 let all_expert = expert_param_bytes
158 .saturating_mul(num_experts)
159 .saturating_mul(num_moe_layers);
160 let resident_expert = expert_param_bytes
161 .saturating_mul(gpu_budget)
162 .saturating_mul(num_moe_layers);
163 MoeOffloadEstimate {
164 expert_param_bytes,
165 num_moe_layers,
166 num_experts,
167 gpu_expert_budget_per_layer: gpu_budget,
168 all_expert_weight_bytes: all_expert,
169 resident_expert_weight_bytes: resident_expert,
170 }
171}
172
173pub fn estimate(graph: &Graph, registry: &WeightRegistry) -> MemoryEstimate {
174 let plan = plan_memory(graph);
175 let mut input_bytes = 0usize;
177 for node in graph.nodes() {
178 if matches!(node.op, rlx_ir::Op::Input { .. }) {
179 input_bytes += node.shape.size_bytes().unwrap_or(0);
180 }
181 }
182 MemoryEstimate {
183 activation_bytes: plan.arena_size,
184 weight_bytes: registry.total_bytes(),
185 input_bytes,
186 }
187}
188
189pub fn available_unified_memory() -> Option<usize> {
193 #[cfg(target_os = "macos")]
194 {
195 use std::ffi::CString;
196 let cname = CString::new("hw.memsize").ok()?;
197 let mut val: u64 = 0;
198 let mut len = std::mem::size_of::<u64>();
199 unsafe extern "C" {
200 fn sysctlbyname(
201 name: *const std::os::raw::c_char,
202 oldp: *mut std::os::raw::c_void,
203 oldlenp: *mut usize,
204 newp: *mut std::os::raw::c_void,
205 newlen: usize,
206 ) -> std::os::raw::c_int;
207 }
208 let rc = unsafe {
209 sysctlbyname(
210 cname.as_ptr(),
211 &mut val as *mut u64 as *mut _,
212 &mut len,
213 std::ptr::null_mut(),
214 0,
215 )
216 };
217 if rc == 0 { Some(val as usize) } else { None }
218 }
219 #[cfg(not(target_os = "macos"))]
220 {
221 None
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use crate::weight_registry::{WeightKind, WeightRegistry};
229 use rlx_ir::*;
230 use std::sync::Arc;
231
232 fn small_graph() -> Graph {
233 let f = DType::F32;
234 let mut g = Graph::new("est");
235 let x = g.input("x", Shape::new(&[2, 16], f)); let w = g.param("w", Shape::new(&[16, 32], f)); let mm = g.matmul(x, w, Shape::new(&[2, 32], f)); g.set_outputs(vec![mm]);
239 g
240 }
241
242 #[test]
243 fn estimate_sums_components() {
244 let g = small_graph();
245 let mut reg = WeightRegistry::new();
246 reg.register(
247 "w",
248 Shape::new(&[16, 32], DType::F32),
249 Arc::from(vec![0u8; 16 * 32 * 4]),
250 WeightKind::Base,
251 );
252 let est = estimate(&g, ®);
253 assert!(
254 est.activation_bytes >= 256,
255 "activation arena should hold mm output"
256 );
257 assert_eq!(est.weight_bytes, 16 * 32 * 4);
258 assert_eq!(est.input_bytes, 2 * 16 * 4);
259 assert!(
260 est.peak_bytes() >= est.activation_bytes + est.weight_bytes + est.input_bytes
261 || est.peak_bytes() == est.activation_bytes + est.weight_bytes + est.input_bytes
262 );
263 }
264
265 #[test]
266 fn fits_in_passes_with_room() {
267 let g = small_graph();
268 let mut reg = WeightRegistry::new();
269 reg.register(
270 "w",
271 Shape::new(&[16, 32], DType::F32),
272 Arc::from(vec![0u8; 2048]),
273 WeightKind::Base,
274 );
275 let est = estimate(&g, ®);
276 assert!(
277 est.fits_in(1 << 30).is_ok(),
278 "1 GiB budget should fit a tiny graph"
279 );
280 }
281
282 #[test]
283 fn fits_in_reports_deficit() {
284 let g = small_graph();
285 let mut reg = WeightRegistry::new();
286 reg.register(
287 "w",
288 Shape::new(&[16, 32], DType::F32),
289 Arc::from(vec![0u8; 100_000_000]),
290 WeightKind::Base,
291 );
292 let est = estimate(&g, ®);
293 let err = est.fits_in(1024).unwrap_err();
294 assert!(err.peak_bytes > err.budget_bytes);
295 assert!(format!("{err}").contains("exceeds"));
296 }
297
298 #[test]
299 fn available_memory_returns_something_on_macos() {
300 if cfg!(target_os = "macos") {
303 let mem = available_unified_memory();
304 assert!(mem.is_some());
305 assert!(mem.unwrap() > 0);
306 }
307 }
308}