use std::collections::HashMap;
use std::hash::Hash;
use crate::error::IoError;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnknownPolicy {
Error,
UseNullCode,
AddToVocab,
}
#[derive(Debug, Clone)]
pub struct DictionaryEncoder<T: Eq + Hash + Clone> {
pub vocab: Vec<T>,
code_map: HashMap<T, u16>,
pub unknown_policy: UnknownPolicy,
}
impl<T: Eq + Hash + Clone + Ord> DictionaryEncoder<T> {
pub fn new(unknown_policy: UnknownPolicy) -> Self {
Self {
vocab: Vec::new(),
code_map: HashMap::new(),
unknown_policy,
}
}
pub fn fit(data: &[T], unknown_policy: UnknownPolicy) -> IoResult<Self> {
let mut freq: HashMap<T, u64> = HashMap::new();
for item in data {
*freq.entry(item.clone()).or_insert(0) += 1;
}
const MAX_VOCAB: usize = u16::MAX as usize - 1; if freq.len() > MAX_VOCAB {
return Err(IoError::FormatError(format!(
"dictionary encoding: too many unique values ({} > {MAX_VOCAB})",
freq.len()
)));
}
let mut entries: Vec<(T, u64)> = freq.into_iter().collect();
entries.sort_by(|(va, fa), (vb, fb)| fb.cmp(fa).then_with(|| va.cmp(vb)));
let mut vocab = Vec::with_capacity(entries.len());
let mut code_map: HashMap<T, u16> = HashMap::with_capacity(entries.len());
for (code, (value, _freq)) in entries.into_iter().enumerate() {
vocab.push(value.clone());
code_map.insert(value, code as u16);
}
Ok(Self {
vocab,
code_map,
unknown_policy,
})
}
pub fn encode(&self, data: &[T]) -> IoResult<Vec<u16>>
where
T: std::fmt::Debug,
{
let mut codes = Vec::with_capacity(data.len());
for item in data {
let code = self.encode_single(item)?;
codes.push(code);
}
Ok(codes)
}
pub fn decode(&self, codes: &[u16]) -> IoResult<Vec<T>> {
let mut out = Vec::with_capacity(codes.len());
for &code in codes {
out.push(self.decode_single(code)?.clone());
}
Ok(out)
}
pub fn vocab_size(&self) -> usize {
self.vocab.len()
}
pub fn encode_single(&self, value: &T) -> IoResult<u16>
where
T: std::fmt::Debug,
{
match self.code_map.get(value) {
Some(&code) => Ok(code),
None => match self.unknown_policy {
UnknownPolicy::Error => Err(IoError::FormatError(format!(
"dictionary encoding: unknown value {:?}",
value
))),
UnknownPolicy::UseNullCode => Ok(u16::MAX),
UnknownPolicy::AddToVocab => Err(IoError::FormatError(
"dictionary encoding: AddToVocab requires a mutable encoder; \
use encode_single_mut instead"
.to_string(),
)),
},
}
}
pub fn encode_single_mut(&mut self, value: T) -> IoResult<u16>
where
T: std::fmt::Debug,
{
if let Some(&code) = self.code_map.get(&value) {
return Ok(code);
}
match self.unknown_policy {
UnknownPolicy::Error => Err(IoError::FormatError(format!(
"dictionary encoding: unknown value {:?}",
value
))),
UnknownPolicy::UseNullCode => Ok(u16::MAX),
UnknownPolicy::AddToVocab => {
const MAX_VOCAB: usize = u16::MAX as usize - 1;
if self.vocab.len() >= MAX_VOCAB {
return Err(IoError::FormatError(
"dictionary encoding: vocabulary capacity exhausted".to_string(),
));
}
let code = self.vocab.len() as u16;
self.vocab.push(value.clone());
self.code_map.insert(value, code);
Ok(code)
}
}
}
pub fn decode_single(&self, code: u16) -> IoResult<&T> {
if code == u16::MAX {
return Err(IoError::FormatError(
"dictionary encoding: null sentinel code (u16::MAX) encountered during decode"
.to_string(),
));
}
self.vocab.get(code as usize).ok_or_else(|| {
IoError::FormatError(format!(
"dictionary encoding: code {} out of range (vocab size {})",
code,
self.vocab.len()
))
})
}
}
pub type StringDictionaryEncoder = DictionaryEncoder<String>;
type IoResult<T> = crate::error::Result<T>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dict_encoder_fit_all_unique() {
let data: Vec<String> = (0..100).map(|i| format!("val_{i}")).collect();
let enc = DictionaryEncoder::fit(&data, UnknownPolicy::Error).expect("fit failed");
assert_eq!(enc.vocab_size(), 100);
}
#[test]
fn test_dict_encoder_encode_decode_roundtrip() {
let data: Vec<String> = vec!["a", "b", "a", "c", "b", "a"]
.into_iter()
.map(|s| s.to_string())
.collect();
let enc = DictionaryEncoder::fit(&data, UnknownPolicy::Error).expect("fit failed");
assert_eq!(enc.vocab[0], "a");
let codes = enc.encode(&data).expect("encode failed");
assert_eq!(codes.len(), 6);
let decoded = enc.decode(&codes).expect("decode failed");
assert_eq!(decoded, data);
}
#[test]
fn test_dict_encoder_unknown_error_policy() {
let train: Vec<String> = vec!["x", "y"].into_iter().map(|s| s.to_string()).collect();
let enc = DictionaryEncoder::fit(&train, UnknownPolicy::Error).expect("fit failed");
let unknown = vec!["z".to_string()];
let result = enc.encode(&unknown);
assert!(result.is_err(), "expected error for unknown value");
}
#[test]
fn test_dict_encoder_null_code_policy() {
let train: Vec<String> = vec!["a".to_string()];
let enc = DictionaryEncoder::fit(&train, UnknownPolicy::UseNullCode).expect("fit failed");
let data = vec!["a".to_string(), "unknown".to_string()];
let codes = enc.encode(&data).expect("encode failed");
assert_eq!(codes[0], 0);
assert_eq!(codes[1], u16::MAX);
}
#[test]
fn test_dict_encoder_add_to_vocab_policy() {
let train: Vec<String> = vec!["a".to_string()];
let mut enc =
DictionaryEncoder::fit(&train, UnknownPolicy::AddToVocab).expect("fit failed");
let code = enc
.encode_single_mut("b".to_string())
.expect("encode failed");
assert_eq!(enc.vocab_size(), 2);
assert_eq!(enc.vocab[code as usize], "b");
}
#[test]
fn test_dict_encoder_decode_null_sentinel_error() {
let train: Vec<String> = vec!["a".to_string()];
let enc = DictionaryEncoder::fit(&train, UnknownPolicy::Error).expect("fit failed");
let result = enc.decode(&[u16::MAX]);
assert!(result.is_err(), "decoding null sentinel should error");
}
#[test]
fn test_dict_encoder_too_many_unique() {
let data: Vec<String> = (0u32..65535).map(|i| format!("v{i}")).collect();
let result = DictionaryEncoder::fit(&data, UnknownPolicy::Error);
assert!(result.is_err(), "should error on > 65534 unique values");
}
}