Skip to main content

candle_mi/util/
positioning.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Model-agnostic character-based position handling.
4//!
5//! Provides universal position handling using character offsets instead of
6//! model-specific token indices.  This enables:
7//!
8//! - **One corpus for all models**: No model-specific corpus files needed
9//! - **Zero preprocessing**: Any new model works immediately
10//! - **Guaranteed accuracy**: No offset heuristics, direct character mapping
11//!
12//! ## How It Works
13//!
14//! 1. Corpus stores character positions (byte offsets into the text)
15//! 2. At runtime, tokenize with offset mapping
16//! 3. Convert character positions to token indices using the offset map
17
18/// Token with its character offset range.
19#[derive(Debug, Clone)]
20pub struct TokenWithOffset {
21    /// The token string.
22    pub token: String,
23    /// Start character position (byte offset).
24    pub start: usize,
25    /// End character position (byte offset, exclusive).
26    pub end: usize,
27}
28
29/// Encoding result with tokens and their character offsets.
30///
31/// Produced by a tokenizer's `encode_with_offsets` method (or equivalent).
32/// Used to map between character positions in source text and token indices.
33///
34/// # Example
35///
36/// ```
37/// use candle_mi::EncodingWithOffsets;
38///
39/// let encoding = EncodingWithOffsets::new(
40///     vec![1, 2, 3],
41///     vec!["def".into(), " ".into(), "add".into()],
42///     vec![(0, 3), (3, 4), (4, 7)],
43/// );
44///
45/// // Character 4 ('a' in "add") is in token 2
46/// assert_eq!(encoding.char_to_token(4), Some(2));
47/// ```
48#[derive(Debug, Clone)]
49pub struct EncodingWithOffsets {
50    /// Token IDs.
51    pub ids: Vec<u32>,
52    /// Token strings.
53    pub tokens: Vec<String>,
54    /// Character offset for each token: `(start, end)`.
55    pub offsets: Vec<(usize, usize)>,
56}
57
58impl EncodingWithOffsets {
59    /// Create a new encoding with offsets.
60    #[must_use]
61    pub const fn new(ids: Vec<u32>, tokens: Vec<String>, offsets: Vec<(usize, usize)>) -> Self {
62        Self {
63            ids,
64            tokens,
65            offsets,
66        }
67    }
68
69    /// Get tokens with their character offsets.
70    #[must_use]
71    pub fn tokens_with_offsets(&self) -> Vec<TokenWithOffset> {
72        self.tokens
73            .iter()
74            .zip(self.offsets.iter())
75            .map(|(token, (start, end))| TokenWithOffset {
76                token: token.clone(),
77                start: *start,
78                end: *end,
79            })
80            .collect()
81    }
82
83    /// Find the token index that contains the given character position.
84    ///
85    /// Returns `None` if no token spans that position.
86    #[must_use]
87    pub fn char_to_token(&self, char_pos: usize) -> Option<usize> {
88        self.offsets
89            .iter()
90            .position(|(start, end)| char_pos >= *start && char_pos < *end)
91    }
92
93    /// Find the token index for a character position, with fuzzy fallback.
94    ///
95    /// If the exact position isn't contained in any token, returns the
96    /// index of the closest token by midpoint distance.
97    #[must_use]
98    pub fn char_to_token_fuzzy(&self, char_pos: usize) -> Option<usize> {
99        // Try exact match first.
100        if let Some(idx) = self.char_to_token(char_pos) {
101            return Some(idx);
102        }
103
104        // Find closest token by midpoint distance.
105        self.offsets
106            .iter()
107            .enumerate()
108            .min_by_key(|(_, (start, end))| {
109                let mid = usize::midpoint(*start, *end);
110                char_pos.abs_diff(mid)
111            })
112            .map(|(idx, _)| idx)
113    }
114
115    /// Find the token index that starts at or after the given character position.
116    #[must_use]
117    pub fn char_to_token_start(&self, char_pos: usize) -> Option<usize> {
118        self.offsets
119            .iter()
120            .position(|(start, _)| *start >= char_pos)
121    }
122
123    /// Find all token indices that overlap with the given character range.
124    #[must_use]
125    pub fn char_range_to_tokens(&self, start_char: usize, end_char: usize) -> Vec<usize> {
126        self.offsets
127            .iter()
128            .enumerate()
129            .filter_map(|(idx, (start, end))| {
130                if *end > start_char && *start < end_char {
131                    Some(idx)
132                } else {
133                    None
134                }
135            })
136            .collect()
137    }
138
139    /// Get the character range for a token index.
140    #[must_use]
141    pub fn token_to_char_range(&self, token_idx: usize) -> Option<(usize, usize)> {
142        self.offsets.get(token_idx).copied()
143    }
144
145    /// Number of tokens.
146    #[must_use]
147    pub const fn len(&self) -> usize {
148        self.tokens.len()
149    }
150
151    /// Whether the encoding is empty.
152    #[must_use]
153    pub const fn is_empty(&self) -> bool {
154        self.tokens.is_empty()
155    }
156}
157
158/// Result of converting a character position to a token index.
159#[derive(Debug, Clone)]
160pub struct PositionConversion {
161    /// Original character position.
162    pub char_pos: usize,
163    /// Converted token index (if found).
164    pub token_idx: Option<usize>,
165    /// The token at that position (if found).
166    pub token: Option<String>,
167    /// Whether this was an exact match or fuzzy.
168    pub exact_match: bool,
169}
170
171/// Convert multiple character positions to token indices.
172#[must_use]
173pub fn convert_positions(
174    encoding: &EncodingWithOffsets,
175    char_positions: &[usize],
176) -> Vec<PositionConversion> {
177    char_positions
178        .iter()
179        .map(|&char_pos| {
180            let exact = encoding.char_to_token(char_pos);
181            let (token_idx, exact_match) = if exact.is_some() {
182                (exact, true)
183            } else {
184                (encoding.char_to_token_fuzzy(char_pos), false)
185            };
186
187            let token = token_idx.and_then(|idx| encoding.tokens.get(idx).cloned());
188
189            PositionConversion {
190                char_pos,
191                token_idx,
192                token,
193                exact_match,
194            }
195        })
196        .collect()
197}
198
199/// Find the character position of a marker pattern in text.
200///
201/// Returns the byte offset of the first occurrence of `marker` in `text`,
202/// or `None` if not found.
203#[cfg(test)]
204#[must_use]
205fn find_marker_char_pos(text: &str, marker: &str) -> Option<usize> {
206    text.find(marker)
207}
208
209// ---------------------------------------------------------------------------
210// Tests
211// ---------------------------------------------------------------------------
212
213#[cfg(test)]
214#[allow(clippy::unwrap_used, clippy::expect_used)]
215mod tests {
216    use super::*;
217
218    fn sample_encoding() -> EncodingWithOffsets {
219        // Simulates tokenization of "def add(a, b):"
220        EncodingWithOffsets::new(
221            vec![1, 2, 3, 4, 5, 6, 7, 8],
222            vec![
223                "def".into(),
224                " ".into(),
225                "add".into(),
226                "(".into(),
227                "a".into(),
228                ",".into(),
229                " ".into(),
230                "b".into(),
231            ],
232            vec![
233                (0, 3),
234                (3, 4),
235                (4, 7),
236                (7, 8),
237                (8, 9),
238                (9, 10),
239                (10, 11),
240                (11, 12),
241            ],
242        )
243    }
244
245    #[test]
246    fn char_to_token_exact() {
247        let encoding = sample_encoding();
248
249        // 'd' at position 0 → token 0 ("def")
250        assert_eq!(encoding.char_to_token(0), Some(0));
251        // 'a' in "add" at position 4 → token 2
252        assert_eq!(encoding.char_to_token(4), Some(2));
253        // Parameter 'a' at position 8 → token 4
254        assert_eq!(encoding.char_to_token(8), Some(4));
255        // Beyond all tokens
256        assert_eq!(encoding.char_to_token(100), None);
257    }
258
259    #[test]
260    fn char_to_token_fuzzy_fallback() {
261        let encoding = sample_encoding();
262
263        // Position 12 is beyond all tokens → fuzzy finds closest
264        let result = encoding.char_to_token_fuzzy(12);
265        assert!(result.is_some());
266    }
267
268    #[test]
269    fn char_range_to_tokens_overlap() {
270        let encoding = sample_encoding();
271
272        // Characters 3..7 overlap tokens: " " (3,4), "add" (4,7)
273        let tokens = encoding.char_range_to_tokens(3, 7);
274        assert_eq!(tokens, vec![1, 2]);
275    }
276
277    #[test]
278    fn token_to_char_range_roundtrip() {
279        let encoding = sample_encoding();
280
281        assert_eq!(encoding.token_to_char_range(2), Some((4, 7))); // "add"
282        assert_eq!(encoding.token_to_char_range(100), None);
283    }
284
285    #[test]
286    fn convert_positions_batch() {
287        let encoding = sample_encoding();
288        let results = convert_positions(&encoding, &[0, 4, 100]);
289
290        assert_eq!(results.len(), 3);
291        assert!(results[0].exact_match);
292        assert_eq!(results[0].token_idx, Some(0));
293        assert!(results[1].exact_match);
294        assert_eq!(results[1].token_idx, Some(2));
295        assert!(!results[2].exact_match); // fuzzy fallback
296    }
297
298    #[test]
299    fn find_marker() {
300        let code = "def add(a, b):\n    \"\"\"\n    >>> add(2, 3)\n    5\n    \"\"\"";
301        assert!(find_marker_char_pos(code, ">>>").is_some());
302        assert!(find_marker_char_pos(code, "zzz").is_none());
303    }
304
305    #[test]
306    fn encoding_len_and_empty() {
307        let encoding = sample_encoding();
308        assert_eq!(encoding.len(), 8);
309        assert!(!encoding.is_empty());
310
311        let empty = EncodingWithOffsets::new(vec![], vec![], vec![]);
312        assert_eq!(empty.len(), 0);
313        assert!(empty.is_empty());
314    }
315}