#[derive(Debug, Clone)]
pub struct TokenWithOffset {
pub token: String,
pub start: usize,
pub end: usize,
}
#[derive(Debug, Clone)]
pub struct EncodingWithOffsets {
pub ids: Vec<u32>,
pub tokens: Vec<String>,
pub offsets: Vec<(usize, usize)>,
}
impl EncodingWithOffsets {
#[must_use]
pub const fn new(ids: Vec<u32>, tokens: Vec<String>, offsets: Vec<(usize, usize)>) -> Self {
Self {
ids,
tokens,
offsets,
}
}
#[must_use]
pub fn tokens_with_offsets(&self) -> Vec<TokenWithOffset> {
self.tokens
.iter()
.zip(self.offsets.iter())
.map(|(token, (start, end))| TokenWithOffset {
token: token.clone(),
start: *start,
end: *end,
})
.collect()
}
#[must_use]
pub fn char_to_token(&self, char_pos: usize) -> Option<usize> {
self.offsets
.iter()
.position(|(start, end)| char_pos >= *start && char_pos < *end)
}
#[must_use]
pub fn char_to_token_fuzzy(&self, char_pos: usize) -> Option<usize> {
if let Some(idx) = self.char_to_token(char_pos) {
return Some(idx);
}
self.offsets
.iter()
.enumerate()
.min_by_key(|(_, (start, end))| {
let mid = usize::midpoint(*start, *end);
char_pos.abs_diff(mid)
})
.map(|(idx, _)| idx)
}
#[must_use]
pub fn char_to_token_start(&self, char_pos: usize) -> Option<usize> {
self.offsets
.iter()
.position(|(start, _)| *start >= char_pos)
}
#[must_use]
pub fn char_range_to_tokens(&self, start_char: usize, end_char: usize) -> Vec<usize> {
self.offsets
.iter()
.enumerate()
.filter_map(|(idx, (start, end))| {
if *end > start_char && *start < end_char {
Some(idx)
} else {
None
}
})
.collect()
}
#[must_use]
pub fn token_to_char_range(&self, token_idx: usize) -> Option<(usize, usize)> {
self.offsets.get(token_idx).copied()
}
#[must_use]
pub const fn len(&self) -> usize {
self.tokens.len()
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.tokens.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct PositionConversion {
pub char_pos: usize,
pub token_idx: Option<usize>,
pub token: Option<String>,
pub exact_match: bool,
}
#[must_use]
pub fn convert_positions(
encoding: &EncodingWithOffsets,
char_positions: &[usize],
) -> Vec<PositionConversion> {
char_positions
.iter()
.map(|&char_pos| {
let exact = encoding.char_to_token(char_pos);
let (token_idx, exact_match) = if exact.is_some() {
(exact, true)
} else {
(encoding.char_to_token_fuzzy(char_pos), false)
};
let token = token_idx.and_then(|idx| encoding.tokens.get(idx).cloned());
PositionConversion {
char_pos,
token_idx,
token,
exact_match,
}
})
.collect()
}
#[cfg(test)]
#[must_use]
fn find_marker_char_pos(text: &str, marker: &str) -> Option<usize> {
text.find(marker)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
fn sample_encoding() -> EncodingWithOffsets {
EncodingWithOffsets::new(
vec![1, 2, 3, 4, 5, 6, 7, 8],
vec![
"def".into(),
" ".into(),
"add".into(),
"(".into(),
"a".into(),
",".into(),
" ".into(),
"b".into(),
],
vec![
(0, 3),
(3, 4),
(4, 7),
(7, 8),
(8, 9),
(9, 10),
(10, 11),
(11, 12),
],
)
}
#[test]
fn char_to_token_exact() {
let encoding = sample_encoding();
assert_eq!(encoding.char_to_token(0), Some(0));
assert_eq!(encoding.char_to_token(4), Some(2));
assert_eq!(encoding.char_to_token(8), Some(4));
assert_eq!(encoding.char_to_token(100), None);
}
#[test]
fn char_to_token_fuzzy_fallback() {
let encoding = sample_encoding();
let result = encoding.char_to_token_fuzzy(12);
assert!(result.is_some());
}
#[test]
fn char_range_to_tokens_overlap() {
let encoding = sample_encoding();
let tokens = encoding.char_range_to_tokens(3, 7);
assert_eq!(tokens, vec![1, 2]);
}
#[test]
fn token_to_char_range_roundtrip() {
let encoding = sample_encoding();
assert_eq!(encoding.token_to_char_range(2), Some((4, 7))); assert_eq!(encoding.token_to_char_range(100), None);
}
#[test]
fn convert_positions_batch() {
let encoding = sample_encoding();
let results = convert_positions(&encoding, &[0, 4, 100]);
assert_eq!(results.len(), 3);
assert!(results[0].exact_match);
assert_eq!(results[0].token_idx, Some(0));
assert!(results[1].exact_match);
assert_eq!(results[1].token_idx, Some(2));
assert!(!results[2].exact_match); }
#[test]
fn find_marker() {
let code = "def add(a, b):\n \"\"\"\n >>> add(2, 3)\n 5\n \"\"\"";
assert!(find_marker_char_pos(code, ">>>").is_some());
assert!(find_marker_char_pos(code, "zzz").is_none());
}
#[test]
fn encoding_len_and_empty() {
let encoding = sample_encoding();
assert_eq!(encoding.len(), 8);
assert!(!encoding.is_empty());
let empty = EncodingWithOffsets::new(vec![], vec![], vec![]);
assert_eq!(empty.len(), 0);
assert!(empty.is_empty());
}
}