use crate::hyperdim::HVec10240;
const FNV1A_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325;
const FNV1A_PRIME: u64 = 0x0000_0100_0000_01b3;
#[inline]
fn fnv1a_hash(bytes: &[u8]) -> u64 {
let mut hash = FNV1A_OFFSET_BASIS;
for &byte in bytes {
hash ^= byte as u64;
hash = hash.wrapping_mul(FNV1A_PRIME);
}
hash
}
#[derive(Debug, Clone)]
pub struct TextEncoderConfig {
pub position_stride: usize,
pub ngram_size: Option<usize>,
pub lowercase: bool,
pub code_aware: bool,
}
impl Default for TextEncoderConfig {
fn default() -> Self {
Self {
position_stride: 1,
ngram_size: None,
lowercase: true,
code_aware: false,
}
}
}
#[derive(Debug, Clone)]
pub struct TextEncoder {
config: TextEncoderConfig,
}
impl Default for TextEncoder {
fn default() -> Self {
Self::new()
}
}
impl TextEncoder {
pub fn new() -> Self {
Self {
config: TextEncoderConfig::default(),
}
}
pub fn with_config(config: TextEncoderConfig) -> Self {
Self { config }
}
pub fn new_code_aware() -> Self {
Self {
config: TextEncoderConfig {
ngram_size: Some(3), code_aware: true,
..Default::default()
},
}
}
fn tokenize_code(text: &str) -> Vec<String> {
let mut tokens = Vec::new();
for word in text.split_whitespace() {
let parts = Self::split_on_separators(word);
tokens.extend(parts);
}
tokens
}
fn split_on_separators(word: &str) -> Vec<String> {
let mut result = Vec::new();
let mut current = String::new();
let chars: Vec<char> = word.chars().collect();
let mut i = 0;
while i < chars.len() {
if i + 1 < chars.len() && chars[i] == ':' && chars[i + 1] == ':' {
if !current.is_empty() {
result.push(current.clone());
current.clear();
}
i += 2;
continue;
}
let c = chars[i];
if c == '_' || c == '-' || c == '.' || c == '/' {
if !current.is_empty() {
result.push(current.clone());
current.clear();
}
i += 1;
continue;
}
current.push(c);
i += 1;
}
if !current.is_empty() {
result.push(current);
}
result
}
pub fn encode(&self, text: &str) -> HVec10240 {
let processed = if self.config.lowercase {
text.to_lowercase()
} else {
text.to_string()
};
let tokens = if self.config.code_aware {
Self::tokenize_code(&processed)
} else {
processed
.split_whitespace()
.map(|s| s.to_string())
.collect()
};
if tokens.is_empty() {
return HVec10240::zero();
}
let encoded_vectors: Vec<HVec10240> = tokens
.iter()
.enumerate()
.map(|(pos, token)| {
let base = self.token_to_hvec(token);
base.permute(pos * self.config.position_stride)
})
.collect();
let mut result = HVec10240::bundle(&encoded_vectors).unwrap_or_else(|_| HVec10240::zero());
if let Some(n) = self.config.ngram_size {
let ngram_hv = self.encode_ngrams(&processed, n);
result = HVec10240::bundle(&[result, ngram_hv]).unwrap_or_else(|_| HVec10240::zero());
}
result
}
pub fn encode_with_ngrams(&self, text: &str, n: usize) -> HVec10240 {
let config = TextEncoderConfig {
ngram_size: Some(n),
..self.config.clone()
};
let encoder = Self::with_config(config);
encoder.encode(text)
}
fn token_to_hvec(&self, token: &str) -> HVec10240 {
let hash = self.stable_hash(token);
HVec10240::new_seeded(hash)
}
fn stable_hash(&self, token: &str) -> u64 {
fnv1a_hash(token.as_bytes())
}
fn encode_ngrams(&self, text: &str, n: usize) -> HVec10240 {
let chars: Vec<char> = text.chars().collect();
if chars.len() < n {
return HVec10240::zero();
}
let ngrams: Vec<String> = chars
.windows(n)
.map(|window| window.iter().collect::<String>())
.collect();
if ngrams.is_empty() {
return HVec10240::zero();
}
let ngram_vectors: Vec<HVec10240> = ngrams
.iter()
.map(|ngram| self.token_to_hvec(ngram))
.collect();
HVec10240::bundle(&ngram_vectors).unwrap_or_else(|_| HVec10240::zero())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_deterministic() {
let encoder = TextEncoder::new();
let hv1 = encoder.encode("hello world");
let hv2 = encoder.encode("hello world");
assert!(hv1.cosine_similarity(&hv2) > 0.99);
}
#[test]
fn test_encode_empty_returns_zero() {
let encoder = TextEncoder::new();
let hv = encoder.encode("");
assert_eq!(hv, HVec10240::zero());
}
#[test]
fn test_encode_whitespace_only_returns_zero() {
let encoder = TextEncoder::new();
let hv = encoder.encode(" \t\n ");
assert_eq!(hv, HVec10240::zero());
}
#[test]
fn test_encode_similar_texts() {
let encoder = TextEncoder::new();
let hv1 = encoder.encode("the quick brown fox");
let hv2 = encoder.encode("the quick brown fox jumps");
assert!(hv1.cosine_similarity(&hv2) > 0.5);
}
#[test]
fn test_encode_dissimilar_texts() {
let encoder = TextEncoder::new();
let hv1 = encoder.encode("hello world");
let hv2 = encoder.encode("xyzzy plugh");
assert!(hv1.cosine_similarity(&hv2) < 0.7);
}
#[test]
fn test_encode_with_ngrams() {
let encoder = TextEncoder::new();
let hv1 = encoder.encode_with_ngrams("hello", 2);
let hv2 = encoder.encode_with_ngrams("hello", 2);
assert!(hv1.cosine_similarity(&hv2) > 0.99);
}
#[test]
fn test_encode_case_insensitive_by_default() {
let encoder = TextEncoder::new();
let hv1 = encoder.encode("Hello World");
let hv2 = encoder.encode("hello world");
assert!(hv1.cosine_similarity(&hv2) > 0.99);
}
#[test]
fn test_encode_case_sensitive() {
let config = TextEncoderConfig {
lowercase: false,
..Default::default()
};
let encoder = TextEncoder::with_config(config);
let hv1 = encoder.encode("Hello World");
let hv2 = encoder.encode("hello world");
assert!(hv1.cosine_similarity(&hv2) < 0.99);
}
#[test]
fn test_position_encoding_affects_result() {
let encoder = TextEncoder::new();
let hv1 = encoder.encode("cat dog");
let hv2 = encoder.encode("dog cat");
assert!(hv1.cosine_similarity(&hv2) < 0.99);
}
#[test]
fn test_config_custom_stride() {
let config = TextEncoderConfig {
position_stride: 5,
..Default::default()
};
let encoder = TextEncoder::with_config(config);
let hv = encoder.encode("hello world");
assert_ne!(hv, HVec10240::zero());
}
#[test]
fn test_code_aware_tokenize_snake_case() {
let tokens = TextEncoder::tokenize_code("my_function_name");
assert_eq!(tokens, vec!["my", "function", "name"]);
}
#[test]
fn test_code_aware_tokenize_camel_case() {
let tokens = TextEncoder::tokenize_code("MyClassName");
assert_eq!(tokens, vec!["MyClassName"]);
}
#[test]
fn test_code_aware_tokenize_path() {
let tokens = TextEncoder::tokenize_code("src/lib.rs");
assert_eq!(tokens, vec!["src", "lib", "rs"]);
}
#[test]
fn test_code_aware_tokenize_double_colon() {
let tokens = TextEncoder::tokenize_code("std::collections::HashMap");
assert_eq!(tokens, vec!["std", "collections", "HashMap"]);
}
#[test]
fn test_code_aware_tokenize_mixed() {
let tokens = TextEncoder::tokenize_code("my_module::MyClass.method_name");
assert_eq!(tokens, vec!["my", "module", "MyClass", "method", "name"]);
}
#[test]
fn test_code_aware_similarity() {
let encoder = TextEncoder::new_code_aware();
let hv1 = encoder.encode("get_user_by_id");
let hv2 = encoder.encode("get_user_by_name");
assert!(hv1.cosine_similarity(&hv2) > 0.5);
}
#[test]
fn test_code_aware_deterministic() {
let encoder = TextEncoder::new_code_aware();
let hv1 = encoder.encode("fn process_data(input: &str) -> Result");
let hv2 = encoder.encode("fn process_data(input: &str) -> Result");
assert!(hv1.cosine_similarity(&hv2) > 0.99);
}
#[test]
fn test_code_aware_vs_regular() {
let regular = TextEncoder::new();
let code_aware = TextEncoder::new_code_aware();
let hv1 = regular.encode("my_function_name");
let hv2 = code_aware.encode("my_function_name");
assert!(hv1.cosine_similarity(&hv2) < 0.95);
}
#[test]
fn test_split_on_separators_edge_cases() {
let tokens = TextEncoder::split_on_separators("");
assert!(tokens.is_empty());
let tokens = TextEncoder::split_on_separators("___");
assert!(tokens.is_empty());
let tokens = TextEncoder::split_on_separators("_leading");
assert_eq!(tokens, vec!["leading"]);
let tokens = TextEncoder::split_on_separators("trailing_");
assert_eq!(tokens, vec!["trailing"]);
}
}