1use crate::config::{MultiscreenConfig, TrimConfig};
2use crate::error::{Error, Result};
3use crate::screen::Screen;
4use crate::tile::Tile;
5
6#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
8pub struct TokenSpan {
9 pub start: usize,
10 pub end: usize,
11}
12
13impl TokenSpan {
14 pub fn new(start: usize, end: usize) -> Result<Self> {
15 if start > end {
16 return Err(Error::Layout("token span start must be <= end".into()));
17 }
18 Ok(Self { start, end })
19 }
20
21 pub fn len(self) -> usize {
22 self.end - self.start
23 }
24
25 pub fn is_empty(self) -> bool {
26 self.start == self.end
27 }
28
29 pub fn contains(self, index: usize) -> bool {
30 self.start <= index && index < self.end
31 }
32}
33
34#[derive(Clone, Debug, PartialEq, Eq)]
36pub struct ScreenLayout {
37 sequence_len: usize,
38 screens: Vec<Screen>,
39}
40
41impl ScreenLayout {
42 pub fn build(config: &MultiscreenConfig, sequence_len: usize) -> Result<Self> {
43 config.validate()?;
44
45 if sequence_len == 0 {
46 return Ok(Self {
47 sequence_len,
48 screens: Vec::new(),
49 });
50 }
51
52 let mut screens = Vec::new();
53 let mut screen_start = 0usize;
54
55 loop {
56 if let Some(max_screen_count) = config.screens.max_screen_count
57 && screens.len() >= max_screen_count
58 {
59 break;
60 }
61
62 let screen_end = screen_start
63 .saturating_add(config.screens.tokens_per_screen)
64 .min(sequence_len);
65 let screen_span = TokenSpan::new(screen_start, screen_end)?;
66 let screen_index = screens.len();
67 let tiles = build_tiles(config, screen_index, screen_span)?;
68 screens.push(Screen::new(screen_index, screen_span, tiles));
69
70 if screen_end == sequence_len {
71 break;
72 }
73 let Some(next_start) = next_window_start(
74 screen_start,
75 sequence_len,
76 config.screens.tokens_per_screen,
77 config.screens.screen_stride_tokens,
78 "screen",
79 )?
80 else {
81 break;
82 };
83 screen_start = next_start;
84 }
85
86 Ok(Self {
87 sequence_len,
88 screens,
89 })
90 }
91
92 pub fn sequence_len(&self) -> usize {
94 self.sequence_len
95 }
96
97 #[deprecated(note = "use sequence_len for paper-aligned naming")]
98 pub fn token_count(&self) -> usize {
99 self.sequence_len()
100 }
101
102 pub fn screens(&self) -> &[Screen] {
103 &self.screens
104 }
105
106 pub fn screening_tile_count(&self) -> usize {
107 self.screens.iter().map(|screen| screen.tiles.len()).sum()
108 }
109
110 pub fn screening_tiles(&self) -> impl Iterator<Item = &Tile> {
111 self.screens.iter().flat_map(|screen| screen.tiles.iter())
112 }
113
114 #[deprecated(note = "use screening_tiles for paper-aligned naming")]
115 pub fn tiles(&self) -> impl Iterator<Item = &Tile> {
116 self.screening_tiles()
117 }
118}
119
120fn build_tiles(
121 config: &MultiscreenConfig,
122 screen_index: usize,
123 screen_span: TokenSpan,
124) -> Result<Vec<Tile>> {
125 if screen_span.is_empty() {
126 return Ok(Vec::new());
127 }
128
129 let mut tiles = Vec::new();
130 let mut tile_start = screen_span.start;
131 let screening_grid = &config.tiles.screening_grid;
132
133 loop {
134 let tile_end = tile_start
135 .saturating_add(config.tiles.tokens_per_tile)
136 .min(screen_span.end);
137 let span = TokenSpan::new(tile_start, tile_end)?;
138 let tile_index = tiles.len();
139 let layer_index = (tile_index / screening_grid.head_count) % screening_grid.layer_count;
140 let head_index = tile_index % screening_grid.head_count;
141 tiles.push(Tile::new(
142 tile_index,
143 screen_index,
144 layer_index,
145 head_index,
146 span,
147 ));
148
149 if tile_end == screen_span.end {
150 break;
151 }
152 let Some(next_start) = next_window_start(
153 tile_start,
154 screen_span.end,
155 config.tiles.tokens_per_tile,
156 config.tiles.tile_stride_tokens,
157 "tile",
158 )?
159 else {
160 break;
161 };
162 tile_start = next_start;
163 }
164
165 Ok(tiles)
166}
167
168fn next_window_start(
169 current_start: usize,
170 sequence_end: usize,
171 window_len: usize,
172 stride: usize,
173 label: &str,
174) -> Result<Option<usize>> {
175 let next_start = current_start
176 .checked_add(stride)
177 .ok_or_else(|| Error::Layout(format!("{label} stride overflowed usize")))?;
178
179 if next_start < sequence_end {
180 return Ok(Some(next_start));
181 }
182
183 let tail_start = sequence_end.saturating_sub(window_len);
184 if tail_start > current_start {
185 Ok(Some(tail_start))
186 } else {
187 Ok(None)
188 }
189}
190
191pub fn trim_and_square(similarity_s_ij: f32, acceptance_sharpness_r: f32) -> f32 {
193 let score = 1.0 - acceptance_sharpness_r * (1.0 - similarity_s_ij);
194 score.clamp(0.0, 1.0).powi(2)
195}
196
197pub fn causal_softmask(query_index_i: usize, key_index_j: usize, screening_window_w: f32) -> f32 {
199 if screening_window_w <= 0.0 || key_index_j > query_index_i {
200 return 0.0;
201 }
202
203 let distance = key_index_j as f32 - query_index_i as f32;
204 if distance <= -screening_window_w {
205 return 0.0;
206 }
207
208 ((distance / screening_window_w) * std::f32::consts::PI).cos() * 0.5 + 0.5
209}
210
211pub fn causal_trim_relevance(
213 query_index_i: usize,
214 key_index_j: usize,
215 similarity_s_ij: f32,
216 trim: &TrimConfig,
217) -> f32 {
218 trim_and_square(similarity_s_ij, trim.acceptance_sharpness_r)
219 * causal_softmask(query_index_i, key_index_j, trim.screening_window_w)
220}