use alloc::{collections::BTreeMap, string::String, vec::Vec};
use serde::{Deserialize, Serialize};
use crate::PronunciationDict;
use crate::dictionary::entry::{DictEntry, Pronunciation, Region};
use crate::error::{Result, ShabdakoshError};
const MAGIC: [u8; 4] = *b"SHBD";
const VERSION: u8 = 1;
const HEADER_SIZE: usize = 5;
#[derive(Deserialize)]
struct BinaryDict {
entries: BTreeMap<String, Vec<BinaryPronunciation>>,
user_entries: BTreeMap<String, Vec<BinaryPronunciation>>,
language: Option<String>,
}
#[derive(Serialize, Deserialize)]
struct BinaryPronunciation {
phonemes: Vec<svara::phoneme::Phoneme>,
frequency: Option<f32>,
region: Option<Region>,
}
#[derive(Serialize)]
struct BinaryDictRef<'a> {
entries: BTreeMap<&'a str, Vec<BinaryPronunciationRef<'a>>>,
user_entries: BTreeMap<&'a str, Vec<BinaryPronunciationRef<'a>>>,
language: Option<&'a str>,
}
#[derive(Serialize)]
struct BinaryPronunciationRef<'a> {
phonemes: &'a [svara::phoneme::Phoneme],
frequency: Option<f32>,
region: Option<Region>,
}
impl<'a> BinaryDictRef<'a> {
fn from_dict(dict: &'a PronunciationDict) -> Self {
Self {
entries: dict
.entries()
.iter()
.map(|(k, v)| (k.as_str(), convert_entry_ref(v)))
.collect(),
user_entries: dict
.user_entries()
.iter()
.map(|(k, v)| (k.as_str(), convert_entry_ref(v)))
.collect(),
language: dict.language(),
}
}
}
impl BinaryDict {
fn into_dict(self) -> PronunciationDict {
let mut dict = PronunciationDict::new();
if let Some(lang) = &self.language {
dict.set_language(lang);
}
for (word, prons) in self.entries {
if let Some(entry) = to_dict_entry(prons) {
dict.insert_entry(&word, entry);
}
}
for (word, prons) in self.user_entries {
if let Some(entry) = to_dict_entry(prons) {
dict.insert_user_entry(&word, entry);
}
}
dict
}
}
fn convert_entry_ref<'a>(entry: &'a DictEntry) -> Vec<BinaryPronunciationRef<'a>> {
entry
.all()
.iter()
.map(|p| BinaryPronunciationRef {
phonemes: p.phonemes(),
frequency: p.frequency(),
region: p.region(),
})
.collect()
}
fn to_dict_entry(prons: Vec<BinaryPronunciation>) -> Option<DictEntry> {
let pronunciations: Vec<Pronunciation> = prons
.into_iter()
.map(|bp| {
let mut p = Pronunciation::new(bp.phonemes);
if let Some(f) = bp.frequency {
p = p.with_frequency(f);
}
if let Some(r) = bp.region {
p = p.with_region(r);
}
p
})
.collect();
DictEntry::from_pronunciations(pronunciations)
}
#[must_use = "serialization result should be used"]
pub fn to_binary(dict: &PronunciationDict) -> Result<Vec<u8>> {
let intermediate = BinaryDictRef::from_dict(dict);
let payload = postcard::to_allocvec(&intermediate).map_err(|e| {
ShabdakoshError::DictParseError(alloc::format!("binary serialize error: {e}"))
})?;
let mut out = Vec::with_capacity(HEADER_SIZE + payload.len());
out.extend_from_slice(&MAGIC);
out.push(VERSION);
out.extend_from_slice(&payload);
Ok(out)
}
#[must_use = "deserialization result should be used"]
pub fn from_binary(data: &[u8]) -> Result<PronunciationDict> {
if data.len() < HEADER_SIZE {
return Err(ShabdakoshError::DictParseError(
"binary data too short for header".into(),
));
}
if data[..4] != MAGIC {
return Err(ShabdakoshError::DictParseError(
"invalid magic number: not a shabdakosh binary dictionary".into(),
));
}
if data[4] != VERSION {
return Err(ShabdakoshError::DictParseError(alloc::format!(
"unsupported binary format version: {} (expected {VERSION})",
data[4]
)));
}
let intermediate: BinaryDict = postcard::from_bytes(&data[HEADER_SIZE..]).map_err(|e| {
ShabdakoshError::DictParseError(alloc::format!("binary deserialize error: {e}"))
})?;
Ok(intermediate.into_dict())
}
#[cfg(feature = "std")]
pub fn save_binary_file(dict: &PronunciationDict, path: &std::path::Path) -> Result<()> {
let data = to_binary(dict)?;
std::fs::write(path, data).map_err(|e| {
ShabdakoshError::DictParseError(alloc::format!("failed to write binary file: {e}"))
})
}
#[cfg(feature = "std")]
pub fn load_binary_file(path: &std::path::Path) -> Result<PronunciationDict> {
let data = std::fs::read(path).map_err(|e| {
ShabdakoshError::DictParseError(alloc::format!("failed to read binary file: {e}"))
})?;
from_binary(&data)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_binary_roundtrip_minimal() {
let dict = PronunciationDict::english_minimal();
let bytes = to_binary(&dict).unwrap();
let dict2 = from_binary(&bytes).unwrap();
assert_eq!(dict.len(), dict2.len());
assert_eq!(dict.language(), dict2.language());
assert_eq!(dict.lookup("hello"), dict2.lookup("hello"));
assert_eq!(dict.lookup("the"), dict2.lookup("the"));
}
#[test]
fn test_binary_roundtrip_with_user_overlay() {
let mut dict = PronunciationDict::english_minimal();
dict.insert_user("custom", &[svara::phoneme::Phoneme::VowelA]);
let bytes = to_binary(&dict).unwrap();
let dict2 = from_binary(&bytes).unwrap();
assert_eq!(dict2.user_len(), 1);
assert_eq!(dict2.lookup("custom"), dict.lookup("custom"));
}
#[test]
fn test_binary_roundtrip_empty() {
let dict = PronunciationDict::new();
let bytes = to_binary(&dict).unwrap();
let dict2 = from_binary(&bytes).unwrap();
assert!(dict2.is_empty());
}
#[test]
fn test_binary_has_header() {
let dict = PronunciationDict::new();
let bytes = to_binary(&dict).unwrap();
assert!(bytes.len() >= HEADER_SIZE);
assert_eq!(&bytes[..4], b"SHBD");
assert_eq!(bytes[4], VERSION);
}
#[test]
fn test_binary_reject_short_data() {
let result = from_binary(&[0, 1, 2]);
assert!(result.is_err());
}
#[test]
fn test_binary_reject_bad_magic() {
let result = from_binary(&[0, 0, 0, 0, 1]);
assert!(result.is_err());
}
#[test]
fn test_binary_reject_bad_version() {
let mut bytes = to_binary(&PronunciationDict::new()).unwrap();
bytes[4] = 99; let result = from_binary(&bytes);
assert!(result.is_err());
}
#[test]
fn test_binary_smaller_than_json() {
let dict = PronunciationDict::english_minimal();
let binary = to_binary(&dict).unwrap();
let json = serde_json::to_string(&dict).unwrap();
assert!(
binary.len() < json.len(),
"binary ({}) should be smaller than JSON ({})",
binary.len(),
json.len()
);
}
#[cfg(feature = "std")]
#[test]
fn test_binary_file_roundtrip() {
let dict = PronunciationDict::english_minimal();
let tmp = std::env::temp_dir().join("shabdakosh_test_binary.bin");
save_binary_file(&dict, &tmp).unwrap();
let dict2 = load_binary_file(&tmp).unwrap();
assert_eq!(dict.len(), dict2.len());
assert_eq!(dict.lookup("hello"), dict2.lookup("hello"));
let _ = std::fs::remove_file(&tmp);
}
}