Skip to main content

multiscreen_rs/
lib.rs

1//! # multiscreen-rs
2//!
3//! > A Rust implementation of the Multiscreen neural language model — training
4//! > and inference — powered by [Burn](https://github.com/tracel-ai/burn).
5//!
6//! ## Quick Start — Training
7//!
8//! ```rust,no_run
9//! use multiscreen_rs::prelude::*;
10//!
11//! fn main() -> multiscreen_rs::Result<()> {
12//!     let mut trainer = Trainer::builder()
13//!         .vocab_size(1000)
14//!         .budget(ParameterBudget::Params10M)
15//!         .device(auto_device()?)
16//!         .batch_size(16)
17//!         .seq_len(128)
18//!         .steps(50_000)
19//!         .build()?;
20//!
21//!     let sequences = vec![vec![1, 2, 3, 4], vec![1, 2, 5, 4]];
22//!     let report = trainer.train_on_token_sequences(&sequences)?;
23//!     println!("trained {} steps, final loss {:.4}", report.steps, report.final_loss);
24//!     Ok(())
25//! }
26//! ```
27//!
28//! ## Quick Start — Chat / Inference
29//!
30//! ### Non-streaming (all tokens at once)
31//!
32//! ```rust,no_run
33//! use multiscreen_rs::prelude::*;
34//!
35//! fn main() -> multiscreen_rs::Result<()> {
36//!     let model = ChatModel::load("checkpoints/latest.mpk")?;
37//!     let token_ids = model.generate(&[1, 2, 3], GenerationConfig::default())?;
38//!     println!("generated tokens: {:?}", token_ids);
39//!     Ok(())
40//! }
41//! ```
42//!
43//! ### Streaming (token by token, like ChatGPT)
44//!
45//! ```rust,no_run
46//! use multiscreen_rs::prelude::*;
47//!
48//! fn main() -> multiscreen_rs::Result<()> {
49//!     let model = ChatModel::load("checkpoints/latest.mpk")?;
50//!     let full = model.generate_stream(
51//!         &[1, 2, 3],
52//!         GenerationConfig::default(),
53//!         |token_id, _index| {
54//!             // Decode with YOUR tokenizer and print word-by-word
55//!             print!("{} ", token_id);
56//!             true // return false to stop early
57//!         },
58//!     )?;
59//!     Ok(())
60//! }
61//! ```
62//!
63//! ## Device Selection
64//!
65//! ```rust,no_run
66//! use multiscreen_rs::prelude::*;
67//!
68//! fn main() -> multiscreen_rs::Result<()> {
69//!     let device = auto_device()?;  // best available (CPU or CUDA)
70//!     // let device = cuda(0)?;       // CUDA GPU (requires "cuda" feature)
71//!     Ok(())
72//! }
73//! ```
74//!
75//! ## Low-Level In-Memory Quick Start
76//!
77//! ```rust
78//! use multiscreen_rs::prelude::*;
79//!
80//! fn main() -> multiscreen_rs::Result<()> {
81//!     let device = auto_device()?;
82//!     let mut model = DefaultMultiscreenModel::new(
83//!         MultiscreenModelConfig::tiny_for_tests(),
84//!         &device,
85//!     )?;
86//!
87//!     model.train_token_sequences(
88//!         &[vec![1, 2, 3, 4], vec![1, 2, 5, 4]],
89//!         &ModelTrainingConfig {
90//!             steps: 2,
91//!             batch_size: 2,
92//!             learning_rate: 1e-3,
93//!             weight_decay: 0.0,
94//!             grad_clip_norm: Some(1.0),
95//!             pad_token_id: 0,
96//!         },
97//!         &device,
98//!         |_, _| {},
99//!     )?;
100//!
101//!     let output = model.infer_tokens(
102//!         &[1, 2],
103//!         &ModelInferenceConfig {
104//!             max_new_tokens: 2,
105//!             pad_token_id: 0,
106//!         },
107//!         &device,
108//!     )?;
109//!     println!("tokens: {:?}", output.token_ids);
110//!     Ok(())
111//! }
112//! ```
113//!
114//! ## Feature Flags
115//!
116//! The default neural path uses Burn Flex for Candle-free CPU training.
117//! Enable the `cuda` feature for GPU acceleration.
118
119// ---- Public modules (the only ones users should care about) ----
120pub mod device;
121pub mod inference;
122pub mod prelude;
123pub mod training;
124
125// ---- Internal modules ----
126pub(crate) mod config;
127pub(crate) mod engine;
128pub(crate) mod error;
129pub(crate) mod layout;
130pub(crate) mod lm;
131pub(crate) mod model;
132pub(crate) mod optim;
133pub(crate) mod param_io;
134pub(crate) mod runtime;
135pub(crate) mod screen;
136pub(crate) mod tile;
137
138// ---- High-level API re-exports ----
139#[cfg(not(feature = "cuda"))]
140pub use device::cpu;
141pub use device::{auto_device, cuda};
142pub use inference::{ChatModel, GenerationConfig};
143pub use training::{ParameterBudget, Trainer, TrainingReport};
144
145// ---- Core types (available through prelude) ----
146pub use error::{Error, Result};
147pub use model::{
148    cross_entropy_loss_with_mask, DefaultMultiscreenModel, EvaluationResult, ModelInferenceConfig,
149    ModelTrainingConfig, ModelTrainingReport, MultiscreenModel, MultiscreenModelConfig,
150    MultiscreenModelOutput, MultiscreenParameterBudget,
151};
152pub use runtime::{device_label, DefaultAutodiffBackend, DefaultBackend, Device};
153
154#[cfg(not(feature = "cuda"))]
155pub use runtime::default_device;
156
157#[cfg(feature = "cuda")]
158pub use runtime::{CudaAutodiffBackend, CudaDevice, CudaMultiscreenModel};
159
160// ---- Engine types (lightweight transition engine) ----
161pub use config::{InferenceConfig, MultiscreenConfig, TrimConfig};
162pub use engine::{InferenceOutput, MultiscreenEngine, TrainInput, TrainReport};
163pub use layout::{
164    causal_softmask, causal_trim_relevance, trim_and_square, ScreenLayout, TokenSpan,
165};
166pub use screen::{Screen, ScreenConfig};
167pub use tile::{ScreeningGridConfig, Tile, TileConfig};
168
169// ---- Burn re-exports ----
170pub use burn::{
171    tensor::backend::{AutodiffBackend, Backend},
172    tensor::{Int, Tensor, TensorData},
173};
174
175#[cfg(feature = "cuda")]
176pub use burn::backend::Cuda;
177
178#[deprecated(note = "use MultiscreenConfig; the paper styles this as one word")]
179pub type MultiScreenConfig = MultiscreenConfig;
180
181#[deprecated(note = "use MultiscreenEngine; the paper styles this as one word")]
182pub type MultiScreenEngine = MultiscreenEngine;
183
184#[deprecated(note = "use ScreeningGridConfig for N_L x N_H naming")]
185pub type GridConfig = ScreeningGridConfig;