Skip to main content

multiscreen_rs/
engine.rs

1use crate::config::MultiscreenConfig;
2use crate::error::{Error, Result};
3use crate::layout::{causal_trim_relevance, ScreenLayout};
4use serde::{Deserialize, Serialize};
5use std::collections::{BTreeMap, HashMap};
6use std::path::Path;
7
8// =====================================================================
9// Training types (from train.rs)
10// =====================================================================
11
12/// User-provided training input.
13#[derive(Clone, Debug, PartialEq, Eq)]
14pub enum TrainInput {
15    TokenSequences(Vec<Vec<u32>>),
16}
17
18impl TrainInput {
19    pub fn from_token_sequences(token_sequences: Vec<Vec<u32>>) -> Self {
20        Self::TokenSequences(token_sequences)
21    }
22
23    #[deprecated(note = "use from_token_sequences for paper-aligned naming")]
24    pub fn from_token_ids(token_sequences: Vec<Vec<u32>>) -> Self {
25        Self::from_token_sequences(token_sequences)
26    }
27
28    pub fn is_empty(&self) -> bool {
29        match self {
30            Self::TokenSequences(token_sequences) => token_sequences.is_empty(),
31        }
32    }
33}
34
35/// Summary returned after training.
36#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
37pub struct TrainReport {
38    #[serde(alias = "sequence_count")]
39    pub training_sequence_count: usize,
40    #[serde(alias = "token_count")]
41    pub training_token_count: usize,
42    #[serde(alias = "unique_tokens")]
43    pub observed_vocab_size: usize,
44    #[serde(alias = "screen_count")]
45    pub screen_layout_count: usize,
46    #[serde(alias = "tile_count")]
47    pub screening_tile_count: usize,
48}
49
50#[derive(Clone, Debug, Default, Serialize, Deserialize)]
51pub(crate) struct ScreeningState {
52    #[serde(alias = "token_counts", alias = "training_observed_token_counts")]
53    pub observed_token_counts: BTreeMap<u32, usize>,
54    #[serde(alias = "transitions")]
55    next_token_counts: HashMap<u32, BTreeMap<u32, usize>>,
56    #[serde(alias = "sequence_count")]
57    pub training_sequence_count: usize,
58    #[serde(alias = "token_count")]
59    pub training_token_count: usize,
60    #[serde(alias = "screen_count")]
61    pub screen_layout_count: usize,
62    #[serde(alias = "tile_count")]
63    pub screening_tile_count: usize,
64}
65
66impl ScreeningState {
67    pub fn clear(&mut self) {
68        *self = Self::default();
69    }
70
71    pub fn observe_token_sequence(&mut self, tokens: &[u32], layout: &ScreenLayout) {
72        self.training_sequence_count += 1;
73        self.training_token_count += tokens.len();
74        self.screen_layout_count += layout.screens().len();
75        self.screening_tile_count += layout.screening_tile_count();
76
77        for token in tokens {
78            *self.observed_token_counts.entry(*token).or_insert(0) += 1;
79        }
80
81        for pair in tokens.windows(2) {
82            let from = pair[0];
83            let to = pair[1];
84            *self
85                .next_token_counts
86                .entry(from)
87                .or_default()
88                .entry(to)
89                .or_insert(0) += 1;
90        }
91    }
92
93    pub fn report(&self) -> TrainReport {
94        TrainReport {
95            training_sequence_count: self.training_sequence_count,
96            training_token_count: self.training_token_count,
97            observed_vocab_size: self.observed_token_counts.len(),
98            screen_layout_count: self.screen_layout_count,
99            screening_tile_count: self.screening_tile_count,
100        }
101    }
102
103    pub fn predict_next_token(&self, token: u32) -> Option<u32> {
104        self.next_token_counts.get(&token).and_then(|candidates| {
105            candidates
106                .iter()
107                .max_by(|left, right| {
108                    let count_order = left.1.cmp(right.1);
109                    if count_order == std::cmp::Ordering::Equal {
110                        right.0.cmp(left.0)
111                    } else {
112                        count_order
113                    }
114                })
115                .map(|(token, _count)| *token)
116        })
117    }
118
119    pub fn fallback_token(&self) -> Option<u32> {
120        self.observed_token_counts
121            .iter()
122            .max_by(|left, right| {
123                let count_order = left.1.cmp(right.1);
124                if count_order == std::cmp::Ordering::Equal {
125                    right.0.cmp(left.0)
126                } else {
127                    count_order
128                }
129            })
130            .map(|(token, _count)| *token)
131    }
132}
133
134// =====================================================================
135// Inference types (from inference.rs)
136// =====================================================================
137
138/// Output from token inference.
139#[derive(Clone, Debug, PartialEq)]
140pub struct InferenceOutput {
141    pub output_token_ids: Vec<u32>,
142    pub layout: ScreenLayout,
143    pub mean_distance_relevance_alpha_d: f32,
144}
145
146/// Serialized weights file containing the config and learned state.
147///
148/// This is the on-disk format. The config is embedded so that
149/// `load_weights` can verify it matches the engine's active config.
150#[derive(Serialize, Deserialize)]
151struct ScreeningWeightsFile {
152    config: MultiscreenConfig,
153    state: ScreeningState,
154    report: TrainReport,
155}
156
157/// Main user-facing engine.
158#[derive(Clone, Debug)]
159pub struct MultiscreenEngine {
160    config: MultiscreenConfig,
161    state: ScreeningState,
162}
163
164impl MultiscreenEngine {
165    /// Creates a new engine with the given configuration.
166    pub fn new(config: MultiscreenConfig) -> Result<Self> {
167        config.validate()?;
168        Ok(Self {
169            config,
170            state: ScreeningState::default(),
171        })
172    }
173
174    /// Creates a new engine by loading a weights file.
175    ///
176    /// The config embedded in the weights file becomes the engine's config.
177    /// This is the easiest way to resume from a previously saved model.
178    pub fn from_weights_file(path: impl AsRef<Path>) -> Result<Self> {
179        let contents = std::fs::read_to_string(path.as_ref())
180            .map_err(|e| Error::Io(format!("{}: {e}", path.as_ref().display())))?;
181        let file: ScreeningWeightsFile =
182            serde_json::from_str(&contents).map_err(|e| Error::Serialization(e.to_string()))?;
183        file.config.validate()?;
184        Ok(Self {
185            config: file.config,
186            state: file.state,
187        })
188    }
189
190    /// Returns the active configuration.
191    pub fn config(&self) -> &MultiscreenConfig {
192        &self.config
193    }
194
195    /// Trains the lightweight transition state from token IDs.
196    pub fn train(&mut self, input: TrainInput) -> Result<TrainReport> {
197        let sequences = self.resolve_train_input(input)?;
198        self.state.clear();
199
200        for sequence in &sequences {
201            let layout = ScreenLayout::build(&self.config, sequence.len())?;
202            self.state.observe_token_sequence(sequence, &layout);
203        }
204
205        Ok(self.state.report())
206    }
207
208    /// Runs deterministic token inference.
209    pub fn infer_tokens(&self, input_ids: &[u32]) -> Result<InferenceOutput> {
210        let layout = ScreenLayout::build(&self.config, input_ids.len())?;
211        let limit = self
212            .config
213            .inference
214            .max_inference_tokens
215            .unwrap_or(input_ids.len())
216            .min(input_ids.len());
217
218        let fallback = self.state.fallback_token();
219        let output_token_ids = input_ids
220            .iter()
221            .take(limit)
222            .map(|token| {
223                self.state
224                    .predict_next_token(*token)
225                    .or(fallback)
226                    .or_else(|| {
227                        self.config
228                            .inference
229                            .use_input_token_fallback
230                            .then_some(*token)
231                    })
232                    .ok_or_else(|| {
233                        Error::Inference(
234                            "no trained transition, fallback token, or input fallback available"
235                                .into(),
236                        )
237                    })
238            })
239            .collect::<Result<Vec<_>>>()?;
240
241        Ok(InferenceOutput {
242            output_token_ids,
243            mean_distance_relevance_alpha_d: layout_relevance(&layout, &self.config),
244            layout,
245        })
246    }
247
248    // ------------------------------------------------------------------
249    // Weight persistence
250    // ------------------------------------------------------------------
251
252    /// Saves the engine's config, learned state, and training report to a JSON
253    /// weights file.
254    ///
255    /// The file embeds the full `MultiscreenConfig` so that `load_weights` can
256    /// verify a match before restoring state.
257    pub fn save_weights(&self, path: impl AsRef<Path>) -> Result<()> {
258        let file = ScreeningWeightsFile {
259            config: self.config.clone(),
260            state: self.state.clone(),
261            report: self.state.report(),
262        };
263        let json =
264            serde_json::to_string_pretty(&file).map_err(|e| Error::Serialization(e.to_string()))?;
265        std::fs::write(path, json).map_err(|e| Error::Io(e.to_string()))?;
266        Ok(())
267    }
268
269    /// Loads weights from a JSON file into this engine.
270    ///
271    /// **Config validation:** the config stored in the weights file is compared
272    /// against this engine's active config. If they do not match exactly, the
273    /// load is **rejected** and a [`Error::WeightsConfigMismatch`] is returned.
274    /// This prevents subtle bugs caused by running a state trained under one
275    /// configuration through an engine configured differently.
276    pub fn load_weights(&mut self, path: impl AsRef<Path>) -> Result<TrainReport> {
277        let contents = std::fs::read_to_string(path.as_ref())
278            .map_err(|e| Error::Io(format!("{}: {e}", path.as_ref().display())))?;
279        let file: ScreeningWeightsFile =
280            serde_json::from_str(&contents).map_err(|e| Error::Serialization(e.to_string()))?;
281
282        if file.config != self.config {
283            return Err(Error::WeightsConfigMismatch(
284                "the config embedded in the weights file does not match the engine's active \
285                 config — create a new engine with the correct config first, or use \
286                 MultiscreenEngine::from_weights_file() to auto-detect the config"
287                    .into(),
288            ));
289        }
290
291        self.state = file.state;
292        Ok(self.state.report())
293    }
294
295    // ------------------------------------------------------------------
296    // Internal helpers
297    // ------------------------------------------------------------------
298
299    fn resolve_train_input(&self, input: TrainInput) -> Result<Vec<Vec<u32>>> {
300        match input {
301            TrainInput::TokenSequences(sequences) => Ok(sequences),
302        }
303    }
304}
305
306fn layout_relevance(layout: &ScreenLayout, config: &MultiscreenConfig) -> f32 {
307    let mut alpha_d_sum = 0.0;
308    let mut relevance_pair_count = 0usize;
309
310    for screen in layout.screens() {
311        for tile in &screen.tiles {
312            if tile.span.is_empty() {
313                continue;
314            }
315            let query_index_i = tile.span.end - 1;
316            for key_index_j in tile.span.start..tile.span.end {
317                let distance = query_index_i.abs_diff(key_index_j) as f32;
318                let similarity_s_ij = 1.0 / (1.0 + distance);
319                alpha_d_sum += causal_trim_relevance(
320                    query_index_i,
321                    key_index_j,
322                    similarity_s_ij,
323                    &config.trim,
324                );
325                relevance_pair_count += 1;
326            }
327        }
328    }
329
330    if relevance_pair_count == 0 {
331        0.0
332    } else {
333        alpha_d_sum / relevance_pair_count as f32
334    }
335}