Skip to main content

tract_gpu/ops/
stft.rs

1use crate::tensor::{DeviceTensor, DeviceTensorExt, IntoDevice};
2use tract_core::internal::*;
3
4/// The power-of-two FFT frame lengths the GPU STFT kernels support (the radix-2 kernel
5/// generalizes to any pow2; these cover Kaldi-style featurizers at 8/16/32/48 kHz).
6pub const SUPPORTED_FRAMES: [usize; 4] = [256, 512, 1024, 2048];
7
8/// Whether the GPU STFT kernels handle this frame length (a supported power of two).
9pub fn is_supported_frame(frame: usize) -> bool {
10    SUPPORTED_FRAMES.contains(&frame)
11}
12
13/// Per-backend STFT kernel launcher: `(stride, input, window, output)`. `input` is
14/// interleaved-complex f32 `[lead.., T, 2]`, `window` the pre-padded real window
15/// `[frame]`, `output` `[lead.., frames, frame, 2]`. The kernel reads the frame length
16/// from the output shape (`output[axis + 1]`).
17pub type DispatchStftFn = fn(usize, &DeviceTensor, &DeviceTensor, &DeviceTensor) -> TractResult<()>;
18
19/// Backend-agnostic fused STFT (frame + window + forward FFT). `frame` is a supported
20/// power of two ([`SUPPORTED_FRAMES`]); the window is pre-padded to `[frame]` (all-ones
21/// when the source had none), matching `core::ops::fft::Stft`'s symmetric padding. The
22/// time axis must sit just before the trailing complex pair (`axis == rank - 2`). Each
23/// backend supplies its own `dispatch` kernel; everything else (facts, output allocation,
24/// window upload) is shared.
25#[derive(Clone)]
26pub struct GpuStft {
27    pub axis: usize,
28    pub frame: usize,
29    pub stride: usize,
30    pub window: Arc<Tensor>,
31    pub backend_name: &'static str,
32    pub dispatch: DispatchStftFn,
33}
34
35impl GpuStft {
36    fn output_shape<D: DimLike>(&self, input: &[D]) -> TVec<D> {
37        let mut shape: TVec<D> = input.into();
38        let frames = (input[self.axis].clone() - self.frame) / self.stride + 1;
39        shape[self.axis] = frames;
40        shape.insert(self.axis + 1, self.frame.into());
41        shape
42    }
43}
44
45impl std::fmt::Debug for GpuStft {
46    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
47        write!(f, "{}Stft(frame={}, stride={})", self.backend_name, self.frame, self.stride)
48    }
49}
50
51impl PartialEq for GpuStft {
52    fn eq(&self, other: &Self) -> bool {
53        self.backend_name == other.backend_name
54            && self.axis == other.axis
55            && self.frame == other.frame
56            && self.stride == other.stride
57            && self.window == other.window
58    }
59}
60
61impl Eq for GpuStft {}
62
63impl std::hash::Hash for GpuStft {
64    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
65        self.backend_name.hash(state);
66        self.axis.hash(state);
67        self.frame.hash(state);
68        self.stride.hash(state);
69        self.window.hash(state);
70    }
71}
72
73impl Op for GpuStft {
74    fn name(&self) -> StaticName {
75        format!("{}Stft", self.backend_name).into()
76    }
77
78    op_as_typed_op!();
79}
80
81impl EvalOp for GpuStft {
82    fn is_stateless(&self) -> bool {
83        true
84    }
85
86    fn eval_with_session(
87        &self,
88        node_id: usize,
89        session: &TurnState,
90        inputs: TVec<TValue>,
91    ) -> TractResult<TVec<TValue>> {
92        let input = inputs[0].to_device_tensor()?;
93        let window = (*self.window).clone().into_device()?;
94        let output = crate::session_handler::make_tensor_for_node(
95            session,
96            node_id,
97            input.datum_type(),
98            &self.output_shape(input.shape()),
99        )?;
100        (self.dispatch)(self.stride, input, &window, &output)
101            .with_context(|| format!("Error while dispatching eval for {}", self.name()))?;
102        Ok(tvec!(output.into_tensor().into_tvalue()))
103    }
104}
105
106impl TypedOp for GpuStft {
107    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
108        crate::utils::facts_to_device_facts(inputs, |facts| {
109            let input = facts[0];
110            ensure!(
111                input.rank() >= 2 && input.shape[input.rank() - 1] == 2.to_dim(),
112                "{} expects a complex input [.., T, 2]",
113                self.name()
114            );
115            Ok(tvec!(input.datum_type.fact(self.output_shape(&input.shape.to_tvec()))))
116        })
117        .with_context(|| format!("Error while computing facts for {:?}", self.name()))
118    }
119
120    as_op!();
121}
122
123/// Per-backend FFT kernel launcher: `(inverse, input, output)`, both interleaved-complex
124/// f32 `[lead.., N, 2]` (N a supported power of two, transformed axis at `rank-2`). The
125/// inverse is UNNORMALIZED, matching core `Fft` (rustfft). The kernel reads N from the
126/// input shape.
127pub type DispatchFftFn = fn(bool, &DeviceTensor, &DeviceTensor) -> TractResult<()>;
128
129/// Backend-agnostic complex FFT over the innermost-but-one axis (the trailing dim is the
130/// `[re, im]` pair); forward or `inverse`, shape-preserving. Mirrors `core::ops::fft::Fft`
131/// for a supported power-of-two length with `axis == rank - 2`.
132#[derive(Clone)]
133pub struct GpuFft {
134    pub axis: usize,
135    pub inverse: bool,
136    pub backend_name: &'static str,
137    pub dispatch: DispatchFftFn,
138}
139
140impl std::fmt::Debug for GpuFft {
141    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
142        write!(f, "{}Fft({})", self.backend_name, if self.inverse { "inverse" } else { "forward" })
143    }
144}
145
146impl PartialEq for GpuFft {
147    fn eq(&self, other: &Self) -> bool {
148        self.backend_name == other.backend_name
149            && self.axis == other.axis
150            && self.inverse == other.inverse
151    }
152}
153
154impl Eq for GpuFft {}
155
156impl std::hash::Hash for GpuFft {
157    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
158        self.backend_name.hash(state);
159        self.axis.hash(state);
160        self.inverse.hash(state);
161    }
162}
163
164impl Op for GpuFft {
165    fn name(&self) -> StaticName {
166        format!("{}Fft", self.backend_name).into()
167    }
168
169    op_as_typed_op!();
170}
171
172impl EvalOp for GpuFft {
173    fn is_stateless(&self) -> bool {
174        true
175    }
176
177    fn eval_with_session(
178        &self,
179        node_id: usize,
180        session: &TurnState,
181        inputs: TVec<TValue>,
182    ) -> TractResult<TVec<TValue>> {
183        let input = inputs[0].to_device_tensor()?;
184        let output = crate::session_handler::make_tensor_for_node(
185            session,
186            node_id,
187            input.datum_type(),
188            input.shape(),
189        )?;
190        (self.dispatch)(self.inverse, input, &output)
191            .with_context(|| format!("Error while dispatching eval for {}", self.name()))?;
192        Ok(tvec!(output.into_tensor().into_tvalue()))
193    }
194}
195
196impl TypedOp for GpuFft {
197    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
198        crate::utils::facts_to_device_facts(inputs, |facts| {
199            let input = facts[0];
200            ensure!(
201                input.rank() >= 2 && input.shape[input.rank() - 1] == 2.to_dim(),
202                "{} expects a complex input [.., N, 2]",
203                self.name()
204            );
205            Ok(tvec!(input.datum_type.fact(input.shape.clone())))
206        })
207        .with_context(|| format!("Error while computing facts for {:?}", self.name()))
208    }
209
210    as_op!();
211}
212
213/// Pre-pad `window` (or all-ones) to `[frame]` exactly as core `Stft` does: symmetric
214/// padding when shorter than the frame. Shared by every backend's lowering rule.
215pub fn padded_window(window: Option<&Arc<Tensor>>, frame: usize) -> TractResult<Arc<Tensor>> {
216    let mut win = vec![0f32; frame];
217    match window {
218        Some(w) => {
219            let w = w.cast_to::<f32>()?;
220            let w = w.try_as_plain()?;
221            let w = w.as_slice::<f32>()?;
222            ensure!(w.len() <= frame, "STFT window longer than frame");
223            let pad_left = (frame - w.len()) / 2;
224            win[pad_left..pad_left + w.len()].copy_from_slice(w);
225        }
226        None => win.fill(1.0),
227    }
228    Ok(Arc::new(tensor1(&win)))
229}