1use crate::CaptionGenError;
5
6#[derive(Debug, Clone, PartialEq)]
8pub struct WordTimestamp {
9 pub word: String,
10 pub start_ms: u64,
11 pub end_ms: u64,
12 pub confidence: f32,
14 pub word_confidence: f32,
18}
19
20impl WordTimestamp {
21 pub fn with_word_confidence(
23 word: String,
24 start_ms: u64,
25 end_ms: u64,
26 confidence: f32,
27 word_confidence: f32,
28 ) -> Self {
29 Self {
30 word,
31 start_ms,
32 end_ms,
33 confidence,
34 word_confidence,
35 }
36 }
37
38 pub fn is_high_quality(&self, threshold: f32) -> bool {
40 self.word_confidence >= threshold
41 }
42}
43
44#[derive(Debug, Clone, PartialEq)]
46pub struct TranscriptSegment {
47 pub text: String,
48 pub start_ms: u64,
49 pub end_ms: u64,
50 pub speaker_id: Option<u8>,
52 pub words: Vec<WordTimestamp>,
53}
54
55impl TranscriptSegment {
56 pub fn duration_ms(&self) -> u64 {
58 self.end_ms.saturating_sub(self.start_ms)
59 }
60}
61
62#[derive(Debug, Clone, PartialEq, thiserror::Error)]
64pub enum AlignmentError {
65 #[error(
66 "segment duration ({segment_ms}ms) is incompatible with video duration ({video_ms}ms)"
67 )]
68 IncompatibleDuration { segment_ms: u64, video_ms: u64 },
69
70 #[error("transcript is empty — no segments to align")]
71 EmptyTranscript,
72
73 #[error("invalid timestamp: start_ms ({start_ms}) >= end_ms ({end_ms})")]
74 InvalidTimestamp { start_ms: u64, end_ms: u64 },
75}
76
77#[derive(Debug, Clone, PartialEq)]
79pub enum CaptionPosition {
80 Bottom,
82 Top,
84 Custom(f32, f32),
86}
87
88#[derive(Debug, Clone, PartialEq)]
90pub struct CaptionBlock {
91 pub id: u32,
93 pub start_ms: u64,
94 pub end_ms: u64,
95 pub lines: Vec<String>,
97 pub speaker_id: Option<u8>,
98 pub position: CaptionPosition,
99}
100
101impl CaptionBlock {
102 pub fn char_count(&self) -> usize {
104 self.lines.iter().map(|l| l.chars().count()).sum()
105 }
106
107 pub fn duration_ms(&self) -> u64 {
109 self.end_ms.saturating_sub(self.start_ms)
110 }
111}
112
113pub fn align_to_frames(
124 segment: &TranscriptSegment,
125 fps: f32,
126) -> Result<Vec<(u64, String)>, CaptionGenError> {
127 if fps <= 0.0 {
128 return Err(CaptionGenError::InvalidParameter(
129 "fps must be positive".to_string(),
130 ));
131 }
132 if segment.start_ms >= segment.end_ms && !segment.text.is_empty() {
133 return Err(CaptionGenError::Alignment(
134 AlignmentError::InvalidTimestamp {
135 start_ms: segment.start_ms,
136 end_ms: segment.end_ms,
137 },
138 ));
139 }
140
141 let ms_per_frame = 1000.0 / fps as f64;
142
143 if !segment.words.is_empty() {
145 let mut result: Vec<(u64, String)> = Vec::new();
146 for word in &segment.words {
147 let frame = (word.start_ms as f64 / ms_per_frame).floor() as u64;
148 if let Some(last) = result.last_mut() {
150 if last.0 == frame {
151 last.1.push(' ');
152 last.1.push_str(&word.word);
153 continue;
154 }
155 }
156 result.push((frame, word.word.clone()));
157 }
158 return Ok(result);
159 }
160
161 let start_frame = (segment.start_ms as f64 / ms_per_frame).floor() as u64;
163 Ok(vec![(start_frame, segment.text.clone())])
164}
165
166pub fn align_to_frames_batch(
176 segments: &[TranscriptSegment],
177 fps: f32,
178) -> Result<Vec<Vec<(u64, String)>>, CaptionGenError> {
179 if fps <= 0.0 {
180 return Err(CaptionGenError::InvalidParameter(
181 "fps must be positive".to_string(),
182 ));
183 }
184 segments
185 .iter()
186 .map(|seg| align_to_frames(seg, fps))
187 .collect()
188}
189
190pub fn merge_short_segments(
199 segments: &[TranscriptSegment],
200 min_duration_ms: u32,
201) -> Vec<TranscriptSegment> {
202 if segments.is_empty() {
203 return Vec::new();
204 }
205
206 let mut merged: Vec<TranscriptSegment> = segments.to_vec();
208 let min_ms = u64::from(min_duration_ms);
209
210 loop {
212 let mut changed = false;
213 let mut output: Vec<TranscriptSegment> = Vec::with_capacity(merged.len());
214
215 let mut i = 0;
216 while i < merged.len() {
217 let seg = merged[i].clone();
218 if seg.duration_ms() < min_ms && output.is_empty() && i + 1 < merged.len() {
219 let next = merged[i + 1].clone();
221 let combined = combine_segments(&seg, &next);
222 output.push(combined);
223 i += 2;
224 changed = true;
225 } else if seg.duration_ms() < min_ms {
226 if let Some(prev) = output.last_mut() {
228 let combined = combine_segments(prev, &seg);
229 *prev = combined;
230 changed = true;
231 } else {
232 output.push(seg);
233 }
234 i += 1;
235 } else {
236 output.push(seg);
237 i += 1;
238 }
239 }
240
241 merged = output;
242 if !changed {
243 break;
244 }
245 }
246
247 merged
248}
249
250fn combine_segments(a: &TranscriptSegment, b: &TranscriptSegment) -> TranscriptSegment {
253 let mut text = a.text.clone();
254 if !a.text.is_empty() && !b.text.is_empty() {
255 text.push(' ');
256 }
257 text.push_str(&b.text);
258
259 let mut words = a.words.clone();
260 words.extend_from_slice(&b.words);
261
262 TranscriptSegment {
263 text,
264 start_ms: a.start_ms.min(b.start_ms),
265 end_ms: a.end_ms.max(b.end_ms),
266 speaker_id: a.speaker_id,
267 words,
268 }
269}
270
271pub fn split_long_segments(
280 segment: &TranscriptSegment,
281 max_duration_ms: u32,
282 max_chars: u16,
283) -> Vec<TranscriptSegment> {
284 let max_dur = u64::from(max_duration_ms);
285 let max_ch = usize::from(max_chars);
286
287 let needs_split = segment.duration_ms() > max_dur || segment.text.chars().count() > max_ch;
288 if !needs_split {
289 return vec![segment.clone()];
290 }
291
292 let chunks = split_text_into_chunks(&segment.text, max_ch);
294 if chunks.len() <= 1 {
295 return vec![segment.clone()];
296 }
297
298 let total_chars: usize = chunks.iter().map(|c| c.chars().count()).sum();
300 let total_duration = segment.duration_ms();
301 let mut result = Vec::with_capacity(chunks.len());
302 let mut cursor_ms = segment.start_ms;
303
304 for (idx, chunk) in chunks.iter().enumerate() {
305 let chunk_chars = chunk.chars().count();
306 let chunk_duration = if idx + 1 < chunks.len() {
307 if total_chars > 0 {
308 (total_duration as f64 * chunk_chars as f64 / total_chars as f64).round() as u64
309 } else {
310 total_duration / chunks.len() as u64
311 }
312 } else {
313 segment.end_ms.saturating_sub(cursor_ms)
315 };
316
317 let start_ms = cursor_ms;
318 let end_ms = (cursor_ms + chunk_duration).min(segment.end_ms);
319
320 let sub_words: Vec<WordTimestamp> = segment
322 .words
323 .iter()
324 .filter(|w| w.start_ms >= start_ms && w.start_ms < end_ms)
325 .cloned()
326 .collect();
327
328 result.push(TranscriptSegment {
329 text: chunk.clone(),
330 start_ms,
331 end_ms,
332 speaker_id: segment.speaker_id,
333 words: sub_words,
334 });
335
336 cursor_ms = end_ms;
337 }
338
339 result
340}
341
342fn split_text_into_chunks(text: &str, max_chars: usize) -> Vec<String> {
345 if text.chars().count() <= max_chars {
346 return vec![text.to_string()];
347 }
348
349 let mut chunks: Vec<String> = Vec::new();
350 let mut remaining = text.trim();
351
352 while !remaining.is_empty() {
353 if remaining.chars().count() <= max_chars {
354 chunks.push(remaining.to_string());
355 break;
356 }
357
358 let window: String = remaining.chars().take(max_chars + 1).collect();
360 let cut = find_sentence_boundary(&window, max_chars)
361 .or_else(|| find_word_boundary(&window, max_chars))
362 .unwrap_or(max_chars);
363
364 let (chunk, rest) = split_at_char_index(remaining, cut);
365 chunks.push(chunk.trim().to_string());
366 remaining = rest.trim();
367 }
368
369 chunks
370}
371
372fn find_sentence_boundary(text: &str, max_chars: usize) -> Option<usize> {
375 let chars: Vec<char> = text.chars().take(max_chars).collect();
376 for (i, &ch) in chars.iter().enumerate().rev() {
377 if ch == '.' || ch == '!' || ch == '?' {
378 return Some(i + 1);
379 }
380 }
381 None
382}
383
384fn find_word_boundary(text: &str, max_chars: usize) -> Option<usize> {
387 let chars: Vec<char> = text.chars().take(max_chars).collect();
388 for (i, &ch) in chars.iter().enumerate().rev() {
389 if ch == ' ' {
390 return Some(i);
391 }
392 }
393 None
394}
395
396fn split_at_char_index(text: &str, idx: usize) -> (&str, &str) {
398 let byte_pos = text
399 .char_indices()
400 .nth(idx)
401 .map(|(b, _)| b)
402 .unwrap_or(text.len());
403 (&text[..byte_pos], &text[byte_pos..])
404}
405
406pub fn build_caption_blocks(
413 segments: &[TranscriptSegment],
414 max_lines: u8,
415 max_chars_per_line: u8,
416) -> Vec<CaptionBlock> {
417 use crate::line_breaking::greedy_break;
418
419 let max_l = max_lines.max(1) as usize;
420 let max_c = max_chars_per_line.max(1);
421
422 segments
423 .iter()
424 .enumerate()
425 .map(|(idx, seg)| {
426 let all_lines = greedy_break(&seg.text, max_c);
427 let lines = if all_lines.len() <= max_l {
429 all_lines
430 } else {
431 let mut truncated = all_lines[..max_l - 1].to_vec();
432 let overflow = all_lines[max_l - 1..].join(" ");
433 truncated.push(overflow);
434 truncated
435 };
436
437 CaptionBlock {
438 id: (idx as u32) + 1,
439 start_ms: seg.start_ms,
440 end_ms: seg.end_ms,
441 lines,
442 speaker_id: seg.speaker_id,
443 position: CaptionPosition::Bottom,
444 }
445 })
446 .collect()
447}
448
449#[cfg(test)]
452mod tests {
453 use super::*;
454
455 fn make_seg(text: &str, start_ms: u64, end_ms: u64) -> TranscriptSegment {
456 TranscriptSegment {
457 text: text.to_string(),
458 start_ms,
459 end_ms,
460 speaker_id: None,
461 words: Vec::new(),
462 }
463 }
464
465 fn make_word(word: &str, start_ms: u64, end_ms: u64) -> WordTimestamp {
466 WordTimestamp {
467 word: word.to_string(),
468 start_ms,
469 end_ms,
470 confidence: 1.0,
471 word_confidence: 1.0,
472 }
473 }
474
475 #[test]
478 fn align_to_frames_segment_level() {
479 let seg = make_seg("Hello world", 0, 2000);
480 let frames = align_to_frames(&seg, 25.0).expect("align to frames should succeed");
481 assert_eq!(frames.len(), 1);
482 assert_eq!(frames[0].0, 0);
483 assert_eq!(frames[0].1, "Hello world");
484 }
485
486 #[test]
487 fn align_to_frames_word_level() {
488 let mut seg = make_seg("Hello world", 0, 2000);
489 seg.words = vec![make_word("Hello", 0, 1000), make_word("world", 1000, 2000)];
490 let frames = align_to_frames(&seg, 25.0).expect("align to frames should succeed");
491 assert_eq!(frames[0].0, 0);
492 assert_eq!(frames[1].0, 25);
493 }
494
495 #[test]
496 fn align_to_frames_rejects_zero_fps() {
497 let seg = make_seg("test", 0, 1000);
498 assert!(align_to_frames(&seg, 0.0).is_err());
499 }
500
501 #[test]
502 fn align_to_frames_rejects_negative_fps() {
503 let seg = make_seg("test", 0, 1000);
504 assert!(align_to_frames(&seg, -30.0).is_err());
505 }
506
507 #[test]
508 fn align_to_frames_same_start_frame_merges_words() {
509 let mut seg = make_seg("Hi", 0, 500);
510 seg.words = vec![make_word("Hi", 0, 200), make_word("there", 20, 300)];
512 let frames = align_to_frames(&seg, 25.0).expect("align to frames should succeed");
513 assert_eq!(frames.len(), 1);
515 assert!(frames[0].1.contains("Hi"));
516 assert!(frames[0].1.contains("there"));
517 }
518
519 #[test]
520 fn align_to_frames_correct_frame_numbers_at_30fps() {
521 let mut seg = make_seg("A B C", 0, 3000);
522 seg.words = vec![
523 make_word("A", 0, 1000),
524 make_word("B", 1000, 2000),
525 make_word("C", 2000, 3000),
526 ];
527 let frames = align_to_frames(&seg, 30.0).expect("align");
528 assert_eq!(frames[0].0, 0);
529 assert!(frames[1].0 == 29 || frames[1].0 == 30);
532 assert!(frames[2].0 == 59 || frames[2].0 == 60);
533 }
534
535 #[test]
538 fn merge_short_segments_empty() {
539 assert!(merge_short_segments(&[], 500).is_empty());
540 }
541
542 #[test]
543 fn merge_short_segments_no_op_if_all_long_enough() {
544 let segs = vec![make_seg("hello", 0, 1000), make_seg("world", 1000, 2000)];
545 let result = merge_short_segments(&segs, 500);
546 assert_eq!(result.len(), 2);
547 }
548
549 #[test]
550 fn merge_short_segments_merges_short_prefix() {
551 let segs = vec![
552 make_seg("Hi", 0, 100), make_seg("world", 100, 1200),
554 ];
555 let result = merge_short_segments(&segs, 500);
556 assert_eq!(result.len(), 1);
557 assert!(result[0].text.contains("Hi"));
558 assert!(result[0].text.contains("world"));
559 }
560
561 #[test]
562 fn merge_short_segments_merges_short_suffix() {
563 let segs = vec![
564 make_seg("Hello there", 0, 1000),
565 make_seg("ok", 1000, 1050), ];
567 let result = merge_short_segments(&segs, 500);
568 assert_eq!(result.len(), 1);
569 assert!(result[0].text.contains("Hello"));
570 assert!(result[0].text.contains("ok"));
571 }
572
573 #[test]
574 fn merge_short_segments_span_extends() {
575 let segs = vec![
576 make_seg("A", 0, 100),
577 make_seg("long segment here", 100, 2000),
578 ];
579 let result = merge_short_segments(&segs, 500);
580 assert_eq!(result[0].start_ms, 0);
581 assert_eq!(result[0].end_ms, 2000);
582 }
583
584 #[test]
587 fn split_long_segments_no_op_if_short() {
588 let seg = make_seg("Hello", 0, 1000);
589 let result = split_long_segments(&seg, 5000, 200);
590 assert_eq!(result.len(), 1);
591 }
592
593 #[test]
594 fn split_long_segments_by_duration() {
595 let seg = make_seg("This is a longer sentence for testing purposes.", 0, 20000);
598 let result = split_long_segments(&seg, 5000, 20);
599 assert!(result.len() > 1, "expected multiple segments");
600 for s in &result {
601 assert!(s.duration_ms() <= 20000);
602 }
603 }
604
605 #[test]
606 fn split_long_segments_preserves_total_duration() {
607 let seg = make_seg("Word one. Word two. Word three. Word four.", 0, 10000);
608 let result = split_long_segments(&seg, 3000, 20);
609 let first_start = result.first().map(|s| s.start_ms).unwrap_or(0);
610 let last_end = result.last().map(|s| s.end_ms).unwrap_or(0);
611 assert_eq!(first_start, 0);
612 assert_eq!(last_end, 10000);
613 }
614
615 #[test]
616 fn split_long_segments_respects_max_chars() {
617 let seg = make_seg(
618 "This is a very long text that exceeds the character limit.",
619 0,
620 10000,
621 );
622 let result = split_long_segments(&seg, 100_000, 15);
623 for s in &result {
624 assert!(s.text.chars().count() <= 20, "chunk '{}' too long", s.text);
625 }
626 }
627
628 #[test]
629 fn split_long_segments_words_assigned_to_subsegments() {
630 let mut seg = make_seg("Hello world test", 0, 3000);
631 seg.words = vec![
632 make_word("Hello", 0, 1000),
633 make_word("world", 1000, 2000),
634 make_word("test", 2000, 3000),
635 ];
636 let result = split_long_segments(&seg, 1200, 8);
637 assert!(result.len() > 1);
638 }
639
640 #[test]
643 fn build_caption_blocks_basic() {
644 let segs = vec![
645 make_seg("Hello world", 0, 2000),
646 make_seg("How are you", 2000, 4000),
647 ];
648 let blocks = build_caption_blocks(&segs, 2, 40);
649 assert_eq!(blocks.len(), 2);
650 assert_eq!(blocks[0].id, 1);
651 assert_eq!(blocks[1].id, 2);
652 }
653
654 #[test]
655 fn build_caption_blocks_respects_max_lines() {
656 let seg = make_seg(
657 "This is a very very very very very very very very long text to wrap over many lines.",
658 0,
659 5000,
660 );
661 let blocks = build_caption_blocks(&[seg], 2, 20);
662 assert_eq!(blocks.len(), 1);
663 assert!(
664 blocks[0].lines.len() <= 2,
665 "got {} lines",
666 blocks[0].lines.len()
667 );
668 }
669
670 #[test]
671 fn build_caption_blocks_preserves_timestamps() {
672 let segs = vec![make_seg("Test", 1500, 3000)];
673 let blocks = build_caption_blocks(&segs, 2, 40);
674 assert_eq!(blocks[0].start_ms, 1500);
675 assert_eq!(blocks[0].end_ms, 3000);
676 }
677
678 #[test]
679 fn build_caption_blocks_default_position_bottom() {
680 let segs = vec![make_seg("Test", 0, 1000)];
681 let blocks = build_caption_blocks(&segs, 2, 40);
682 assert_eq!(blocks[0].position, CaptionPosition::Bottom);
683 }
684
685 #[test]
686 fn build_caption_blocks_speaker_id_preserved() {
687 let mut seg = make_seg("Test", 0, 1000);
688 seg.speaker_id = Some(3);
689 let blocks = build_caption_blocks(&[seg], 2, 40);
690 assert_eq!(blocks[0].speaker_id, Some(3));
691 }
692
693 #[test]
694 fn caption_block_char_count() {
695 let block = CaptionBlock {
696 id: 1,
697 start_ms: 0,
698 end_ms: 1000,
699 lines: vec!["Hello".to_string(), "world".to_string()],
700 speaker_id: None,
701 position: CaptionPosition::Bottom,
702 };
703 assert_eq!(block.char_count(), 10);
704 }
705
706 #[test]
707 fn word_timestamp_fields_accessible() {
708 let w = make_word("hello", 100, 500);
709 assert_eq!(w.word, "hello");
710 assert_eq!(w.start_ms, 100);
711 assert_eq!(w.end_ms, 500);
712 assert!((w.confidence - 1.0).abs() < 1e-6);
713 assert!((w.word_confidence - 1.0).abs() < 1e-6);
714 }
715
716 #[test]
717 fn word_timestamp_with_word_confidence() {
718 let w = WordTimestamp::with_word_confidence("uncertain".to_string(), 100, 500, 0.9, 0.55);
719 assert_eq!(w.word, "uncertain");
720 assert!((w.confidence - 0.9).abs() < 1e-6);
721 assert!((w.word_confidence - 0.55).abs() < 1e-6);
722 assert!(w.is_high_quality(0.5));
723 assert!(!w.is_high_quality(0.8));
724 }
725
726 #[test]
727 fn build_caption_blocks_with_overlapping_word_timestamps() {
728 let mut seg1 = make_seg("Hello there", 0, 2000);
730 seg1.words = vec![
731 make_word("Hello", 0, 900),
732 make_word("there", 800, 2000), ];
734 let mut seg2 = make_seg("world", 1900, 3500);
735 seg2.words = vec![make_word("world", 1900, 3500)];
736 let blocks = build_caption_blocks(&[seg1, seg2], 2, 40);
737 assert_eq!(blocks.len(), 2);
738 assert_eq!(blocks[0].start_ms, 0);
739 assert_eq!(blocks[0].end_ms, 2000);
740 assert_eq!(blocks[1].start_ms, 1900);
741 assert_eq!(blocks[1].end_ms, 3500);
742 }
743
744 #[test]
745 fn transcript_segment_duration() {
746 let s = make_seg("test", 1000, 3500);
747 assert_eq!(s.duration_ms(), 2500);
748 }
749
750 #[test]
751 fn alignment_error_display_empty_transcript() {
752 let e = AlignmentError::EmptyTranscript;
753 assert!(e.to_string().contains("empty"));
754 }
755
756 #[test]
757 fn alignment_error_display_invalid_timestamp() {
758 let e = AlignmentError::InvalidTimestamp {
759 start_ms: 5000,
760 end_ms: 3000,
761 };
762 assert!(e.to_string().contains("5000"));
763 }
764
765 #[test]
766 fn split_text_sentence_boundary_preferred() {
767 let text = "Hello there! How are you doing today? Fine thanks.";
768 let chunks = split_text_into_chunks(text, 15);
769 for c in &chunks {
771 assert!(c.chars().count() <= 15, "chunk '{c}' exceeds 15 chars");
772 }
773 }
774
775 #[test]
776 fn split_text_word_boundary_fallback() {
777 let text = "AAAA BBBB CCCC DDDD EEEE";
778 let chunks = split_text_into_chunks(text, 10);
779 for c in &chunks {
780 assert!(c.chars().count() <= 12, "chunk '{c}' too long");
781 }
782 }
783
784 #[test]
787 fn round_trip_split_then_merge_preserves_text() {
788 let original_text = "Hello world. This is a test. We have multiple sentences here.";
789 let seg = make_seg(original_text, 0, 10000);
790
791 let split = split_long_segments(&seg, 3000, 20);
793 assert!(split.len() > 1, "expected multiple segments after split");
794
795 let merged = merge_short_segments(&split, 0);
797
798 let reconstructed: String = merged
800 .iter()
801 .map(|s| s.text.as_str())
802 .collect::<Vec<_>>()
803 .join(" ");
804
805 let original_words: std::collections::HashSet<&str> =
807 original_text.split_whitespace().collect();
808 let reconstructed_words: std::collections::HashSet<&str> =
809 reconstructed.split_whitespace().collect();
810
811 for word in &original_words {
812 let cleaned = word.trim_matches(|c: char| !c.is_alphanumeric());
813 if !cleaned.is_empty() {
814 assert!(
815 reconstructed_words.iter().any(|w| w.contains(cleaned)),
816 "word '{cleaned}' missing from reconstruction"
817 );
818 }
819 }
820 }
821
822 #[test]
825 fn align_to_frames_batch_basic() {
826 let segs = vec![make_seg("Hello", 0, 1000), make_seg("World", 1000, 2000)];
827 let result =
828 align_to_frames_batch(&segs, 25.0).expect("align to frames batch should succeed");
829 assert_eq!(result.len(), 2);
830 assert_eq!(result[0][0].1, "Hello");
831 assert_eq!(result[1][0].1, "World");
832 }
833
834 #[test]
835 fn align_to_frames_batch_rejects_zero_fps() {
836 let segs = vec![make_seg("test", 0, 1000)];
837 assert!(align_to_frames_batch(&segs, 0.0).is_err());
838 }
839}