Skip to main content

multiscreen_rs/
lib.rs

1//! # multiscreen-rs
2//!
3//! > A Rust implementation of the Multiscreen neural language model — training
4//! > and inference — powered by [Burn](https://github.com/tracel-ai/burn).
5//!
6//! ## Quick Start — Training
7//!
8//! ```rust,no_run
9//! use multiscreen_rs::prelude::*;
10//!
11//! fn main() -> multiscreen_rs::Result<()> {
12//!     let mut trainer = Trainer::builder()
13//!         .vocab_size(1000)
14//!         .budget(ParameterBudget::Params10M)
15//!         .device(cpu()?)
16//!         .batch_size(16)
17//!         .seq_len(128)
18//!         .steps(50_000)
19//!         .build()?;
20//!
21//!     let sequences = vec![vec![1, 2, 3, 4], vec![1, 2, 5, 4]];
22//!     let report = trainer.train_on_token_sequences(&sequences)?;
23//!     println!("trained {} steps, final loss {:.4}", report.steps, report.final_loss);
24//!     Ok(())
25//! }
26//! ```
27//!
28//! ## Quick Start — Chat / Inference
29//!
30//! ### Non-streaming (all tokens at once)
31//!
32//! ```rust,no_run
33//! use multiscreen_rs::prelude::*;
34//!
35//! fn main() -> multiscreen_rs::Result<()> {
36//!     let model = ChatModel::load("checkpoints/latest.mpk")?;
37//!     let token_ids = model.generate(&[1, 2, 3], GenerationConfig::default())?;
38//!     println!("generated tokens: {:?}", token_ids);
39//!     Ok(())
40//! }
41//! ```
42//!
43//! ### Streaming (token by token, like ChatGPT)
44//!
45//! ```rust,no_run
46//! use multiscreen_rs::prelude::*;
47//!
48//! fn main() -> multiscreen_rs::Result<()> {
49//!     let model = ChatModel::load("checkpoints/latest.mpk")?;
50//!     let full = model.generate_stream(
51//!         &[1, 2, 3],
52//!         GenerationConfig::default(),
53//!         |token_id, _index| {
54//!             // Decode with YOUR tokenizer and print word-by-word
55//!             print!("{} ", token_id);
56//!             true // return false to stop early
57//!         },
58//!     )?;
59//!     Ok(())
60//! }
61//! ```
62//!
63//! ## Device Selection
64//!
65//! ```rust,no_run
66//! use multiscreen_rs::prelude::*;
67//!
68//! fn main() -> multiscreen_rs::Result<()> {
69//!     let device = cpu()?;        // CPU (always available)
70//!     // let device = cuda(0)?;   // CUDA GPU (requires "cuda" feature)
71//!     // let device = auto_device()?; // best available
72//!     Ok(())
73//! }
74//! ```
75//!
76//! ## Low-Level In-Memory Quick Start
77//!
78//! ```rust
79//! use multiscreen_rs::prelude::*;
80//!
81//! fn main() -> multiscreen_rs::Result<()> {
82//!     let device = cpu()?;
83//!     let mut model = DefaultMultiscreenModel::new(
84//!         MultiscreenModelConfig::tiny_for_tests(),
85//!         &device,
86//!     )?;
87//!
88//!     model.train_token_sequences(
89//!         &[vec![1, 2, 3, 4], vec![1, 2, 5, 4]],
90//!         &ModelTrainingConfig {
91//!             steps: 2,
92//!             batch_size: 2,
93//!             learning_rate: 1e-3,
94//!             weight_decay: 0.0,
95//!             grad_clip_norm: Some(1.0),
96//!             pad_token_id: 0,
97//!         },
98//!         &device,
99//!         |_, _| {},
100//!     )?;
101//!
102//!     let output = model.infer_tokens(
103//!         &[1, 2],
104//!         &ModelInferenceConfig {
105//!             max_new_tokens: 2,
106//!             pad_token_id: 0,
107//!         },
108//!         &device,
109//!     )?;
110//!     println!("tokens: {:?}", output.token_ids);
111//!     Ok(())
112//! }
113//! ```
114//!
115//! ## Feature Flags
116//!
117//! The default neural path uses Burn Flex for Candle-free CPU training.
118//! Enable the `cuda` feature for GPU acceleration.
119
120// ---- Public modules (the only ones users should care about) ----
121pub mod device;
122pub mod inference;
123pub mod prelude;
124pub mod training;
125
126// ---- Internal modules ----
127pub(crate) mod config;
128pub(crate) mod engine;
129pub(crate) mod error;
130pub(crate) mod layout;
131pub(crate) mod lm;
132pub(crate) mod model;
133pub(crate) mod optim;
134pub(crate) mod param_io;
135pub(crate) mod runtime;
136pub(crate) mod screen;
137pub(crate) mod tile;
138
139// ---- High-level API re-exports ----
140pub use device::{auto_device, cpu, cuda};
141pub use inference::{ChatModel, GenerationConfig};
142pub use training::{ParameterBudget, Trainer, TrainingReport};
143
144// ---- Core types (available through prelude) ----
145pub use error::{Error, Result};
146pub use model::{
147    cross_entropy_loss_with_mask, DefaultMultiscreenModel, EvaluationResult, ModelInferenceConfig,
148    ModelTrainingConfig, ModelTrainingReport, MultiscreenModel, MultiscreenModelConfig,
149    MultiscreenModelOutput, MultiscreenParameterBudget,
150};
151pub use runtime::{default_device, device_label, DefaultAutodiffBackend, DefaultBackend, Device};
152
153#[cfg(feature = "cuda")]
154pub use runtime::{cuda_device, CudaAutodiffBackend, CudaDevice, CudaMultiscreenModel};
155
156// ---- Engine types (lightweight transition engine) ----
157pub use config::{InferenceConfig, MultiscreenConfig, TrimConfig};
158pub use engine::{InferenceOutput, MultiscreenEngine, TrainInput, TrainReport};
159pub use layout::{
160    causal_softmask, causal_trim_relevance, trim_and_square, ScreenLayout, TokenSpan,
161};
162pub use screen::{Screen, ScreenConfig};
163pub use tile::{ScreeningGridConfig, Tile, TileConfig};
164
165// ---- Burn re-exports ----
166pub use burn::{
167    tensor::backend::{AutodiffBackend, Backend},
168    tensor::{Int, Tensor, TensorData},
169};
170
171#[cfg(feature = "cuda")]
172pub use burn::backend::Cuda;
173
174#[deprecated(note = "use MultiscreenConfig; the paper styles this as one word")]
175pub type MultiScreenConfig = MultiscreenConfig;
176
177#[deprecated(note = "use MultiscreenEngine; the paper styles this as one word")]
178pub type MultiScreenEngine = MultiscreenEngine;
179
180#[deprecated(note = "use ScreeningGridConfig for N_L x N_H naming")]
181pub type GridConfig = ScreeningGridConfig;