multiscreen_rs/inference.rs
1//! High-level inference API.
2//!
3//! Provides [`GenerationConfig`] and [`ChatModel`] for easy token-level
4//! text generation. Users encode/decode text with their own tokenizer.
5//!
6//! # Non-streaming (all tokens at once)
7//!
8//! ```rust,no_run
9//! use multiscreen_rs::prelude::*;
10//!
11//! fn main() -> multiscreen_rs::Result<()> {
12//! let model = ChatModel::load("checkpoints/latest.mpk")?;
13//! let token_ids = model.generate(&[1, 2, 3], GenerationConfig::default())?;
14//! println!("generated tokens: {:?}", token_ids);
15//! Ok(())
16//! }
17//! ```
18//!
19//! # Streaming (token by token, like ChatGPT)
20//!
21//! ```rust,no_run
22//! use multiscreen_rs::prelude::*;
23//!
24//! fn main() -> multiscreen_rs::Result<()> {
25//! let model = ChatModel::load("checkpoints/latest.mpk")?;
26//! let full = model.generate_stream(
27//! &[1, 2, 3],
28//! GenerationConfig::default(),
29//! |token_id, _index| {
30//! // Decode with YOUR tokenizer and print word-by-word
31//! print!("{} ", token_id);
32//! true // return false to stop early
33//! },
34//! )?;
35//! Ok(())
36//! }
37//! ```
38
39use crate::error::{Error, Result};
40use crate::model::{DefaultMultiscreenModel, ModelInferenceConfig, MultiscreenModelConfig};
41use crate::runtime::{default_device, Device};
42use std::fs;
43use std::path::{Path, PathBuf};
44
45// ---------------------------------------------------------------------------
46// GenerationConfig
47// ---------------------------------------------------------------------------
48
49/// Configuration for text generation.
50#[derive(Clone, Debug)]
51pub struct GenerationConfig {
52 /// Maximum number of new tokens to generate (default: 64).
53 pub max_new_tokens: usize,
54 /// Pad token ID (default: 0).
55 pub pad_token_id: u32,
56}
57
58impl Default for GenerationConfig {
59 fn default() -> Self {
60 Self {
61 max_new_tokens: 64,
62 pad_token_id: 0,
63 }
64 }
65}
66
67// ---------------------------------------------------------------------------
68// ChatModel
69// ---------------------------------------------------------------------------
70
71/// High-level model for token-level text generation.
72///
73/// Load a trained checkpoint and generate token IDs in a single call.
74/// `ChatModel` automatically discovers the model config next to the
75/// checkpoint file. Users bring their own tokenizer to encode/decode text.
76///
77/// # Example
78///
79/// ```rust,no_run
80/// use multiscreen_rs::prelude::*;
81///
82/// fn main() -> multiscreen_rs::Result<()> {
83/// let model = ChatModel::load("checkpoints/latest.mpk")?;
84/// let token_ids = model.generate(&[1, 2, 3], GenerationConfig::default())?;
85/// println!("generated tokens: {:?}", token_ids);
86/// Ok(())
87/// }
88/// ```
89pub struct ChatModel {
90 model: DefaultMultiscreenModel,
91 device: Device,
92 config: MultiscreenModelConfig,
93}
94
95impl ChatModel {
96 /// Load a `ChatModel` from a checkpoint path.
97 ///
98 /// `path` should point to a `.mpk` weights file (e.g.
99 /// `"checkpoints/latest.mpk"` or `"runs/chat/checkpoints/latest.mpk"`).
100 ///
101 /// The method resolves `config.json` relative to the checkpoint's parent
102 /// directory for model architecture. Falls back to Params10M defaults.
103 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
104 let checkpoint_path = path.as_ref();
105
106 // Resolve the directory that contains the checkpoint file.
107 let checkpoint_dir = checkpoint_path
108 .parent()
109 .ok_or_else(|| Error::Io(format!("cannot determine parent of {:?}", checkpoint_path)))?
110 .to_path_buf();
111
112 // ------ config.json ------
113 let config = match find_file(&[checkpoint_dir.join("config.json")]) {
114 Ok(config_path) => {
115 let json = fs::read_to_string(&config_path).map_err(|e| {
116 Error::Io(format!("failed to read {}: {e}", config_path.display()))
117 })?;
118 serde_json::from_str::<MultiscreenModelConfig>(&json).map_err(|e| {
119 Error::Serialization(format!("failed to parse {}: {e}", config_path.display()))
120 })?
121 }
122 Err(_) => {
123 // Fall back to 10M-parameter preset defaults.
124 // User should provide config.json for correct architecture.
125 MultiscreenModelConfig::preset_10m(8192, 512)
126 }
127 };
128
129 // ------ device + model ------
130 let device = default_device()?;
131 let mut model = DefaultMultiscreenModel::new(config.clone(), &device)?;
132 model.load_parameters(checkpoint_path)?;
133
134 Ok(Self {
135 model,
136 device,
137 config,
138 })
139 }
140
141 /// Generate token IDs from a prompt token sequence.
142 ///
143 /// Returns all generated tokens (prompt + new) at once.
144 /// For streaming / token-by-token output, use [`Self::generate_stream`].
145 pub fn generate(&self, prompt: &[u32], config: GenerationConfig) -> Result<Vec<u32>> {
146 let inference_config = ModelInferenceConfig {
147 max_new_tokens: config.max_new_tokens,
148 pad_token_id: config.pad_token_id,
149 };
150 let output = self
151 .model
152 .infer_tokens(prompt, &inference_config, &self.device)?;
153 Ok(output.token_ids)
154 }
155
156 /// Generate token IDs one at a time, invoking a callback for each newly
157 /// produced token.
158 ///
159 /// This enables streaming / word-by-word output similar to ChatGPT.
160 /// The callback receives `(token_id, index)` where `index` is the
161 /// zero-based position of the *new* token (0 = first generated token).
162 /// Return `false` from the callback to stop generation early.
163 ///
164 /// Returns the full output sequence (prompt + generated tokens).
165 ///
166 /// # Example
167 ///
168 /// ```rust,no_run
169 /// use multiscreen_rs::prelude::*;
170 ///
171 /// fn main() -> multiscreen_rs::Result<()> {
172 /// let model = ChatModel::load("checkpoints/latest.mpk")?;
173 /// let prompt: &[u32] = &[1, 2, 3];
174 ///
175 /// let full_output = model.generate_stream(
176 /// prompt,
177 /// GenerationConfig::default(),
178 /// |token_id, _index| {
179 /// // Stream each token as it is produced.
180 /// // Decode with YOUR tokenizer and print word-by-word.
181 /// print!("{} ", token_id); // Replace with actual decoding
182 /// true // return false to stop early
183 /// },
184 /// )?;
185 ///
186 /// println!("\nFull sequence: {:?}", full_output);
187 /// Ok(())
188 /// }
189 /// ```
190 pub fn generate_stream(
191 &self,
192 prompt: &[u32],
193 config: GenerationConfig,
194 on_token: impl FnMut(u32, usize) -> bool,
195 ) -> Result<Vec<u32>> {
196 let inference_config = ModelInferenceConfig {
197 max_new_tokens: config.max_new_tokens,
198 pad_token_id: config.pad_token_id,
199 };
200 let output =
201 self.model
202 .infer_tokens_stream(prompt, &inference_config, &self.device, on_token)?;
203 Ok(output.token_ids)
204 }
205
206 /// Access the underlying neural model.
207 pub fn model(&self) -> &DefaultMultiscreenModel {
208 &self.model
209 }
210
211 /// Access the model configuration.
212 pub fn config(&self) -> &MultiscreenModelConfig {
213 &self.config
214 }
215
216 /// Access the device the model is running on.
217 pub fn device(&self) -> &Device {
218 &self.device
219 }
220}
221
222// ---------------------------------------------------------------------------
223// Helpers
224// ---------------------------------------------------------------------------
225
226/// Return the first path in `candidates` that exists on disk, or an error.
227fn find_file(candidates: &[PathBuf]) -> Result<PathBuf> {
228 for candidate in candidates {
229 if candidate.exists() {
230 return Ok(candidate.clone());
231 }
232 }
233 let descriptions = candidates
234 .iter()
235 .map(|p| format!(" {}", p.display()))
236 .collect::<Vec<_>>()
237 .join("\n");
238 Err(Error::Io(format!(
239 "file not found; searched:\n{descriptions}"
240 )))
241}