Skip to main content

multiscreen_rs/
lib.rs

1//! # multiscreen-rs
2//!
3//! A Rust implementation of the Multiscreen neural language model — training
4//! and inference — powered by [Burn](https://github.com/tracel-ai/burn).
5//! Based on [Screening Is Enough](https://arxiv.org/pdf/2604.01178).
6//!
7//! ## Quick Start
8//!
9//! ### Training
10//!
11//! ```rust,no_run
12//! use multiscreen_rs::prelude::*;
13//!
14//! fn main() -> multiscreen_rs::Result<()> {
15//!     let mut trainer = Trainer::builder()
16//!         .vocab_size(1000)
17//!         .budget(ParameterBudget::Params10M)
18//!         .device(auto_device()?)
19//!         .batch_size(16).seq_len(128).steps(50_000)
20//!         .build()?;
21//!
22//!     let sequences = vec![vec![1, 2, 3, 4], vec![1, 2, 5, 4]];
23//!     let report = trainer.train_on_token_sequences(&sequences)?;
24//!     // or: trainer.train_on_chat_sequences(&chat_pairs)?;
25//!     Ok(())
26//! }
27//! ```
28//!
29//! ### Inference
30//!
31//! ```rust,no_run
32//! use multiscreen_rs::prelude::*;
33//!
34//! fn main() -> multiscreen_rs::Result<()> {
35//!     let model = ChatModel::load("checkpoints/latest.mpk")?;
36//!
37//!     // One-shot
38//!     let tokens = model.generate(&[1, 2, 3], GenerationConfig::default())?;
39//!
40//!     // Streaming
41//!     model.generate_stream(&[1, 2, 3], GenerationConfig::default(), |id, _| {
42//!         print!("{id} "); true
43//!     })?;
44//!     Ok(())
45//! }
46//! ```
47//!
48//! ## Feature Flags
49//!
50//! Enable the `cuda` feature for NVIDIA GPU acceleration.
51//! Default uses Burn Flex for CPU training with auto SIMD detection.
52
53// ---- Public modules (the only ones users should care about) ----
54pub mod device;
55pub mod inference;
56pub mod prelude;
57pub mod training;
58
59// ---- Internal modules ----
60pub(crate) mod config;
61pub(crate) mod engine;
62pub(crate) mod error;
63pub(crate) mod layout;
64pub(crate) mod lm;
65pub(crate) mod model;
66pub(crate) mod optim;
67pub(crate) mod param_io;
68pub(crate) mod runtime;
69pub(crate) mod screen;
70pub(crate) mod tile;
71
72// ---- High-level API re-exports ----
73#[cfg(not(feature = "cuda"))]
74pub use device::cpu;
75pub use device::{auto_device, cuda};
76pub use inference::{ChatModel, GenerationConfig};
77pub use training::{ParameterBudget, Trainer, TrainingReport};
78
79// ---- Core types (available through prelude) ----
80pub use error::{Error, Result};
81pub use model::{
82    DefaultMultiscreenModel, EvaluationResult, ModelInferenceConfig, ModelTrainingConfig,
83    ModelTrainingReport, MultiscreenModel, MultiscreenModelConfig, MultiscreenModelOutput,
84    MultiscreenParameterBudget, cross_entropy_loss_with_mask,
85};
86pub use runtime::{DefaultAutodiffBackend, DefaultBackend, Device, device_label};
87
88#[cfg(not(feature = "cuda"))]
89pub use runtime::default_device;
90
91#[cfg(feature = "cuda")]
92pub use runtime::{CudaAutodiffBackend, CudaDevice, CudaMultiscreenModel};
93
94// ---- Engine types (lightweight transition engine) ----
95pub use config::{InferenceConfig, MultiscreenConfig, TrimConfig};
96pub use engine::{InferenceOutput, MultiscreenEngine, TrainInput, TrainReport};
97pub use layout::{
98    ScreenLayout, TokenSpan, causal_softmask, causal_trim_relevance, trim_and_square,
99};
100pub use screen::{Screen, ScreenConfig};
101pub use tile::{ScreeningGridConfig, Tile, TileConfig};
102
103// ---- Burn re-exports ----
104pub use burn::{
105    tensor::backend::{AutodiffBackend, Backend},
106    tensor::{Int, Tensor, TensorData},
107};
108
109#[cfg(feature = "cuda")]
110pub use burn::backend::Cuda;
111
112#[deprecated(note = "use MultiscreenConfig; the paper styles this as one word")]
113pub type MultiScreenConfig = MultiscreenConfig;
114
115#[deprecated(note = "use MultiscreenEngine; the paper styles this as one word")]
116pub type MultiScreenEngine = MultiscreenEngine;
117
118#[deprecated(note = "use ScreeningGridConfig for N_L x N_H naming")]
119pub type GridConfig = ScreeningGridConfig;