1use std::fs;
7use std::io::BufRead;
8use std::path::{Path, PathBuf};
9
10use crate::error::PiperError;
11
12#[derive(Debug, Clone)]
14pub struct BatchJob {
15 pub text: String,
16 pub output_path: PathBuf,
17 pub speaker_id: Option<i64>,
18 pub language: Option<String>,
19}
20
21#[derive(Debug, Clone)]
23pub struct BatchResult {
24 pub job_index: usize,
25 pub output_path: PathBuf,
26 pub audio_seconds: f64,
27 pub infer_seconds: f64,
28 pub success: bool,
29 pub error: Option<String>,
30}
31
32#[derive(Debug, Clone)]
34pub struct BatchSummary {
35 pub total_jobs: usize,
36 pub successful: usize,
37 pub failed: usize,
38 pub total_audio_seconds: f64,
39 pub total_infer_seconds: f64,
40 pub results: Vec<BatchResult>,
41}
42
43impl BatchSummary {
44 pub fn from_results(results: Vec<BatchResult>) -> Self {
46 let total_jobs = results.len();
47 let successful = results.iter().filter(|r| r.success).count();
48 let failed = total_jobs - successful;
49 let total_audio_seconds: f64 = results.iter().map(|r| r.audio_seconds).sum();
50 let total_infer_seconds: f64 = results.iter().map(|r| r.infer_seconds).sum();
51
52 Self {
53 total_jobs,
54 successful,
55 failed,
56 total_audio_seconds,
57 total_infer_seconds,
58 results,
59 }
60 }
61
62 pub fn real_time_factor(&self) -> f64 {
65 if self.total_audio_seconds > 0.0 {
66 self.total_infer_seconds / self.total_audio_seconds
67 } else {
68 0.0
69 }
70 }
71
72 pub fn to_summary_string(&self) -> String {
74 format!(
75 "Batch complete: {}/{} succeeded, {} failed | audio {:.2}s, infer {:.2}s, RTF {:.3}",
76 self.successful,
77 self.total_jobs,
78 self.failed,
79 self.total_audio_seconds,
80 self.total_infer_seconds,
81 self.real_time_factor(),
82 )
83 }
84}
85
86pub type BatchProgressCallback = Box<dyn Fn(usize, usize, &BatchResult) + Send>;
90
91pub fn auto_output_path(output_dir: &Path, index: usize, prefix: &str) -> PathBuf {
95 output_dir.join(format!("{prefix}_{:03}.wav", index + 1))
96}
97
98pub fn jobs_from_text_file(
103 text_file: &Path,
104 output_dir: &Path,
105 speaker_id: Option<i64>,
106 language: Option<&str>,
107) -> Result<Vec<BatchJob>, PiperError> {
108 let content = fs::read_to_string(text_file)?;
109 let mut jobs = Vec::new();
110 let mut index = 0usize;
111
112 for line in content.lines() {
113 let text = line.trim().to_string();
114 if text.is_empty() {
115 continue;
116 }
117 jobs.push(BatchJob {
118 text,
119 output_path: auto_output_path(output_dir, index, "utt"),
120 speaker_id,
121 language: language.map(|s| s.to_string()),
122 });
123 index += 1;
124 }
125
126 Ok(jobs)
127}
128
129#[derive(serde::Deserialize)]
134struct BatchJsonlLine {
135 text: String,
136 #[serde(default)]
137 speaker_id: Option<i64>,
138 #[serde(default)]
139 language: Option<String>,
140 #[serde(default)]
141 output_file: Option<String>,
142}
143
144pub fn jobs_from_jsonl(jsonl_path: &Path, output_dir: &Path) -> Result<Vec<BatchJob>, PiperError> {
150 let file = fs::File::open(jsonl_path)?;
151 let reader = std::io::BufReader::new(file);
152 let mut jobs = Vec::new();
153 let mut auto_index = 0usize;
154
155 for (line_no, line_result) in reader.lines().enumerate() {
156 let line = line_result?;
157 let trimmed = line.trim();
158 if trimmed.is_empty() {
159 continue;
160 }
161
162 let parsed: BatchJsonlLine =
163 serde_json::from_str(trimmed).map_err(|e| PiperError::InvalidConfig {
164 reason: format!("JSONL line {}: {}", line_no + 1, e),
165 })?;
166
167 let output_path = if let Some(ref filename) = parsed.output_file {
168 output_dir.join(filename)
169 } else {
170 auto_output_path(output_dir, auto_index, "utt")
171 };
172
173 jobs.push(BatchJob {
174 text: parsed.text,
175 output_path,
176 speaker_id: parsed.speaker_id,
177 language: parsed.language,
178 });
179 auto_index += 1;
180 }
181
182 Ok(jobs)
183}
184
185#[cfg(test)]
190mod tests {
191 use super::*;
192 use std::io::Write;
193
194 #[test]
198 fn test_auto_output_path_basic() {
199 let p = auto_output_path(Path::new("/tmp/out"), 0, "utt");
200 assert_eq!(p, PathBuf::from("/tmp/out/utt_001.wav"));
201 }
202
203 #[test]
204 fn test_auto_output_path_double_digit() {
205 let p = auto_output_path(Path::new("/tmp/out"), 9, "utt");
206 assert_eq!(p, PathBuf::from("/tmp/out/utt_010.wav"));
207 }
208
209 #[test]
210 fn test_auto_output_path_triple_digit() {
211 let p = auto_output_path(Path::new("/tmp/out"), 99, "utt");
212 assert_eq!(p, PathBuf::from("/tmp/out/utt_100.wav"));
213 }
214
215 #[test]
216 fn test_auto_output_path_large_index() {
217 let p = auto_output_path(Path::new("/out"), 999, "batch");
219 assert_eq!(p, PathBuf::from("/out/batch_1000.wav"));
220 }
221
222 #[test]
223 fn test_auto_output_path_custom_prefix() {
224 let p = auto_output_path(Path::new("/data"), 4, "chapter");
225 assert_eq!(p, PathBuf::from("/data/chapter_005.wav"));
226 }
227
228 #[test]
232 fn test_batch_job_construction() {
233 let job = BatchJob {
234 text: "Hello world".to_string(),
235 output_path: PathBuf::from("/tmp/out.wav"),
236 speaker_id: Some(3),
237 language: Some("en".to_string()),
238 };
239 assert_eq!(job.text, "Hello world");
240 assert_eq!(job.output_path, PathBuf::from("/tmp/out.wav"));
241 assert_eq!(job.speaker_id, Some(3));
242 assert_eq!(job.language.as_deref(), Some("en"));
243 }
244
245 #[test]
246 fn test_batch_job_no_optional_fields() {
247 let job = BatchJob {
248 text: "Test".to_string(),
249 output_path: PathBuf::from("/tmp/test.wav"),
250 speaker_id: None,
251 language: None,
252 };
253 assert!(job.speaker_id.is_none());
254 assert!(job.language.is_none());
255 }
256
257 #[test]
258 fn test_batch_job_clone() {
259 let job = BatchJob {
260 text: "Clone me".to_string(),
261 output_path: PathBuf::from("/tmp/clone.wav"),
262 speaker_id: Some(1),
263 language: Some("ja".to_string()),
264 };
265 let cloned = job.clone();
266 assert_eq!(cloned.text, job.text);
267 assert_eq!(cloned.output_path, job.output_path);
268 assert_eq!(cloned.speaker_id, job.speaker_id);
269 assert_eq!(cloned.language, job.language);
270 }
271
272 #[test]
276 fn test_batch_result_success() {
277 let result = BatchResult {
278 job_index: 0,
279 output_path: PathBuf::from("/tmp/utt_001.wav"),
280 audio_seconds: 2.5,
281 infer_seconds: 0.3,
282 success: true,
283 error: None,
284 };
285 assert!(result.success);
286 assert!(result.error.is_none());
287 assert!((result.audio_seconds - 2.5).abs() < 1e-6);
288 }
289
290 #[test]
291 fn test_batch_result_failure() {
292 let result = BatchResult {
293 job_index: 5,
294 output_path: PathBuf::from("/tmp/utt_006.wav"),
295 audio_seconds: 0.0,
296 infer_seconds: 0.0,
297 success: false,
298 error: Some("phonemization failed".to_string()),
299 };
300 assert!(!result.success);
301 assert_eq!(result.error.as_deref(), Some("phonemization failed"));
302 assert_eq!(result.job_index, 5);
303 }
304
305 #[test]
309 fn test_batch_summary_from_results() {
310 let results = vec![
311 BatchResult {
312 job_index: 0,
313 output_path: PathBuf::from("/tmp/utt_001.wav"),
314 audio_seconds: 2.0,
315 infer_seconds: 0.4,
316 success: true,
317 error: None,
318 },
319 BatchResult {
320 job_index: 1,
321 output_path: PathBuf::from("/tmp/utt_002.wav"),
322 audio_seconds: 0.0,
323 infer_seconds: 0.0,
324 success: false,
325 error: Some("error".to_string()),
326 },
327 BatchResult {
328 job_index: 2,
329 output_path: PathBuf::from("/tmp/utt_003.wav"),
330 audio_seconds: 3.0,
331 infer_seconds: 0.6,
332 success: true,
333 error: None,
334 },
335 ];
336
337 let summary = BatchSummary::from_results(results);
338 assert_eq!(summary.total_jobs, 3);
339 assert_eq!(summary.successful, 2);
340 assert_eq!(summary.failed, 1);
341 assert!((summary.total_audio_seconds - 5.0).abs() < 1e-6);
342 assert!((summary.total_infer_seconds - 1.0).abs() < 1e-6);
343 }
344
345 #[test]
346 fn test_batch_summary_real_time_factor() {
347 let summary = BatchSummary {
348 total_jobs: 2,
349 successful: 2,
350 failed: 0,
351 total_audio_seconds: 10.0,
352 total_infer_seconds: 2.0,
353 results: Vec::new(),
354 };
355 assert!((summary.real_time_factor() - 0.2).abs() < 1e-6);
356 }
357
358 #[test]
359 fn test_batch_summary_real_time_factor_zero_audio() {
360 let summary = BatchSummary {
361 total_jobs: 1,
362 successful: 0,
363 failed: 1,
364 total_audio_seconds: 0.0,
365 total_infer_seconds: 0.1,
366 results: Vec::new(),
367 };
368 assert!((summary.real_time_factor()).abs() < 1e-6);
369 }
370
371 #[test]
372 fn test_batch_summary_to_summary_string() {
373 let summary = BatchSummary {
374 total_jobs: 10,
375 successful: 8,
376 failed: 2,
377 total_audio_seconds: 25.0,
378 total_infer_seconds: 5.0,
379 results: Vec::new(),
380 };
381 let s = summary.to_summary_string();
382 assert!(s.contains("8/10 succeeded"), "got: {s}");
383 assert!(s.contains("2 failed"), "got: {s}");
384 assert!(s.contains("audio 25.00s"), "got: {s}");
385 assert!(s.contains("infer 5.00s"), "got: {s}");
386 assert!(s.contains("RTF 0.200"), "got: {s}");
387 }
388
389 #[test]
390 fn test_batch_summary_empty() {
391 let summary = BatchSummary::from_results(Vec::new());
392 assert_eq!(summary.total_jobs, 0);
393 assert_eq!(summary.successful, 0);
394 assert_eq!(summary.failed, 0);
395 assert!((summary.total_audio_seconds).abs() < 1e-6);
396 let s = summary.to_summary_string();
397 assert!(s.contains("0/0 succeeded"), "got: {s}");
398 }
399
400 #[test]
404 fn test_jobs_from_text_file_basic() {
405 let dir = tempfile::tempdir().unwrap();
406 let text_path = dir.path().join("input.txt");
407 fs::write(&text_path, "Hello world\nGoodbye world\n").unwrap();
408
409 let jobs = jobs_from_text_file(&text_path, dir.path(), Some(0), Some("en")).unwrap();
410 assert_eq!(jobs.len(), 2);
411 assert_eq!(jobs[0].text, "Hello world");
412 assert_eq!(jobs[0].output_path, dir.path().join("utt_001.wav"));
413 assert_eq!(jobs[0].speaker_id, Some(0));
414 assert_eq!(jobs[0].language.as_deref(), Some("en"));
415 assert_eq!(jobs[1].text, "Goodbye world");
416 assert_eq!(jobs[1].output_path, dir.path().join("utt_002.wav"));
417 }
418
419 #[test]
420 fn test_jobs_from_text_file_skips_empty_lines() {
421 let dir = tempfile::tempdir().unwrap();
422 let text_path = dir.path().join("input.txt");
423 fs::write(&text_path, "Line one\n\n\nLine two\n\n").unwrap();
424
425 let jobs = jobs_from_text_file(&text_path, dir.path(), None, None).unwrap();
426 assert_eq!(jobs.len(), 2);
427 assert_eq!(jobs[0].text, "Line one");
428 assert_eq!(jobs[1].text, "Line two");
429 assert_eq!(jobs[0].output_path, dir.path().join("utt_001.wav"));
431 assert_eq!(jobs[1].output_path, dir.path().join("utt_002.wav"));
432 }
433
434 #[test]
435 fn test_jobs_from_text_file_no_optional_fields() {
436 let dir = tempfile::tempdir().unwrap();
437 let text_path = dir.path().join("input.txt");
438 fs::write(&text_path, "Single line\n").unwrap();
439
440 let jobs = jobs_from_text_file(&text_path, dir.path(), None, None).unwrap();
441 assert_eq!(jobs.len(), 1);
442 assert!(jobs[0].speaker_id.is_none());
443 assert!(jobs[0].language.is_none());
444 }
445
446 #[test]
447 fn test_jobs_from_text_file_nonexistent() {
448 let result = jobs_from_text_file(
449 Path::new("/nonexistent/file.txt"),
450 Path::new("/tmp"),
451 None,
452 None,
453 );
454 assert!(result.is_err());
455 }
456
457 #[test]
461 fn test_jobs_from_jsonl_basic() {
462 let dir = tempfile::tempdir().unwrap();
463 let jsonl_path = dir.path().join("batch.jsonl");
464 let content = r#"{"text": "Hello"}
465{"text": "World", "speaker_id": 5}
466"#;
467 fs::write(&jsonl_path, content).unwrap();
468
469 let jobs = jobs_from_jsonl(&jsonl_path, dir.path()).unwrap();
470 assert_eq!(jobs.len(), 2);
471 assert_eq!(jobs[0].text, "Hello");
472 assert!(jobs[0].speaker_id.is_none());
473 assert_eq!(jobs[0].output_path, dir.path().join("utt_001.wav"));
474 assert_eq!(jobs[1].text, "World");
475 assert_eq!(jobs[1].speaker_id, Some(5));
476 }
477
478 #[test]
479 fn test_jobs_from_jsonl_with_output_file() {
480 let dir = tempfile::tempdir().unwrap();
481 let jsonl_path = dir.path().join("batch.jsonl");
482 let content = r#"{"text": "Custom", "output_file": "custom_output.wav"}"#;
483 fs::write(&jsonl_path, content).unwrap();
484
485 let jobs = jobs_from_jsonl(&jsonl_path, dir.path()).unwrap();
486 assert_eq!(jobs.len(), 1);
487 assert_eq!(jobs[0].output_path, dir.path().join("custom_output.wav"));
488 }
489
490 #[test]
491 fn test_jobs_from_jsonl_with_language() {
492 let dir = tempfile::tempdir().unwrap();
493 let jsonl_path = dir.path().join("batch.jsonl");
494 let content = r#"{"text": "Bonjour", "language": "fr", "speaker_id": 2}"#;
495 fs::write(&jsonl_path, content).unwrap();
496
497 let jobs = jobs_from_jsonl(&jsonl_path, dir.path()).unwrap();
498 assert_eq!(jobs.len(), 1);
499 assert_eq!(jobs[0].language.as_deref(), Some("fr"));
500 assert_eq!(jobs[0].speaker_id, Some(2));
501 }
502
503 #[test]
504 fn test_jobs_from_jsonl_skips_empty_lines() {
505 let dir = tempfile::tempdir().unwrap();
506 let jsonl_path = dir.path().join("batch.jsonl");
507 let content = "{\"text\": \"A\"}\n\n{\"text\": \"B\"}\n";
508 fs::write(&jsonl_path, content).unwrap();
509
510 let jobs = jobs_from_jsonl(&jsonl_path, dir.path()).unwrap();
511 assert_eq!(jobs.len(), 2);
512 assert_eq!(jobs[0].text, "A");
513 assert_eq!(jobs[1].text, "B");
514 }
515
516 #[test]
517 fn test_jobs_from_jsonl_invalid_json() {
518 let dir = tempfile::tempdir().unwrap();
519 let jsonl_path = dir.path().join("bad.jsonl");
520 fs::write(&jsonl_path, "not valid json\n").unwrap();
521
522 let result = jobs_from_jsonl(&jsonl_path, dir.path());
523 assert!(result.is_err());
524 }
525
526 #[test]
530 fn test_empty_text_job_fields() {
531 let job = BatchJob {
533 text: String::new(),
534 output_path: PathBuf::from("/tmp/empty.wav"),
535 speaker_id: None,
536 language: None,
537 };
538 assert!(job.text.is_empty());
539
540 let result = BatchResult {
542 job_index: 0,
543 output_path: job.output_path.clone(),
544 audio_seconds: 0.0,
545 infer_seconds: 0.0,
546 success: true,
547 error: None,
548 };
549 assert!(result.success);
550 assert!((result.audio_seconds).abs() < 1e-6);
551 }
552
553 #[test]
554 fn test_text_file_all_empty_lines() {
555 let dir = tempfile::tempdir().unwrap();
556 let text_path = dir.path().join("empty.txt");
557 fs::write(&text_path, "\n\n\n").unwrap();
558
559 let jobs = jobs_from_text_file(&text_path, dir.path(), None, None).unwrap();
560 assert!(jobs.is_empty());
561 }
562
563 #[test]
567 fn test_auto_output_path_four_digits() {
568 let p = auto_output_path(Path::new("/out"), 1234, "utt");
569 assert_eq!(p, PathBuf::from("/out/utt_1235.wav"));
570 }
571
572 #[test]
576 fn test_batch_summary_all_success() {
577 let results = vec![
578 BatchResult {
579 job_index: 0,
580 output_path: PathBuf::from("a.wav"),
581 audio_seconds: 1.0,
582 infer_seconds: 0.1,
583 success: true,
584 error: None,
585 },
586 BatchResult {
587 job_index: 1,
588 output_path: PathBuf::from("b.wav"),
589 audio_seconds: 2.0,
590 infer_seconds: 0.2,
591 success: true,
592 error: None,
593 },
594 ];
595 let summary = BatchSummary::from_results(results);
596 assert_eq!(summary.successful, 2);
597 assert_eq!(summary.failed, 0);
598 assert!((summary.total_audio_seconds - 3.0).abs() < 1e-6);
599 assert!((summary.total_infer_seconds - 0.3).abs() < 1e-6);
600 }
601
602 #[test]
603 fn test_batch_summary_all_failure() {
604 let results = vec![
605 BatchResult {
606 job_index: 0,
607 output_path: PathBuf::from("a.wav"),
608 audio_seconds: 0.0,
609 infer_seconds: 0.0,
610 success: false,
611 error: Some("err1".into()),
612 },
613 BatchResult {
614 job_index: 1,
615 output_path: PathBuf::from("b.wav"),
616 audio_seconds: 0.0,
617 infer_seconds: 0.0,
618 success: false,
619 error: Some("err2".into()),
620 },
621 ];
622 let summary = BatchSummary::from_results(results);
623 assert_eq!(summary.successful, 0);
624 assert_eq!(summary.failed, 2);
625 assert!((summary.real_time_factor()).abs() < 1e-6);
626 }
627
628 #[test]
632 fn test_batch_result_clone() {
633 let result = BatchResult {
634 job_index: 7,
635 output_path: PathBuf::from("/tmp/utt_008.wav"),
636 audio_seconds: 1.5,
637 infer_seconds: 0.2,
638 success: true,
639 error: None,
640 };
641 let cloned = result.clone();
642 assert_eq!(cloned.job_index, result.job_index);
643 assert_eq!(cloned.output_path, result.output_path);
644 assert!((cloned.audio_seconds - result.audio_seconds).abs() < 1e-6);
645 }
646
647 #[test]
651 fn test_jobs_from_text_file_trims_whitespace() {
652 let dir = tempfile::tempdir().unwrap();
653 let text_path = dir.path().join("spaces.txt");
654 fs::write(&text_path, " hello \n world \n").unwrap();
655
656 let jobs = jobs_from_text_file(&text_path, dir.path(), None, None).unwrap();
657 assert_eq!(jobs.len(), 2);
658 assert_eq!(jobs[0].text, "hello");
659 assert_eq!(jobs[1].text, "world");
660 }
661
662 #[test]
666 fn test_jobs_from_jsonl_mixed_output_paths() {
667 let dir = tempfile::tempdir().unwrap();
668 let jsonl_path = dir.path().join("mixed.jsonl");
669 let mut f = fs::File::create(&jsonl_path).unwrap();
670 writeln!(f, r#"{{"text": "auto"}}"#).unwrap();
671 writeln!(f, r#"{{"text": "custom", "output_file": "my.wav"}}"#).unwrap();
672 writeln!(f, r#"{{"text": "auto2"}}"#).unwrap();
673 drop(f);
674
675 let jobs = jobs_from_jsonl(&jsonl_path, dir.path()).unwrap();
676 assert_eq!(jobs.len(), 3);
677 assert_eq!(jobs[0].output_path, dir.path().join("utt_001.wav"));
678 assert_eq!(jobs[1].output_path, dir.path().join("my.wav"));
679 assert_eq!(jobs[2].output_path, dir.path().join("utt_003.wav"));
680 }
681
682 #[test]
686 fn test_jobs_from_jsonl_missing_text_field() {
687 let dir = tempfile::tempdir().unwrap();
688 let jsonl_path = dir.path().join("no_text.jsonl");
689 fs::write(&jsonl_path, r#"{"speaker_id": 1, "language": "en"}"#).unwrap();
691
692 let result = jobs_from_jsonl(&jsonl_path, dir.path());
693 assert!(
694 result.is_err(),
695 "missing 'text' field should cause an error"
696 );
697 let err_msg = format!("{}", result.unwrap_err());
698 assert!(
699 err_msg.contains("text") || err_msg.contains("missing field"),
700 "error should mention the missing field, got: {err_msg}"
701 );
702 }
703
704 #[test]
708 fn test_jobs_from_jsonl_invalid_speaker_id_type() {
709 let dir = tempfile::tempdir().unwrap();
710 let jsonl_path = dir.path().join("bad_sid.jsonl");
711 fs::write(
713 &jsonl_path,
714 r#"{"text": "hello", "speaker_id": "not_a_number"}"#,
715 )
716 .unwrap();
717
718 let result = jobs_from_jsonl(&jsonl_path, dir.path());
719 assert!(
720 result.is_err(),
721 "speaker_id as string should cause a deserialization error"
722 );
723 }
724
725 #[test]
729 fn test_batch_summary_from_empty_results() {
730 let summary = BatchSummary::from_results(Vec::new());
731 assert_eq!(summary.total_jobs, 0);
732 assert_eq!(summary.successful, 0);
733 assert_eq!(summary.failed, 0);
734 assert!((summary.total_audio_seconds - 0.0).abs() < 1e-9);
735 assert!((summary.total_infer_seconds - 0.0).abs() < 1e-9);
736 assert!((summary.real_time_factor() - 0.0).abs() < 1e-9);
737 assert!(summary.results.is_empty());
738 }
739
740 #[test]
744 fn test_real_time_factor_zero_audio_returns_zero() {
745 let summary = BatchSummary {
747 total_jobs: 5,
748 successful: 0,
749 failed: 5,
750 total_audio_seconds: 0.0,
751 total_infer_seconds: 42.0,
752 results: Vec::new(),
753 };
754 assert_eq!(summary.real_time_factor(), 0.0);
755 }
756
757 #[test]
761 fn test_auto_output_path_unicode_prefix() {
762 let p = auto_output_path(Path::new("/tmp/out"), 0, "発話");
763 assert_eq!(p, PathBuf::from("/tmp/out/発話_001.wav"));
764
765 let p2 = auto_output_path(Path::new("/tmp/out"), 9, "テスト");
766 assert_eq!(p2, PathBuf::from("/tmp/out/テスト_010.wav"));
767
768 let p3 = auto_output_path(Path::new("/data"), 2, "🔊audio");
770 assert_eq!(p3, PathBuf::from("/data/🔊audio_003.wav"));
771 }
772}