Skip to main content

rlx_runtime/
memory_estimate.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Pre-load memory estimation (plan #35).
17//!
18//! Borrowed from MAX's `max/python/max/pipelines/` pattern: model
19//! peak memory is estimated *before* weights load. On Apple
20//! Silicon this matters disproportionately — unified memory is
21//! shared with the OS, so a model that "would fit on a 96 GB
22//! Mac" can still OOM if you spawn it during a heavy Spotlight
23//! re-index.
24//!
25//! Three components:
26//!
27//!   - **Activation working set** — peak arena bytes. Already
28//!     computed by `rlx_opt::memory::plan_memory(&graph)`; we
29//!     just expose it.
30//!   - **Weight bytes** — sum of registered weights from a
31//!     [`WeightRegistry`]. Aliases (tied embeddings) don't
32//!     double-count.
33//!   - **Per-batch input bytes** — bytes the user is going to
34//!     hand in via `compiled.run()`. Driven by graph inputs.
35//!
36//! [`MemoryEstimate::peak_bytes`] is the sum. [`MemoryEstimate::
37//! fits_in`] takes a budget and returns the gating decision plus
38//! a structured reason.
39
40use crate::expert_pool::gpu_expert_budget_from_vram;
41use crate::weight_registry::WeightRegistry;
42use rlx_ir::Graph;
43use rlx_opt::memory::plan_memory;
44
45#[derive(Debug, Clone)]
46pub struct MemoryEstimate {
47    /// Peak working-set during one forward pass — output of the
48    /// memory planner.
49    pub activation_bytes: usize,
50    /// Total weight bytes, deduplicated against tied aliases.
51    pub weight_bytes: usize,
52    /// Sum of input tensor sizes for one call (computed from
53    /// graph inputs, ignoring outputs since they overlap with
54    /// the activation arena).
55    pub input_bytes: usize,
56}
57
58impl MemoryEstimate {
59    pub fn peak_bytes(&self) -> usize {
60        self.activation_bytes + self.weight_bytes + self.input_bytes
61    }
62
63    /// True if the estimate fits within `budget_bytes`. The
64    /// `Result` carries the deficit on failure so callers can
65    /// surface a useful error message.
66    pub fn fits_in(&self, budget_bytes: usize) -> Result<(), MemoryDeficit> {
67        let peak = self.peak_bytes();
68        if peak <= budget_bytes {
69            Ok(())
70        } else {
71            Err(MemoryDeficit {
72                budget_bytes,
73                peak_bytes: peak,
74                activation_bytes: self.activation_bytes,
75                weight_bytes: self.weight_bytes,
76                input_bytes: self.input_bytes,
77            })
78        }
79    }
80}
81
82#[derive(Debug, Clone)]
83pub struct MemoryDeficit {
84    pub budget_bytes: usize,
85    pub peak_bytes: usize,
86    pub activation_bytes: usize,
87    pub weight_bytes: usize,
88    pub input_bytes: usize,
89}
90
91impl std::fmt::Display for MemoryDeficit {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        let over = self.peak_bytes - self.budget_bytes;
94        write!(
95            f,
96            "estimated peak {peak_mb:.1} MiB exceeds budget {budget_mb:.1} MiB by {over_mb:.1} MiB \
97             (activation {act_mb:.1}, weights {w_mb:.1}, inputs {in_mb:.1})",
98            peak_mb = self.peak_bytes as f64 / 1024.0 / 1024.0,
99            budget_mb = self.budget_bytes as f64 / 1024.0 / 1024.0,
100            over_mb = over as f64 / 1024.0 / 1024.0,
101            act_mb = self.activation_bytes as f64 / 1024.0 / 1024.0,
102            w_mb = self.weight_bytes as f64 / 1024.0 / 1024.0,
103            in_mb = self.input_bytes as f64 / 1024.0 / 1024.0,
104        )
105    }
106}
107
108impl std::error::Error for MemoryDeficit {}
109
110/// Estimate peak memory for running `graph` on a session bound to
111/// `registry`. Pure analysis — runs the memory planner internally
112/// and queries the registry for weight bytes; doesn't compile or
113/// execute.
114/// MoE offload sizing (TIDE `enable_predictive_expert_offload`).
115#[derive(Debug, Clone)]
116pub struct MoeOffloadEstimate {
117    /// Bytes for one expert FFN (gate+up+down) at runtime dtype.
118    pub expert_param_bytes: usize,
119    pub num_moe_layers: usize,
120    pub num_experts: usize,
121    /// Experts pinned on device per layer after budget clamp.
122    pub gpu_expert_budget_per_layer: usize,
123    /// All experts resident on host+device (upper bound).
124    pub all_expert_weight_bytes: usize,
125    /// Only `gpu_expert_budget_per_layer` experts per layer on device.
126    pub resident_expert_weight_bytes: usize,
127}
128
129impl MoeOffloadEstimate {
130    /// Resident expert weights + non-expert peak from [`MemoryEstimate`].
131    pub fn peak_with_offload(&self, base: &MemoryEstimate) -> usize {
132        base.activation_bytes
133            + base.input_bytes
134            + (base.weight_bytes - self.all_expert_weight_bytes)
135            + self.resident_expert_weight_bytes
136    }
137}
138
139/// Compute GPU expert budget from a memory budget (unified RAM or VRAM).
140pub fn estimate_moe_offload(
141    expert_param_bytes: usize,
142    num_moe_layers: usize,
143    num_experts: usize,
144    max_gpu_experts_per_layer: usize,
145    memory_budget_bytes: usize,
146    reserve_fraction: f32,
147) -> MoeOffloadEstimate {
148    let reserve_bytes = (memory_budget_bytes as f64 * reserve_fraction as f64) as usize;
149    let gpu_budget = gpu_expert_budget_from_vram(
150        memory_budget_bytes,
151        reserve_bytes,
152        expert_param_bytes,
153        num_moe_layers,
154        max_gpu_experts_per_layer,
155        num_experts,
156    );
157    let all_expert = expert_param_bytes
158        .saturating_mul(num_experts)
159        .saturating_mul(num_moe_layers);
160    let resident_expert = expert_param_bytes
161        .saturating_mul(gpu_budget)
162        .saturating_mul(num_moe_layers);
163    MoeOffloadEstimate {
164        expert_param_bytes,
165        num_moe_layers,
166        num_experts,
167        gpu_expert_budget_per_layer: gpu_budget,
168        all_expert_weight_bytes: all_expert,
169        resident_expert_weight_bytes: resident_expert,
170    }
171}
172
173pub fn estimate(graph: &Graph, registry: &WeightRegistry) -> MemoryEstimate {
174    let plan = plan_memory(graph);
175    // Inputs: walk the graph once; sum each Op::Input's shape.
176    let mut input_bytes = 0usize;
177    for node in graph.nodes() {
178        if matches!(node.op, rlx_ir::Op::Input { .. }) {
179            input_bytes += node.shape.size_bytes().unwrap_or(0);
180        }
181    }
182    MemoryEstimate {
183        activation_bytes: plan.arena_size,
184        weight_bytes: registry.total_bytes(),
185        input_bytes,
186    }
187}
188
189/// Available unified-memory budget on the running machine. On
190/// macOS reads `hw.memsize` via sysctl; everywhere else returns
191/// `None` so callers can fall back to a user-supplied budget.
192pub fn available_unified_memory() -> Option<usize> {
193    #[cfg(target_os = "macos")]
194    {
195        use std::ffi::CString;
196        let cname = CString::new("hw.memsize").ok()?;
197        let mut val: u64 = 0;
198        let mut len = std::mem::size_of::<u64>();
199        unsafe extern "C" {
200            fn sysctlbyname(
201                name: *const std::os::raw::c_char,
202                oldp: *mut std::os::raw::c_void,
203                oldlenp: *mut usize,
204                newp: *mut std::os::raw::c_void,
205                newlen: usize,
206            ) -> std::os::raw::c_int;
207        }
208        let rc = unsafe {
209            sysctlbyname(
210                cname.as_ptr(),
211                &mut val as *mut u64 as *mut _,
212                &mut len,
213                std::ptr::null_mut(),
214                0,
215            )
216        };
217        if rc == 0 { Some(val as usize) } else { None }
218    }
219    #[cfg(not(target_os = "macos"))]
220    {
221        None
222    }
223}
224
225/// Default soft cap as a fraction of physical RAM (stay below OOM).
226pub const DEFAULT_SOFT_MEMORY_FRACTION: f64 = 0.80;
227
228/// Fraction of physical RAM treated as a soft working-set cap.
229/// Override with `RLX_SOFT_MEMORY_FRACTION` (e.g. `0.8`).
230pub fn soft_memory_fraction() -> f64 {
231    std::env::var("RLX_SOFT_MEMORY_FRACTION")
232        .ok()
233        .and_then(|s| s.parse::<f64>().ok())
234        .filter(|&f| f > 0.0 && f <= 1.0)
235        .unwrap_or(DEFAULT_SOFT_MEMORY_FRACTION)
236}
237
238/// Soft RSS budget: `physical_ram * soft_memory_fraction()`.
239/// Returns `None` when physical RAM is unknown (non-macOS without override).
240pub fn soft_memory_budget_bytes() -> Option<usize> {
241    if let Ok(v) = std::env::var("RLX_SOFT_MEMORY_BUDGET_BYTES") {
242        if let Ok(n) = v.parse::<usize>() {
243            return Some(n);
244        }
245    }
246    available_unified_memory()
247        .map(|total| ((total as f64) * soft_memory_fraction()).floor() as usize)
248}
249
250/// Current process resident set size, when available.
251pub fn process_rss_bytes() -> Option<usize> {
252    #[cfg(target_os = "macos")]
253    {
254        use std::mem::MaybeUninit;
255        unsafe extern "C" {
256            fn mach_task_self() -> u32;
257            fn task_info(
258                target_task: u32,
259                flavor: u32,
260                task_info_out: *mut std::os::raw::c_char,
261                task_info_out_cnt: *mut u32,
262            ) -> i32;
263        }
264        const TASK_BASIC_INFO: u32 = 5;
265        #[repr(C)]
266        struct TaskBasicInfo {
267            suspend_count: u32,
268            virtual_size: u64,
269            resident_size: u64,
270            user_time: u64,
271            system_time: u64,
272            policy: i32,
273        }
274        let mut info = MaybeUninit::<TaskBasicInfo>::uninit();
275        let mut count = (std::mem::size_of::<TaskBasicInfo>() / std::mem::size_of::<u32>()) as u32;
276        let rc = unsafe {
277            task_info(
278                mach_task_self(),
279                TASK_BASIC_INFO,
280                info.as_mut_ptr().cast(),
281                &mut count,
282            )
283        };
284        if rc == 0 {
285            Some(unsafe { info.assume_init() }.resident_size as usize)
286        } else {
287            None
288        }
289    }
290    #[cfg(target_os = "linux")]
291    {
292        let statm = std::fs::read_to_string("/proc/self/statm").ok()?;
293        let resident_pages = statm.split_whitespace().nth(1)?.parse::<usize>().ok()?;
294        Some(resident_pages.saturating_mul(page_size()))
295    }
296    #[cfg(not(any(target_os = "macos", target_os = "linux")))]
297    {
298        None
299    }
300}
301
302#[cfg(target_os = "linux")]
303fn page_size() -> usize {
304    unsafe extern "C" {
305        fn sysconf(name: i32) -> i64;
306    }
307    const _SC_PAGESIZE: i32 = 30;
308    let ps = unsafe { sysconf(_SC_PAGESIZE) };
309    if ps > 0 { ps as usize } else { 4096 }
310}
311
312/// Bytes remaining before the soft budget (`budget - current RSS`).
313pub fn memory_headroom_bytes() -> Option<usize> {
314    let budget = soft_memory_budget_bytes()?;
315    let rss = process_rss_bytes().unwrap_or(0);
316    Some(budget.saturating_sub(rss))
317}
318
319/// True when `current_rss + additional` would exceed the soft budget.
320/// Unknown budget → false (do not block).
321pub fn would_exceed_soft_budget(additional_bytes: usize) -> bool {
322    let Some(budget) = soft_memory_budget_bytes() else {
323        return false;
324    };
325    let rss = process_rss_bytes().unwrap_or(0);
326    rss.saturating_add(additional_bytes) > budget
327}
328
329/// Conservative peak for one LLaMA decode graph compile with F32 params (3B class).
330pub fn llama_decode_bucket_compile_peak_bytes() -> usize {
331    std::env::var("RLX_DECODE_BUCKET_PEAK_BYTES")
332        .ok()
333        .and_then(|s| s.parse().ok())
334        .unwrap_or(12 * 1024 * 1024 * 1024)
335}
336
337/// Lazy per-step decode compile (GGUF on demand, no resident param cache).
338pub fn llama_decode_oneshot_compile_peak_bytes() -> usize {
339    std::env::var("RLX_DECODE_ONESHOT_PEAK_BYTES")
340        .ok()
341        .and_then(|s| s.parse().ok())
342        .unwrap_or(4 * 1024 * 1024 * 1024)
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use crate::weight_registry::{WeightKind, WeightRegistry};
349    use rlx_ir::*;
350    use std::sync::Arc;
351
352    fn small_graph() -> Graph {
353        let f = DType::F32;
354        let mut g = Graph::new("est");
355        let x = g.input("x", Shape::new(&[2, 16], f)); // 128 B
356        let w = g.param("w", Shape::new(&[16, 32], f)); // weight via registry
357        let mm = g.matmul(x, w, Shape::new(&[2, 32], f)); // 256 B activation
358        g.set_outputs(vec![mm]);
359        g
360    }
361
362    #[test]
363    fn estimate_sums_components() {
364        let g = small_graph();
365        let mut reg = WeightRegistry::new();
366        reg.register(
367            "w",
368            Shape::new(&[16, 32], DType::F32),
369            Arc::from(vec![0u8; 16 * 32 * 4]),
370            WeightKind::Base,
371        );
372        let est = estimate(&g, &reg);
373        assert!(
374            est.activation_bytes >= 256,
375            "activation arena should hold mm output"
376        );
377        assert_eq!(est.weight_bytes, 16 * 32 * 4);
378        assert_eq!(est.input_bytes, 2 * 16 * 4);
379        assert!(
380            est.peak_bytes() >= est.activation_bytes + est.weight_bytes + est.input_bytes
381                || est.peak_bytes() == est.activation_bytes + est.weight_bytes + est.input_bytes
382        );
383    }
384
385    #[test]
386    fn fits_in_passes_with_room() {
387        let g = small_graph();
388        let mut reg = WeightRegistry::new();
389        reg.register(
390            "w",
391            Shape::new(&[16, 32], DType::F32),
392            Arc::from(vec![0u8; 2048]),
393            WeightKind::Base,
394        );
395        let est = estimate(&g, &reg);
396        assert!(
397            est.fits_in(1 << 30).is_ok(),
398            "1 GiB budget should fit a tiny graph"
399        );
400    }
401
402    #[test]
403    fn fits_in_reports_deficit() {
404        let g = small_graph();
405        let mut reg = WeightRegistry::new();
406        reg.register(
407            "w",
408            Shape::new(&[16, 32], DType::F32),
409            Arc::from(vec![0u8; 100_000_000]),
410            WeightKind::Base,
411        );
412        let est = estimate(&g, &reg);
413        let err = est.fits_in(1024).unwrap_err();
414        assert!(err.peak_bytes > err.budget_bytes);
415        assert!(format!("{err}").contains("exceeds"));
416    }
417
418    #[test]
419    fn available_memory_returns_something_on_macos() {
420        // basic test — on macOS this should return Some(>0).
421        // Elsewhere (CI) we accept None.
422        if cfg!(target_os = "macos") {
423            let mem = available_unified_memory();
424            assert!(mem.is_some());
425            assert!(mem.unwrap() > 0);
426        }
427    }
428
429    #[test]
430    fn soft_budget_is_fraction_of_physical() {
431        if let Some(total) = available_unified_memory() {
432            let budget = soft_memory_budget_bytes().unwrap();
433            let expect = ((total as f64) * DEFAULT_SOFT_MEMORY_FRACTION).floor() as usize;
434            assert_eq!(budget, expect);
435        }
436    }
437
438    #[test]
439    fn would_exceed_respects_headroom() {
440        if soft_memory_budget_bytes().is_none() {
441            return;
442        }
443        let budget = soft_memory_budget_bytes().unwrap();
444        let rss = process_rss_bytes().unwrap_or(0);
445        assert!(!would_exceed_soft_budget(budget.saturating_sub(rss)));
446        assert!(would_exceed_soft_budget(budget.saturating_sub(rss) + 1));
447    }
448}