1use std::path::{Path, PathBuf};
26
27use crate::error::PiperError;
28
29#[cfg(feature = "dict-download")]
35const DICTIONARY_URL: &str =
36 "https://github.com/r9y9/open_jtalk/releases/download/v1.11.1/open_jtalk_dic_utf_8-1.11.tar.gz";
37
38const DICTIONARY_DIR_NAME: &str = "open_jtalk_dic_utf_8-1.11";
40
41#[cfg(feature = "dict-download")]
43const DICTIONARY_SHA256: &str = "fe6ba0e43542cef98339abdffd903e062008ea170b04e7e2a35da805902f382a";
44
45#[cfg(feature = "dict-download")]
49const SENTINEL_FILE: &str = ".piper_dict_ok";
50
51pub fn find_dictionary() -> Option<PathBuf> {
59 if let Ok(path) = std::env::var("OPENJTALK_DICTIONARY_PATH") {
61 let p = PathBuf::from(&path);
62 if is_valid_dictionary(&p) {
63 return Some(p);
64 }
65 }
66
67 if let Some(p) = exe_relative_dict_path()
69 && is_valid_dictionary(&p)
70 {
71 return Some(p);
72 }
73
74 for p in system_dict_paths() {
76 if is_valid_dictionary(&p) {
77 return Some(p);
78 }
79 }
80
81 let data_dict = get_data_dir().join(DICTIONARY_DIR_NAME);
83 if is_valid_dictionary(&data_dict) {
84 return Some(data_dict);
85 }
86
87 None
88}
89
90pub fn ensure_dictionary() -> Result<PathBuf, PiperError> {
102 if let Some(p) = find_dictionary() {
104 return Ok(p);
105 }
106
107 if is_offline_mode() {
109 return Err(PiperError::DictionaryLoad {
110 path: "OpenJTalk dictionary not found and PIPER_OFFLINE_MODE=1 is set".to_string(),
111 });
112 }
113
114 if !is_auto_download_enabled() {
115 return Err(PiperError::DictionaryLoad {
116 path: "OpenJTalk dictionary not found and PIPER_AUTO_DOWNLOAD_DICT=0 is set. \
117 Set OPENJTALK_DICTIONARY_PATH or enable auto-download"
118 .to_string(),
119 });
120 }
121
122 download_and_extract()
124}
125
126fn get_data_dir() -> PathBuf {
137 if let Ok(dir) = std::env::var("OPENJTALK_DATA_DIR") {
139 return PathBuf::from(dir);
140 }
141
142 #[cfg(target_os = "windows")]
144 {
145 if let Ok(appdata) = std::env::var("APPDATA") {
146 return PathBuf::from(appdata).join("piper");
147 }
148 PathBuf::from(".").join("data")
150 }
151
152 #[cfg(not(target_os = "windows"))]
153 {
154 if let Ok(xdg) = std::env::var("XDG_DATA_HOME") {
156 return PathBuf::from(xdg).join("piper");
157 }
158 if let Ok(home) = std::env::var("HOME") {
160 return PathBuf::from(home)
161 .join(".local")
162 .join("share")
163 .join("piper");
164 }
165 PathBuf::from("/tmp/piper")
167 }
168}
169
170fn exe_relative_dict_path() -> Option<PathBuf> {
176 std::env::current_exe().ok().and_then(|exe| {
177 exe.parent()
178 .and_then(|dir| dir.parent())
179 .map(|prefix| prefix.join("share").join("open_jtalk").join("dic"))
180 })
181}
182
183fn system_dict_paths() -> Vec<PathBuf> {
185 #[cfg(target_os = "windows")]
186 {
187 vec![
188 PathBuf::from(r"C:\Program Files\open_jtalk\dic"),
189 PathBuf::from(r"C:\Program Files (x86)\open_jtalk\dic"),
190 ]
191 }
192
193 #[cfg(not(target_os = "windows"))]
194 {
195 vec![
196 PathBuf::from("/usr/share/open_jtalk/dic"),
197 PathBuf::from("/usr/local/share/open_jtalk/dic"),
198 PathBuf::from("/opt/open_jtalk/dic"),
199 ]
200 }
201}
202
203fn is_valid_dictionary(path: &Path) -> bool {
208 if !path.is_dir() {
209 return false;
210 }
211 if let Ok(entries) = std::fs::read_dir(path) {
213 for entry in entries.flatten() {
214 if let Some(ext) = entry.path().extension()
215 && (ext == "bin" || ext == "dic")
216 {
217 return true;
218 }
219 }
220 }
221 false
222}
223
224fn is_offline_mode() -> bool {
230 std::env::var("PIPER_OFFLINE_MODE")
231 .map(|v| v == "1")
232 .unwrap_or(false)
233}
234
235fn is_auto_download_enabled() -> bool {
239 std::env::var("PIPER_AUTO_DOWNLOAD_DICT")
240 .map(|v| v != "0")
241 .unwrap_or(true)
242}
243
244#[cfg(feature = "dict-download")]
252fn download_and_extract() -> Result<PathBuf, PiperError> {
253 let data_dir = get_data_dir();
254 let dict_dir = data_dir.join(DICTIONARY_DIR_NAME);
255 let archive_path = data_dir.join("open_jtalk_dic_utf_8-1.11.tar.gz");
256
257 std::fs::create_dir_all(&data_dir).map_err(|e| PiperError::DictionaryLoad {
259 path: format!(
260 "failed to create data directory {}: {e}",
261 data_dir.display()
262 ),
263 })?;
264
265 if is_valid_dictionary(&dict_dir) && dict_dir.join(SENTINEL_FILE).exists() {
267 return Ok(dict_dir);
268 }
269
270 eprintln!(
271 "[piper] Downloading OpenJTalk dictionary from {}",
272 DICTIONARY_URL
273 );
274
275 download_archive(&archive_path)?;
277
278 eprintln!("[piper] Verifying SHA-256 checksum...");
280 verify_sha256(&archive_path)?;
281
282 eprintln!("[piper] Extracting dictionary to {}...", data_dir.display());
284 extract_tar_gz(&archive_path, &data_dir)?;
285
286 if dict_dir.is_dir() {
288 let _ = std::fs::write(dict_dir.join(SENTINEL_FILE), "ok");
289 }
290
291 if archive_path.exists() {
293 let _ = std::fs::remove_file(&archive_path);
294 }
295
296 if is_valid_dictionary(&dict_dir) {
297 eprintln!("[piper] Dictionary ready: {}", dict_dir.display());
298 Ok(dict_dir)
299 } else {
300 Err(PiperError::DictionaryLoad {
301 path: format!(
302 "extraction succeeded but dictionary not found at {}",
303 dict_dir.display()
304 ),
305 })
306 }
307}
308
309#[cfg(feature = "dict-download")]
311fn download_archive(dest: &Path) -> Result<(), PiperError> {
312 use std::io::{Read as _, Write};
313
314 let client = reqwest::blocking::Client::builder()
315 .connect_timeout(std::time::Duration::from_secs(30))
316 .timeout(std::time::Duration::from_secs(600))
317 .build()
318 .map_err(|e| PiperError::Download(format!("HTTP client error: {e}")))?;
319
320 let mut response = client
321 .get(DICTIONARY_URL)
322 .send()
323 .map_err(|e| PiperError::Download(format!("dictionary download failed: {e}")))?;
324
325 if !response.status().is_success() {
326 return Err(PiperError::Download(format!(
327 "HTTP {} downloading dictionary from {}",
328 response.status(),
329 DICTIONARY_URL
330 )));
331 }
332
333 let total_bytes = response.content_length();
334 let mut bytes_downloaded: u64 = 0;
335 let mut last_pct: u64 = 0;
336
337 let file = std::fs::File::create(dest).map_err(|e| PiperError::DictionaryLoad {
338 path: format!("failed to create {}: {e}", dest.display()),
339 })?;
340 let mut writer = std::io::BufWriter::with_capacity(256 * 1024, file);
341 let mut buf = [0u8; 64 * 1024];
342
343 loop {
344 let n = response
345 .read(&mut buf)
346 .map_err(|e| PiperError::Download(format!("read error: {e}")))?;
347 if n == 0 {
348 break;
349 }
350 writer
351 .write_all(&buf[..n])
352 .map_err(|e| PiperError::DictionaryLoad {
353 path: format!("write error: {e}"),
354 })?;
355 bytes_downloaded += n as u64;
356
357 if let Some(total) = total_bytes
359 && total > 0
360 {
361 let pct = (bytes_downloaded * 100) / total;
362 if pct >= last_pct + 10 {
363 eprintln!(
364 "[piper] Downloaded {:.1} / {:.1} MB ({}%)",
365 bytes_downloaded as f64 / 1_048_576.0,
366 total as f64 / 1_048_576.0,
367 pct
368 );
369 last_pct = pct;
370 }
371 }
372 }
373
374 writer.flush().map_err(|e| PiperError::DictionaryLoad {
375 path: format!("flush error: {e}"),
376 })?;
377
378 eprintln!(
379 "[piper] Download complete ({:.1} MB)",
380 bytes_downloaded as f64 / 1_048_576.0
381 );
382
383 Ok(())
384}
385
386#[cfg(feature = "dict-download")]
388fn verify_sha256(path: &Path) -> Result<(), PiperError> {
389 use sha2::{Digest, Sha256};
390 use std::io::Read as _;
391
392 let mut file = std::fs::File::open(path).map_err(|e| PiperError::DictionaryLoad {
393 path: format!("failed to open {}: {e}", path.display()),
394 })?;
395
396 let mut hasher = Sha256::new();
397 let mut buf = [0u8; 64 * 1024];
398 loop {
399 let n = file
400 .read(&mut buf)
401 .map_err(|e| PiperError::DictionaryLoad {
402 path: format!("read error during hash: {e}"),
403 })?;
404 if n == 0 {
405 break;
406 }
407 hasher.update(&buf[..n]);
408 }
409
410 let hash = format!("{:x}", hasher.finalize());
411
412 if hash != DICTIONARY_SHA256 {
413 let _ = std::fs::remove_file(path);
415 return Err(PiperError::DictionaryLoad {
416 path: format!(
417 "SHA-256 mismatch for {}: expected {}, got {}",
418 path.display(),
419 DICTIONARY_SHA256,
420 hash
421 ),
422 });
423 }
424
425 Ok(())
426}
427
428#[cfg(feature = "dict-download")]
430fn extract_tar_gz(archive_path: &Path, dest_dir: &Path) -> Result<(), PiperError> {
431 use flate2::read::GzDecoder;
432 use tar::Archive;
433
434 let file = std::fs::File::open(archive_path).map_err(|e| PiperError::DictionaryLoad {
435 path: format!("failed to open archive {}: {e}", archive_path.display()),
436 })?;
437
438 let decoder = GzDecoder::new(file);
439 let mut archive = Archive::new(decoder);
440
441 archive
442 .unpack(dest_dir)
443 .map_err(|e| PiperError::DictionaryLoad {
444 path: format!(
445 "failed to extract {} to {}: {e}",
446 archive_path.display(),
447 dest_dir.display()
448 ),
449 })?;
450
451 Ok(())
452}
453
454#[cfg(not(feature = "dict-download"))]
456fn download_and_extract() -> Result<PathBuf, PiperError> {
457 Err(PiperError::DictionaryLoad {
458 path: "OpenJTalk dictionary not found. Auto-download requires the \
459 \"dict-download\" feature; rebuild with `--features dict-download` \
460 or set OPENJTALK_DICTIONARY_PATH"
461 .to_string(),
462 })
463}
464
465#[cfg(test)]
470mod tests {
471 use super::*;
472
473 #[test]
478 fn test_is_valid_dictionary_nonexistent() {
479 assert!(!is_valid_dictionary(Path::new("/nonexistent/path/12345")));
480 }
481
482 #[test]
483 fn test_is_valid_dictionary_empty_dir() {
484 let dir = tempfile::tempdir().unwrap();
485 assert!(!is_valid_dictionary(dir.path()));
486 }
487
488 #[test]
489 fn test_is_valid_dictionary_with_dic_file() {
490 let dir = tempfile::tempdir().unwrap();
491 std::fs::write(dir.path().join("sys.dic"), b"fake").unwrap();
492 assert!(is_valid_dictionary(dir.path()));
493 }
494
495 #[test]
496 fn test_is_valid_dictionary_with_bin_extension() {
497 let dir = tempfile::tempdir().unwrap();
498 std::fs::write(dir.path().join("matrix.bin"), b"fake").unwrap();
499 assert!(is_valid_dictionary(dir.path()));
500 }
501
502 #[test]
503 fn test_is_valid_dictionary_ignores_txt_files() {
504 let dir = tempfile::tempdir().unwrap();
505 std::fs::write(dir.path().join("readme.txt"), b"hello").unwrap();
506 assert!(!is_valid_dictionary(dir.path()));
507 }
508
509 #[test]
510 fn test_system_dict_paths_not_empty() {
511 let paths = system_dict_paths();
512 assert!(!paths.is_empty());
513 for p in &paths {
515 assert!(p.is_absolute(), "system path should be absolute: {p:?}");
516 }
517 }
518
519 #[test]
520 fn test_exe_relative_dict_path_returns_some() {
521 let result = exe_relative_dict_path();
522 assert!(result.is_some());
523 let p = result.unwrap();
524 assert!(p.ends_with("dic"));
525 }
526
527 #[test]
528 fn test_constants_dir_name() {
529 assert_eq!(DICTIONARY_DIR_NAME, "open_jtalk_dic_utf_8-1.11");
530 }
531
532 #[cfg(feature = "dict-download")]
533 #[test]
534 fn test_constants_download() {
535 assert!(DICTIONARY_URL.starts_with("https://"));
536 assert!(DICTIONARY_URL.ends_with(".tar.gz"));
537 assert!(DICTIONARY_URL.contains("open_jtalk_dic_utf_8"));
538 assert_eq!(DICTIONARY_SHA256.len(), 64); assert!(DICTIONARY_SHA256.chars().all(|c| c.is_ascii_hexdigit()));
541 }
542
543 #[test]
544 fn test_get_data_dir_returns_non_empty() {
545 let dir = get_data_dir();
547 assert!(!dir.as_os_str().is_empty());
548 }
549
550 #[test]
551 fn test_find_dictionary_returns_valid_or_none() {
552 let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
556 if let Some(p) = find_dictionary() {
558 assert!(
559 is_valid_dictionary(&p),
560 "find_dictionary returned invalid path: {p:?}"
561 );
562 }
563 }
564
565 #[cfg(feature = "dict-download")]
570 #[test]
571 fn test_verify_sha256_bad_hash() {
572 let dir = tempfile::tempdir().unwrap();
573 let path = dir.path().join("test_archive.tar.gz");
574 std::fs::write(&path, b"not a real archive").unwrap();
575 let result = verify_sha256(&path);
576 assert!(result.is_err());
577 let err = format!("{}", result.unwrap_err());
578 assert!(err.contains("SHA-256 mismatch"));
579 assert!(!path.exists());
581 }
582
583 #[cfg(feature = "dict-download")]
584 #[test]
585 fn test_verify_sha256_missing_file() {
586 let result = verify_sha256(Path::new("/nonexistent/file.tar.gz"));
587 assert!(result.is_err());
588 }
589
590 #[cfg(feature = "dict-download")]
591 #[test]
592 fn test_verify_sha256_known_hash() {
593 use sha2::{Digest, Sha256};
595 let data = b"hello world";
596 let expected = format!("{:x}", Sha256::digest(data));
597
598 let dir = tempfile::tempdir().unwrap();
599 let path = dir.path().join("known_hash_test.bin");
600 std::fs::write(&path, data).unwrap();
601
602 let result = verify_sha256(&path);
605 assert!(result.is_err());
606 let err = format!("{}", result.unwrap_err());
607 assert!(
608 err.contains(&expected),
609 "error should contain actual hash: {err}"
610 );
611 }
612
613 #[cfg(feature = "dict-download")]
618 #[test]
619 fn test_extract_tar_gz_valid() {
620 use flate2::Compression;
621 use flate2::write::GzEncoder;
622 use std::io::Write;
623
624 let dir = tempfile::tempdir().unwrap();
625 let archive_path = dir.path().join("test.tar.gz");
626
627 {
631 let file = std::fs::File::create(&archive_path).unwrap();
632 let encoder = GzEncoder::new(file, Compression::default());
633 let mut builder = tar::Builder::new(encoder);
634
635 let data = b"test dictionary content";
636 let mut header = tar::Header::new_gnu();
637 header.set_size(data.len() as u64);
638 header.set_mode(0o644);
639 header.set_cksum();
640 builder
641 .append_data(&mut header, "test_dict/sys.dic", &data[..])
642 .unwrap();
643
644 let mut gz = builder.into_inner().unwrap();
646 gz.flush().unwrap();
647 gz.finish().unwrap();
649 }
650
651 let extract_dir = dir.path().join("extracted");
653 std::fs::create_dir_all(&extract_dir).unwrap();
654 let result = extract_tar_gz(&archive_path, &extract_dir);
655 assert!(result.is_ok(), "extraction failed: {result:?}");
656
657 let extracted_file = extract_dir.join("test_dict").join("sys.dic");
659 assert!(extracted_file.exists(), "extracted file should exist");
660 let content = std::fs::read(&extracted_file).unwrap();
661 assert_eq!(content, b"test dictionary content");
662 }
663
664 #[cfg(feature = "dict-download")]
665 #[test]
666 fn test_extract_tar_gz_invalid_archive() {
667 let dir = tempfile::tempdir().unwrap();
668 let archive_path = dir.path().join("bad.tar.gz");
669 std::fs::write(&archive_path, b"not a tar.gz file").unwrap();
670
671 let result = extract_tar_gz(&archive_path, dir.path());
672 assert!(result.is_err());
673 }
674
675 #[test]
680 fn test_download_and_extract_stub() {
681 let result = ensure_dictionary();
686 let _ = result;
689 }
690
691 use std::sync::Mutex;
700
701 static ENV_MUTEX: Mutex<()> = Mutex::new(());
703
704 #[test]
705 fn test_find_dictionary_env_var_valid() {
706 let _lock = ENV_MUTEX.lock().unwrap();
707
708 let dir = tempfile::tempdir().unwrap();
710 std::fs::write(dir.path().join("sys.dic"), b"test").unwrap();
711
712 unsafe {
714 std::env::set_var("OPENJTALK_DICTIONARY_PATH", dir.path());
715 }
716 let result = find_dictionary();
717 unsafe {
718 std::env::remove_var("OPENJTALK_DICTIONARY_PATH");
719 }
720
721 assert_eq!(result, Some(dir.path().to_path_buf()));
722 }
723
724 #[test]
725 fn test_find_dictionary_env_var_invalid_skipped() {
726 let _lock = ENV_MUTEX.lock().unwrap();
727
728 unsafe {
731 std::env::set_var("OPENJTALK_DICTIONARY_PATH", "/nonexistent/path/dict");
732 }
733 let result = find_dictionary();
734 unsafe {
735 std::env::remove_var("OPENJTALK_DICTIONARY_PATH");
736 }
737
738 assert_ne!(
740 result,
741 Some(std::path::PathBuf::from("/nonexistent/path/dict"))
742 );
743 }
744
745 #[test]
750 fn test_offline_mode_enabled() {
751 let _lock = ENV_MUTEX.lock().unwrap();
752
753 unsafe {
755 std::env::set_var("PIPER_OFFLINE_MODE", "1");
756 }
757 assert!(is_offline_mode());
758 unsafe {
759 std::env::remove_var("PIPER_OFFLINE_MODE");
760 }
761 }
762
763 #[test]
764 fn test_offline_mode_disabled_by_default() {
765 let _lock = ENV_MUTEX.lock().unwrap();
766
767 unsafe {
769 std::env::remove_var("PIPER_OFFLINE_MODE");
770 }
771 assert!(!is_offline_mode());
772 }
773
774 #[test]
775 fn test_offline_mode_other_values_not_offline() {
776 let _lock = ENV_MUTEX.lock().unwrap();
777
778 unsafe {
780 std::env::set_var("PIPER_OFFLINE_MODE", "0");
781 }
782 assert!(!is_offline_mode());
783 unsafe {
784 std::env::set_var("PIPER_OFFLINE_MODE", "true");
785 }
786 assert!(!is_offline_mode());
787 unsafe {
788 std::env::remove_var("PIPER_OFFLINE_MODE");
789 }
790 }
791
792 #[test]
793 fn test_auto_download_enabled_by_default() {
794 let _lock = ENV_MUTEX.lock().unwrap();
795
796 unsafe {
798 std::env::remove_var("PIPER_AUTO_DOWNLOAD_DICT");
799 }
800 assert!(is_auto_download_enabled());
801 }
802
803 #[test]
804 fn test_auto_download_disabled() {
805 let _lock = ENV_MUTEX.lock().unwrap();
806
807 unsafe {
809 std::env::set_var("PIPER_AUTO_DOWNLOAD_DICT", "0");
810 }
811 assert!(!is_auto_download_enabled());
812 unsafe {
813 std::env::remove_var("PIPER_AUTO_DOWNLOAD_DICT");
814 }
815 }
816
817 #[test]
818 fn test_auto_download_other_values_enabled() {
819 let _lock = ENV_MUTEX.lock().unwrap();
820
821 unsafe {
823 std::env::set_var("PIPER_AUTO_DOWNLOAD_DICT", "1");
824 }
825 assert!(is_auto_download_enabled());
826 unsafe {
827 std::env::set_var("PIPER_AUTO_DOWNLOAD_DICT", "false");
828 }
829 assert!(is_auto_download_enabled());
830 unsafe {
831 std::env::remove_var("PIPER_AUTO_DOWNLOAD_DICT");
832 }
833 }
834
835 #[test]
840 fn test_get_data_dir_env_override() {
841 let _lock = ENV_MUTEX.lock().unwrap();
842
843 let dir = tempfile::tempdir().unwrap();
844 unsafe {
846 std::env::set_var("OPENJTALK_DATA_DIR", dir.path());
847 }
848 let result = get_data_dir();
849 unsafe {
850 std::env::remove_var("OPENJTALK_DATA_DIR");
851 }
852
853 assert_eq!(result, dir.path().to_path_buf());
854 }
855}