Skip to main content

candle_mi/util/
positioning.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Model-agnostic character-based position handling.
4//!
5//! Provides universal position handling using character offsets instead of
6//! model-specific token indices.  This enables:
7//!
8//! - **One corpus for all models**: No model-specific corpus files needed
9//! - **Zero preprocessing**: Any new model works immediately
10//! - **Guaranteed accuracy**: No offset heuristics, direct character mapping
11//!
12//! ## How It Works
13//!
14//! 1. Corpus stores character positions (byte offsets into the text)
15//! 2. At runtime, tokenize with offset mapping
16//! 3. Convert character positions to token indices using the offset map
17
18/// Token with its character offset range.
19#[derive(Debug, Clone)]
20pub struct TokenWithOffset {
21    /// The token string.
22    pub token: String,
23    /// Start character position (byte offset).
24    pub start: usize,
25    /// End character position (byte offset, exclusive).
26    pub end: usize,
27}
28
29/// Encoding result with tokens and their character offsets.
30///
31/// Produced by a tokenizer's `encode_with_offsets` method (or equivalent).
32/// Used to map between character positions in source text and token indices.
33///
34/// # Example
35///
36/// ```
37/// use candle_mi::EncodingWithOffsets;
38///
39/// let encoding = EncodingWithOffsets::new(
40///     vec![1, 2, 3],
41///     vec!["def".into(), " ".into(), "add".into()],
42///     vec![(0, 3), (3, 4), (4, 7)],
43/// );
44///
45/// // Character 4 ('a' in "add") is in token 2
46/// assert_eq!(encoding.char_to_token(4), Some(2));
47/// ```
48#[derive(Debug, Clone)]
49pub struct EncodingWithOffsets {
50    /// Token IDs.
51    pub ids: Vec<u32>,
52    /// Token strings.
53    pub tokens: Vec<String>,
54    /// Character offset for each token: `(start, end)`.
55    pub offsets: Vec<(usize, usize)>,
56}
57
58impl EncodingWithOffsets {
59    /// Create a new encoding with offsets.
60    #[must_use]
61    pub const fn new(ids: Vec<u32>, tokens: Vec<String>, offsets: Vec<(usize, usize)>) -> Self {
62        Self {
63            ids,
64            tokens,
65            offsets,
66        }
67    }
68
69    /// Get tokens with their character offsets.
70    #[must_use]
71    pub fn tokens_with_offsets(&self) -> Vec<TokenWithOffset> {
72        self.tokens
73            .iter()
74            .zip(self.offsets.iter())
75            .map(|(token, (start, end))| TokenWithOffset {
76                token: token.clone(),
77                start: *start,
78                end: *end,
79            })
80            .collect()
81    }
82
83    /// Find the token index that contains the given character position.
84    ///
85    /// Returns `None` if no token spans that position.
86    #[must_use]
87    pub fn char_to_token(&self, char_pos: usize) -> Option<usize> {
88        self.offsets
89            .iter()
90            .position(|(start, end)| char_pos >= *start && char_pos < *end)
91    }
92
93    /// Find the token index for a character position, with fuzzy fallback.
94    ///
95    /// If the exact position isn't contained in any token, returns the
96    /// index of the closest token by midpoint distance.
97    #[must_use]
98    pub fn char_to_token_fuzzy(&self, char_pos: usize) -> Option<usize> {
99        // Try exact match first.
100        if let Some(idx) = self.char_to_token(char_pos) {
101            return Some(idx);
102        }
103
104        // Find closest token by midpoint distance.
105        self.offsets
106            .iter()
107            .enumerate()
108            .min_by_key(|(_, (start, end))| {
109                let mid = usize::midpoint(*start, *end);
110                char_pos.abs_diff(mid)
111            })
112            .map(|(idx, _)| idx)
113    }
114
115    /// Find the token index that starts at or after the given character position.
116    #[must_use]
117    pub fn char_to_token_start(&self, char_pos: usize) -> Option<usize> {
118        self.offsets
119            .iter()
120            .position(|(start, _)| *start >= char_pos)
121    }
122
123    /// Find all token indices that overlap with the given character range.
124    #[must_use]
125    pub fn char_range_to_tokens(&self, start_char: usize, end_char: usize) -> Vec<usize> {
126        self.offsets
127            .iter()
128            .enumerate()
129            .filter_map(|(idx, (start, end))| {
130                if *end > start_char && *start < end_char {
131                    Some(idx)
132                } else {
133                    None
134                }
135            })
136            .collect()
137    }
138
139    /// Get the character range for a token index.
140    #[must_use]
141    pub fn token_to_char_range(&self, token_idx: usize) -> Option<(usize, usize)> {
142        self.offsets.get(token_idx).copied()
143    }
144
145    /// Number of tokens.
146    #[must_use]
147    pub const fn len(&self) -> usize {
148        self.tokens.len()
149    }
150
151    /// Whether the encoding is empty.
152    #[must_use]
153    pub const fn is_empty(&self) -> bool {
154        self.tokens.is_empty()
155    }
156
157    /// Label each token by which named span it overlaps with.
158    ///
159    /// For each token, finds the first span (by input order) whose byte
160    /// range overlaps the token's byte range. The last token matching
161    /// each span label gets `"_final"` appended. Tokens matching no
162    /// span receive `"other"`.
163    ///
164    /// # Example
165    ///
166    /// ```
167    /// use candle_mi::EncodingWithOffsets;
168    ///
169    /// let enc = EncodingWithOffsets::new(
170    ///     vec![1, 2, 3, 4],
171    ///     vec!["The".into(), " Eiffel".into(), " Tower".into(), " is".into()],
172    ///     vec![(0, 3), (3, 10), (10, 16), (16, 19)],
173    /// );
174    /// let labels = enc.label_spans(&[("subject", 0..16), ("relation", 16..19)]);
175    /// assert_eq!(labels, vec!["subject", "subject", "subject_final", "relation_final"]);
176    /// ```
177    #[must_use]
178    pub fn label_spans(&self, spans: &[(&str, std::ops::Range<usize>)]) -> Vec<String> {
179        let mut labels: Vec<String> = self
180            .offsets
181            .iter()
182            .map(|&(tok_start, tok_end)| {
183                // BOS or special tokens with zero-width offsets match nothing
184                if tok_start == tok_end {
185                    return String::from("other");
186                }
187                // First span whose byte range overlaps this token wins
188                for (name, range) in spans {
189                    if tok_end > range.start && tok_start < range.end {
190                        return String::from(*name);
191                    }
192                }
193                String::from("other")
194            })
195            .collect();
196
197        // For each unique span label, upgrade the last occurrence to "{label}_final"
198        for (name, _) in spans {
199            if let Some(last_idx) = labels.iter().rposition(|l| l.as_str() == *name) {
200                // INDEX: last_idx came from rposition on labels, bounded by labels.len()
201                #[allow(clippy::indexing_slicing)]
202                {
203                    labels[last_idx] = format!("{name}_final");
204                }
205            }
206        }
207
208        labels
209    }
210}
211
212/// Result of converting a character position to a token index.
213#[derive(Debug, Clone)]
214pub struct PositionConversion {
215    /// Original character position.
216    pub char_pos: usize,
217    /// Converted token index (if found).
218    pub token_idx: Option<usize>,
219    /// The token at that position (if found).
220    pub token: Option<String>,
221    /// Whether this was an exact match or fuzzy.
222    pub exact_match: bool,
223}
224
225/// Convert multiple character positions to token indices.
226#[must_use]
227pub fn convert_positions(
228    encoding: &EncodingWithOffsets,
229    char_positions: &[usize],
230) -> Vec<PositionConversion> {
231    char_positions
232        .iter()
233        .map(|&char_pos| {
234            let exact = encoding.char_to_token(char_pos);
235            let (token_idx, exact_match) = if exact.is_some() {
236                (exact, true)
237            } else {
238                (encoding.char_to_token_fuzzy(char_pos), false)
239            };
240
241            let token = token_idx.and_then(|idx| encoding.tokens.get(idx).cloned());
242
243            PositionConversion {
244                char_pos,
245                token_idx,
246                token,
247                exact_match,
248            }
249        })
250        .collect()
251}
252
253/// Find the character position of a marker pattern in text.
254///
255/// Returns the byte offset of the first occurrence of `marker` in `text`,
256/// or `None` if not found.
257#[cfg(test)]
258#[must_use]
259fn find_marker_char_pos(text: &str, marker: &str) -> Option<usize> {
260    text.find(marker)
261}
262
263// ---------------------------------------------------------------------------
264// Tests
265// ---------------------------------------------------------------------------
266
267#[cfg(test)]
268#[allow(clippy::unwrap_used, clippy::expect_used)]
269mod tests {
270    use super::*;
271
272    fn sample_encoding() -> EncodingWithOffsets {
273        // Simulates tokenization of "def add(a, b):"
274        EncodingWithOffsets::new(
275            vec![1, 2, 3, 4, 5, 6, 7, 8],
276            vec![
277                "def".into(),
278                " ".into(),
279                "add".into(),
280                "(".into(),
281                "a".into(),
282                ",".into(),
283                " ".into(),
284                "b".into(),
285            ],
286            vec![
287                (0, 3),
288                (3, 4),
289                (4, 7),
290                (7, 8),
291                (8, 9),
292                (9, 10),
293                (10, 11),
294                (11, 12),
295            ],
296        )
297    }
298
299    #[test]
300    fn char_to_token_exact() {
301        let encoding = sample_encoding();
302
303        // 'd' at position 0 → token 0 ("def")
304        assert_eq!(encoding.char_to_token(0), Some(0));
305        // 'a' in "add" at position 4 → token 2
306        assert_eq!(encoding.char_to_token(4), Some(2));
307        // Parameter 'a' at position 8 → token 4
308        assert_eq!(encoding.char_to_token(8), Some(4));
309        // Beyond all tokens
310        assert_eq!(encoding.char_to_token(100), None);
311    }
312
313    #[test]
314    fn char_to_token_fuzzy_fallback() {
315        let encoding = sample_encoding();
316
317        // Position 12 is beyond all tokens → fuzzy finds closest
318        let result = encoding.char_to_token_fuzzy(12);
319        assert!(result.is_some());
320    }
321
322    #[test]
323    fn char_range_to_tokens_overlap() {
324        let encoding = sample_encoding();
325
326        // Characters 3..7 overlap tokens: " " (3,4), "add" (4,7)
327        let tokens = encoding.char_range_to_tokens(3, 7);
328        assert_eq!(tokens, vec![1, 2]);
329    }
330
331    #[test]
332    fn token_to_char_range_roundtrip() {
333        let encoding = sample_encoding();
334
335        assert_eq!(encoding.token_to_char_range(2), Some((4, 7))); // "add"
336        assert_eq!(encoding.token_to_char_range(100), None);
337    }
338
339    #[test]
340    fn convert_positions_batch() {
341        let encoding = sample_encoding();
342        let results = convert_positions(&encoding, &[0, 4, 100]);
343
344        assert_eq!(results.len(), 3);
345        assert!(results[0].exact_match);
346        assert_eq!(results[0].token_idx, Some(0));
347        assert!(results[1].exact_match);
348        assert_eq!(results[1].token_idx, Some(2));
349        assert!(!results[2].exact_match); // fuzzy fallback
350    }
351
352    #[test]
353    fn find_marker() {
354        let code = "def add(a, b):\n    \"\"\"\n    >>> add(2, 3)\n    5\n    \"\"\"";
355        assert!(find_marker_char_pos(code, ">>>").is_some());
356        assert!(find_marker_char_pos(code, "zzz").is_none());
357    }
358
359    #[test]
360    fn encoding_len_and_empty() {
361        let encoding = sample_encoding();
362        assert_eq!(encoding.len(), 8);
363        assert!(!encoding.is_empty());
364
365        let empty = EncodingWithOffsets::new(vec![], vec![], vec![]);
366        assert_eq!(empty.len(), 0);
367        assert!(empty.is_empty());
368    }
369
370    #[test]
371    fn label_spans_subject_relation() {
372        // "The Eiffel Tower is located in"
373        let enc = EncodingWithOffsets::new(
374            vec![1, 2, 3, 4, 5, 6, 7],
375            vec![
376                "The".into(),
377                " Eiffel".into(),
378                " Tower".into(),
379                " is".into(),
380                " located".into(),
381                " in".into(),
382                " Paris".into(),
383            ],
384            vec![
385                (0, 3),
386                (3, 10),
387                (10, 16),
388                (16, 19),
389                (19, 27),
390                (27, 30),
391                (30, 36),
392            ],
393        );
394        let labels = enc.label_spans(&[("subject", 0..16), ("relation", 17..30)]);
395        assert_eq!(
396            labels,
397            vec![
398                "subject",
399                "subject",
400                "subject_final",
401                "relation",
402                "relation",
403                "relation_final",
404                "other",
405            ]
406        );
407    }
408
409    #[test]
410    fn label_spans_bos_token() {
411        // BOS has offset (0, 0) — should be "other"
412        let enc = EncodingWithOffsets::new(
413            vec![0, 1, 2],
414            vec!["<bos>".into(), "Hello".into(), " world".into()],
415            vec![(0, 0), (0, 5), (5, 11)],
416        );
417        let labels = enc.label_spans(&[("greeting", 0..5)]);
418        assert_eq!(labels, vec!["other", "greeting_final", "other"]);
419    }
420
421    #[test]
422    fn label_spans_no_spans() {
423        let enc = EncodingWithOffsets::new(
424            vec![1, 2],
425            vec!["Hello".into(), " world".into()],
426            vec![(0, 5), (5, 11)],
427        );
428        let labels = enc.label_spans(&[]);
429        assert_eq!(labels, vec!["other", "other"]);
430    }
431
432    #[test]
433    fn label_spans_first_span_wins() {
434        // Two overlapping spans — first one wins
435        let enc = EncodingWithOffsets::new(vec![1], vec!["overlap".into()], vec![(0, 7)]);
436        let labels = enc.label_spans(&[("first", 0..5), ("second", 3..7)]);
437        assert_eq!(labels, vec!["first_final"]);
438    }
439}