Skip to main content

rlx_phi/
lib.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Phi 3 / Phi 4 runner.
17//!
18//! Phi-3 and Phi-4 ship as `general.architecture = phi3` in their GGUF
19//! converters (Phi-4 reuses the Phi-3 arch tag upstream — there's no
20//! separate `phi4` enum in llama.cpp). This crate is a thin wrapper
21//! over [`rlx_llama32::Llama32Runner`] with arch validation.
22//!
23//! **Caveat:** Phi-3's per-layer LayerNorm placement and partial-RoPE
24//! split aren't yet implemented in `rlx-llama32` — runs will produce
25//! *some* tokens but won't match the upstream reference until those
26//! land. PLAN.md M4 follow-up.
27
28use anyhow::{Context, Result, bail};
29use rlx_llama_base::LlamaBaseConfig;
30use std::path::{Path, PathBuf};
31
32pub use rlx_llama32::{Llama32ConfigSource, Llama32Runner, Llama32RunnerBuilder};
33
34pub const PLAN_MILESTONE: &str = "M4";
35pub const FAMILY: &str = "Phi 3 / Phi 4";
36
37const ACCEPTED_ARCHES: &[&str] = &["phi3"];
38
39pub struct PhiRunner {
40    inner: Llama32Runner,
41    config: LlamaBaseConfig,
42}
43
44impl PhiRunner {
45    pub fn builder() -> PhiRunnerBuilder {
46        PhiRunnerBuilder::default()
47    }
48    pub fn config(&self) -> &LlamaBaseConfig {
49        &self.config
50    }
51    pub fn inner(&self) -> &Llama32Runner {
52        &self.inner
53    }
54    pub fn inner_mut(&mut self) -> &mut Llama32Runner {
55        &mut self.inner
56    }
57    pub fn generate_packed(
58        &mut self,
59        prompt_ids: &[u32],
60        n_new: usize,
61        on_token: impl FnMut(u32),
62    ) -> Result<Vec<u32>> {
63        self.inner.generate_packed(prompt_ids, n_new, on_token)
64    }
65}
66
67#[derive(Debug, Clone, Default)]
68pub struct PhiRunnerBuilder {
69    weights: Option<PathBuf>,
70    inner: Llama32RunnerBuilder,
71}
72
73impl PhiRunnerBuilder {
74    pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
75        let p: PathBuf = path.into();
76        self.weights = Some(p.clone());
77        self.inner = self.inner.weights(p);
78        self
79    }
80    pub fn max_seq(mut self, n: usize) -> Self {
81        self.inner = self.inner.max_seq(n);
82        self
83    }
84    pub fn packed_weights(mut self, on: bool) -> Self {
85        self.inner = self.inner.packed_weights(on);
86        self
87    }
88    pub fn build(self) -> Result<PhiRunner> {
89        let weights = self
90            .weights
91            .as_ref()
92            .ok_or_else(|| anyhow::anyhow!("weights path required"))?
93            .clone();
94        let config = LlamaBaseConfig::from_gguf_path(&weights)
95            .with_context(|| format!("rlx-phi: parse {weights:?}"))?;
96        if !ACCEPTED_ARCHES.contains(&config.arch.as_str()) {
97            bail!(
98                "rlx-phi: expected `general.architecture` ∈ {ACCEPTED_ARCHES:?}; got `{}` at {weights:?}",
99                config.arch
100            );
101        }
102        let inner = self
103            .inner
104            .build()
105            .context("rlx-phi: building underlying Llama32Runner")?;
106        Ok(PhiRunner { inner, config })
107    }
108}
109
110pub fn cli_run(args: &[String]) -> Result<()> {
111    if let Some(first) = args.iter().position(|a| a == "--weights") {
112        if let Some(path) = args.get(first + 1) {
113            let cfg = LlamaBaseConfig::from_gguf_path(Path::new(path))
114                .with_context(|| format!("rlx-phi: parse {path}"))?;
115            if !ACCEPTED_ARCHES.contains(&cfg.arch.as_str()) {
116                bail!(
117                    "rlx-phi: {path}: GGUF arch = `{}`, expected one of {ACCEPTED_ARCHES:?}",
118                    cfg.arch
119                );
120            }
121        }
122    }
123    rlx_llama32::cli::run(args)
124}