1#[derive(Debug, Clone)]
20pub struct TokenWithOffset {
21 pub token: String,
23 pub start: usize,
25 pub end: usize,
27}
28
29#[derive(Debug, Clone)]
49pub struct EncodingWithOffsets {
50 pub ids: Vec<u32>,
52 pub tokens: Vec<String>,
54 pub offsets: Vec<(usize, usize)>,
56}
57
58impl EncodingWithOffsets {
59 #[must_use]
61 pub const fn new(ids: Vec<u32>, tokens: Vec<String>, offsets: Vec<(usize, usize)>) -> Self {
62 Self {
63 ids,
64 tokens,
65 offsets,
66 }
67 }
68
69 #[must_use]
71 pub fn tokens_with_offsets(&self) -> Vec<TokenWithOffset> {
72 self.tokens
73 .iter()
74 .zip(self.offsets.iter())
75 .map(|(token, (start, end))| TokenWithOffset {
76 token: token.clone(),
77 start: *start,
78 end: *end,
79 })
80 .collect()
81 }
82
83 #[must_use]
87 pub fn char_to_token(&self, char_pos: usize) -> Option<usize> {
88 self.offsets
89 .iter()
90 .position(|(start, end)| char_pos >= *start && char_pos < *end)
91 }
92
93 #[must_use]
98 pub fn char_to_token_fuzzy(&self, char_pos: usize) -> Option<usize> {
99 if let Some(idx) = self.char_to_token(char_pos) {
101 return Some(idx);
102 }
103
104 self.offsets
106 .iter()
107 .enumerate()
108 .min_by_key(|(_, (start, end))| {
109 let mid = usize::midpoint(*start, *end);
110 char_pos.abs_diff(mid)
111 })
112 .map(|(idx, _)| idx)
113 }
114
115 #[must_use]
117 pub fn char_to_token_start(&self, char_pos: usize) -> Option<usize> {
118 self.offsets
119 .iter()
120 .position(|(start, _)| *start >= char_pos)
121 }
122
123 #[must_use]
125 pub fn char_range_to_tokens(&self, start_char: usize, end_char: usize) -> Vec<usize> {
126 self.offsets
127 .iter()
128 .enumerate()
129 .filter_map(|(idx, (start, end))| {
130 if *end > start_char && *start < end_char {
131 Some(idx)
132 } else {
133 None
134 }
135 })
136 .collect()
137 }
138
139 #[must_use]
141 pub fn token_to_char_range(&self, token_idx: usize) -> Option<(usize, usize)> {
142 self.offsets.get(token_idx).copied()
143 }
144
145 #[must_use]
147 pub const fn len(&self) -> usize {
148 self.tokens.len()
149 }
150
151 #[must_use]
153 pub const fn is_empty(&self) -> bool {
154 self.tokens.is_empty()
155 }
156}
157
158#[derive(Debug, Clone)]
160pub struct PositionConversion {
161 pub char_pos: usize,
163 pub token_idx: Option<usize>,
165 pub token: Option<String>,
167 pub exact_match: bool,
169}
170
171#[must_use]
173pub fn convert_positions(
174 encoding: &EncodingWithOffsets,
175 char_positions: &[usize],
176) -> Vec<PositionConversion> {
177 char_positions
178 .iter()
179 .map(|&char_pos| {
180 let exact = encoding.char_to_token(char_pos);
181 let (token_idx, exact_match) = if exact.is_some() {
182 (exact, true)
183 } else {
184 (encoding.char_to_token_fuzzy(char_pos), false)
185 };
186
187 let token = token_idx.and_then(|idx| encoding.tokens.get(idx).cloned());
188
189 PositionConversion {
190 char_pos,
191 token_idx,
192 token,
193 exact_match,
194 }
195 })
196 .collect()
197}
198
199#[cfg(test)]
204#[must_use]
205fn find_marker_char_pos(text: &str, marker: &str) -> Option<usize> {
206 text.find(marker)
207}
208
209#[cfg(test)]
214#[allow(clippy::unwrap_used, clippy::expect_used)]
215mod tests {
216 use super::*;
217
218 fn sample_encoding() -> EncodingWithOffsets {
219 EncodingWithOffsets::new(
221 vec![1, 2, 3, 4, 5, 6, 7, 8],
222 vec![
223 "def".into(),
224 " ".into(),
225 "add".into(),
226 "(".into(),
227 "a".into(),
228 ",".into(),
229 " ".into(),
230 "b".into(),
231 ],
232 vec![
233 (0, 3),
234 (3, 4),
235 (4, 7),
236 (7, 8),
237 (8, 9),
238 (9, 10),
239 (10, 11),
240 (11, 12),
241 ],
242 )
243 }
244
245 #[test]
246 fn char_to_token_exact() {
247 let encoding = sample_encoding();
248
249 assert_eq!(encoding.char_to_token(0), Some(0));
251 assert_eq!(encoding.char_to_token(4), Some(2));
253 assert_eq!(encoding.char_to_token(8), Some(4));
255 assert_eq!(encoding.char_to_token(100), None);
257 }
258
259 #[test]
260 fn char_to_token_fuzzy_fallback() {
261 let encoding = sample_encoding();
262
263 let result = encoding.char_to_token_fuzzy(12);
265 assert!(result.is_some());
266 }
267
268 #[test]
269 fn char_range_to_tokens_overlap() {
270 let encoding = sample_encoding();
271
272 let tokens = encoding.char_range_to_tokens(3, 7);
274 assert_eq!(tokens, vec![1, 2]);
275 }
276
277 #[test]
278 fn token_to_char_range_roundtrip() {
279 let encoding = sample_encoding();
280
281 assert_eq!(encoding.token_to_char_range(2), Some((4, 7))); assert_eq!(encoding.token_to_char_range(100), None);
283 }
284
285 #[test]
286 fn convert_positions_batch() {
287 let encoding = sample_encoding();
288 let results = convert_positions(&encoding, &[0, 4, 100]);
289
290 assert_eq!(results.len(), 3);
291 assert!(results[0].exact_match);
292 assert_eq!(results[0].token_idx, Some(0));
293 assert!(results[1].exact_match);
294 assert_eq!(results[1].token_idx, Some(2));
295 assert!(!results[2].exact_match); }
297
298 #[test]
299 fn find_marker() {
300 let code = "def add(a, b):\n \"\"\"\n >>> add(2, 3)\n 5\n \"\"\"";
301 assert!(find_marker_char_pos(code, ">>>").is_some());
302 assert!(find_marker_char_pos(code, "zzz").is_none());
303 }
304
305 #[test]
306 fn encoding_len_and_empty() {
307 let encoding = sample_encoding();
308 assert_eq!(encoding.len(), 8);
309 assert!(!encoding.is_empty());
310
311 let empty = EncodingWithOffsets::new(vec![], vec![], vec![]);
312 assert_eq!(empty.len(), 0);
313 assert!(empty.is_empty());
314 }
315}