1use anyhow::{Context, Result, bail};
29use rlx_llama_base::LlamaBaseConfig;
30use std::path::{Path, PathBuf};
31
32pub use rlx_llama32::{Llama32ConfigSource, Llama32Runner, Llama32RunnerBuilder};
33
34pub mod config;
35pub mod flow;
36pub mod runner;
37
38pub use config::{NemotronHybridConfig, NemotronLayerKind};
39pub use flow::{mamba2_decode_layer_plugin_with_sink, stateless_attention_layer_plugin};
40pub use runner::{NemotronHybridRunner, NemotronHybridRunnerBuilder};
41
42pub const PLAN_MILESTONE: &str = "M5";
43pub const FAMILY: &str = "Nemotron (text)";
44
45const ACCEPTED_ARCHES: &[&str] = &["nemotron", "nemotron_h", "nemotron_h_moe"];
46const ATTN_ONLY_ARCHES: &[&str] = &["nemotron"];
47
48pub struct NemotronRunner {
49 inner: Llama32Runner,
50 config: LlamaBaseConfig,
51}
52
53impl NemotronRunner {
54 pub fn builder() -> NemotronRunnerBuilder {
55 NemotronRunnerBuilder::default()
56 }
57 pub fn config(&self) -> &LlamaBaseConfig {
58 &self.config
59 }
60 pub fn inner(&self) -> &Llama32Runner {
61 &self.inner
62 }
63 pub fn inner_mut(&mut self) -> &mut Llama32Runner {
64 &mut self.inner
65 }
66}
67
68#[derive(Debug, Default)]
69pub struct NemotronRunnerBuilder {
70 weights: Option<PathBuf>,
71 inner: Llama32RunnerBuilder,
72}
73
74impl NemotronRunnerBuilder {
75 pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
76 let p = path.into();
77 self.weights = Some(p.clone());
78 self.inner = self.inner.weights(p);
79 self
80 }
81 pub fn inner(mut self, f: impl FnOnce(Llama32RunnerBuilder) -> Llama32RunnerBuilder) -> Self {
82 self.inner = f(self.inner);
83 self
84 }
85
86 pub fn build(self) -> Result<NemotronRunner> {
87 let weights = self
88 .weights
89 .as_ref()
90 .ok_or_else(|| anyhow::anyhow!("weights path required (call .weights(...))"))?
91 .clone();
92 let config = LlamaBaseConfig::from_gguf_path(&weights)
93 .with_context(|| format!("rlx-nemotron: parse {weights:?}"))?;
94 if !ACCEPTED_ARCHES.contains(&config.arch.as_str()) {
95 bail!(
96 "rlx-nemotron: expected `general.architecture` ∈ {ACCEPTED_ARCHES:?}; \
97 got `{}` at {weights:?}",
98 config.arch
99 );
100 }
101 if !ATTN_ONLY_ARCHES.contains(&config.arch.as_str()) {
102 bail!(
103 "rlx-nemotron: arch `{}` is hybrid Mamba2+attention — use \
104 `NemotronHybridRunner::builder()` (this builder is attention-only \
105 via the Llama32Runner delegate). The hybrid runner reads the same \
106 `--weights` path, picks layer kinds from \
107 `{0}.{{layer_kinds, attn_layer_period}}` metadata, and drives \
108 per-layer Mamba2 state buffers across decode calls.",
109 config.arch
110 );
111 }
112 let inner = self
113 .inner
114 .build()
115 .context("rlx-nemotron: building underlying Llama32Runner")?;
116 Ok(NemotronRunner { inner, config })
117 }
118}
119
120pub fn cli_run(args: &[String]) -> Result<()> {
121 if let Some(first) = args.iter().position(|a| a == "--weights") {
122 if let Some(path) = args.get(first + 1) {
123 let cfg = LlamaBaseConfig::from_gguf_path(Path::new(path))
124 .with_context(|| format!("rlx-nemotron: parse {path}"))?;
125 if !ACCEPTED_ARCHES.contains(&cfg.arch.as_str()) {
126 bail!(
127 "rlx-nemotron: {path}: GGUF arch = `{}`, expected one of {ACCEPTED_ARCHES:?}",
128 cfg.arch
129 );
130 }
131 }
132 }
133 rlx_llama32::cli::run(args)
134}