Skip to main content

rlx_llada2/tide/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16// RLX — TIDE-facing runner API (`LLaDA2MoeModelLM` parity).
17
18use crate::tide::{
19    BlockDenoiseConfig, BlockDenoiseLoop, BlockDenoiseStepStats, PredictiveOffloadInfo,
20    PredictiveOffloadParams, TideOffloadStats, enable_predictive_expert_offload,
21};
22use crate::{GenerateConfig, LLaDA2MoeConfig, LLaDA2Runner, LLaDA2RunnerBuilder, LLaDA2Weights};
23use anyhow::Result;
24
25/// TIDE reference model runner (LLaDA2 MoE + block diffusion + predictive offload).
26pub struct TideRunner {
27    inner: LLaDA2Runner,
28}
29
30impl TideRunner {
31    pub fn builder() -> LLaDA2RunnerBuilder {
32        LLaDA2Runner::builder()
33    }
34
35    pub fn from_llada2(inner: LLaDA2Runner) -> Self {
36        Self { inner }
37    }
38
39    pub fn into_llada2(self) -> LLaDA2Runner {
40        self.inner
41    }
42
43    pub fn llada2(&self) -> &LLaDA2Runner {
44        &self.inner
45    }
46
47    pub fn llada2_mut(&mut self) -> &mut LLaDA2Runner {
48        &mut self.inner
49    }
50
51    pub fn config(&self) -> &LLaDA2MoeConfig {
52        self.inner.config()
53    }
54
55    /// TIDE `enable_predictive_expert_offload` (configure via [`Self::builder`] before `build`).
56    pub fn predictive_offload_info(&self) -> Option<PredictiveOffloadInfo> {
57        self.inner.predictive_offload_info()
58    }
59
60    pub fn predictive_offload_enabled(&self) -> bool {
61        self.inner.predictive_offload_enabled()
62    }
63
64    pub fn jump_steps(&self) -> usize {
65        self.inner.jump_steps()
66    }
67
68    /// TIDE `get_offload_stats()` — sum across MoE layers + last-forward residency.
69    pub fn get_offload_stats(&mut self) -> TideOffloadStats {
70        self.inner.get_offload_stats()
71    }
72
73    /// TIDE `generate(input_ids, ...)`.
74    pub fn generate(
75        &mut self,
76        input_ids: &[u32],
77        gen_cfg: &GenerateConfig,
78    ) -> Result<(Vec<u32>, Vec<BlockDenoiseStepStats>)> {
79        self.inner.generate(gen_cfg, input_ids)
80    }
81
82    pub fn block_denoise_loop(
83        &mut self,
84        cfg: BlockDenoiseConfig,
85    ) -> BlockDenoiseLoop<crate::runner::LLaDA2RunnerForward<'_>> {
86        self.inner.block_denoise_loop(cfg)
87    }
88}
89
90impl LLaDA2RunnerBuilder {
91    /// TIDE `enable_predictive_expert_offload(max_gpu_experts_per_layer, ...)`.
92    pub fn tide_enable_predictive_expert_offload(
93        mut self,
94        max_gpu_experts_per_layer: usize,
95        reserve_vram_gb: f64,
96        collect_stats: bool,
97        jump_steps: usize,
98    ) -> Self {
99        self = self
100            .enable_predictive_expert_offload(max_gpu_experts_per_layer)
101            .reserve_vram_gb(reserve_vram_gb)
102            .jump_steps(jump_steps)
103            .moe_collect_stats(collect_stats);
104        self
105    }
106}
107
108/// Preview TIDE offload budget without building a runner (host/unified memory when no CUDA).
109pub fn preview_predictive_offload(
110    cfg: &LLaDA2MoeConfig,
111    weights: &LLaDA2Weights,
112    max_gpu_experts_per_layer: usize,
113    reserve_vram_gb: f64,
114    collect_stats: bool,
115    jump_steps: usize,
116) -> Option<PredictiveOffloadInfo> {
117    let layer_count = crate::moe_offload::count_moe_layers(weights).max(1);
118    let mut params = PredictiveOffloadParams::new(
119        max_gpu_experts_per_layer,
120        cfg.num_experts,
121        layer_count,
122        cfg.expert_param_bytes_f32(),
123    );
124    params.reserve_vram_gb = reserve_vram_gb;
125    params.collect_stats = collect_stats;
126    params.jump_steps = jump_steps;
127    enable_predictive_expert_offload(&params).map(|(_, info)| info)
128}