1use crate::session;
2use eyre::{Context, ContextCompat, Result};
3use ndarray::{ArrayBase, Axis, IxDyn, ViewRepr};
4use std::{cmp::Ordering, collections::VecDeque, path::Path};
5
6#[derive(Debug, Clone)]
7#[repr(C)]
8pub struct Segment {
9 pub start: f64,
10 pub end: f64,
11 pub samples: Vec<i16>,
12}
13
14fn find_max_index(row: ArrayBase<ViewRepr<&f32>, IxDyn>) -> Result<usize> {
15 let (max_index, _) = row
16 .iter()
17 .enumerate()
18 .max_by(|a, b| {
19 a.1.partial_cmp(b.1)
20 .context("Comparison error")
21 .unwrap_or(Ordering::Equal)
22 })
23 .context("sub_row should not be empty")?;
24 Ok(max_index)
25}
26
27pub fn get_segments<P: AsRef<Path>>(
28 samples: &[i16],
29 sample_rate: u32,
30 model_path: P,
31) -> Result<impl Iterator<Item = Result<Segment>> + '_> {
32 let mut session = session::create_session(model_path.as_ref())?;
34
35 let frame_size = 270;
37 let frame_start = 721;
38 let window_size = (sample_rate * 10) as usize; let mut is_speeching = false;
40 let mut offset = frame_start;
41 let mut start_offset = 0.0;
42
43 let padded_samples = {
45 let mut padded = Vec::from(samples);
46 padded.extend(vec![0; window_size - (samples.len() % window_size)]);
47 padded
48 };
49
50 let mut start_iter = (0..padded_samples.len()).step_by(window_size);
51
52 let mut segments_queue = VecDeque::new();
53 Ok(std::iter::from_fn(move || {
54 if let Some(start) = start_iter.next() {
55 let end = (start + window_size).min(padded_samples.len());
56 let window = &padded_samples[start..end];
57
58 let array = ndarray::Array1::from_iter(window.iter().map(|&x| x as f32));
60 let array = array.view().insert_axis(Axis(0)).insert_axis(Axis(1));
61
62 let inputs = ort::inputs![ort::value::TensorRef::from_array_view(array.into_dyn())
64 .map_err(|e| eyre::eyre!("Failed to prepare inputs: {:?}", e))
65 .ok()?];
66
67 let ort_outs = match session.run(inputs) {
68 Ok(outputs) => outputs,
69 Err(e) => return Some(Err(eyre::eyre!("Failed to run the session: {:?}", e))),
70 };
71
72 let ort_out = match ort_outs.get("output").context("Output tensor not found") {
73 Ok(output) => output,
74 Err(e) => return Some(Err(eyre::eyre!("Output tensor error: {:?}", e))),
75 };
76
77 let ort_out = match ort_out
78 .try_extract_tensor::<f32>()
79 .context("Failed to extract tensor")
80 {
81 Ok(tensor) => tensor,
82 Err(e) => return Some(Err(eyre::eyre!("Tensor extraction error: {:?}", e))),
83 };
84
85 let (shape, data) = ort_out; let shape_slice: Vec<usize> = (0..shape.len()).map(|i| shape[i] as usize).collect();
88 let view =
89 ndarray::ArrayViewD::<f32>::from_shape(ndarray::IxDyn(&shape_slice), data).unwrap();
90
91 for row in view.outer_iter() {
92 for sub_row in row.axis_iter(Axis(0)) {
93 let max_index = match find_max_index(sub_row) {
94 Ok(index) => index,
95 Err(e) => return Some(Err(e)),
96 };
97
98 if max_index != 0 {
99 if !is_speeching {
100 start_offset = offset as f64;
101 is_speeching = true;
102 }
103 } else if is_speeching {
104 let start = start_offset / sample_rate as f64;
105 let end = offset as f64 / sample_rate as f64;
106
107 let start_f64 = start * (sample_rate as f64);
108 let end_f64 = end * (sample_rate as f64);
109
110 let start_idx = start_f64.min((samples.len() - 1) as f64) as usize;
112 let end_idx = end_f64.min(samples.len() as f64) as usize;
113
114 let segment_samples = &padded_samples[start_idx..end_idx];
115
116 is_speeching = false;
117
118 let segment = Segment {
119 start,
120 end,
121 samples: segment_samples.to_vec(),
122 };
123 segments_queue.push_back(segment);
124 }
125 offset += frame_size;
126 }
127 }
128 }
129 segments_queue.pop_front().map(Ok)
130 }))
131}