use std::collections::HashMap;
use std::sync::RwLock;
const PUA_START: u32 = 0xE000;
const PUA_END: u32 = 0xF8FF;
pub const MAX_SYMBOLS: usize = (PUA_END - PUA_START + 1) as usize;
type AcCache = (Vec<String>, Vec<String>, aho_corasick::AhoCorasick);
pub struct SymbolDict {
encode: HashMap<String, char>,
decode: HashMap<char, String>,
next_code: u32,
ac_cache: RwLock<Option<AcCache>>,
}
impl Default for SymbolDict {
fn default() -> Self {
Self::new()
}
}
impl SymbolDict {
pub fn new() -> Self {
Self {
encode: HashMap::new(),
decode: HashMap::new(),
next_code: PUA_START,
ac_cache: RwLock::new(None),
}
}
pub fn len(&self) -> usize {
self.encode.len()
}
pub fn is_empty(&self) -> bool {
self.encode.is_empty()
}
pub fn intern(&mut self, term: &str) -> Result<char, SymbolOverflowError> {
if let Some(&sym) = self.encode.get(term) {
return Ok(sym);
}
if self.next_code > PUA_END {
return Err(SymbolOverflowError { max: MAX_SYMBOLS });
}
let sym = char::from_u32(self.next_code)
.expect("codepoints within the PUA range are always valid");
self.encode.insert(term.to_string(), sym);
self.decode.insert(sym, term.to_string());
self.next_code += 1;
*self.ac_cache.write().unwrap() = None;
Ok(sym)
}
#[cfg(test)]
pub(crate) fn decode_str(&self, input: &str) -> String {
input
.chars()
.flat_map(|c| {
if let Some(term) = self.decode.get(&c) {
term.chars().collect::<Vec<_>>()
} else {
vec![c]
}
})
.collect()
}
pub fn encode_str(&self, input: &str) -> String {
if self.encode.is_empty() {
return input.to_string();
}
{
let cache = self.ac_cache.read().unwrap();
if let Some((_, replacements, ac)) = cache.as_ref() {
return ac.replace_all(input, replacements);
}
}
{
let mut pairs: Vec<(String, String)> = self
.encode
.iter()
.map(|(k, v)| (k.clone(), v.to_string()))
.collect();
pairs.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
let patterns: Vec<&str> = pairs.iter().map(|(k, _)| k.as_str()).collect();
let replacements: Vec<String> = pairs.iter().map(|(_, v)| v.clone()).collect();
let ac = aho_corasick::AhoCorasick::builder()
.match_kind(aho_corasick::MatchKind::LeftmostLongest)
.build(&patterns)
.expect("AhoCorasick build cannot fail with valid patterns");
let pattern_strs: Vec<String> = pairs.into_iter().map(|(k, _)| k).collect();
*self.ac_cache.write().unwrap() = Some((pattern_strs, replacements, ac));
}
let cache = self.ac_cache.read().unwrap();
let (_, replacements, ac) = cache.as_ref().unwrap();
ac.replace_all(input, replacements)
}
pub fn render_dict_header(&self) -> String {
if self.is_empty() {
return String::new();
}
let mut entries: Vec<(char, &str)> =
self.decode.iter().map(|(c, s)| (*c, s.as_str())).collect();
entries.sort_by_key(|(c, _)| *c as u32);
let body: String = entries
.iter()
.map(|(sym, term)| format!("{}={}", sym, term))
.collect::<Vec<_>>()
.join("\n");
format!("<D>\n{}\n</D>\n", body)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SymbolOverflowError {
pub max: usize,
}
impl std::fmt::Display for SymbolOverflowError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "symbol table overflow: maximum {} symbols", self.max)
}
}
impl std::error::Error for SymbolOverflowError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn intern_idempotent() {
let mut dict = SymbolDict::new();
let sym1 = dict.intern("법률용어").unwrap();
let sym2 = dict.intern("법률용어").unwrap();
assert_eq!(
sym1, sym2,
"re-interning the same term must return the same symbol"
);
}
#[test]
fn encode_decode_roundtrip() {
let mut dict = SymbolDict::new();
dict.intern("손해배상").unwrap();
dict.intern("계약해제").unwrap();
let original = "손해배상 청구와 계약해제 요건";
let encoded = dict.encode_str(original);
let decoded = dict.decode_str(&encoded);
assert_eq!(
decoded, original,
"encode → decode round-trip must restore the original text"
);
}
#[test]
fn no_collision_with_dollar_sign() {
let mut dict = SymbolDict::new();
let sym = dict.intern("테스트용어").unwrap();
assert!(sym as u32 >= PUA_START);
assert!(sym as u32 <= PUA_END);
}
#[test]
fn decode_passes_through_unknown_pua() {
let dict = SymbolDict::new(); let unknown = "\u{E100}hello";
assert_eq!(dict.decode_str(unknown), unknown);
}
#[test]
fn render_dict_header_empty() {
let dict = SymbolDict::new();
assert!(dict.render_dict_header().is_empty());
}
#[test]
fn render_dict_header_format() {
let mut dict = SymbolDict::new();
dict.intern("Alpha").unwrap();
let header = dict.render_dict_header();
assert!(header.starts_with("<D>\n"));
assert!(header.contains("Alpha"));
assert!(header.ends_with("</D>\n"));
}
#[test]
fn overflow_returns_error() {
let mut dict = SymbolDict::new();
dict.next_code = PUA_END + 1;
let result = dict.intern("overflow_term");
assert!(result.is_err());
}
#[test]
fn encode_str_aho_corasick_no_partial_match() {
let mut dict = SymbolDict::new();
dict.intern("ab").unwrap();
dict.intern("abc").unwrap();
let sym_ab = *dict.encode.get("ab").unwrap();
let sym_abc = *dict.encode.get("abc").unwrap();
let encoded = dict.encode_str("abc");
assert_eq!(
encoded,
sym_abc.to_string(),
"LeftmostLongest: full 'abc' must be substituted, sym_ab={:?}",
sym_ab
);
}
#[test]
fn symbol_dict_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<SymbolDict>();
}
}