1use anyhow::{Context, Result, bail};
29use rlx_llama_base::LlamaBaseConfig;
30use std::path::{Path, PathBuf};
31
32pub use rlx_llama32::{Llama32ConfigSource, Llama32Runner, Llama32RunnerBuilder};
33
34pub const PLAN_MILESTONE: &str = "M4";
35pub const FAMILY: &str = "Phi 3 / Phi 4";
36
37const ACCEPTED_ARCHES: &[&str] = &["phi3"];
38
39pub struct PhiRunner {
40 inner: Llama32Runner,
41 config: LlamaBaseConfig,
42}
43
44impl PhiRunner {
45 pub fn builder() -> PhiRunnerBuilder {
46 PhiRunnerBuilder::default()
47 }
48 pub fn config(&self) -> &LlamaBaseConfig {
49 &self.config
50 }
51 pub fn inner(&self) -> &Llama32Runner {
52 &self.inner
53 }
54 pub fn inner_mut(&mut self) -> &mut Llama32Runner {
55 &mut self.inner
56 }
57 pub fn generate_packed(
58 &mut self,
59 prompt_ids: &[u32],
60 n_new: usize,
61 on_token: impl FnMut(u32),
62 ) -> Result<Vec<u32>> {
63 self.inner.generate_packed(prompt_ids, n_new, on_token)
64 }
65}
66
67#[derive(Debug, Clone, Default)]
68pub struct PhiRunnerBuilder {
69 weights: Option<PathBuf>,
70 inner: Llama32RunnerBuilder,
71}
72
73impl PhiRunnerBuilder {
74 pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
75 let p: PathBuf = path.into();
76 self.weights = Some(p.clone());
77 self.inner = self.inner.weights(p);
78 self
79 }
80 pub fn max_seq(mut self, n: usize) -> Self {
81 self.inner = self.inner.max_seq(n);
82 self
83 }
84 pub fn packed_weights(mut self, on: bool) -> Self {
85 self.inner = self.inner.packed_weights(on);
86 self
87 }
88 pub fn build(self) -> Result<PhiRunner> {
89 let weights = self
90 .weights
91 .as_ref()
92 .ok_or_else(|| anyhow::anyhow!("weights path required"))?
93 .clone();
94 let config = LlamaBaseConfig::from_gguf_path(&weights)
95 .with_context(|| format!("rlx-phi: parse {weights:?}"))?;
96 if !ACCEPTED_ARCHES.contains(&config.arch.as_str()) {
97 bail!(
98 "rlx-phi: expected `general.architecture` ∈ {ACCEPTED_ARCHES:?}; got `{}` at {weights:?}",
99 config.arch
100 );
101 }
102 let inner = self
103 .inner
104 .build()
105 .context("rlx-phi: building underlying Llama32Runner")?;
106 Ok(PhiRunner { inner, config })
107 }
108}
109
110pub fn cli_run(args: &[String]) -> Result<()> {
111 if let Some(first) = args.iter().position(|a| a == "--weights") {
112 if let Some(path) = args.get(first + 1) {
113 let cfg = LlamaBaseConfig::from_gguf_path(Path::new(path))
114 .with_context(|| format!("rlx-phi: parse {path}"))?;
115 if !ACCEPTED_ARCHES.contains(&cfg.arch.as_str()) {
116 bail!(
117 "rlx-phi: {path}: GGUF arch = `{}`, expected one of {ACCEPTED_ARCHES:?}",
118 cfg.arch
119 );
120 }
121 }
122 }
123 rlx_llama32::cli::run(args)
124}