Skip to main content

rlx_runtime/
memory_estimate.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Pre-load memory estimation (plan #35).
17//!
18//! Borrowed from MAX's `max/python/max/pipelines/` pattern: model
19//! peak memory is estimated *before* weights load. On Apple
20//! Silicon this matters disproportionately — unified memory is
21//! shared with the OS, so a model that "would fit on a 96 GB
22//! Mac" can still OOM if you spawn it during a heavy Spotlight
23//! re-index.
24//!
25//! Three components:
26//!
27//!   - **Activation working set** — peak arena bytes. Already
28//!     computed by `rlx_opt::memory::plan_memory(&graph)`; we
29//!     just expose it.
30//!   - **Weight bytes** — sum of registered weights from a
31//!     [`WeightRegistry`]. Aliases (tied embeddings) don't
32//!     double-count.
33//!   - **Per-batch input bytes** — bytes the user is going to
34//!     hand in via `compiled.run()`. Driven by graph inputs.
35//!
36//! [`MemoryEstimate::peak_bytes`] is the sum. [`MemoryEstimate::
37//! fits_in`] takes a budget and returns the gating decision plus
38//! a structured reason.
39
40use crate::expert_pool::gpu_expert_budget_from_vram;
41use crate::weight_registry::WeightRegistry;
42use rlx_ir::Graph;
43use rlx_opt::memory::plan_memory;
44
45#[derive(Debug, Clone)]
46pub struct MemoryEstimate {
47    /// Peak working-set during one forward pass — output of the
48    /// memory planner.
49    pub activation_bytes: usize,
50    /// Total weight bytes, deduplicated against tied aliases.
51    pub weight_bytes: usize,
52    /// Sum of input tensor sizes for one call (computed from
53    /// graph inputs, ignoring outputs since they overlap with
54    /// the activation arena).
55    pub input_bytes: usize,
56}
57
58impl MemoryEstimate {
59    pub fn peak_bytes(&self) -> usize {
60        self.activation_bytes + self.weight_bytes + self.input_bytes
61    }
62
63    /// True if the estimate fits within `budget_bytes`. The
64    /// `Result` carries the deficit on failure so callers can
65    /// surface a useful error message.
66    pub fn fits_in(&self, budget_bytes: usize) -> Result<(), MemoryDeficit> {
67        let peak = self.peak_bytes();
68        if peak <= budget_bytes {
69            Ok(())
70        } else {
71            Err(MemoryDeficit {
72                budget_bytes,
73                peak_bytes: peak,
74                activation_bytes: self.activation_bytes,
75                weight_bytes: self.weight_bytes,
76                input_bytes: self.input_bytes,
77            })
78        }
79    }
80}
81
82#[derive(Debug, Clone)]
83pub struct MemoryDeficit {
84    pub budget_bytes: usize,
85    pub peak_bytes: usize,
86    pub activation_bytes: usize,
87    pub weight_bytes: usize,
88    pub input_bytes: usize,
89}
90
91impl std::fmt::Display for MemoryDeficit {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        let over = self.peak_bytes - self.budget_bytes;
94        write!(
95            f,
96            "estimated peak {peak_mb:.1} MiB exceeds budget {budget_mb:.1} MiB by {over_mb:.1} MiB \
97             (activation {act_mb:.1}, weights {w_mb:.1}, inputs {in_mb:.1})",
98            peak_mb = self.peak_bytes as f64 / 1024.0 / 1024.0,
99            budget_mb = self.budget_bytes as f64 / 1024.0 / 1024.0,
100            over_mb = over as f64 / 1024.0 / 1024.0,
101            act_mb = self.activation_bytes as f64 / 1024.0 / 1024.0,
102            w_mb = self.weight_bytes as f64 / 1024.0 / 1024.0,
103            in_mb = self.input_bytes as f64 / 1024.0 / 1024.0,
104        )
105    }
106}
107
108impl std::error::Error for MemoryDeficit {}
109
110/// Estimate peak memory for running `graph` on a session bound to
111/// `registry`. Pure analysis — runs the memory planner internally
112/// and queries the registry for weight bytes; doesn't compile or
113/// execute.
114/// MoE offload sizing (TIDE `enable_predictive_expert_offload`).
115#[derive(Debug, Clone)]
116pub struct MoeOffloadEstimate {
117    /// Bytes for one expert FFN (gate+up+down) at runtime dtype.
118    pub expert_param_bytes: usize,
119    pub num_moe_layers: usize,
120    pub num_experts: usize,
121    /// Experts pinned on device per layer after budget clamp.
122    pub gpu_expert_budget_per_layer: usize,
123    /// All experts resident on host+device (upper bound).
124    pub all_expert_weight_bytes: usize,
125    /// Only `gpu_expert_budget_per_layer` experts per layer on device.
126    pub resident_expert_weight_bytes: usize,
127}
128
129impl MoeOffloadEstimate {
130    /// Resident expert weights + non-expert peak from [`MemoryEstimate`].
131    pub fn peak_with_offload(&self, base: &MemoryEstimate) -> usize {
132        base.activation_bytes
133            + base.input_bytes
134            + (base.weight_bytes - self.all_expert_weight_bytes)
135            + self.resident_expert_weight_bytes
136    }
137}
138
139/// Compute GPU expert budget from a memory budget (unified RAM or VRAM).
140pub fn estimate_moe_offload(
141    expert_param_bytes: usize,
142    num_moe_layers: usize,
143    num_experts: usize,
144    max_gpu_experts_per_layer: usize,
145    memory_budget_bytes: usize,
146    reserve_fraction: f32,
147) -> MoeOffloadEstimate {
148    let reserve_bytes = (memory_budget_bytes as f64 * reserve_fraction as f64) as usize;
149    let gpu_budget = gpu_expert_budget_from_vram(
150        memory_budget_bytes,
151        reserve_bytes,
152        expert_param_bytes,
153        num_moe_layers,
154        max_gpu_experts_per_layer,
155        num_experts,
156    );
157    let all_expert = expert_param_bytes
158        .saturating_mul(num_experts)
159        .saturating_mul(num_moe_layers);
160    let resident_expert = expert_param_bytes
161        .saturating_mul(gpu_budget)
162        .saturating_mul(num_moe_layers);
163    MoeOffloadEstimate {
164        expert_param_bytes,
165        num_moe_layers,
166        num_experts,
167        gpu_expert_budget_per_layer: gpu_budget,
168        all_expert_weight_bytes: all_expert,
169        resident_expert_weight_bytes: resident_expert,
170    }
171}
172
173pub fn estimate(graph: &Graph, registry: &WeightRegistry) -> MemoryEstimate {
174    let plan = plan_memory(graph);
175    // Inputs: walk the graph once; sum each Op::Input's shape.
176    let mut input_bytes = 0usize;
177    for node in graph.nodes() {
178        if matches!(node.op, rlx_ir::Op::Input { .. }) {
179            input_bytes += node.shape.size_bytes().unwrap_or(0);
180        }
181    }
182    MemoryEstimate {
183        activation_bytes: plan.arena_size,
184        weight_bytes: registry.total_bytes(),
185        input_bytes,
186    }
187}
188
189/// Available unified-memory budget on the running machine. On
190/// macOS reads `hw.memsize` via sysctl; everywhere else returns
191/// `None` so callers can fall back to a user-supplied budget.
192pub fn available_unified_memory() -> Option<usize> {
193    #[cfg(target_os = "macos")]
194    {
195        use std::ffi::CString;
196        let cname = CString::new("hw.memsize").ok()?;
197        let mut val: u64 = 0;
198        let mut len = std::mem::size_of::<u64>();
199        unsafe extern "C" {
200            fn sysctlbyname(
201                name: *const std::os::raw::c_char,
202                oldp: *mut std::os::raw::c_void,
203                oldlenp: *mut usize,
204                newp: *mut std::os::raw::c_void,
205                newlen: usize,
206            ) -> std::os::raw::c_int;
207        }
208        let rc = unsafe {
209            sysctlbyname(
210                cname.as_ptr(),
211                &mut val as *mut u64 as *mut _,
212                &mut len,
213                std::ptr::null_mut(),
214                0,
215            )
216        };
217        if rc == 0 { Some(val as usize) } else { None }
218    }
219    #[cfg(not(target_os = "macos"))]
220    {
221        None
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use crate::weight_registry::{WeightKind, WeightRegistry};
229    use rlx_ir::*;
230    use std::sync::Arc;
231
232    fn small_graph() -> Graph {
233        let f = DType::F32;
234        let mut g = Graph::new("est");
235        let x = g.input("x", Shape::new(&[2, 16], f)); // 128 B
236        let w = g.param("w", Shape::new(&[16, 32], f)); // weight via registry
237        let mm = g.matmul(x, w, Shape::new(&[2, 32], f)); // 256 B activation
238        g.set_outputs(vec![mm]);
239        g
240    }
241
242    #[test]
243    fn estimate_sums_components() {
244        let g = small_graph();
245        let mut reg = WeightRegistry::new();
246        reg.register(
247            "w",
248            Shape::new(&[16, 32], DType::F32),
249            Arc::from(vec![0u8; 16 * 32 * 4]),
250            WeightKind::Base,
251        );
252        let est = estimate(&g, &reg);
253        assert!(
254            est.activation_bytes >= 256,
255            "activation arena should hold mm output"
256        );
257        assert_eq!(est.weight_bytes, 16 * 32 * 4);
258        assert_eq!(est.input_bytes, 2 * 16 * 4);
259        assert!(
260            est.peak_bytes() >= est.activation_bytes + est.weight_bytes + est.input_bytes
261                || est.peak_bytes() == est.activation_bytes + est.weight_bytes + est.input_bytes
262        );
263    }
264
265    #[test]
266    fn fits_in_passes_with_room() {
267        let g = small_graph();
268        let mut reg = WeightRegistry::new();
269        reg.register(
270            "w",
271            Shape::new(&[16, 32], DType::F32),
272            Arc::from(vec![0u8; 2048]),
273            WeightKind::Base,
274        );
275        let est = estimate(&g, &reg);
276        assert!(
277            est.fits_in(1 << 30).is_ok(),
278            "1 GiB budget should fit a tiny graph"
279        );
280    }
281
282    #[test]
283    fn fits_in_reports_deficit() {
284        let g = small_graph();
285        let mut reg = WeightRegistry::new();
286        reg.register(
287            "w",
288            Shape::new(&[16, 32], DType::F32),
289            Arc::from(vec![0u8; 100_000_000]),
290            WeightKind::Base,
291        );
292        let est = estimate(&g, &reg);
293        let err = est.fits_in(1024).unwrap_err();
294        assert!(err.peak_bytes > err.budget_bytes);
295        assert!(format!("{err}").contains("exceeds"));
296    }
297
298    #[test]
299    fn available_memory_returns_something_on_macos() {
300        // basic test — on macOS this should return Some(>0).
301        // Elsewhere (CI) we accept None.
302        if cfg!(target_os = "macos") {
303            let mem = available_unified_memory();
304            assert!(mem.is_some());
305            assert!(mem.unwrap() > 0);
306        }
307    }
308}