Skip to main content

rlx_nemotron/
lib.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! NVIDIA Nemotron 3 Nano runner.
17//!
18//! Nemotron ships as several GGUF arch tags:
19//! * `nemotron` — text-only, Llama-shaped attention stack; runs via the
20//!   [`rlx_llama32::Llama32Runner`] delegate below.
21//! * `nemotron_h` / `nemotron_h_moe` — hybrid Mamba2 + attention; the
22//!   [`NemotronHybridRunner`] in `runner.rs` drives it via per-layer
23//!   `Mamba2StepStage` interleaved with stateless attention blocks.
24//!
25//! The Omni 30B variant (vision + audio) lives in `rlx-nemotron-omni`
26//! and is wired independently.
27
28use anyhow::{Context, Result, bail};
29use rlx_llama_base::LlamaBaseConfig;
30use std::path::{Path, PathBuf};
31
32pub use rlx_llama32::{Llama32ConfigSource, Llama32Runner, Llama32RunnerBuilder};
33
34pub mod config;
35pub mod flow;
36pub mod runner;
37
38pub use config::{NemotronHybridConfig, NemotronLayerKind};
39pub use flow::{mamba2_decode_layer_plugin_with_sink, stateless_attention_layer_plugin};
40pub use runner::{NemotronHybridRunner, NemotronHybridRunnerBuilder};
41
42pub const PLAN_MILESTONE: &str = "M5";
43pub const FAMILY: &str = "Nemotron (text)";
44
45const ACCEPTED_ARCHES: &[&str] = &["nemotron", "nemotron_h", "nemotron_h_moe"];
46const ATTN_ONLY_ARCHES: &[&str] = &["nemotron"];
47
48pub struct NemotronRunner {
49    inner: Llama32Runner,
50    config: LlamaBaseConfig,
51}
52
53impl NemotronRunner {
54    pub fn builder() -> NemotronRunnerBuilder {
55        NemotronRunnerBuilder::default()
56    }
57    pub fn config(&self) -> &LlamaBaseConfig {
58        &self.config
59    }
60    pub fn inner(&self) -> &Llama32Runner {
61        &self.inner
62    }
63    pub fn inner_mut(&mut self) -> &mut Llama32Runner {
64        &mut self.inner
65    }
66}
67
68#[derive(Debug, Default)]
69pub struct NemotronRunnerBuilder {
70    weights: Option<PathBuf>,
71    inner: Llama32RunnerBuilder,
72}
73
74impl NemotronRunnerBuilder {
75    pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
76        let p = path.into();
77        self.weights = Some(p.clone());
78        self.inner = self.inner.weights(p);
79        self
80    }
81    pub fn inner(mut self, f: impl FnOnce(Llama32RunnerBuilder) -> Llama32RunnerBuilder) -> Self {
82        self.inner = f(self.inner);
83        self
84    }
85
86    pub fn build(self) -> Result<NemotronRunner> {
87        let weights = self
88            .weights
89            .as_ref()
90            .ok_or_else(|| anyhow::anyhow!("weights path required (call .weights(...))"))?
91            .clone();
92        let config = LlamaBaseConfig::from_gguf_path(&weights)
93            .with_context(|| format!("rlx-nemotron: parse {weights:?}"))?;
94        if !ACCEPTED_ARCHES.contains(&config.arch.as_str()) {
95            bail!(
96                "rlx-nemotron: expected `general.architecture` ∈ {ACCEPTED_ARCHES:?}; \
97                 got `{}` at {weights:?}",
98                config.arch
99            );
100        }
101        if !ATTN_ONLY_ARCHES.contains(&config.arch.as_str()) {
102            bail!(
103                "rlx-nemotron: arch `{}` is hybrid Mamba2+attention — use \
104                 `NemotronHybridRunner::builder()` (this builder is attention-only \
105                 via the Llama32Runner delegate). The hybrid runner reads the same \
106                 `--weights` path, picks layer kinds from \
107                 `{0}.{{layer_kinds, attn_layer_period}}` metadata, and drives \
108                 per-layer Mamba2 state buffers across decode calls.",
109                config.arch
110            );
111        }
112        let inner = self
113            .inner
114            .build()
115            .context("rlx-nemotron: building underlying Llama32Runner")?;
116        Ok(NemotronRunner { inner, config })
117    }
118}
119
120pub fn cli_run(args: &[String]) -> Result<()> {
121    if let Some(first) = args.iter().position(|a| a == "--weights") {
122        if let Some(path) = args.get(first + 1) {
123            let cfg = LlamaBaseConfig::from_gguf_path(Path::new(path))
124                .with_context(|| format!("rlx-nemotron: parse {path}"))?;
125            if !ACCEPTED_ARCHES.contains(&cfg.arch.as_str()) {
126                bail!(
127                    "rlx-nemotron: {path}: GGUF arch = `{}`, expected one of {ACCEPTED_ARCHES:?}",
128                    cfg.arch
129                );
130            }
131        }
132    }
133    rlx_llama32::cli::run(args)
134}