pyannote_rs/
segment.rs

1use crate::session;
2use eyre::{Context, ContextCompat, Result};
3use ndarray::{ArrayBase, Axis, IxDyn, ViewRepr};
4use std::{cmp::Ordering, collections::VecDeque, path::Path};
5
6#[derive(Debug, Clone)]
7#[repr(C)]
8pub struct Segment {
9    pub start: f64,
10    pub end: f64,
11    pub samples: Vec<i16>,
12}
13
14fn find_max_index(row: ArrayBase<ViewRepr<&f32>, IxDyn>) -> Result<usize> {
15    let (max_index, _) = row
16        .iter()
17        .enumerate()
18        .max_by(|a, b| {
19            a.1.partial_cmp(b.1)
20                .context("Comparison error")
21                .unwrap_or(Ordering::Equal)
22        })
23        .context("sub_row should not be empty")?;
24    Ok(max_index)
25}
26
27pub fn get_segments<P: AsRef<Path>>(
28    samples: &[i16],
29    sample_rate: u32,
30    model_path: P,
31) -> Result<impl Iterator<Item = Result<Segment>> + '_> {
32    // Create session using the provided model path
33    let session = session::create_session(model_path.as_ref())?;
34
35    // Define frame parameters
36    let frame_size = 270;
37    let frame_start = 721;
38    let window_size = (sample_rate * 10) as usize; // 10 seconds
39    let mut is_speeching = false;
40    let mut offset = frame_start;
41    let mut start_offset = 0.0;
42
43    // Pad end with silence for full last segment
44    let padded_samples = {
45        let mut padded = Vec::from(samples);
46        padded.extend(vec![0; window_size - (samples.len() % window_size)]);
47        padded
48    };
49
50    let mut start_iter = (0..padded_samples.len()).step_by(window_size);
51
52    let mut segments_queue = VecDeque::new();
53    Ok(std::iter::from_fn(move || {
54        if let Some(start) = start_iter.next() {
55            let end = (start + window_size).min(padded_samples.len());
56            let window = &padded_samples[start..end];
57
58            // Convert window to ndarray::Array1
59            let array = ndarray::Array1::from_iter(window.iter().map(|&x| x as f32));
60            let array = array.view().insert_axis(Axis(0)).insert_axis(Axis(1));
61
62            // Handle potential errors during the session and input processing
63            let inputs = match ort::inputs![array.into_dyn()] {
64                Ok(inputs) => inputs,
65                Err(e) => return Some(Err(eyre::eyre!("Failed to prepare inputs: {:?}", e))),
66            };
67
68            let ort_outs = match session.run(inputs) {
69                Ok(outputs) => outputs,
70                Err(e) => return Some(Err(eyre::eyre!("Failed to run the session: {:?}", e))),
71            };
72
73            let ort_out = match ort_outs.get("output").context("Output tensor not found") {
74                Ok(output) => output,
75                Err(e) => return Some(Err(eyre::eyre!("Output tensor error: {:?}", e))),
76            };
77
78            let ort_out = match ort_out
79                .try_extract_tensor::<f32>()
80                .context("Failed to extract tensor")
81            {
82                Ok(tensor) => tensor,
83                Err(e) => return Some(Err(eyre::eyre!("Tensor extraction error: {:?}", e))),
84            };
85
86            for row in ort_out.outer_iter() {
87                for sub_row in row.axis_iter(Axis(0)) {
88                    let max_index = match find_max_index(sub_row) {
89                        Ok(index) => index,
90                        Err(e) => return Some(Err(e)),
91                    };
92
93                    if max_index != 0 {
94                        if !is_speeching {
95                            start_offset = offset as f64;
96                            is_speeching = true;
97                        }
98                    } else if is_speeching {
99                        let start = start_offset / sample_rate as f64;
100                        let end = offset as f64 / sample_rate as f64;
101
102                        let start_f64 = start * (sample_rate as f64);
103                        let end_f64 = end * (sample_rate as f64);
104
105                        // Ensure indices are within bounds
106                        let start_idx = start_f64.min((samples.len() - 1) as f64) as usize;
107                        let end_idx = end_f64.min(samples.len() as f64) as usize;
108
109                        let segment_samples = &padded_samples[start_idx..end_idx];
110
111                        is_speeching = false;
112
113                        let segment = Segment {
114                            start,
115                            end,
116                            samples: segment_samples.to_vec(),
117                        };
118                        segments_queue.push_back(segment);
119                    }
120                    offset += frame_size;
121                }
122            }
123        }
124        segments_queue.pop_front().map(Ok)
125    }))
126}