1use crate::config::MultiscreenConfig;
2use crate::error::{Error, Result};
3use crate::layout::{causal_trim_relevance, ScreenLayout};
4use serde::{Deserialize, Serialize};
5use std::collections::{BTreeMap, HashMap};
6use std::path::Path;
7
8#[derive(Clone, Debug, PartialEq, Eq)]
14pub enum TrainInput {
15 TokenSequences(Vec<Vec<u32>>),
16}
17
18impl TrainInput {
19 pub fn from_token_sequences(token_sequences: Vec<Vec<u32>>) -> Self {
20 Self::TokenSequences(token_sequences)
21 }
22
23 #[deprecated(note = "use from_token_sequences for paper-aligned naming")]
24 pub fn from_token_ids(token_sequences: Vec<Vec<u32>>) -> Self {
25 Self::from_token_sequences(token_sequences)
26 }
27
28 pub fn is_empty(&self) -> bool {
29 match self {
30 Self::TokenSequences(token_sequences) => token_sequences.is_empty(),
31 }
32 }
33}
34
35#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
37pub struct TrainReport {
38 #[serde(alias = "sequence_count")]
39 pub training_sequence_count: usize,
40 #[serde(alias = "token_count")]
41 pub training_token_count: usize,
42 #[serde(alias = "unique_tokens")]
43 pub observed_vocab_size: usize,
44 #[serde(alias = "screen_count")]
45 pub screen_layout_count: usize,
46 #[serde(alias = "tile_count")]
47 pub screening_tile_count: usize,
48}
49
50#[derive(Clone, Debug, Default, Serialize, Deserialize)]
51pub(crate) struct ScreeningState {
52 #[serde(alias = "token_counts", alias = "training_observed_token_counts")]
53 pub observed_token_counts: BTreeMap<u32, usize>,
54 #[serde(alias = "transitions")]
55 next_token_counts: HashMap<u32, BTreeMap<u32, usize>>,
56 #[serde(alias = "sequence_count")]
57 pub training_sequence_count: usize,
58 #[serde(alias = "token_count")]
59 pub training_token_count: usize,
60 #[serde(alias = "screen_count")]
61 pub screen_layout_count: usize,
62 #[serde(alias = "tile_count")]
63 pub screening_tile_count: usize,
64}
65
66impl ScreeningState {
67 pub fn clear(&mut self) {
68 *self = Self::default();
69 }
70
71 pub fn observe_token_sequence(&mut self, tokens: &[u32], layout: &ScreenLayout) {
72 self.training_sequence_count += 1;
73 self.training_token_count += tokens.len();
74 self.screen_layout_count += layout.screens().len();
75 self.screening_tile_count += layout.screening_tile_count();
76
77 for token in tokens {
78 *self.observed_token_counts.entry(*token).or_insert(0) += 1;
79 }
80
81 for pair in tokens.windows(2) {
82 let from = pair[0];
83 let to = pair[1];
84 *self
85 .next_token_counts
86 .entry(from)
87 .or_default()
88 .entry(to)
89 .or_insert(0) += 1;
90 }
91 }
92
93 pub fn report(&self) -> TrainReport {
94 TrainReport {
95 training_sequence_count: self.training_sequence_count,
96 training_token_count: self.training_token_count,
97 observed_vocab_size: self.observed_token_counts.len(),
98 screen_layout_count: self.screen_layout_count,
99 screening_tile_count: self.screening_tile_count,
100 }
101 }
102
103 pub fn predict_next_token(&self, token: u32) -> Option<u32> {
104 self.next_token_counts.get(&token).and_then(|candidates| {
105 candidates
106 .iter()
107 .max_by(|left, right| {
108 let count_order = left.1.cmp(right.1);
109 if count_order == std::cmp::Ordering::Equal {
110 right.0.cmp(left.0)
111 } else {
112 count_order
113 }
114 })
115 .map(|(token, _count)| *token)
116 })
117 }
118
119 pub fn fallback_token(&self) -> Option<u32> {
120 self.observed_token_counts
121 .iter()
122 .max_by(|left, right| {
123 let count_order = left.1.cmp(right.1);
124 if count_order == std::cmp::Ordering::Equal {
125 right.0.cmp(left.0)
126 } else {
127 count_order
128 }
129 })
130 .map(|(token, _count)| *token)
131 }
132}
133
134#[derive(Clone, Debug, PartialEq)]
140pub struct InferenceOutput {
141 pub output_token_ids: Vec<u32>,
142 pub layout: ScreenLayout,
143 pub mean_distance_relevance_alpha_d: f32,
144}
145
146#[derive(Serialize, Deserialize)]
151struct ScreeningWeightsFile {
152 config: MultiscreenConfig,
153 state: ScreeningState,
154 report: TrainReport,
155}
156
157#[derive(Clone, Debug)]
159pub struct MultiscreenEngine {
160 config: MultiscreenConfig,
161 state: ScreeningState,
162}
163
164impl MultiscreenEngine {
165 pub fn new(config: MultiscreenConfig) -> Result<Self> {
167 config.validate()?;
168 Ok(Self {
169 config,
170 state: ScreeningState::default(),
171 })
172 }
173
174 pub fn from_weights_file(path: impl AsRef<Path>) -> Result<Self> {
179 let contents = std::fs::read_to_string(path.as_ref())
180 .map_err(|e| Error::Io(format!("{}: {e}", path.as_ref().display())))?;
181 let file: ScreeningWeightsFile =
182 serde_json::from_str(&contents).map_err(|e| Error::Serialization(e.to_string()))?;
183 file.config.validate()?;
184 Ok(Self {
185 config: file.config,
186 state: file.state,
187 })
188 }
189
190 pub fn config(&self) -> &MultiscreenConfig {
192 &self.config
193 }
194
195 pub fn train(&mut self, input: TrainInput) -> Result<TrainReport> {
197 let sequences = self.resolve_train_input(input)?;
198 self.state.clear();
199
200 for sequence in &sequences {
201 let layout = ScreenLayout::build(&self.config, sequence.len())?;
202 self.state.observe_token_sequence(sequence, &layout);
203 }
204
205 Ok(self.state.report())
206 }
207
208 pub fn infer_tokens(&self, input_ids: &[u32]) -> Result<InferenceOutput> {
210 let layout = ScreenLayout::build(&self.config, input_ids.len())?;
211 let limit = self
212 .config
213 .inference
214 .max_inference_tokens
215 .unwrap_or(input_ids.len())
216 .min(input_ids.len());
217
218 let fallback = self.state.fallback_token();
219 let output_token_ids = input_ids
220 .iter()
221 .take(limit)
222 .map(|token| {
223 self.state
224 .predict_next_token(*token)
225 .or(fallback)
226 .or_else(|| {
227 self.config
228 .inference
229 .use_input_token_fallback
230 .then_some(*token)
231 })
232 .ok_or_else(|| {
233 Error::Inference(
234 "no trained transition, fallback token, or input fallback available"
235 .into(),
236 )
237 })
238 })
239 .collect::<Result<Vec<_>>>()?;
240
241 Ok(InferenceOutput {
242 output_token_ids,
243 mean_distance_relevance_alpha_d: layout_relevance(&layout, &self.config),
244 layout,
245 })
246 }
247
248 pub fn save_weights(&self, path: impl AsRef<Path>) -> Result<()> {
258 let file = ScreeningWeightsFile {
259 config: self.config.clone(),
260 state: self.state.clone(),
261 report: self.state.report(),
262 };
263 let json =
264 serde_json::to_string_pretty(&file).map_err(|e| Error::Serialization(e.to_string()))?;
265 std::fs::write(path, json).map_err(|e| Error::Io(e.to_string()))?;
266 Ok(())
267 }
268
269 pub fn load_weights(&mut self, path: impl AsRef<Path>) -> Result<TrainReport> {
277 let contents = std::fs::read_to_string(path.as_ref())
278 .map_err(|e| Error::Io(format!("{}: {e}", path.as_ref().display())))?;
279 let file: ScreeningWeightsFile =
280 serde_json::from_str(&contents).map_err(|e| Error::Serialization(e.to_string()))?;
281
282 if file.config != self.config {
283 return Err(Error::WeightsConfigMismatch(
284 "the config embedded in the weights file does not match the engine's active \
285 config — create a new engine with the correct config first, or use \
286 MultiscreenEngine::from_weights_file() to auto-detect the config"
287 .into(),
288 ));
289 }
290
291 self.state = file.state;
292 Ok(self.state.report())
293 }
294
295 fn resolve_train_input(&self, input: TrainInput) -> Result<Vec<Vec<u32>>> {
300 match input {
301 TrainInput::TokenSequences(sequences) => Ok(sequences),
302 }
303 }
304}
305
306fn layout_relevance(layout: &ScreenLayout, config: &MultiscreenConfig) -> f32 {
307 let mut alpha_d_sum = 0.0;
308 let mut relevance_pair_count = 0usize;
309
310 for screen in layout.screens() {
311 for tile in &screen.tiles {
312 if tile.span.is_empty() {
313 continue;
314 }
315 let query_index_i = tile.span.end - 1;
316 for key_index_j in tile.span.start..tile.span.end {
317 let distance = query_index_i.abs_diff(key_index_j) as f32;
318 let similarity_s_ij = 1.0 / (1.0 + distance);
319 alpha_d_sum += causal_trim_relevance(
320 query_index_i,
321 key_index_j,
322 similarity_s_ij,
323 &config.trim,
324 );
325 relevance_pair_count += 1;
326 }
327 }
328 }
329
330 if relevance_pair_count == 0 {
331 0.0
332 } else {
333 alpha_d_sum / relevance_pair_count as f32
334 }
335}