Skip to main content

multiscreen_rs/
layout.rs

1use crate::config::{MultiscreenConfig, TrimConfig};
2use crate::error::{Error, Result};
3use crate::screen::Screen;
4use crate::tile::Tile;
5
6/// Half-open token span `[start, end)`.
7#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
8pub struct TokenSpan {
9    pub start: usize,
10    pub end: usize,
11}
12
13impl TokenSpan {
14    pub fn new(start: usize, end: usize) -> Result<Self> {
15        if start > end {
16            return Err(Error::Layout("token span start must be <= end".into()));
17        }
18        Ok(Self { start, end })
19    }
20
21    pub fn len(self) -> usize {
22        self.end - self.start
23    }
24
25    pub fn is_empty(self) -> bool {
26        self.start == self.end
27    }
28
29    pub fn contains(self, index: usize) -> bool {
30        self.start <= index && index < self.end
31    }
32}
33
34/// Fully materialized screen and tile layout for a token sequence.
35#[derive(Clone, Debug, PartialEq, Eq)]
36pub struct ScreenLayout {
37    sequence_len: usize,
38    screens: Vec<Screen>,
39}
40
41impl ScreenLayout {
42    pub fn build(config: &MultiscreenConfig, sequence_len: usize) -> Result<Self> {
43        config.validate()?;
44
45        if sequence_len == 0 {
46            return Ok(Self {
47                sequence_len,
48                screens: Vec::new(),
49            });
50        }
51
52        let mut screens = Vec::new();
53        let mut screen_start = 0usize;
54
55        loop {
56            if let Some(max_screen_count) = config.screens.max_screen_count
57                && screens.len() >= max_screen_count
58            {
59                break;
60            }
61
62            let screen_end = screen_start
63                .saturating_add(config.screens.tokens_per_screen)
64                .min(sequence_len);
65            let screen_span = TokenSpan::new(screen_start, screen_end)?;
66            let screen_index = screens.len();
67            let tiles = build_tiles(config, screen_index, screen_span)?;
68            screens.push(Screen::new(screen_index, screen_span, tiles));
69
70            if screen_end == sequence_len {
71                break;
72            }
73            let Some(next_start) = next_window_start(
74                screen_start,
75                sequence_len,
76                config.screens.tokens_per_screen,
77                config.screens.screen_stride_tokens,
78                "screen",
79            )?
80            else {
81                break;
82            };
83            screen_start = next_start;
84        }
85
86        Ok(Self {
87            sequence_len,
88            screens,
89        })
90    }
91
92    /// Paper sequence length `T` for this layout.
93    pub fn sequence_len(&self) -> usize {
94        self.sequence_len
95    }
96
97    #[deprecated(note = "use sequence_len for paper-aligned naming")]
98    pub fn token_count(&self) -> usize {
99        self.sequence_len()
100    }
101
102    pub fn screens(&self) -> &[Screen] {
103        &self.screens
104    }
105
106    pub fn screening_tile_count(&self) -> usize {
107        self.screens.iter().map(|screen| screen.tiles.len()).sum()
108    }
109
110    pub fn screening_tiles(&self) -> impl Iterator<Item = &Tile> {
111        self.screens.iter().flat_map(|screen| screen.tiles.iter())
112    }
113
114    #[deprecated(note = "use screening_tiles for paper-aligned naming")]
115    pub fn tiles(&self) -> impl Iterator<Item = &Tile> {
116        self.screening_tiles()
117    }
118}
119
120fn build_tiles(
121    config: &MultiscreenConfig,
122    screen_index: usize,
123    screen_span: TokenSpan,
124) -> Result<Vec<Tile>> {
125    if screen_span.is_empty() {
126        return Ok(Vec::new());
127    }
128
129    let mut tiles = Vec::new();
130    let mut tile_start = screen_span.start;
131    let screening_grid = &config.tiles.screening_grid;
132
133    loop {
134        let tile_end = tile_start
135            .saturating_add(config.tiles.tokens_per_tile)
136            .min(screen_span.end);
137        let span = TokenSpan::new(tile_start, tile_end)?;
138        let tile_index = tiles.len();
139        let layer_index = (tile_index / screening_grid.head_count) % screening_grid.layer_count;
140        let head_index = tile_index % screening_grid.head_count;
141        tiles.push(Tile::new(
142            tile_index,
143            screen_index,
144            layer_index,
145            head_index,
146            span,
147        ));
148
149        if tile_end == screen_span.end {
150            break;
151        }
152        let Some(next_start) = next_window_start(
153            tile_start,
154            screen_span.end,
155            config.tiles.tokens_per_tile,
156            config.tiles.tile_stride_tokens,
157            "tile",
158        )?
159        else {
160            break;
161        };
162        tile_start = next_start;
163    }
164
165    Ok(tiles)
166}
167
168fn next_window_start(
169    current_start: usize,
170    sequence_end: usize,
171    window_len: usize,
172    stride: usize,
173    label: &str,
174) -> Result<Option<usize>> {
175    let next_start = current_start
176        .checked_add(stride)
177        .ok_or_else(|| Error::Layout(format!("{label} stride overflowed usize")))?;
178
179    if next_start < sequence_end {
180        return Ok(Some(next_start));
181    }
182
183    let tail_start = sequence_end.saturating_sub(window_len);
184    if tail_start > current_start {
185        Ok(Some(tail_start))
186    } else {
187        Ok(None)
188    }
189}
190
191/// Trim-and-square relevance gate from the multiscreen experiment.
192pub fn trim_and_square(similarity_s_ij: f32, acceptance_sharpness_r: f32) -> f32 {
193    let score = 1.0 - acceptance_sharpness_r * (1.0 - similarity_s_ij);
194    score.clamp(0.0, 1.0).powi(2)
195}
196
197/// Causal softmask for query/key token positions.
198pub fn causal_softmask(query_index_i: usize, key_index_j: usize, screening_window_w: f32) -> f32 {
199    if screening_window_w <= 0.0 || key_index_j > query_index_i {
200        return 0.0;
201    }
202
203    let distance = key_index_j as f32 - query_index_i as f32;
204    if distance <= -screening_window_w {
205        return 0.0;
206    }
207
208    ((distance / screening_window_w) * std::f32::consts::PI).cos() * 0.5 + 0.5
209}
210
211/// Combined causal trim relevance for a query/key pair.
212pub fn causal_trim_relevance(
213    query_index_i: usize,
214    key_index_j: usize,
215    similarity_s_ij: f32,
216    trim: &TrimConfig,
217) -> f32 {
218    trim_and_square(similarity_s_ij, trim.acceptance_sharpness_r)
219        * causal_softmask(query_index_i, key_index_j, trim.screening_window_w)
220}