use crate::hyperdim::HVec10240;
const FNV1A_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325;
const FNV1A_PRIME: u64 = 0x0000_0100_0000_01b3;
#[inline]
fn fnv1a_hash(bytes: &[u8]) -> u64 {
let mut hash = FNV1A_OFFSET_BASIS;
for &byte in bytes {
hash ^= byte as u64;
hash = hash.wrapping_mul(FNV1A_PRIME);
}
hash
}
#[derive(Debug, Clone)]
pub struct TextEncoderConfig {
pub position_stride: usize,
pub ngram_size: Option<usize>,
pub lowercase: bool,
pub code_aware: bool,
}
impl Default for TextEncoderConfig {
fn default() -> Self {
Self {
position_stride: 1,
ngram_size: None,
lowercase: true,
code_aware: false,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct TextEncoder {
config: TextEncoderConfig,
}
impl TextEncoder {
pub fn new() -> Self {
Self {
config: TextEncoderConfig::default(),
}
}
pub const fn with_config(config: TextEncoderConfig) -> Self {
Self { config }
}
pub fn new_code_aware() -> Self {
Self {
config: TextEncoderConfig {
ngram_size: Some(3), code_aware: true,
..Default::default()
},
}
}
pub const fn config(&self) -> &TextEncoderConfig {
&self.config
}
fn tokenize_code(text: &str) -> Vec<&str> {
let mut tokens = Vec::new();
for word in text.split_whitespace() {
let parts = Self::split_on_separators(word);
tokens.extend(parts);
}
tokens
}
fn split_on_separators(word: &str) -> Vec<&str> {
let mut result = Vec::new();
let mut start = 0;
let mut char_indices = word.char_indices().peekable();
while let Some((i, c)) = char_indices.next() {
let is_sep = match c {
':' => {
if let Some(&(_, next_c)) = char_indices.peek() {
if next_c == ':' {
char_indices.next(); if i > start {
result.push(&word[start..i]);
}
start = i + 2; continue;
}
}
false
}
'_' | '-' | '.' | '/' => true,
_ => false,
};
if is_sep {
if i > start {
result.push(&word[start..i]);
}
start = i + 1; }
}
if start < word.len() {
result.push(&word[start..]);
}
result
}
pub fn encode(&self, text: &str) -> HVec10240 {
let processed_owned: Option<String>;
let processed = if self.config.lowercase {
processed_owned = Some(text.to_lowercase());
processed_owned.as_ref().expect("owned string must be set")
} else {
processed_owned = None;
text
};
let _ = &processed_owned;
let tokens = if self.config.code_aware {
Self::tokenize_code(processed)
} else {
processed.split_whitespace().collect()
};
if tokens.is_empty() {
return HVec10240::zero();
}
let encoded_vectors: Vec<HVec10240> = tokens
.iter()
.enumerate()
.map(|(pos, &token)| {
let base = self.token_to_hvec(token);
base.permute(pos * self.config.position_stride)
})
.collect();
let mut result = HVec10240::bundle(&encoded_vectors).unwrap_or_else(|_| HVec10240::zero());
if let Some(n) = self.config.ngram_size {
let ngram_hv = self.encode_ngrams(processed, n);
result = HVec10240::bundle(&[result, ngram_hv]).unwrap_or_else(|_| HVec10240::zero());
}
result
}
pub fn encode_with_ngrams(&self, text: &str, n: usize) -> HVec10240 {
let config = TextEncoderConfig {
ngram_size: Some(n),
..self.config.clone()
};
let encoder = Self::with_config(config);
encoder.encode(text)
}
pub fn tokenize(text: &str, code_aware: bool, lowercase: bool) -> Vec<String> {
let processed_owned: Option<String>;
let processed = if lowercase {
processed_owned = Some(text.to_lowercase());
processed_owned.as_ref().expect("owned string must be set")
} else {
processed_owned = None;
text
};
let _ = &processed_owned;
if code_aware {
Self::tokenize_code(processed)
.into_iter()
.map(|s| s.to_string())
.collect()
} else {
processed
.split_whitespace()
.map(|s| s.to_string())
.collect()
}
}
fn token_to_hvec(&self, token: &str) -> HVec10240 {
let hash = self.stable_hash(token);
HVec10240::new_seeded(hash)
}
fn stable_hash(&self, token: &str) -> u64 {
fnv1a_hash(token.as_bytes())
}
fn encode_ngrams(&self, text: &str, n: usize) -> HVec10240 {
let char_indices: Vec<(usize, char)> = text.char_indices().collect();
if char_indices.len() < n {
return HVec10240::zero();
}
let ngram_vectors: Vec<HVec10240> = char_indices
.windows(n)
.map(|window| {
let start = window[0].0;
let end = window[n - 1].0 + window[n - 1].1.len_utf8();
self.token_to_hvec(&text[start..end])
})
.collect();
HVec10240::bundle(&ngram_vectors).unwrap_or_else(|_| HVec10240::zero())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_deterministic() {
let encoder = TextEncoder::new();
let text = "hello world";
let v1 = encoder.encode(text);
let v2 = encoder.encode(text);
assert_eq!(v1, v2);
}
#[test]
fn encode_position_aware() {
let encoder = TextEncoder::new();
let v1 = encoder.encode("cat sat");
let v2 = encoder.encode("sat cat");
assert_ne!(v1, v2);
}
#[test]
fn tokenize_splits_whitespace() {
let tokens = TextEncoder::tokenize("hello world test", false, true);
assert_eq!(tokens, vec!["hello", "world", "test"]);
}
#[test]
fn tokenize_lowercase() {
let tokens = TextEncoder::tokenize("HELLO World", false, true);
assert_eq!(tokens, vec!["hello", "world"]);
}
#[test]
fn tokenize_code_aware() {
let tokens = TextEncoder::tokenize("my_var::method", true, true);
assert!(tokens.contains(&"my".to_string()));
assert!(tokens.contains(&"var".to_string()));
assert!(tokens.contains(&"method".to_string()));
}
#[test]
fn encode_with_ngrams() {
let encoder = TextEncoder::new();
let v = encoder.encode_with_ngrams("abc", 2);
let zero = HVec10240::zero();
assert!(v.hamming_distance(&zero) > 0);
}
#[test]
fn stable_hash_consistent() {
let encoder = TextEncoder::new();
let h1 = encoder.stable_hash("test_token");
let h2 = encoder.stable_hash("test_token");
assert_eq!(h1, h2);
}
}