multiscreen_rs/inference.rs
1//! High-level inference API.
2//!
3//! Provides [`GenerationConfig`] and [`ChatModel`] for easy token-level
4//! text generation. Users encode/decode text with their own tokenizer.
5//!
6//! # Non-streaming (all tokens at once)
7//!
8//! ```rust,no_run
9//! use multiscreen_rs::prelude::*;
10//!
11//! fn main() -> multiscreen_rs::Result<()> {
12//! let model = ChatModel::load("checkpoints/latest.mpk")?;
13//! let token_ids = model.generate(&[1, 2, 3], GenerationConfig::default())?;
14//! println!("generated tokens: {:?}", token_ids);
15//! Ok(())
16//! }
17//! ```
18//!
19//! # Streaming (token by token, like ChatGPT)
20//!
21//! ```rust,no_run
22//! use multiscreen_rs::prelude::*;
23//!
24//! fn main() -> multiscreen_rs::Result<()> {
25//! let model = ChatModel::load("checkpoints/latest.mpk")?;
26//! let full = model.generate_stream(
27//! &[1, 2, 3],
28//! GenerationConfig::default(),
29//! |token_id, _index| {
30//! // Decode with YOUR tokenizer and print word-by-word
31//! print!("{} ", token_id);
32//! true // return false to stop early
33//! },
34//! )?;
35//! Ok(())
36//! }
37//! ```
38
39use crate::error::{Error, Result};
40use crate::model::{ModelInferenceConfig, MultiscreenModel, MultiscreenModelConfig};
41use crate::runtime::{DefaultAutodiffBackend, DefaultBackend, InferenceDevice, default_device};
42use burn::module::AutodiffModule;
43use std::fs;
44use std::path::{Path, PathBuf};
45
46// ---------------------------------------------------------------------------
47// GenerationConfig
48// ---------------------------------------------------------------------------
49
50/// Configuration for text generation.
51#[derive(Clone, Debug)]
52pub struct GenerationConfig {
53 /// Maximum number of new tokens to generate (default: 64).
54 pub max_new_tokens: usize,
55 /// Pad token ID (default: 0).
56 pub pad_token_id: u32,
57}
58
59impl Default for GenerationConfig {
60 fn default() -> Self {
61 Self {
62 max_new_tokens: 64,
63 pad_token_id: 0,
64 }
65 }
66}
67
68// ---------------------------------------------------------------------------
69// ChatModel
70// ---------------------------------------------------------------------------
71
72/// High-level model for token-level text generation.
73///
74/// Load a trained checkpoint and generate token IDs in a single call.
75/// `ChatModel` automatically discovers the model config next to the
76/// checkpoint file. Users bring their own tokenizer to encode/decode text.
77///
78/// # Example
79///
80/// ```rust,no_run
81/// use multiscreen_rs::prelude::*;
82///
83/// fn main() -> multiscreen_rs::Result<()> {
84/// let model = ChatModel::load("checkpoints/latest.mpk")?;
85/// let token_ids = model.generate(&[1, 2, 3], GenerationConfig::default())?;
86/// println!("generated tokens: {:?}", token_ids);
87/// Ok(())
88/// }
89/// ```
90pub struct ChatModel {
91 model: MultiscreenModel<DefaultBackend>,
92 device: InferenceDevice,
93 config: MultiscreenModelConfig,
94}
95
96impl ChatModel {
97 /// Load a `ChatModel` from a checkpoint path.
98 ///
99 /// `path` should point to a `.mpk` weights file (e.g.
100 /// `"checkpoints/latest.mpk"` or `"runs/chat/checkpoints/latest.mpk"`).
101 ///
102 /// The method resolves `config.json` relative to the checkpoint's parent
103 /// directory for model architecture. Falls back to Params10M defaults.
104 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
105 let checkpoint_path = path.as_ref();
106
107 // Resolve the directory that contains the checkpoint file.
108 let checkpoint_dir = checkpoint_path
109 .parent()
110 .ok_or_else(|| Error::Io(format!("cannot determine parent of {:?}", checkpoint_path)))?
111 .to_path_buf();
112
113 // ------ config.json ------
114 let config = match find_file(&[checkpoint_dir.join("config.json")]) {
115 Ok(config_path) => {
116 let json = fs::read_to_string(&config_path).map_err(|e| {
117 Error::Io(format!("failed to read {}: {e}", config_path.display()))
118 })?;
119 serde_json::from_str::<MultiscreenModelConfig>(&json).map_err(|e| {
120 Error::Serialization(format!("failed to parse {}: {e}", config_path.display()))
121 })?
122 }
123 Err(_) => {
124 // Fall back to 10M-parameter preset defaults.
125 // User should provide config.json for correct architecture.
126 MultiscreenModelConfig::preset_10m(8192, 512)
127 }
128 };
129
130 // ------ device + model ------
131 // Load with Autodiff backend first (needed for parameter loading),
132 // then convert to inference-only inner backend via .valid().
133 // This prevents VRAM leak from autodiff computation graphs during
134 // autoregressive generation.
135 let device = default_device()?;
136 let mut model = MultiscreenModel::<DefaultAutodiffBackend>::new(config.clone(), &device)?;
137 model.load_parameters(checkpoint_path)?;
138 let inner_device = device;
139 let model = model.valid(); // Strip Autodiff wrapper → MultiscreenModel<DefaultBackend>
140
141 Ok(Self {
142 model,
143 device: inner_device,
144 config,
145 })
146 }
147
148 /// Generate token IDs from a prompt token sequence.
149 ///
150 /// Returns all generated tokens (prompt + new) at once.
151 /// For streaming / token-by-token output, use [`Self::generate_stream`].
152 pub fn generate(&self, prompt: &[u32], config: GenerationConfig) -> Result<Vec<u32>> {
153 let inference_config = ModelInferenceConfig {
154 max_new_tokens: config.max_new_tokens,
155 pad_token_id: config.pad_token_id,
156 };
157 let output = self
158 .model
159 .infer_tokens(prompt, &inference_config, &self.device)?;
160 Ok(output.token_ids)
161 }
162
163 /// Generate token IDs one at a time, invoking a callback for each newly
164 /// produced token.
165 ///
166 /// This enables streaming / word-by-word output similar to ChatGPT.
167 /// The callback receives `(token_id, index)` where `index` is the
168 /// zero-based position of the *new* token (0 = first generated token).
169 /// Return `false` from the callback to stop generation early.
170 ///
171 /// Returns the full output sequence (prompt + generated tokens).
172 ///
173 /// # Example
174 ///
175 /// ```rust,no_run
176 /// use multiscreen_rs::prelude::*;
177 ///
178 /// fn main() -> multiscreen_rs::Result<()> {
179 /// let model = ChatModel::load("checkpoints/latest.mpk")?;
180 /// let prompt: &[u32] = &[1, 2, 3];
181 ///
182 /// let full_output = model.generate_stream(
183 /// prompt,
184 /// GenerationConfig::default(),
185 /// |token_id, _index| {
186 /// // Stream each token as it is produced.
187 /// // Decode with YOUR tokenizer and print word-by-word.
188 /// print!("{} ", token_id); // Replace with actual decoding
189 /// true // return false to stop early
190 /// },
191 /// )?;
192 ///
193 /// println!("\nFull sequence: {:?}", full_output);
194 /// Ok(())
195 /// }
196 /// ```
197 pub fn generate_stream(
198 &self,
199 prompt: &[u32],
200 config: GenerationConfig,
201 on_token: impl FnMut(u32, usize) -> bool,
202 ) -> Result<Vec<u32>> {
203 let inference_config = ModelInferenceConfig {
204 max_new_tokens: config.max_new_tokens,
205 pad_token_id: config.pad_token_id,
206 };
207 let output =
208 self.model
209 .infer_tokens_stream(prompt, &inference_config, &self.device, on_token)?;
210 Ok(output.token_ids)
211 }
212
213 /// Run a forward pass on the padded context and return logits.
214 ///
215 /// Returns a tensor of shape `[1, seq_len, vocab_size]`.
216 /// Use this for custom sampling strategies (top-k, temperature, etc.).
217 pub fn predict_logits(&self, context: &[u32]) -> Result<burn::Tensor<DefaultBackend, 3>> {
218 let pad_token_id = 0;
219 self.model
220 .forward_logits(context, pad_token_id, &self.device)
221 }
222
223 /// Access the underlying neural model.
224 pub fn model(&self) -> &MultiscreenModel<DefaultBackend> {
225 &self.model
226 }
227
228 /// Access the model configuration.
229 pub fn config(&self) -> &MultiscreenModelConfig {
230 &self.config
231 }
232
233 /// Access the device the model is running on.
234 pub fn device(&self) -> &InferenceDevice {
235 &self.device
236 }
237}
238
239// ---------------------------------------------------------------------------
240// Helpers
241// ---------------------------------------------------------------------------
242
243/// Return the first path in `candidates` that exists on disk, or an error.
244fn find_file(candidates: &[PathBuf]) -> Result<PathBuf> {
245 for candidate in candidates {
246 if candidate.exists() {
247 return Ok(candidate.clone());
248 }
249 }
250 let descriptions = candidates
251 .iter()
252 .map(|p| format!(" {}", p.display()))
253 .collect::<Vec<_>>()
254 .join("\n");
255 Err(Error::Io(format!(
256 "file not found; searched:\n{descriptions}"
257 )))
258}