Skip to main content

rlx_cohere/
lib.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Cohere Command-R runner.
17//!
18//! Command-R / Cohere2 ship as `general.architecture = command-r` or
19//! `cohere2` in their GGUF converters — Llama-shaped with
20//! parallel-residual attention and no embedding-output norm. This crate
21//! is a thin wrapper over [`rlx_llama32::Llama32Runner`] with arch
22//! validation.
23//!
24//! **Caveat:** Command-R's parallel residual (Q/K/V and FFN added in
25//! one residual pass) and missing embedding-output LayerNorm aren't
26//! yet wired in `rlx-llama32` — runs will produce *some* tokens but
27//! won't match the upstream reference until those land. PLAN.md M4
28//! follow-up.
29
30use anyhow::{Context, Result, bail};
31use rlx_llama_base::LlamaBaseConfig;
32use std::path::{Path, PathBuf};
33
34pub use rlx_llama32::{Llama32ConfigSource, Llama32Runner, Llama32RunnerBuilder};
35
36pub const PLAN_MILESTONE: &str = "M4";
37pub const FAMILY: &str = "Command-R / Cohere";
38
39const ACCEPTED_ARCHES: &[&str] = &["command-r", "cohere2"];
40
41pub struct CohereRunner {
42    inner: Llama32Runner,
43    config: LlamaBaseConfig,
44}
45
46impl CohereRunner {
47    pub fn builder() -> CohereRunnerBuilder {
48        CohereRunnerBuilder::default()
49    }
50    pub fn config(&self) -> &LlamaBaseConfig {
51        &self.config
52    }
53    pub fn inner(&self) -> &Llama32Runner {
54        &self.inner
55    }
56    pub fn inner_mut(&mut self) -> &mut Llama32Runner {
57        &mut self.inner
58    }
59    pub fn generate_packed(
60        &mut self,
61        prompt_ids: &[u32],
62        n_new: usize,
63        on_token: impl FnMut(u32),
64    ) -> Result<Vec<u32>> {
65        self.inner.generate_packed(prompt_ids, n_new, on_token)
66    }
67}
68
69#[derive(Debug, Clone, Default)]
70pub struct CohereRunnerBuilder {
71    weights: Option<PathBuf>,
72    inner: Llama32RunnerBuilder,
73}
74
75impl CohereRunnerBuilder {
76    pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
77        let p: PathBuf = path.into();
78        self.weights = Some(p.clone());
79        self.inner = self.inner.weights(p);
80        self
81    }
82    pub fn max_seq(mut self, n: usize) -> Self {
83        self.inner = self.inner.max_seq(n);
84        self
85    }
86    pub fn packed_weights(mut self, on: bool) -> Self {
87        self.inner = self.inner.packed_weights(on);
88        self
89    }
90    pub fn build(self) -> Result<CohereRunner> {
91        let weights = self
92            .weights
93            .as_ref()
94            .ok_or_else(|| anyhow::anyhow!("weights path required"))?
95            .clone();
96        let config = LlamaBaseConfig::from_gguf_path(&weights)
97            .with_context(|| format!("rlx-cohere: parse {weights:?}"))?;
98        if !ACCEPTED_ARCHES.contains(&config.arch.as_str()) {
99            bail!(
100                "rlx-cohere: expected `general.architecture` ∈ {ACCEPTED_ARCHES:?}; got `{}` at {weights:?}",
101                config.arch
102            );
103        }
104        let inner = self
105            .inner
106            .build()
107            .context("rlx-cohere: building underlying Llama32Runner")?;
108        Ok(CohereRunner { inner, config })
109    }
110}
111
112pub fn cli_run(args: &[String]) -> Result<()> {
113    if let Some(first) = args.iter().position(|a| a == "--weights") {
114        if let Some(path) = args.get(first + 1) {
115            let cfg = LlamaBaseConfig::from_gguf_path(Path::new(path))
116                .with_context(|| format!("rlx-cohere: parse {path}"))?;
117            if !ACCEPTED_ARCHES.contains(&cfg.arch.as_str()) {
118                bail!(
119                    "rlx-cohere: {path}: GGUF arch = `{}`, expected one of {ACCEPTED_ARCHES:?}",
120                    cfg.arch
121                );
122            }
123        }
124    }
125    rlx_llama32::cli::run(args)
126}