use std::path::{Path, PathBuf};
use crate::error::PiperError;
#[cfg(feature = "dict-download")]
const DICTIONARY_URL: &str =
"https://github.com/r9y9/open_jtalk/releases/download/v1.11.1/open_jtalk_dic_utf_8-1.11.tar.gz";
const DICTIONARY_DIR_NAME: &str = "open_jtalk_dic_utf_8-1.11";
#[cfg(feature = "dict-download")]
const DICTIONARY_SHA256: &str = "fe6ba0e43542cef98339abdffd903e062008ea170b04e7e2a35da805902f382a";
#[cfg(feature = "dict-download")]
const SENTINEL_FILE: &str = ".piper_dict_ok";
pub fn find_dictionary() -> Option<PathBuf> {
if let Ok(path) = std::env::var("OPENJTALK_DICTIONARY_PATH") {
let p = PathBuf::from(&path);
if is_valid_dictionary(&p) {
return Some(p);
}
}
if let Some(p) = exe_relative_dict_path()
&& is_valid_dictionary(&p)
{
return Some(p);
}
for p in system_dict_paths() {
if is_valid_dictionary(&p) {
return Some(p);
}
}
let data_dict = get_data_dir().join(DICTIONARY_DIR_NAME);
if is_valid_dictionary(&data_dict) {
return Some(data_dict);
}
None
}
pub fn ensure_dictionary() -> Result<PathBuf, PiperError> {
if let Some(p) = find_dictionary() {
return Ok(p);
}
if is_offline_mode() {
return Err(PiperError::DictionaryLoad {
path: "OpenJTalk dictionary not found and PIPER_OFFLINE_MODE=1 is set".to_string(),
});
}
if !is_auto_download_enabled() {
return Err(PiperError::DictionaryLoad {
path: "OpenJTalk dictionary not found and PIPER_AUTO_DOWNLOAD_DICT=0 is set. \
Set OPENJTALK_DICTIONARY_PATH or enable auto-download"
.to_string(),
});
}
download_and_extract()
}
fn get_data_dir() -> PathBuf {
if let Ok(dir) = std::env::var("OPENJTALK_DATA_DIR") {
return PathBuf::from(dir);
}
#[cfg(target_os = "windows")]
{
if let Ok(appdata) = std::env::var("APPDATA") {
return PathBuf::from(appdata).join("piper");
}
PathBuf::from(".").join("data")
}
#[cfg(not(target_os = "windows"))]
{
if let Ok(xdg) = std::env::var("XDG_DATA_HOME") {
return PathBuf::from(xdg).join("piper");
}
if let Ok(home) = std::env::var("HOME") {
return PathBuf::from(home)
.join(".local")
.join("share")
.join("piper");
}
PathBuf::from("/tmp/piper")
}
}
fn exe_relative_dict_path() -> Option<PathBuf> {
std::env::current_exe().ok().and_then(|exe| {
exe.parent()
.and_then(|dir| dir.parent())
.map(|prefix| prefix.join("share").join("open_jtalk").join("dic"))
})
}
fn system_dict_paths() -> Vec<PathBuf> {
#[cfg(target_os = "windows")]
{
vec![
PathBuf::from(r"C:\Program Files\open_jtalk\dic"),
PathBuf::from(r"C:\Program Files (x86)\open_jtalk\dic"),
]
}
#[cfg(not(target_os = "windows"))]
{
vec![
PathBuf::from("/usr/share/open_jtalk/dic"),
PathBuf::from("/usr/local/share/open_jtalk/dic"),
PathBuf::from("/opt/open_jtalk/dic"),
]
}
}
fn is_valid_dictionary(path: &Path) -> bool {
if !path.is_dir() {
return false;
}
if let Ok(entries) = std::fs::read_dir(path) {
for entry in entries.flatten() {
if let Some(ext) = entry.path().extension()
&& (ext == "bin" || ext == "dic")
{
return true;
}
}
}
false
}
fn is_offline_mode() -> bool {
std::env::var("PIPER_OFFLINE_MODE")
.map(|v| v == "1")
.unwrap_or(false)
}
fn is_auto_download_enabled() -> bool {
std::env::var("PIPER_AUTO_DOWNLOAD_DICT")
.map(|v| v != "0")
.unwrap_or(true)
}
#[cfg(feature = "dict-download")]
fn download_and_extract() -> Result<PathBuf, PiperError> {
let data_dir = get_data_dir();
let dict_dir = data_dir.join(DICTIONARY_DIR_NAME);
let archive_path = data_dir.join("open_jtalk_dic_utf_8-1.11.tar.gz");
std::fs::create_dir_all(&data_dir).map_err(|e| PiperError::DictionaryLoad {
path: format!(
"failed to create data directory {}: {e}",
data_dir.display()
),
})?;
if is_valid_dictionary(&dict_dir) && dict_dir.join(SENTINEL_FILE).exists() {
return Ok(dict_dir);
}
eprintln!(
"[piper] Downloading OpenJTalk dictionary from {}",
DICTIONARY_URL
);
download_archive(&archive_path)?;
eprintln!("[piper] Verifying SHA-256 checksum...");
verify_sha256(&archive_path)?;
eprintln!("[piper] Extracting dictionary to {}...", data_dir.display());
extract_tar_gz(&archive_path, &data_dir)?;
if dict_dir.is_dir() {
let _ = std::fs::write(dict_dir.join(SENTINEL_FILE), "ok");
}
if archive_path.exists() {
let _ = std::fs::remove_file(&archive_path);
}
if is_valid_dictionary(&dict_dir) {
eprintln!("[piper] Dictionary ready: {}", dict_dir.display());
Ok(dict_dir)
} else {
Err(PiperError::DictionaryLoad {
path: format!(
"extraction succeeded but dictionary not found at {}",
dict_dir.display()
),
})
}
}
#[cfg(feature = "dict-download")]
fn download_archive(dest: &Path) -> Result<(), PiperError> {
use std::io::{Read as _, Write};
let client = reqwest::blocking::Client::builder()
.connect_timeout(std::time::Duration::from_secs(30))
.timeout(std::time::Duration::from_secs(600))
.build()
.map_err(|e| PiperError::Download(format!("HTTP client error: {e}")))?;
let mut response = client
.get(DICTIONARY_URL)
.send()
.map_err(|e| PiperError::Download(format!("dictionary download failed: {e}")))?;
if !response.status().is_success() {
return Err(PiperError::Download(format!(
"HTTP {} downloading dictionary from {}",
response.status(),
DICTIONARY_URL
)));
}
let total_bytes = response.content_length();
let mut bytes_downloaded: u64 = 0;
let mut last_pct: u64 = 0;
let file = std::fs::File::create(dest).map_err(|e| PiperError::DictionaryLoad {
path: format!("failed to create {}: {e}", dest.display()),
})?;
let mut writer = std::io::BufWriter::with_capacity(256 * 1024, file);
let mut buf = [0u8; 64 * 1024];
loop {
let n = response
.read(&mut buf)
.map_err(|e| PiperError::Download(format!("read error: {e}")))?;
if n == 0 {
break;
}
writer
.write_all(&buf[..n])
.map_err(|e| PiperError::DictionaryLoad {
path: format!("write error: {e}"),
})?;
bytes_downloaded += n as u64;
if let Some(total) = total_bytes
&& total > 0
{
let pct = (bytes_downloaded * 100) / total;
if pct >= last_pct + 10 {
eprintln!(
"[piper] Downloaded {:.1} / {:.1} MB ({}%)",
bytes_downloaded as f64 / 1_048_576.0,
total as f64 / 1_048_576.0,
pct
);
last_pct = pct;
}
}
}
writer.flush().map_err(|e| PiperError::DictionaryLoad {
path: format!("flush error: {e}"),
})?;
eprintln!(
"[piper] Download complete ({:.1} MB)",
bytes_downloaded as f64 / 1_048_576.0
);
Ok(())
}
#[cfg(feature = "dict-download")]
fn verify_sha256(path: &Path) -> Result<(), PiperError> {
use sha2::{Digest, Sha256};
use std::io::Read as _;
let mut file = std::fs::File::open(path).map_err(|e| PiperError::DictionaryLoad {
path: format!("failed to open {}: {e}", path.display()),
})?;
let mut hasher = Sha256::new();
let mut buf = [0u8; 64 * 1024];
loop {
let n = file
.read(&mut buf)
.map_err(|e| PiperError::DictionaryLoad {
path: format!("read error during hash: {e}"),
})?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
let hash = format!("{:x}", hasher.finalize());
if hash != DICTIONARY_SHA256 {
let _ = std::fs::remove_file(path);
return Err(PiperError::DictionaryLoad {
path: format!(
"SHA-256 mismatch for {}: expected {}, got {}",
path.display(),
DICTIONARY_SHA256,
hash
),
});
}
Ok(())
}
#[cfg(feature = "dict-download")]
fn extract_tar_gz(archive_path: &Path, dest_dir: &Path) -> Result<(), PiperError> {
use flate2::read::GzDecoder;
use tar::Archive;
let file = std::fs::File::open(archive_path).map_err(|e| PiperError::DictionaryLoad {
path: format!("failed to open archive {}: {e}", archive_path.display()),
})?;
let decoder = GzDecoder::new(file);
let mut archive = Archive::new(decoder);
archive
.unpack(dest_dir)
.map_err(|e| PiperError::DictionaryLoad {
path: format!(
"failed to extract {} to {}: {e}",
archive_path.display(),
dest_dir.display()
),
})?;
Ok(())
}
#[cfg(not(feature = "dict-download"))]
fn download_and_extract() -> Result<PathBuf, PiperError> {
Err(PiperError::DictionaryLoad {
path: "OpenJTalk dictionary not found. Auto-download requires the \
\"dict-download\" feature; rebuild with `--features dict-download` \
or set OPENJTALK_DICTIONARY_PATH"
.to_string(),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_valid_dictionary_nonexistent() {
assert!(!is_valid_dictionary(Path::new("/nonexistent/path/12345")));
}
#[test]
fn test_is_valid_dictionary_empty_dir() {
let dir = tempfile::tempdir().unwrap();
assert!(!is_valid_dictionary(dir.path()));
}
#[test]
fn test_is_valid_dictionary_with_dic_file() {
let dir = tempfile::tempdir().unwrap();
std::fs::write(dir.path().join("sys.dic"), b"fake").unwrap();
assert!(is_valid_dictionary(dir.path()));
}
#[test]
fn test_is_valid_dictionary_with_bin_extension() {
let dir = tempfile::tempdir().unwrap();
std::fs::write(dir.path().join("matrix.bin"), b"fake").unwrap();
assert!(is_valid_dictionary(dir.path()));
}
#[test]
fn test_is_valid_dictionary_ignores_txt_files() {
let dir = tempfile::tempdir().unwrap();
std::fs::write(dir.path().join("readme.txt"), b"hello").unwrap();
assert!(!is_valid_dictionary(dir.path()));
}
#[test]
fn test_system_dict_paths_not_empty() {
let paths = system_dict_paths();
assert!(!paths.is_empty());
for p in &paths {
assert!(p.is_absolute(), "system path should be absolute: {p:?}");
}
}
#[test]
fn test_exe_relative_dict_path_returns_some() {
let result = exe_relative_dict_path();
assert!(result.is_some());
let p = result.unwrap();
assert!(p.ends_with("dic"));
}
#[test]
fn test_constants_dir_name() {
assert_eq!(DICTIONARY_DIR_NAME, "open_jtalk_dic_utf_8-1.11");
}
#[cfg(feature = "dict-download")]
#[test]
fn test_constants_download() {
assert!(DICTIONARY_URL.starts_with("https://"));
assert!(DICTIONARY_URL.ends_with(".tar.gz"));
assert!(DICTIONARY_URL.contains("open_jtalk_dic_utf_8"));
assert_eq!(DICTIONARY_SHA256.len(), 64); assert!(DICTIONARY_SHA256.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn test_get_data_dir_returns_non_empty() {
let dir = get_data_dir();
assert!(!dir.as_os_str().is_empty());
}
#[test]
fn test_find_dictionary_returns_valid_or_none() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
if let Some(p) = find_dictionary() {
assert!(
is_valid_dictionary(&p),
"find_dictionary returned invalid path: {p:?}"
);
}
}
#[cfg(feature = "dict-download")]
#[test]
fn test_verify_sha256_bad_hash() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test_archive.tar.gz");
std::fs::write(&path, b"not a real archive").unwrap();
let result = verify_sha256(&path);
assert!(result.is_err());
let err = format!("{}", result.unwrap_err());
assert!(err.contains("SHA-256 mismatch"));
assert!(!path.exists());
}
#[cfg(feature = "dict-download")]
#[test]
fn test_verify_sha256_missing_file() {
let result = verify_sha256(Path::new("/nonexistent/file.tar.gz"));
assert!(result.is_err());
}
#[cfg(feature = "dict-download")]
#[test]
fn test_verify_sha256_known_hash() {
use sha2::{Digest, Sha256};
let data = b"hello world";
let expected = format!("{:x}", Sha256::digest(data));
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("known_hash_test.bin");
std::fs::write(&path, data).unwrap();
let result = verify_sha256(&path);
assert!(result.is_err());
let err = format!("{}", result.unwrap_err());
assert!(
err.contains(&expected),
"error should contain actual hash: {err}"
);
}
#[cfg(feature = "dict-download")]
#[test]
fn test_extract_tar_gz_valid() {
use flate2::Compression;
use flate2::write::GzEncoder;
use std::io::Write;
let dir = tempfile::tempdir().unwrap();
let archive_path = dir.path().join("test.tar.gz");
{
let file = std::fs::File::create(&archive_path).unwrap();
let encoder = GzEncoder::new(file, Compression::default());
let mut builder = tar::Builder::new(encoder);
let data = b"test dictionary content";
let mut header = tar::Header::new_gnu();
header.set_size(data.len() as u64);
header.set_mode(0o644);
header.set_cksum();
builder
.append_data(&mut header, "test_dict/sys.dic", &data[..])
.unwrap();
let mut gz = builder.into_inner().unwrap();
gz.flush().unwrap();
gz.finish().unwrap();
}
let extract_dir = dir.path().join("extracted");
std::fs::create_dir_all(&extract_dir).unwrap();
let result = extract_tar_gz(&archive_path, &extract_dir);
assert!(result.is_ok(), "extraction failed: {result:?}");
let extracted_file = extract_dir.join("test_dict").join("sys.dic");
assert!(extracted_file.exists(), "extracted file should exist");
let content = std::fs::read(&extracted_file).unwrap();
assert_eq!(content, b"test dictionary content");
}
#[cfg(feature = "dict-download")]
#[test]
fn test_extract_tar_gz_invalid_archive() {
let dir = tempfile::tempdir().unwrap();
let archive_path = dir.path().join("bad.tar.gz");
std::fs::write(&archive_path, b"not a tar.gz file").unwrap();
let result = extract_tar_gz(&archive_path, dir.path());
assert!(result.is_err());
}
#[test]
fn test_download_and_extract_stub() {
let result = ensure_dictionary();
let _ = result;
}
use std::sync::Mutex;
static ENV_MUTEX: Mutex<()> = Mutex::new(());
#[test]
fn test_find_dictionary_env_var_valid() {
let _lock = ENV_MUTEX.lock().unwrap();
let dir = tempfile::tempdir().unwrap();
std::fs::write(dir.path().join("sys.dic"), b"test").unwrap();
unsafe {
std::env::set_var("OPENJTALK_DICTIONARY_PATH", dir.path());
}
let result = find_dictionary();
unsafe {
std::env::remove_var("OPENJTALK_DICTIONARY_PATH");
}
assert_eq!(result, Some(dir.path().to_path_buf()));
}
#[test]
fn test_find_dictionary_env_var_invalid_skipped() {
let _lock = ENV_MUTEX.lock().unwrap();
unsafe {
std::env::set_var("OPENJTALK_DICTIONARY_PATH", "/nonexistent/path/dict");
}
let result = find_dictionary();
unsafe {
std::env::remove_var("OPENJTALK_DICTIONARY_PATH");
}
assert_ne!(
result,
Some(std::path::PathBuf::from("/nonexistent/path/dict"))
);
}
#[test]
fn test_offline_mode_enabled() {
let _lock = ENV_MUTEX.lock().unwrap();
unsafe {
std::env::set_var("PIPER_OFFLINE_MODE", "1");
}
assert!(is_offline_mode());
unsafe {
std::env::remove_var("PIPER_OFFLINE_MODE");
}
}
#[test]
fn test_offline_mode_disabled_by_default() {
let _lock = ENV_MUTEX.lock().unwrap();
unsafe {
std::env::remove_var("PIPER_OFFLINE_MODE");
}
assert!(!is_offline_mode());
}
#[test]
fn test_offline_mode_other_values_not_offline() {
let _lock = ENV_MUTEX.lock().unwrap();
unsafe {
std::env::set_var("PIPER_OFFLINE_MODE", "0");
}
assert!(!is_offline_mode());
unsafe {
std::env::set_var("PIPER_OFFLINE_MODE", "true");
}
assert!(!is_offline_mode());
unsafe {
std::env::remove_var("PIPER_OFFLINE_MODE");
}
}
#[test]
fn test_auto_download_enabled_by_default() {
let _lock = ENV_MUTEX.lock().unwrap();
unsafe {
std::env::remove_var("PIPER_AUTO_DOWNLOAD_DICT");
}
assert!(is_auto_download_enabled());
}
#[test]
fn test_auto_download_disabled() {
let _lock = ENV_MUTEX.lock().unwrap();
unsafe {
std::env::set_var("PIPER_AUTO_DOWNLOAD_DICT", "0");
}
assert!(!is_auto_download_enabled());
unsafe {
std::env::remove_var("PIPER_AUTO_DOWNLOAD_DICT");
}
}
#[test]
fn test_auto_download_other_values_enabled() {
let _lock = ENV_MUTEX.lock().unwrap();
unsafe {
std::env::set_var("PIPER_AUTO_DOWNLOAD_DICT", "1");
}
assert!(is_auto_download_enabled());
unsafe {
std::env::set_var("PIPER_AUTO_DOWNLOAD_DICT", "false");
}
assert!(is_auto_download_enabled());
unsafe {
std::env::remove_var("PIPER_AUTO_DOWNLOAD_DICT");
}
}
#[test]
fn test_get_data_dir_env_override() {
let _lock = ENV_MUTEX.lock().unwrap();
let dir = tempfile::tempdir().unwrap();
unsafe {
std::env::set_var("OPENJTALK_DATA_DIR", dir.path());
}
let result = get_data_dir();
unsafe {
std::env::remove_var("OPENJTALK_DATA_DIR");
}
assert_eq!(result, dir.path().to_path_buf());
}
}