1#[derive(Debug, Clone)]
20pub struct TokenWithOffset {
21 pub token: String,
23 pub start: usize,
25 pub end: usize,
27}
28
29#[derive(Debug, Clone)]
49pub struct EncodingWithOffsets {
50 pub ids: Vec<u32>,
52 pub tokens: Vec<String>,
54 pub offsets: Vec<(usize, usize)>,
56}
57
58impl EncodingWithOffsets {
59 #[must_use]
61 pub const fn new(ids: Vec<u32>, tokens: Vec<String>, offsets: Vec<(usize, usize)>) -> Self {
62 Self {
63 ids,
64 tokens,
65 offsets,
66 }
67 }
68
69 #[must_use]
71 pub fn tokens_with_offsets(&self) -> Vec<TokenWithOffset> {
72 self.tokens
73 .iter()
74 .zip(self.offsets.iter())
75 .map(|(token, (start, end))| TokenWithOffset {
76 token: token.clone(),
77 start: *start,
78 end: *end,
79 })
80 .collect()
81 }
82
83 #[must_use]
87 pub fn char_to_token(&self, char_pos: usize) -> Option<usize> {
88 self.offsets
89 .iter()
90 .position(|(start, end)| char_pos >= *start && char_pos < *end)
91 }
92
93 #[must_use]
98 pub fn char_to_token_fuzzy(&self, char_pos: usize) -> Option<usize> {
99 if let Some(idx) = self.char_to_token(char_pos) {
101 return Some(idx);
102 }
103
104 self.offsets
106 .iter()
107 .enumerate()
108 .min_by_key(|(_, (start, end))| {
109 let mid = usize::midpoint(*start, *end);
110 char_pos.abs_diff(mid)
111 })
112 .map(|(idx, _)| idx)
113 }
114
115 #[must_use]
117 pub fn char_to_token_start(&self, char_pos: usize) -> Option<usize> {
118 self.offsets
119 .iter()
120 .position(|(start, _)| *start >= char_pos)
121 }
122
123 #[must_use]
125 pub fn char_range_to_tokens(&self, start_char: usize, end_char: usize) -> Vec<usize> {
126 self.offsets
127 .iter()
128 .enumerate()
129 .filter_map(|(idx, (start, end))| {
130 if *end > start_char && *start < end_char {
131 Some(idx)
132 } else {
133 None
134 }
135 })
136 .collect()
137 }
138
139 #[must_use]
141 pub fn token_to_char_range(&self, token_idx: usize) -> Option<(usize, usize)> {
142 self.offsets.get(token_idx).copied()
143 }
144
145 #[must_use]
147 pub const fn len(&self) -> usize {
148 self.tokens.len()
149 }
150
151 #[must_use]
153 pub const fn is_empty(&self) -> bool {
154 self.tokens.is_empty()
155 }
156
157 #[must_use]
178 pub fn label_spans(&self, spans: &[(&str, std::ops::Range<usize>)]) -> Vec<String> {
179 let mut labels: Vec<String> = self
180 .offsets
181 .iter()
182 .map(|&(tok_start, tok_end)| {
183 if tok_start == tok_end {
185 return String::from("other");
186 }
187 for (name, range) in spans {
189 if tok_end > range.start && tok_start < range.end {
190 return String::from(*name);
191 }
192 }
193 String::from("other")
194 })
195 .collect();
196
197 for (name, _) in spans {
199 if let Some(last_idx) = labels.iter().rposition(|l| l.as_str() == *name) {
200 #[allow(clippy::indexing_slicing)]
202 {
203 labels[last_idx] = format!("{name}_final");
204 }
205 }
206 }
207
208 labels
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct PositionConversion {
215 pub char_pos: usize,
217 pub token_idx: Option<usize>,
219 pub token: Option<String>,
221 pub exact_match: bool,
223}
224
225#[must_use]
227pub fn convert_positions(
228 encoding: &EncodingWithOffsets,
229 char_positions: &[usize],
230) -> Vec<PositionConversion> {
231 char_positions
232 .iter()
233 .map(|&char_pos| {
234 let exact = encoding.char_to_token(char_pos);
235 let (token_idx, exact_match) = if exact.is_some() {
236 (exact, true)
237 } else {
238 (encoding.char_to_token_fuzzy(char_pos), false)
239 };
240
241 let token = token_idx.and_then(|idx| encoding.tokens.get(idx).cloned());
242
243 PositionConversion {
244 char_pos,
245 token_idx,
246 token,
247 exact_match,
248 }
249 })
250 .collect()
251}
252
253#[cfg(test)]
258#[must_use]
259fn find_marker_char_pos(text: &str, marker: &str) -> Option<usize> {
260 text.find(marker)
261}
262
263#[cfg(test)]
268#[allow(clippy::unwrap_used, clippy::expect_used)]
269mod tests {
270 use super::*;
271
272 fn sample_encoding() -> EncodingWithOffsets {
273 EncodingWithOffsets::new(
275 vec![1, 2, 3, 4, 5, 6, 7, 8],
276 vec![
277 "def".into(),
278 " ".into(),
279 "add".into(),
280 "(".into(),
281 "a".into(),
282 ",".into(),
283 " ".into(),
284 "b".into(),
285 ],
286 vec![
287 (0, 3),
288 (3, 4),
289 (4, 7),
290 (7, 8),
291 (8, 9),
292 (9, 10),
293 (10, 11),
294 (11, 12),
295 ],
296 )
297 }
298
299 #[test]
300 fn char_to_token_exact() {
301 let encoding = sample_encoding();
302
303 assert_eq!(encoding.char_to_token(0), Some(0));
305 assert_eq!(encoding.char_to_token(4), Some(2));
307 assert_eq!(encoding.char_to_token(8), Some(4));
309 assert_eq!(encoding.char_to_token(100), None);
311 }
312
313 #[test]
314 fn char_to_token_fuzzy_fallback() {
315 let encoding = sample_encoding();
316
317 let result = encoding.char_to_token_fuzzy(12);
319 assert!(result.is_some());
320 }
321
322 #[test]
323 fn char_range_to_tokens_overlap() {
324 let encoding = sample_encoding();
325
326 let tokens = encoding.char_range_to_tokens(3, 7);
328 assert_eq!(tokens, vec![1, 2]);
329 }
330
331 #[test]
332 fn token_to_char_range_roundtrip() {
333 let encoding = sample_encoding();
334
335 assert_eq!(encoding.token_to_char_range(2), Some((4, 7))); assert_eq!(encoding.token_to_char_range(100), None);
337 }
338
339 #[test]
340 fn convert_positions_batch() {
341 let encoding = sample_encoding();
342 let results = convert_positions(&encoding, &[0, 4, 100]);
343
344 assert_eq!(results.len(), 3);
345 assert!(results[0].exact_match);
346 assert_eq!(results[0].token_idx, Some(0));
347 assert!(results[1].exact_match);
348 assert_eq!(results[1].token_idx, Some(2));
349 assert!(!results[2].exact_match); }
351
352 #[test]
353 fn find_marker() {
354 let code = "def add(a, b):\n \"\"\"\n >>> add(2, 3)\n 5\n \"\"\"";
355 assert!(find_marker_char_pos(code, ">>>").is_some());
356 assert!(find_marker_char_pos(code, "zzz").is_none());
357 }
358
359 #[test]
360 fn encoding_len_and_empty() {
361 let encoding = sample_encoding();
362 assert_eq!(encoding.len(), 8);
363 assert!(!encoding.is_empty());
364
365 let empty = EncodingWithOffsets::new(vec![], vec![], vec![]);
366 assert_eq!(empty.len(), 0);
367 assert!(empty.is_empty());
368 }
369
370 #[test]
371 fn label_spans_subject_relation() {
372 let enc = EncodingWithOffsets::new(
374 vec![1, 2, 3, 4, 5, 6, 7],
375 vec![
376 "The".into(),
377 " Eiffel".into(),
378 " Tower".into(),
379 " is".into(),
380 " located".into(),
381 " in".into(),
382 " Paris".into(),
383 ],
384 vec![
385 (0, 3),
386 (3, 10),
387 (10, 16),
388 (16, 19),
389 (19, 27),
390 (27, 30),
391 (30, 36),
392 ],
393 );
394 let labels = enc.label_spans(&[("subject", 0..16), ("relation", 17..30)]);
395 assert_eq!(
396 labels,
397 vec![
398 "subject",
399 "subject",
400 "subject_final",
401 "relation",
402 "relation",
403 "relation_final",
404 "other",
405 ]
406 );
407 }
408
409 #[test]
410 fn label_spans_bos_token() {
411 let enc = EncodingWithOffsets::new(
413 vec![0, 1, 2],
414 vec!["<bos>".into(), "Hello".into(), " world".into()],
415 vec![(0, 0), (0, 5), (5, 11)],
416 );
417 let labels = enc.label_spans(&[("greeting", 0..5)]);
418 assert_eq!(labels, vec!["other", "greeting_final", "other"]);
419 }
420
421 #[test]
422 fn label_spans_no_spans() {
423 let enc = EncodingWithOffsets::new(
424 vec![1, 2],
425 vec!["Hello".into(), " world".into()],
426 vec![(0, 5), (5, 11)],
427 );
428 let labels = enc.label_spans(&[]);
429 assert_eq!(labels, vec!["other", "other"]);
430 }
431
432 #[test]
433 fn label_spans_first_span_wins() {
434 let enc = EncodingWithOffsets::new(vec![1], vec!["overlap".into()], vec![(0, 7)]);
436 let labels = enc.label_spans(&[("first", 0..5), ("second", 3..7)]);
437 assert_eq!(labels, vec!["first_final"]);
438 }
439}