1use anyhow::{Context, Result, bail};
31use rlx_llama_base::LlamaBaseConfig;
32use std::path::{Path, PathBuf};
33
34pub use rlx_llama32::{Llama32ConfigSource, Llama32Runner, Llama32RunnerBuilder};
35
36pub const PLAN_MILESTONE: &str = "M4";
37pub const FAMILY: &str = "Command-R / Cohere";
38
39const ACCEPTED_ARCHES: &[&str] = &["command-r", "cohere2"];
40
41pub struct CohereRunner {
42 inner: Llama32Runner,
43 config: LlamaBaseConfig,
44}
45
46impl CohereRunner {
47 pub fn builder() -> CohereRunnerBuilder {
48 CohereRunnerBuilder::default()
49 }
50 pub fn config(&self) -> &LlamaBaseConfig {
51 &self.config
52 }
53 pub fn inner(&self) -> &Llama32Runner {
54 &self.inner
55 }
56 pub fn inner_mut(&mut self) -> &mut Llama32Runner {
57 &mut self.inner
58 }
59 pub fn generate_packed(
60 &mut self,
61 prompt_ids: &[u32],
62 n_new: usize,
63 on_token: impl FnMut(u32),
64 ) -> Result<Vec<u32>> {
65 self.inner.generate_packed(prompt_ids, n_new, on_token)
66 }
67}
68
69#[derive(Debug, Clone, Default)]
70pub struct CohereRunnerBuilder {
71 weights: Option<PathBuf>,
72 inner: Llama32RunnerBuilder,
73}
74
75impl CohereRunnerBuilder {
76 pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
77 let p: PathBuf = path.into();
78 self.weights = Some(p.clone());
79 self.inner = self.inner.weights(p);
80 self
81 }
82 pub fn max_seq(mut self, n: usize) -> Self {
83 self.inner = self.inner.max_seq(n);
84 self
85 }
86 pub fn packed_weights(mut self, on: bool) -> Self {
87 self.inner = self.inner.packed_weights(on);
88 self
89 }
90 pub fn build(self) -> Result<CohereRunner> {
91 let weights = self
92 .weights
93 .as_ref()
94 .ok_or_else(|| anyhow::anyhow!("weights path required"))?
95 .clone();
96 let config = LlamaBaseConfig::from_gguf_path(&weights)
97 .with_context(|| format!("rlx-cohere: parse {weights:?}"))?;
98 if !ACCEPTED_ARCHES.contains(&config.arch.as_str()) {
99 bail!(
100 "rlx-cohere: expected `general.architecture` ∈ {ACCEPTED_ARCHES:?}; got `{}` at {weights:?}",
101 config.arch
102 );
103 }
104 let inner = self
105 .inner
106 .build()
107 .context("rlx-cohere: building underlying Llama32Runner")?;
108 Ok(CohereRunner { inner, config })
109 }
110}
111
112pub fn cli_run(args: &[String]) -> Result<()> {
113 if let Some(first) = args.iter().position(|a| a == "--weights") {
114 if let Some(path) = args.get(first + 1) {
115 let cfg = LlamaBaseConfig::from_gguf_path(Path::new(path))
116 .with_context(|| format!("rlx-cohere: parse {path}"))?;
117 if !ACCEPTED_ARCHES.contains(&cfg.arch.as_str()) {
118 bail!(
119 "rlx-cohere: {path}: GGUF arch = `{}`, expected one of {ACCEPTED_ARCHES:?}",
120 cfg.arch
121 );
122 }
123 }
124 }
125 rlx_llama32::cli::run(args)
126}