aft-tokenizer 0.36.1

Claude lookup-encoding tokenizer for Agent File Tools
Documentation
use fancy_regex::Regex;
use std::collections::HashMap;
use std::sync::OnceLock;

// Generated by build.rs from ai-tokenizer Claude encoding; vendored in-tree so
// the crate builds from a clean clone without fetching `tmp/ai-tokenizer-pkg/`.
// To refresh: see crates/aft-tokenizer/build.rs.
include!("claude_data.rs");

const PAT_STR: &str = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
const NO_RANK: u32 = u32::MAX;

static PATTERN_REGEX: OnceLock<Regex> = OnceLock::new();
static STRING_ENCODER: OnceLock<HashMap<&'static str, u32>> = OnceLock::new();
static BINARY_ENCODER: OnceLock<HashMap<Vec<u8>, u32>> = OnceLock::new();

/// Count Claude tokens in `text`.
pub fn count_tokens(text: &str) -> usize {
    if text.is_empty() {
        return 0;
    }
    encode_internal(text).len()
}

/// Encode `text` into Claude token IDs.
pub fn encode(text: &str) -> Vec<u32> {
    if text.is_empty() {
        return Vec::new();
    }
    encode_internal(text)
}

fn encode_internal(text: &str) -> Vec<u32> {
    // ai-tokenizer checks direct rank before pre-tokenizing for short strings.
    if text.encode_utf16().take(10).count() < 10 {
        if let Some(&rank) = string_encoder().get(text) {
            return vec![rank];
        }
    }

    let mut result = Vec::new();
    for piece in pattern_regex().find_iter(text) {
        let piece = piece
            .expect("Claude pre-tokenizer regex should not fail")
            .as_str();

        if let Some(&rank) = string_encoder().get(piece) {
            result.push(rank);
            continue;
        }

        result.extend(byte_pair_merge(piece.as_bytes()));
    }
    result
}

fn pattern_regex() -> &'static Regex {
    PATTERN_REGEX.get_or_init(|| Regex::new(PAT_STR).expect("Claude pre-tokenizer regex"))
}

fn string_encoder() -> &'static HashMap<&'static str, u32> {
    STRING_ENCODER.get_or_init(|| STRING_ENCODER_ENTRIES.iter().copied().collect())
}

fn binary_encoder() -> &'static HashMap<Vec<u8>, u32> {
    BINARY_ENCODER.get_or_init(|| {
        BINARY_ENCODER_ENTRIES
            .iter()
            .map(|(bytes, rank)| (bytes.to_vec(), *rank))
            .collect()
    })
}

fn byte_pair_merge(piece: &[u8]) -> Vec<u32> {
    let len = piece.len();
    let mut starts = Vec::with_capacity(len + 1);
    let mut ranks = Vec::with_capacity(len + 1);

    for i in 0..=len {
        starts.push(i);
        if i < len.saturating_sub(1) {
            ranks.push(get_rank_for_slice(&piece[i..i + 2]));
        } else {
            ranks.push(NO_RANK);
        }
    }

    while starts.len() > 1 {
        let mut min_rank = NO_RANK;
        let mut min_idx = None;
        for (idx, &rank) in ranks.iter().take(ranks.len().saturating_sub(1)).enumerate() {
            if rank < min_rank {
                min_rank = rank;
                min_idx = Some(idx);
            }
        }

        let Some(min_idx) = min_idx else { break };
        if min_rank == NO_RANK {
            break;
        }

        starts.remove(min_idx + 1);
        ranks.remove(min_idx);
        if min_idx < ranks.len() {
            ranks[min_idx] = get_rank(piece, &starts, min_idx);
        }
        if min_idx > 0 {
            ranks[min_idx - 1] = get_rank(piece, &starts, min_idx - 1);
        }
    }

    let mut output = Vec::with_capacity(starts.len().saturating_sub(1));
    for window in starts.windows(2) {
        let rank = get_rank_for_slice(&piece[window[0]..window[1]]);
        if rank != NO_RANK {
            output.push(rank);
        }
    }
    output
}

fn get_rank(piece: &[u8], starts: &[usize], start_index: usize) -> u32 {
    let Some(&pair_start) = starts.get(start_index) else {
        return NO_RANK;
    };
    let Some(&pair_end) = starts.get(start_index + 2) else {
        return NO_RANK;
    };
    get_rank_for_slice(&piece[pair_start..pair_end])
}

fn get_rank_for_slice(slice: &[u8]) -> u32 {
    if let Ok(as_string) = std::str::from_utf8(slice) {
        if let Some(&rank) = string_encoder().get(as_string) {
            return rank;
        }
    }

    binary_encoder().get(slice).copied().unwrap_or(NO_RANK)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn empty_input_has_no_tokens() {
        assert!(encode("").is_empty());
        assert_eq!(count_tokens(""), 0);
    }

    #[test]
    fn direct_ascii_token() {
        assert_eq!(encode("!"), vec![5]);
        assert_eq!(count_tokens("!"), 1);
    }
}