use std::borrow::Cow;
use std::sync::OnceLock;
const TATWEEL: char = '\u{0640}';
const MAGIC: &[u8; 4] = b"TRLD";
use crate::limits::MAX_DICT_ENTRIES;
#[derive(Clone, Copy)]
struct Span {
off: u32,
len: u32,
}
#[derive(Clone, Copy)]
struct UniEntry {
skel: Span,
form: Span,
}
#[derive(Clone, Copy)]
struct BiEntry {
curr: Span,
form: Span,
}
#[derive(Clone, Copy)]
struct BiGroup {
prev: Span,
start: u32,
len: u32,
}
pub struct ContextDict {
data: Cow<'static, [u8]>,
unigrams: Vec<UniEntry>,
bi_groups: Vec<BiGroup>,
bi_entries: Vec<BiEntry>,
}
fn read_u16(data: &[u8], pos: usize) -> Result<u16, String> {
let end = pos.checked_add(2).ok_or("dictionary offset overflow")?;
let slice = data
.get(pos..end)
.ok_or("unexpected end of dictionary data")?;
Ok(u16::from_le_bytes(
slice.try_into().unwrap(), ))
}
fn read_u32(data: &[u8], pos: usize) -> Result<u32, String> {
let end = pos.checked_add(4).ok_or("dictionary offset overflow")?;
let slice = data
.get(pos..end)
.ok_or("unexpected end of dictionary data")?;
Ok(u32::from_le_bytes(
slice.try_into().unwrap(), ))
}
fn read_str_span(data: &[u8], pos: usize, len: usize) -> Result<Span, String> {
let end = pos.checked_add(len).ok_or("dictionary offset overflow")?;
let slice = data
.get(pos..end)
.ok_or("unexpected end of dictionary data")?;
std::str::from_utf8(slice).map_err(|e| e.to_string())?;
Ok(Span {
off: u32::try_from(pos).map_err(|_| "dictionary offset exceeds u32".to_string())?,
len: u32::try_from(len).map_err(|_| "string length exceeds u32".to_string())?,
})
}
impl ContextDict {
#[cfg_attr(not(test), allow(dead_code))]
pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
Self::build(Cow::Owned(data.to_vec()))
}
#[cfg_attr(not(feature = "embed-dicts"), allow(dead_code))]
pub fn from_static(data: &'static [u8]) -> Result<Self, String> {
Self::build(Cow::Borrowed(data))
}
pub fn from_owned(data: Vec<u8>) -> Result<Self, String> {
Self::build(Cow::Owned(data))
}
fn build(data: Cow<'static, [u8]>) -> Result<Self, String> {
struct RawBi {
prev: Span,
curr: Span,
form: Span,
}
let bytes: &[u8] = &data;
if bytes.len() < 24 {
return Err("Dictionary too small".into());
}
if &bytes[0..4] != MAGIC {
return Err("Invalid dictionary magic".into());
}
let version = read_u32(bytes, 4)?;
if version != 1 {
return Err(format!("Unsupported dictionary version: {version}"));
}
let unigram_count = read_u32(bytes, 8)? as usize;
let bigram_count = read_u32(bytes, 12)? as usize;
let unigram_offset = read_u32(bytes, 16)? as usize;
let bigram_offset = read_u32(bytes, 20)? as usize;
if unigram_offset < 24 || bigram_offset < 24 {
return Err("Dictionary section offset overlaps header".into());
}
let span_bytes = |s: Span| &bytes[s.off as usize..s.off as usize + s.len as usize];
let mut unigrams: Vec<UniEntry> = Vec::with_capacity(unigram_count.min(MAX_DICT_ENTRIES));
let mut pos = unigram_offset;
for _ in 0..unigram_count {
let skel_len = read_u16(bytes, pos)? as usize;
pos += 2;
let skel = read_str_span(bytes, pos, skel_len)?;
pos += skel_len;
let num_forms = read_u16(bytes, pos)? as usize;
pos += 2;
let mut best: Option<Span> = None;
for i in 0..num_forms {
let form_len = read_u16(bytes, pos)? as usize;
pos += 2;
let form = read_str_span(bytes, pos, form_len)?;
pos += form_len;
let _freq = read_u32(bytes, pos)?;
pos += 4;
if i == 0 {
best = Some(form);
}
}
if let Some(form) = best {
unigrams.push(UniEntry { skel, form });
}
}
let mut raw: Vec<RawBi> = Vec::with_capacity(bigram_count.min(MAX_DICT_ENTRIES));
pos = bigram_offset;
for _ in 0..bigram_count {
let prev_len = read_u16(bytes, pos)? as usize;
pos += 2;
let prev = read_str_span(bytes, pos, prev_len)?;
pos += prev_len;
let curr_len = read_u16(bytes, pos)? as usize;
pos += 2;
let curr = read_str_span(bytes, pos, curr_len)?;
pos += curr_len;
let form_len = read_u16(bytes, pos)? as usize;
pos += 2;
let form = read_str_span(bytes, pos, form_len)?;
pos += form_len;
raw.push(RawBi { prev, curr, form });
}
unigrams.sort_by(|a, b| span_bytes(a.skel).cmp(span_bytes(b.skel)));
unigrams.dedup_by(|a, b| span_bytes(a.skel) == span_bytes(b.skel));
raw.sort_by(|a, b| {
span_bytes(a.prev)
.cmp(span_bytes(b.prev))
.then_with(|| span_bytes(a.curr).cmp(span_bytes(b.curr)))
});
raw.dedup_by(|a, b| {
span_bytes(a.prev) == span_bytes(b.prev) && span_bytes(a.curr) == span_bytes(b.curr)
});
let mut bi_entries: Vec<BiEntry> = Vec::with_capacity(raw.len());
let mut bi_groups: Vec<BiGroup> = Vec::new();
let mut i = 0usize;
while i < raw.len() {
let prev = raw[i].prev;
let start = bi_entries.len();
let mut j = i;
while j < raw.len() && span_bytes(raw[j].prev) == span_bytes(prev) {
bi_entries.push(BiEntry {
curr: raw[j].curr,
form: raw[j].form,
});
j += 1;
}
bi_groups.push(BiGroup {
prev,
start: u32::try_from(start).map_err(|_| "bigram index exceeds u32".to_string())?,
len: u32::try_from(bi_entries.len() - start)
.map_err(|_| "bigram group exceeds u32".to_string())?,
});
i = j;
}
Ok(ContextDict {
data,
unigrams,
bi_groups,
bi_entries,
})
}
#[inline]
fn span_slice(&self, span: Span) -> &[u8] {
&self.data[span.off as usize..span.off as usize + span.len as usize]
}
#[inline]
fn span_str(&self, span: Span) -> &str {
std::str::from_utf8(self.span_slice(span)).unwrap_or("")
}
pub fn resolve(&self, prev_skeleton: Option<&str>, curr_skeleton: &str) -> Option<&str> {
let curr = curr_skeleton.as_bytes();
if let Some(prev) = prev_skeleton {
let prev_bytes = prev.as_bytes();
if let Ok(gi) = self
.bi_groups
.binary_search_by(|g| self.span_slice(g.prev).cmp(prev_bytes))
{
let g = self.bi_groups[gi];
let entries = &self.bi_entries[g.start as usize..(g.start + g.len) as usize];
if let Ok(ei) = entries.binary_search_by(|e| self.span_slice(e.curr).cmp(curr)) {
return Some(self.span_str(entries[ei].form));
}
}
}
if let Ok(ui) = self
.unigrams
.binary_search_by(|e| self.span_slice(e.skel).cmp(curr))
{
return Some(self.span_str(self.unigrams[ui].form));
}
None
}
#[cfg_attr(not(test), allow(dead_code))]
pub fn stats(&self) -> (usize, usize) {
(self.unigrams.len(), self.bi_entries.len())
}
}
pub fn strip_arabic_diacritics(word: &str) -> String {
word.chars()
.filter(|&c| !is_arabic_diacritic(c) && c != TATWEEL)
.collect()
}
pub fn strip_hebrew_niqqud(word: &str) -> String {
word.chars().filter(|&c| !is_hebrew_niqqud(c)).collect()
}
pub fn strip_diacritics(word: &str, lang: Option<&str>) -> String {
match lang {
Some("he") => strip_hebrew_niqqud(word),
_ => strip_arabic_diacritics(word), }
}
fn is_arabic_char(c: char) -> bool {
matches!(c as u32,
0x0600..=0x06FF |
0x0750..=0x077F |
0x08A0..=0x08FF |
0xFB50..=0xFDFF |
0xFE70..=0xFEFF
)
}
fn is_hebrew_char(c: char) -> bool {
matches!(c as u32, 0x0590..=0x05FF | 0xFB1D..=0xFB4F)
}
#[inline]
fn is_arabic_diacritic(c: char) -> bool {
matches!(c as u32, 0x064B..=0x0655 | 0x0670)
}
#[inline]
fn is_hebrew_niqqud(c: char) -> bool {
matches!(c as u32, 0x05B0..=0x05C5) && !matches!(c as u32, 0x05BE | 0x05C0 | 0x05C3)
}
pub fn tokenize(text: &str) -> Vec<Token<'_>> {
#[inline]
fn is_word_char(c: char) -> bool {
is_arabic_char(c)
|| is_hebrew_char(c)
|| is_arabic_diacritic(c)
|| is_hebrew_niqqud(c)
|| c == TATWEEL
}
let mut tokens = Vec::new();
let mut span_start = 0usize;
let mut in_word = false;
let mut started = false;
for (i, c) in text.char_indices() {
let word = is_word_char(c);
if !started {
span_start = i;
in_word = word;
started = true;
} else if word != in_word {
tokens.push(Token {
text: Cow::Borrowed(&text[span_start..i]),
is_word: in_word,
});
span_start = i;
in_word = word;
}
}
if started {
tokens.push(Token {
text: Cow::Borrowed(&text[span_start..]),
is_word: in_word,
});
}
tokens
}
#[derive(Debug, Clone)]
pub struct Token<'a> {
pub text: Cow<'a, str>,
pub is_word: bool,
}
fn is_context_boundary(text: &str) -> bool {
text.chars().any(|c| {
matches!(c, '\n' | '\r' | '.' | '!' | '?') || matches!(c as u32, 0x061F | 0x06D4)
})
}
pub fn transliterate_context(
text: &str,
lang: Option<&str>,
dict: &ContextDict,
transliterate_fn: impl Fn(&str, Option<&str>) -> String,
) -> String {
let tokens = tokenize(text);
let mut result = String::with_capacity(text.len());
let mut prev_skeleton: Option<String> = None;
for token in &tokens {
if !token.is_word {
result.push_str(&token.text);
if is_context_boundary(&token.text) {
prev_skeleton = None;
}
continue;
}
let skeleton = strip_diacritics(&token.text, lang);
let resolved = dict.resolve(prev_skeleton.as_deref(), &skeleton);
match resolved {
Some(diacritized) => {
result.push_str(&transliterate_fn(diacritized, lang));
}
None => {
result.push_str(&transliterate_fn(&token.text, lang));
}
}
prev_skeleton = Some(skeleton);
}
result
}
pub enum DictState {
Ok(ContextDict),
Absent,
Corrupt(String),
}
static ARABIC_DICT: OnceLock<DictState> = OnceLock::new();
static PERSIAN_DICT: OnceLock<DictState> = OnceLock::new();
static HEBREW_DICT: OnceLock<DictState> = OnceLock::new();
#[cfg(feature = "embed-dicts")]
static ARABIC_DATA: &[u8] = include_bytes!("../data/arabic_dict.bin");
#[cfg(feature = "embed-dicts")]
static PERSIAN_DATA: &[u8] = include_bytes!("../data/persian_dict.bin");
#[cfg(feature = "embed-dicts")]
static HEBREW_DATA: &[u8] = include_bytes!("../data/hebrew_dict.bin");
#[cfg(feature = "embed-dicts")]
fn load_embedded_dict(name: &str, data: &'static [u8]) -> DictState {
match ContextDict::from_static(data) {
Ok(dict) => DictState::Ok(dict),
Err(e) => {
let msg = format!("disarm: failed to load embedded {name} dict: {e}");
crate::emit_warning_stderr(&msg);
DictState::Corrupt(e)
}
}
}
#[cfg(not(feature = "embed-dicts"))]
fn dict_search_paths(name: &str) -> Vec<std::path::PathBuf> {
let mut paths: Vec<std::path::PathBuf> = Vec::new();
if let Some(dir) = std::env::var_os("DISARM_DICT_DIR") {
let dir = std::path::Path::new(&dir);
if dir.is_absolute() {
paths.push(dir.join(format!("{name}_dict.bin")));
} else {
crate::emit_warning_stderr(&format!(
"disarm: ignoring relative DISARM_DICT_DIR={:?}; an absolute path is \
required (security #61: no CWD-relative dictionary loading).",
dir.display()
));
}
}
paths.push(std::path::PathBuf::from(format!(
"{}/data/{name}_dict.bin",
env!("CARGO_MANIFEST_DIR")
)));
paths
}
#[cfg(not(feature = "embed-dicts"))]
fn load_dict_from_fs(name: &str) -> DictState {
let paths = dict_search_paths(name);
for path in &paths {
if let Ok(data) = std::fs::read(path) {
match ContextDict::from_owned(data) {
Ok(dict) => return DictState::Ok(dict),
Err(e) => {
crate::emit_warning_stderr(&format!(
"disarm: failed to load {name} dict from {}: {e}",
path.display()
));
return DictState::Corrupt(format!(
"{name} dictionary at {} is corrupt: {e}",
path.display()
));
}
}
}
}
DictState::Absent
}
pub fn get_arabic_dict() -> Result<Option<&'static ContextDict>, &'static str> {
match ARABIC_DICT.get_or_init(|| {
#[cfg(feature = "embed-dicts")]
{
load_embedded_dict("arabic", ARABIC_DATA)
}
#[cfg(not(feature = "embed-dicts"))]
{
load_dict_from_fs("arabic")
}
}) {
DictState::Ok(d) => Ok(Some(d)),
DictState::Absent => Ok(None),
DictState::Corrupt(msg) => Err(msg.as_str()),
}
}
pub fn get_persian_dict() -> Result<Option<&'static ContextDict>, &'static str> {
match PERSIAN_DICT.get_or_init(|| {
#[cfg(feature = "embed-dicts")]
{
load_embedded_dict("persian", PERSIAN_DATA)
}
#[cfg(not(feature = "embed-dicts"))]
{
load_dict_from_fs("persian")
}
}) {
DictState::Ok(d) => Ok(Some(d)),
DictState::Absent => Ok(None),
DictState::Corrupt(msg) => Err(msg.as_str()),
}
}
pub fn get_hebrew_dict() -> Result<Option<&'static ContextDict>, &'static str> {
match HEBREW_DICT.get_or_init(|| {
#[cfg(feature = "embed-dicts")]
{
load_embedded_dict("hebrew", HEBREW_DATA)
}
#[cfg(not(feature = "embed-dicts"))]
{
load_dict_from_fs("hebrew")
}
}) {
DictState::Ok(d) => Ok(Some(d)),
DictState::Absent => Ok(None),
DictState::Corrupt(msg) => Err(msg.as_str()),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_strip_arabic_diacritics() {
assert_eq!(strip_arabic_diacritics("كَتَبَ"), "كتب");
assert_eq!(strip_arabic_diacritics("دَرَّسَ"), "درس");
}
#[test]
fn test_strip_hebrew_niqqud() {
assert_eq!(strip_hebrew_niqqud("שָׁלוֹם"), "שלום");
}
#[test]
fn test_strip_tatweel() {
assert_eq!(strip_arabic_diacritics("كـتـاب"), "كتاب");
}
#[test]
fn test_tokenize_arabic() {
let tokens = tokenize("كتب العربية");
assert_eq!(tokens.len(), 3); assert!(tokens[0].is_word);
assert!(!tokens[1].is_word);
assert!(tokens[2].is_word);
}
#[test]
fn test_tokenize_mixed() {
let tokens = tokenize("hello كتب world");
assert!(tokens.len() >= 3);
}
fn build_dict_bytes(
unigrams: &[(&str, &[(&str, u32)])],
bigrams: &[(&str, &str, &str)],
) -> Vec<u8> {
let mut uni = Vec::new();
for (skel, forms) in unigrams {
uni.extend_from_slice(&(skel.len() as u16).to_le_bytes());
uni.extend_from_slice(skel.as_bytes());
uni.extend_from_slice(&(forms.len() as u16).to_le_bytes());
for (form, freq) in *forms {
uni.extend_from_slice(&(form.len() as u16).to_le_bytes());
uni.extend_from_slice(form.as_bytes());
uni.extend_from_slice(&freq.to_le_bytes());
}
}
let mut bi = Vec::new();
for (prev, curr, form) in bigrams {
bi.extend_from_slice(&(prev.len() as u16).to_le_bytes());
bi.extend_from_slice(prev.as_bytes());
bi.extend_from_slice(&(curr.len() as u16).to_le_bytes());
bi.extend_from_slice(curr.as_bytes());
bi.extend_from_slice(&(form.len() as u16).to_le_bytes());
bi.extend_from_slice(form.as_bytes());
}
let unigram_offset = 24u32;
let bigram_offset = 24 + uni.len() as u32;
let mut data = Vec::new();
data.extend_from_slice(MAGIC);
data.extend_from_slice(&1u32.to_le_bytes()); data.extend_from_slice(&(unigrams.len() as u32).to_le_bytes());
data.extend_from_slice(&(bigrams.len() as u32).to_le_bytes());
data.extend_from_slice(&unigram_offset.to_le_bytes());
data.extend_from_slice(&bigram_offset.to_le_bytes());
data.extend_from_slice(&uni);
data.extend_from_slice(&bi);
data
}
#[test]
fn test_context_dict_resolve() {
let bytes = build_dict_bytes(
&[("كتب", &[("كَتَبَ", 100), ("كُتُب", 80)])],
&[("ال", "كتب", "كُتُب")],
);
let dict = ContextDict::from_bytes(&bytes).expect("valid dict should parse");
assert_eq!(dict.resolve(None, "كتب"), Some("كَتَبَ"));
assert_eq!(dict.resolve(Some("ال"), "كتب"), Some("كُتُب"));
assert_eq!(dict.resolve(None, "xyz"), None);
}
#[test]
fn test_bigram_fires_across_space() {
let bytes = build_dict_bytes(
&[("كتب", &[("كَتَبَ", 100)])], &[("ال", "كتب", "كُتُب")], );
let dict = ContextDict::from_bytes(&bytes).expect("valid dict should parse");
let out = transliterate_context("ال كتب", None, &dict, |s, _| s.to_string());
assert!(
out.contains("كُتُب"),
"space must preserve bigram context: {out}"
);
assert!(
!out.contains("كَتَبَ"),
"must not fall back to the unigram: {out}"
);
let out2 = transliterate_context("ال\nكتب", None, &dict, |s, _| s.to_string());
assert!(
out2.contains("كَتَبَ"),
"newline must reset to the unigram: {out2}"
);
}
#[test]
fn test_resolve_many_entries_binary_search() {
let bytes = build_dict_bytes(
&[
("dog", &[("DOG", 9)]),
("ant", &[("ANT", 7)]),
("cat", &[("CAT-best", 5), ("CAT-alt", 4)]),
("bee", &[("BEE", 3)]),
],
&[
("the", "dog", "the-DOG"),
("a", "cat", "a-CAT"),
("the", "ant", "the-ANT"),
("the", "cat", "the-CAT"),
],
);
let dict = ContextDict::from_bytes(&bytes).expect("valid dict should parse");
assert_eq!(dict.resolve(None, "ant"), Some("ANT"));
assert_eq!(dict.resolve(None, "bee"), Some("BEE"));
assert_eq!(dict.resolve(None, "cat"), Some("CAT-best"));
assert_eq!(dict.resolve(None, "dog"), Some("DOG"));
assert_eq!(dict.resolve(None, "zzz"), None);
assert_eq!(dict.resolve(Some("the"), "dog"), Some("the-DOG"));
assert_eq!(dict.resolve(Some("the"), "ant"), Some("the-ANT"));
assert_eq!(dict.resolve(Some("the"), "cat"), Some("the-CAT"));
assert_eq!(dict.resolve(Some("a"), "cat"), Some("a-CAT"));
assert_eq!(dict.resolve(Some("the"), "bee"), Some("BEE"));
assert_eq!(dict.resolve(Some("nope"), "cat"), Some("CAT-best"));
assert_eq!(dict.stats(), (4, 4));
}
fn build_valid_dict() -> Vec<u8> {
build_dict_bytes(&[("ab", &[("AB", 5)])], &[("ab", "cd", "X")])
}
#[test]
fn test_from_bytes_valid_roundtrip() {
let dict = ContextDict::from_bytes(&build_valid_dict()).expect("valid dict should parse");
assert_eq!(dict.resolve(None, "ab"), Some("AB"));
assert_eq!(dict.resolve(Some("ab"), "cd"), Some("X"));
}
#[test]
fn test_from_bytes_rejects_small_and_bad_magic() {
assert!(ContextDict::from_bytes(&[]).is_err());
assert!(ContextDict::from_bytes(&[0u8; 10]).is_err());
let mut bad = build_valid_dict();
bad[0] = b'X'; assert!(ContextDict::from_bytes(&bad).is_err());
}
#[test]
fn test_from_bytes_truncation_never_panics() {
let full = build_valid_dict();
for n in 0..full.len() {
let _ = ContextDict::from_bytes(&full[..n]); }
assert!(ContextDict::from_bytes(&full).is_ok());
}
#[test]
fn test_from_bytes_bogus_counts_do_not_panic() {
let mut data = Vec::new();
data.extend_from_slice(MAGIC);
data.extend_from_slice(&1u32.to_le_bytes()); data.extend_from_slice(&u32::MAX.to_le_bytes()); data.extend_from_slice(&0u32.to_le_bytes()); data.extend_from_slice(&24u32.to_le_bytes()); data.extend_from_slice(&24u32.to_le_bytes()); assert!(ContextDict::from_bytes(&data).is_err());
}
#[test]
fn test_from_bytes_offset_out_of_range() {
let mut data = build_valid_dict();
let bad_offset = (data.len() as u32 + 100).to_le_bytes();
data[16..20].copy_from_slice(&bad_offset);
assert!(ContextDict::from_bytes(&data).is_err());
}
#[test]
fn test_from_bytes_offset_inside_header_rejected() {
let mut data = build_valid_dict();
data[16..20].copy_from_slice(&8u32.to_le_bytes());
assert!(ContextDict::from_bytes(&data).is_err());
}
#[cfg(not(feature = "embed-dicts"))]
#[test]
fn test_dict_search_paths_never_cwd_relative() {
let paths = dict_search_paths("arabic");
let manifest = paths.last().expect("at least the manifest-dir candidate");
assert!(
manifest.is_absolute(),
"dev dict path must be absolute, got {manifest:?}"
);
let cwd_relative = std::path::Path::new("data/arabic_dict.bin");
assert!(
!paths.iter().any(|p| p == cwd_relative),
"must not probe the CWD-relative data/ path; got {paths:?}"
);
assert!(
paths.iter().all(|p| p.is_absolute()),
"all dict search paths must be absolute; got {paths:?}"
);
}
}