Skip to main content

multiscreen_rs/
inference.rs

1//! High-level inference API.
2//!
3//! Provides [`GenerationConfig`] and [`ChatModel`] for easy token-level
4//! text generation. Users encode/decode text with their own tokenizer.
5//!
6//! # Non-streaming (all tokens at once)
7//!
8//! ```rust,no_run
9//! use multiscreen_rs::prelude::*;
10//!
11//! fn main() -> multiscreen_rs::Result<()> {
12//!     let model = ChatModel::load("checkpoints/latest.mpk")?;
13//!     let token_ids = model.generate(&[1, 2, 3], GenerationConfig::default())?;
14//!     println!("generated tokens: {:?}", token_ids);
15//!     Ok(())
16//! }
17//! ```
18//!
19//! # Streaming (token by token, like ChatGPT)
20//!
21//! ```rust,no_run
22//! use multiscreen_rs::prelude::*;
23//!
24//! fn main() -> multiscreen_rs::Result<()> {
25//!     let model = ChatModel::load("checkpoints/latest.mpk")?;
26//!     let full = model.generate_stream(
27//!         &[1, 2, 3],
28//!         GenerationConfig::default(),
29//!         |token_id, _index| {
30//!             // Decode with YOUR tokenizer and print word-by-word
31//!             print!("{} ", token_id);
32//!             true // return false to stop early
33//!         },
34//!     )?;
35//!     Ok(())
36//! }
37//! ```
38
39use crate::error::{Error, Result};
40use crate::model::{DefaultMultiscreenModel, ModelInferenceConfig, MultiscreenModelConfig};
41use crate::runtime::{default_device, Device};
42use std::fs;
43use std::path::{Path, PathBuf};
44
45// ---------------------------------------------------------------------------
46// GenerationConfig
47// ---------------------------------------------------------------------------
48
49/// Configuration for text generation.
50#[derive(Clone, Debug)]
51pub struct GenerationConfig {
52    /// Maximum number of new tokens to generate (default: 64).
53    pub max_new_tokens: usize,
54    /// Pad token ID (default: 0).
55    pub pad_token_id: u32,
56}
57
58impl Default for GenerationConfig {
59    fn default() -> Self {
60        Self {
61            max_new_tokens: 64,
62            pad_token_id: 0,
63        }
64    }
65}
66
67// ---------------------------------------------------------------------------
68// ChatModel
69// ---------------------------------------------------------------------------
70
71/// High-level model for token-level text generation.
72///
73/// Load a trained checkpoint and generate token IDs in a single call.
74/// `ChatModel` automatically discovers the model config next to the
75/// checkpoint file. Users bring their own tokenizer to encode/decode text.
76///
77/// # Example
78///
79/// ```rust,no_run
80/// use multiscreen_rs::prelude::*;
81///
82/// fn main() -> multiscreen_rs::Result<()> {
83///     let model = ChatModel::load("checkpoints/latest.mpk")?;
84///     let token_ids = model.generate(&[1, 2, 3], GenerationConfig::default())?;
85///     println!("generated tokens: {:?}", token_ids);
86///     Ok(())
87/// }
88/// ```
89pub struct ChatModel {
90    model: DefaultMultiscreenModel,
91    device: Device,
92    config: MultiscreenModelConfig,
93}
94
95impl ChatModel {
96    /// Load a `ChatModel` from a checkpoint path.
97    ///
98    /// `path` should point to a `.mpk` weights file (e.g.
99    /// `"checkpoints/latest.mpk"` or `"runs/chat/checkpoints/latest.mpk"`).
100    ///
101    /// The method resolves `config.json` relative to the checkpoint's parent
102    /// directory for model architecture. Falls back to Params10M defaults.
103    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
104        let checkpoint_path = path.as_ref();
105
106        // Resolve the directory that contains the checkpoint file.
107        let checkpoint_dir = checkpoint_path
108            .parent()
109            .ok_or_else(|| Error::Io(format!("cannot determine parent of {:?}", checkpoint_path)))?
110            .to_path_buf();
111
112        // ------ config.json ------
113        let config = match find_file(&[checkpoint_dir.join("config.json")]) {
114            Ok(config_path) => {
115                let json = fs::read_to_string(&config_path).map_err(|e| {
116                    Error::Io(format!("failed to read {}: {e}", config_path.display()))
117                })?;
118                serde_json::from_str::<MultiscreenModelConfig>(&json).map_err(|e| {
119                    Error::Serialization(format!("failed to parse {}: {e}", config_path.display()))
120                })?
121            }
122            Err(_) => {
123                // Fall back to 10M-parameter preset defaults.
124                // User should provide config.json for correct architecture.
125                MultiscreenModelConfig::preset_10m(8192, 512)
126            }
127        };
128
129        // ------ device + model ------
130        let device = default_device()?;
131        let mut model = DefaultMultiscreenModel::new(config.clone(), &device)?;
132        model.load_parameters(checkpoint_path)?;
133
134        Ok(Self {
135            model,
136            device,
137            config,
138        })
139    }
140
141    /// Generate token IDs from a prompt token sequence.
142    ///
143    /// Returns all generated tokens (prompt + new) at once.
144    /// For streaming / token-by-token output, use [`Self::generate_stream`].
145    pub fn generate(&self, prompt: &[u32], config: GenerationConfig) -> Result<Vec<u32>> {
146        let inference_config = ModelInferenceConfig {
147            max_new_tokens: config.max_new_tokens,
148            pad_token_id: config.pad_token_id,
149        };
150        let output = self
151            .model
152            .infer_tokens(prompt, &inference_config, &self.device)?;
153        Ok(output.token_ids)
154    }
155
156    /// Generate token IDs one at a time, invoking a callback for each newly
157    /// produced token.
158    ///
159    /// This enables streaming / word-by-word output similar to ChatGPT.
160    /// The callback receives `(token_id, index)` where `index` is the
161    /// zero-based position of the *new* token (0 = first generated token).
162    /// Return `false` from the callback to stop generation early.
163    ///
164    /// Returns the full output sequence (prompt + generated tokens).
165    ///
166    /// # Example
167    ///
168    /// ```rust,no_run
169    /// use multiscreen_rs::prelude::*;
170    ///
171    /// fn main() -> multiscreen_rs::Result<()> {
172    ///     let model = ChatModel::load("checkpoints/latest.mpk")?;
173    ///     let prompt: &[u32] = &[1, 2, 3];
174    ///
175    ///     let full_output = model.generate_stream(
176    ///         prompt,
177    ///         GenerationConfig::default(),
178    ///         |token_id, _index| {
179    ///             // Stream each token as it is produced.
180    ///             // Decode with YOUR tokenizer and print word-by-word.
181    ///             print!("{} ", token_id); // Replace with actual decoding
182    ///             true // return false to stop early
183    ///         },
184    ///     )?;
185    ///
186    ///     println!("\nFull sequence: {:?}", full_output);
187    ///     Ok(())
188    /// }
189    /// ```
190    pub fn generate_stream(
191        &self,
192        prompt: &[u32],
193        config: GenerationConfig,
194        on_token: impl FnMut(u32, usize) -> bool,
195    ) -> Result<Vec<u32>> {
196        let inference_config = ModelInferenceConfig {
197            max_new_tokens: config.max_new_tokens,
198            pad_token_id: config.pad_token_id,
199        };
200        let output =
201            self.model
202                .infer_tokens_stream(prompt, &inference_config, &self.device, on_token)?;
203        Ok(output.token_ids)
204    }
205
206    /// Access the underlying neural model.
207    pub fn model(&self) -> &DefaultMultiscreenModel {
208        &self.model
209    }
210
211    /// Access the model configuration.
212    pub fn config(&self) -> &MultiscreenModelConfig {
213        &self.config
214    }
215
216    /// Access the device the model is running on.
217    pub fn device(&self) -> &Device {
218        &self.device
219    }
220}
221
222// ---------------------------------------------------------------------------
223// Helpers
224// ---------------------------------------------------------------------------
225
226/// Return the first path in `candidates` that exists on disk, or an error.
227fn find_file(candidates: &[PathBuf]) -> Result<PathBuf> {
228    for candidate in candidates {
229        if candidate.exists() {
230            return Ok(candidate.clone());
231        }
232    }
233    let descriptions = candidates
234        .iter()
235        .map(|p| format!("  {}", p.display()))
236        .collect::<Vec<_>>()
237        .join("\n");
238    Err(Error::Io(format!(
239        "file not found; searched:\n{descriptions}"
240    )))
241}