1use crate::expert_pool::gpu_expert_budget_from_vram;
41use crate::weight_registry::WeightRegistry;
42use rlx_ir::Graph;
43use rlx_opt::memory::plan_memory;
44
45#[derive(Debug, Clone)]
46pub struct MemoryEstimate {
47 pub activation_bytes: usize,
50 pub weight_bytes: usize,
52 pub input_bytes: usize,
56}
57
58impl MemoryEstimate {
59 pub fn peak_bytes(&self) -> usize {
60 self.activation_bytes + self.weight_bytes + self.input_bytes
61 }
62
63 pub fn fits_in(&self, budget_bytes: usize) -> Result<(), MemoryDeficit> {
67 let peak = self.peak_bytes();
68 if peak <= budget_bytes {
69 Ok(())
70 } else {
71 Err(MemoryDeficit {
72 budget_bytes,
73 peak_bytes: peak,
74 activation_bytes: self.activation_bytes,
75 weight_bytes: self.weight_bytes,
76 input_bytes: self.input_bytes,
77 })
78 }
79 }
80}
81
82#[derive(Debug, Clone)]
83pub struct MemoryDeficit {
84 pub budget_bytes: usize,
85 pub peak_bytes: usize,
86 pub activation_bytes: usize,
87 pub weight_bytes: usize,
88 pub input_bytes: usize,
89}
90
91impl std::fmt::Display for MemoryDeficit {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 let over = self.peak_bytes - self.budget_bytes;
94 write!(
95 f,
96 "estimated peak {peak_mb:.1} MiB exceeds budget {budget_mb:.1} MiB by {over_mb:.1} MiB \
97 (activation {act_mb:.1}, weights {w_mb:.1}, inputs {in_mb:.1})",
98 peak_mb = self.peak_bytes as f64 / 1024.0 / 1024.0,
99 budget_mb = self.budget_bytes as f64 / 1024.0 / 1024.0,
100 over_mb = over as f64 / 1024.0 / 1024.0,
101 act_mb = self.activation_bytes as f64 / 1024.0 / 1024.0,
102 w_mb = self.weight_bytes as f64 / 1024.0 / 1024.0,
103 in_mb = self.input_bytes as f64 / 1024.0 / 1024.0,
104 )
105 }
106}
107
108impl std::error::Error for MemoryDeficit {}
109
110#[derive(Debug, Clone)]
116pub struct MoeOffloadEstimate {
117 pub expert_param_bytes: usize,
119 pub num_moe_layers: usize,
120 pub num_experts: usize,
121 pub gpu_expert_budget_per_layer: usize,
123 pub all_expert_weight_bytes: usize,
125 pub resident_expert_weight_bytes: usize,
127}
128
129impl MoeOffloadEstimate {
130 pub fn peak_with_offload(&self, base: &MemoryEstimate) -> usize {
132 base.activation_bytes
133 + base.input_bytes
134 + (base.weight_bytes - self.all_expert_weight_bytes)
135 + self.resident_expert_weight_bytes
136 }
137}
138
139pub fn estimate_moe_offload(
141 expert_param_bytes: usize,
142 num_moe_layers: usize,
143 num_experts: usize,
144 max_gpu_experts_per_layer: usize,
145 memory_budget_bytes: usize,
146 reserve_fraction: f32,
147) -> MoeOffloadEstimate {
148 let reserve_bytes = (memory_budget_bytes as f64 * reserve_fraction as f64) as usize;
149 let gpu_budget = gpu_expert_budget_from_vram(
150 memory_budget_bytes,
151 reserve_bytes,
152 expert_param_bytes,
153 num_moe_layers,
154 max_gpu_experts_per_layer,
155 num_experts,
156 );
157 let all_expert = expert_param_bytes
158 .saturating_mul(num_experts)
159 .saturating_mul(num_moe_layers);
160 let resident_expert = expert_param_bytes
161 .saturating_mul(gpu_budget)
162 .saturating_mul(num_moe_layers);
163 MoeOffloadEstimate {
164 expert_param_bytes,
165 num_moe_layers,
166 num_experts,
167 gpu_expert_budget_per_layer: gpu_budget,
168 all_expert_weight_bytes: all_expert,
169 resident_expert_weight_bytes: resident_expert,
170 }
171}
172
173pub fn estimate(graph: &Graph, registry: &WeightRegistry) -> MemoryEstimate {
174 let plan = plan_memory(graph);
175 let mut input_bytes = 0usize;
177 for node in graph.nodes() {
178 if matches!(node.op, rlx_ir::Op::Input { .. }) {
179 input_bytes += node.shape.size_bytes().unwrap_or(0);
180 }
181 }
182 MemoryEstimate {
183 activation_bytes: plan.arena_size,
184 weight_bytes: registry.total_bytes(),
185 input_bytes,
186 }
187}
188
189pub fn available_unified_memory() -> Option<usize> {
193 #[cfg(target_os = "macos")]
194 {
195 use std::ffi::CString;
196 let cname = CString::new("hw.memsize").ok()?;
197 let mut val: u64 = 0;
198 let mut len = std::mem::size_of::<u64>();
199 unsafe extern "C" {
200 fn sysctlbyname(
201 name: *const std::os::raw::c_char,
202 oldp: *mut std::os::raw::c_void,
203 oldlenp: *mut usize,
204 newp: *mut std::os::raw::c_void,
205 newlen: usize,
206 ) -> std::os::raw::c_int;
207 }
208 let rc = unsafe {
209 sysctlbyname(
210 cname.as_ptr(),
211 &mut val as *mut u64 as *mut _,
212 &mut len,
213 std::ptr::null_mut(),
214 0,
215 )
216 };
217 if rc == 0 { Some(val as usize) } else { None }
218 }
219 #[cfg(not(target_os = "macos"))]
220 {
221 None
222 }
223}
224
225pub const DEFAULT_SOFT_MEMORY_FRACTION: f64 = 0.80;
227
228pub fn soft_memory_fraction() -> f64 {
231 std::env::var("RLX_SOFT_MEMORY_FRACTION")
232 .ok()
233 .and_then(|s| s.parse::<f64>().ok())
234 .filter(|&f| f > 0.0 && f <= 1.0)
235 .unwrap_or(DEFAULT_SOFT_MEMORY_FRACTION)
236}
237
238pub fn soft_memory_budget_bytes() -> Option<usize> {
241 if let Ok(v) = std::env::var("RLX_SOFT_MEMORY_BUDGET_BYTES") {
242 if let Ok(n) = v.parse::<usize>() {
243 return Some(n);
244 }
245 }
246 available_unified_memory()
247 .map(|total| ((total as f64) * soft_memory_fraction()).floor() as usize)
248}
249
250pub fn process_rss_bytes() -> Option<usize> {
252 #[cfg(target_os = "macos")]
253 {
254 use std::mem::MaybeUninit;
255 unsafe extern "C" {
256 fn mach_task_self() -> u32;
257 fn task_info(
258 target_task: u32,
259 flavor: u32,
260 task_info_out: *mut std::os::raw::c_char,
261 task_info_out_cnt: *mut u32,
262 ) -> i32;
263 }
264 const TASK_BASIC_INFO: u32 = 5;
265 #[repr(C)]
266 struct TaskBasicInfo {
267 suspend_count: u32,
268 virtual_size: u64,
269 resident_size: u64,
270 user_time: u64,
271 system_time: u64,
272 policy: i32,
273 }
274 let mut info = MaybeUninit::<TaskBasicInfo>::uninit();
275 let mut count = (std::mem::size_of::<TaskBasicInfo>() / std::mem::size_of::<u32>()) as u32;
276 let rc = unsafe {
277 task_info(
278 mach_task_self(),
279 TASK_BASIC_INFO,
280 info.as_mut_ptr().cast(),
281 &mut count,
282 )
283 };
284 if rc == 0 {
285 Some(unsafe { info.assume_init() }.resident_size as usize)
286 } else {
287 None
288 }
289 }
290 #[cfg(target_os = "linux")]
291 {
292 let statm = std::fs::read_to_string("/proc/self/statm").ok()?;
293 let resident_pages = statm.split_whitespace().nth(1)?.parse::<usize>().ok()?;
294 Some(resident_pages.saturating_mul(page_size()))
295 }
296 #[cfg(not(any(target_os = "macos", target_os = "linux")))]
297 {
298 None
299 }
300}
301
302#[cfg(target_os = "linux")]
303fn page_size() -> usize {
304 unsafe extern "C" {
305 fn sysconf(name: i32) -> i64;
306 }
307 const _SC_PAGESIZE: i32 = 30;
308 let ps = unsafe { sysconf(_SC_PAGESIZE) };
309 if ps > 0 { ps as usize } else { 4096 }
310}
311
312pub fn memory_headroom_bytes() -> Option<usize> {
314 let budget = soft_memory_budget_bytes()?;
315 let rss = process_rss_bytes().unwrap_or(0);
316 Some(budget.saturating_sub(rss))
317}
318
319pub fn would_exceed_soft_budget(additional_bytes: usize) -> bool {
322 let Some(budget) = soft_memory_budget_bytes() else {
323 return false;
324 };
325 let rss = process_rss_bytes().unwrap_or(0);
326 rss.saturating_add(additional_bytes) > budget
327}
328
329pub fn llama_decode_bucket_compile_peak_bytes() -> usize {
331 std::env::var("RLX_DECODE_BUCKET_PEAK_BYTES")
332 .ok()
333 .and_then(|s| s.parse().ok())
334 .unwrap_or(12 * 1024 * 1024 * 1024)
335}
336
337pub fn llama_decode_oneshot_compile_peak_bytes() -> usize {
339 std::env::var("RLX_DECODE_ONESHOT_PEAK_BYTES")
340 .ok()
341 .and_then(|s| s.parse().ok())
342 .unwrap_or(4 * 1024 * 1024 * 1024)
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use crate::weight_registry::{WeightKind, WeightRegistry};
349 use rlx_ir::*;
350 use std::sync::Arc;
351
352 fn small_graph() -> Graph {
353 let f = DType::F32;
354 let mut g = Graph::new("est");
355 let x = g.input("x", Shape::new(&[2, 16], f)); let w = g.param("w", Shape::new(&[16, 32], f)); let mm = g.matmul(x, w, Shape::new(&[2, 32], f)); g.set_outputs(vec![mm]);
359 g
360 }
361
362 #[test]
363 fn estimate_sums_components() {
364 let g = small_graph();
365 let mut reg = WeightRegistry::new();
366 reg.register(
367 "w",
368 Shape::new(&[16, 32], DType::F32),
369 Arc::from(vec![0u8; 16 * 32 * 4]),
370 WeightKind::Base,
371 );
372 let est = estimate(&g, ®);
373 assert!(
374 est.activation_bytes >= 256,
375 "activation arena should hold mm output"
376 );
377 assert_eq!(est.weight_bytes, 16 * 32 * 4);
378 assert_eq!(est.input_bytes, 2 * 16 * 4);
379 assert!(
380 est.peak_bytes() >= est.activation_bytes + est.weight_bytes + est.input_bytes
381 || est.peak_bytes() == est.activation_bytes + est.weight_bytes + est.input_bytes
382 );
383 }
384
385 #[test]
386 fn fits_in_passes_with_room() {
387 let g = small_graph();
388 let mut reg = WeightRegistry::new();
389 reg.register(
390 "w",
391 Shape::new(&[16, 32], DType::F32),
392 Arc::from(vec![0u8; 2048]),
393 WeightKind::Base,
394 );
395 let est = estimate(&g, ®);
396 assert!(
397 est.fits_in(1 << 30).is_ok(),
398 "1 GiB budget should fit a tiny graph"
399 );
400 }
401
402 #[test]
403 fn fits_in_reports_deficit() {
404 let g = small_graph();
405 let mut reg = WeightRegistry::new();
406 reg.register(
407 "w",
408 Shape::new(&[16, 32], DType::F32),
409 Arc::from(vec![0u8; 100_000_000]),
410 WeightKind::Base,
411 );
412 let est = estimate(&g, ®);
413 let err = est.fits_in(1024).unwrap_err();
414 assert!(err.peak_bytes > err.budget_bytes);
415 assert!(format!("{err}").contains("exceeds"));
416 }
417
418 #[test]
419 fn available_memory_returns_something_on_macos() {
420 if cfg!(target_os = "macos") {
423 let mem = available_unified_memory();
424 assert!(mem.is_some());
425 assert!(mem.unwrap() > 0);
426 }
427 }
428
429 #[test]
430 fn soft_budget_is_fraction_of_physical() {
431 if let Some(total) = available_unified_memory() {
432 let budget = soft_memory_budget_bytes().unwrap();
433 let expect = ((total as f64) * DEFAULT_SOFT_MEMORY_FRACTION).floor() as usize;
434 assert_eq!(budget, expect);
435 }
436 }
437
438 #[test]
439 fn would_exceed_respects_headroom() {
440 if soft_memory_budget_bytes().is_none() {
441 return;
442 }
443 let budget = soft_memory_budget_bytes().unwrap();
444 let rss = process_rss_bytes().unwrap_or(0);
445 assert!(!would_exceed_soft_budget(budget.saturating_sub(rss)));
446 assert!(would_exceed_soft_budget(budget.saturating_sub(rss) + 1));
447 }
448}