Skip to main content

multiscreen_rs/
inference.rs

1//! High-level inference API.
2//!
3//! Provides [`GenerationConfig`] and [`ChatModel`] for easy token-level
4//! text generation. Users encode/decode text with their own tokenizer.
5//!
6//! # Non-streaming (all tokens at once)
7//!
8//! ```rust,no_run
9//! use multiscreen_rs::prelude::*;
10//!
11//! fn main() -> multiscreen_rs::Result<()> {
12//!     let model = ChatModel::load("checkpoints/latest.mpk")?;
13//!     let token_ids = model.generate(&[1, 2, 3], GenerationConfig::default())?;
14//!     println!("generated tokens: {:?}", token_ids);
15//!     Ok(())
16//! }
17//! ```
18//!
19//! # Streaming (token by token, like ChatGPT)
20//!
21//! ```rust,no_run
22//! use multiscreen_rs::prelude::*;
23//!
24//! fn main() -> multiscreen_rs::Result<()> {
25//!     let model = ChatModel::load("checkpoints/latest.mpk")?;
26//!     let full = model.generate_stream(
27//!         &[1, 2, 3],
28//!         GenerationConfig::default(),
29//!         |token_id, _index| {
30//!             // Decode with YOUR tokenizer and print word-by-word
31//!             print!("{} ", token_id);
32//!             true // return false to stop early
33//!         },
34//!     )?;
35//!     Ok(())
36//! }
37//! ```
38
39use crate::error::{Error, Result};
40use crate::model::{ModelInferenceConfig, MultiscreenModel, MultiscreenModelConfig};
41use crate::runtime::{DefaultAutodiffBackend, DefaultBackend, InferenceDevice, default_device};
42use burn::module::AutodiffModule;
43use std::fs;
44use std::path::{Path, PathBuf};
45
46// ---------------------------------------------------------------------------
47// GenerationConfig
48// ---------------------------------------------------------------------------
49
50/// Configuration for text generation.
51#[derive(Clone, Debug)]
52pub struct GenerationConfig {
53    /// Maximum number of new tokens to generate (default: 64).
54    pub max_new_tokens: usize,
55    /// Pad token ID (default: 0).
56    pub pad_token_id: u32,
57}
58
59impl Default for GenerationConfig {
60    fn default() -> Self {
61        Self {
62            max_new_tokens: 64,
63            pad_token_id: 0,
64        }
65    }
66}
67
68// ---------------------------------------------------------------------------
69// ChatModel
70// ---------------------------------------------------------------------------
71
72/// High-level model for token-level text generation.
73///
74/// Load a trained checkpoint and generate token IDs in a single call.
75/// `ChatModel` automatically discovers the model config next to the
76/// checkpoint file. Users bring their own tokenizer to encode/decode text.
77///
78/// # Example
79///
80/// ```rust,no_run
81/// use multiscreen_rs::prelude::*;
82///
83/// fn main() -> multiscreen_rs::Result<()> {
84///     let model = ChatModel::load("checkpoints/latest.mpk")?;
85///     let token_ids = model.generate(&[1, 2, 3], GenerationConfig::default())?;
86///     println!("generated tokens: {:?}", token_ids);
87///     Ok(())
88/// }
89/// ```
90pub struct ChatModel {
91    model: MultiscreenModel<DefaultBackend>,
92    device: InferenceDevice,
93    config: MultiscreenModelConfig,
94}
95
96impl ChatModel {
97    /// Load a `ChatModel` from a checkpoint path.
98    ///
99    /// `path` should point to a `.mpk` weights file (e.g.
100    /// `"checkpoints/latest.mpk"` or `"runs/chat/checkpoints/latest.mpk"`).
101    ///
102    /// The method resolves `config.json` relative to the checkpoint's parent
103    /// directory for model architecture. Falls back to Params10M defaults.
104    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
105        let checkpoint_path = path.as_ref();
106
107        // Resolve the directory that contains the checkpoint file.
108        let checkpoint_dir = checkpoint_path
109            .parent()
110            .ok_or_else(|| Error::Io(format!("cannot determine parent of {:?}", checkpoint_path)))?
111            .to_path_buf();
112
113        // ------ config.json ------
114        let config = match find_file(&[checkpoint_dir.join("config.json")]) {
115            Ok(config_path) => {
116                let json = fs::read_to_string(&config_path).map_err(|e| {
117                    Error::Io(format!("failed to read {}: {e}", config_path.display()))
118                })?;
119                serde_json::from_str::<MultiscreenModelConfig>(&json).map_err(|e| {
120                    Error::Serialization(format!("failed to parse {}: {e}", config_path.display()))
121                })?
122            }
123            Err(_) => {
124                // Fall back to 10M-parameter preset defaults.
125                // User should provide config.json for correct architecture.
126                MultiscreenModelConfig::preset_10m(8192, 512)
127            }
128        };
129
130        // ------ device + model ------
131        // Load with Autodiff backend first (needed for parameter loading),
132        // then convert to inference-only inner backend via .valid().
133        // This prevents VRAM leak from autodiff computation graphs during
134        // autoregressive generation.
135        let device = default_device()?;
136        let mut model = MultiscreenModel::<DefaultAutodiffBackend>::new(config.clone(), &device)?;
137        model.load_parameters(checkpoint_path)?;
138        let inner_device = device;
139        let model = model.valid(); // Strip Autodiff wrapper → MultiscreenModel<DefaultBackend>
140
141        Ok(Self {
142            model,
143            device: inner_device,
144            config,
145        })
146    }
147
148    /// Generate token IDs from a prompt token sequence.
149    ///
150    /// Returns all generated tokens (prompt + new) at once.
151    /// For streaming / token-by-token output, use [`Self::generate_stream`].
152    pub fn generate(&self, prompt: &[u32], config: GenerationConfig) -> Result<Vec<u32>> {
153        let inference_config = ModelInferenceConfig {
154            max_new_tokens: config.max_new_tokens,
155            pad_token_id: config.pad_token_id,
156        };
157        let output = self
158            .model
159            .infer_tokens(prompt, &inference_config, &self.device)?;
160        Ok(output.token_ids)
161    }
162
163    /// Generate token IDs one at a time, invoking a callback for each newly
164    /// produced token.
165    ///
166    /// This enables streaming / word-by-word output similar to ChatGPT.
167    /// The callback receives `(token_id, index)` where `index` is the
168    /// zero-based position of the *new* token (0 = first generated token).
169    /// Return `false` from the callback to stop generation early.
170    ///
171    /// Returns the full output sequence (prompt + generated tokens).
172    ///
173    /// # Example
174    ///
175    /// ```rust,no_run
176    /// use multiscreen_rs::prelude::*;
177    ///
178    /// fn main() -> multiscreen_rs::Result<()> {
179    ///     let model = ChatModel::load("checkpoints/latest.mpk")?;
180    ///     let prompt: &[u32] = &[1, 2, 3];
181    ///
182    ///     let full_output = model.generate_stream(
183    ///         prompt,
184    ///         GenerationConfig::default(),
185    ///         |token_id, _index| {
186    ///             // Stream each token as it is produced.
187    ///             // Decode with YOUR tokenizer and print word-by-word.
188    ///             print!("{} ", token_id); // Replace with actual decoding
189    ///             true // return false to stop early
190    ///         },
191    ///     )?;
192    ///
193    ///     println!("\nFull sequence: {:?}", full_output);
194    ///     Ok(())
195    /// }
196    /// ```
197    pub fn generate_stream(
198        &self,
199        prompt: &[u32],
200        config: GenerationConfig,
201        on_token: impl FnMut(u32, usize) -> bool,
202    ) -> Result<Vec<u32>> {
203        let inference_config = ModelInferenceConfig {
204            max_new_tokens: config.max_new_tokens,
205            pad_token_id: config.pad_token_id,
206        };
207        let output =
208            self.model
209                .infer_tokens_stream(prompt, &inference_config, &self.device, on_token)?;
210        Ok(output.token_ids)
211    }
212
213    /// Run a forward pass on the padded context and return logits.
214    ///
215    /// Returns a tensor of shape `[1, seq_len, vocab_size]`.
216    /// Use this for custom sampling strategies (top-k, temperature, etc.).
217    pub fn predict_logits(&self, context: &[u32]) -> Result<burn::Tensor<DefaultBackend, 3>> {
218        let pad_token_id = 0;
219        self.model
220            .forward_logits(context, pad_token_id, &self.device)
221    }
222
223    /// Access the underlying neural model.
224    pub fn model(&self) -> &MultiscreenModel<DefaultBackend> {
225        &self.model
226    }
227
228    /// Access the model configuration.
229    pub fn config(&self) -> &MultiscreenModelConfig {
230        &self.config
231    }
232
233    /// Access the device the model is running on.
234    pub fn device(&self) -> &InferenceDevice {
235        &self.device
236    }
237}
238
239// ---------------------------------------------------------------------------
240// Helpers
241// ---------------------------------------------------------------------------
242
243/// Return the first path in `candidates` that exists on disk, or an error.
244fn find_file(candidates: &[PathBuf]) -> Result<PathBuf> {
245    for candidate in candidates {
246        if candidate.exists() {
247            return Ok(candidate.clone());
248        }
249    }
250    let descriptions = candidates
251        .iter()
252        .map(|p| format!("  {}", p.display()))
253        .collect::<Vec<_>>()
254        .join("\n");
255    Err(Error::Io(format!(
256        "file not found; searched:\n{descriptions}"
257    )))
258}