pub struct MultiscreenModel<B: Backend = DefaultAutodiffBackend> { /* private fields */ }Expand description
Burn-backed neural Multiscreen language model ported from
multiscreen-testing.
Implementations§
Source§impl<B: Backend> MultiscreenModel<B>
impl<B: Backend> MultiscreenModel<B>
pub fn new(config: MultiscreenModelConfig, device: &B::Device) -> Result<Self>
pub fn config(&self) -> &MultiscreenModelConfig
pub fn parameter_count(&self) -> usize
pub fn forward(&self, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3>
pub fn save_parameters(&self, path: impl AsRef<Path>) -> Result<()>
pub fn load_parameters(&mut self, path: impl AsRef<Path>) -> Result<()>
Sourcepub fn infer_tokens_stream(
&self,
prompt: &[u32],
inference: &ModelInferenceConfig,
device: &B::Device,
on_token: impl FnMut(u32, usize) -> bool,
) -> Result<MultiscreenModelOutput>
pub fn infer_tokens_stream( &self, prompt: &[u32], inference: &ModelInferenceConfig, device: &B::Device, on_token: impl FnMut(u32, usize) -> bool, ) -> Result<MultiscreenModelOutput>
Greedy token generation. Generate tokens one at a time, invoking a callback for each newly produced token. This enables streaming / word-by-word output similar to ChatGPT.
The callback receives (token_id, index) where index is the
zero-based position of the new token (0 = first generated token).
If the callback returns false, generation stops early.
Returns the full output (prompt + generated) token sequence.
Sourcepub fn infer_tokens(
&self,
prompt: &[u32],
inference: &ModelInferenceConfig,
device: &B::Device,
) -> Result<MultiscreenModelOutput>
pub fn infer_tokens( &self, prompt: &[u32], inference: &ModelInferenceConfig, device: &B::Device, ) -> Result<MultiscreenModelOutput>
Generate tokens and return them all at once (non-streaming).
For streaming / token-by-token output, use Self::infer_tokens_stream.
pub fn predict_next_token( &self, context: &[u32], pad_token_id: u32, device: &B::Device, ) -> Result<u32>
Sourcepub fn forward_logits(
&self,
context: &[u32],
pad_token_id: u32,
device: &B::Device,
) -> Result<Tensor<B, 3>>
pub fn forward_logits( &self, context: &[u32], pad_token_id: u32, device: &B::Device, ) -> Result<Tensor<B, 3>>
Run a forward pass and return the full logit tensor.
The returned tensor has shape [1, seq_len, vocab_size].
This is useful for sampling-based generation (top-k, temperature, etc.)
where you need access to the raw logit values, not just the argmax.
The context is padded/truncated to seq_len automatically.
Sourcepub fn evaluate_on_sequences(
&self,
sequences: &[Vec<u32>],
seq_len: usize,
batch_size: usize,
pad_token_id: u32,
device: &B::Device,
) -> Result<EvaluationResult>
pub fn evaluate_on_sequences( &self, sequences: &[Vec<u32>], seq_len: usize, batch_size: usize, pad_token_id: u32, device: &B::Device, ) -> Result<EvaluationResult>
Evaluates the model on token sequences, returning average loss, perplexity, and next-token prediction accuracy.
This method works on any Backend (including non-autodiff), which makes
it safe to call on an inference-only model without VRAM growth.
Source§impl<B> MultiscreenModel<B>where
B: AutodiffBackend,
impl<B> MultiscreenModel<B>where
B: AutodiffBackend,
Sourcepub fn train_token_sequences(
&mut self,
sequences: &[Vec<u32>],
training: &ModelTrainingConfig,
device: &B::Device,
on_step: impl FnMut(usize, f32),
) -> Result<ModelTrainingReport>
pub fn train_token_sequences( &mut self, sequences: &[Vec<u32>], training: &ModelTrainingConfig, device: &B::Device, on_step: impl FnMut(usize, f32), ) -> Result<ModelTrainingReport>
Trains this model directly on token sequences.
The optional on_step callback is invoked after each optimizer step with
(step_index, loss_value). Use it for progress logging, CSV export, etc.
Sourcepub fn train_chat_sequences(
&mut self,
chat_pairs: &[(Vec<u32>, Vec<u32>)],
training: &ModelTrainingConfig,
device: &B::Device,
on_step: impl FnMut(usize, f32),
) -> Result<ModelTrainingReport>
pub fn train_chat_sequences( &mut self, chat_pairs: &[(Vec<u32>, Vec<u32>)], training: &ModelTrainingConfig, device: &B::Device, on_step: impl FnMut(usize, f32), ) -> Result<ModelTrainingReport>
Trains this model on chat-style (prompt, response) token-ID pairs.
This is the chat-aware counterpart of MultiscreenModel::train_token_sequences. The model
sees the full context (prompt + response) but loss is computed only on
the response tokens, preventing the model from learning to generate role
labels like system:, user:, or assistant:.
Each element of chat_pairs is (prompt_token_ids, response_token_ids).
The caller is responsible for appending an EOS token to the response IDs
when desired — the EOS token will receive loss_mask = 1.0 like any other
response token.