Skip to main content

tract_pulse/
lib.rs

1#![allow(clippy::len_zero)]
2#[macro_use]
3pub mod macros;
4
5pub mod blockify;
6pub mod fact;
7pub mod model;
8pub mod ops;
9
10pub mod internal {
11    pub use std::fmt;
12    pub use tract_nnef::internal::*;
13    pub use tract_pulse_opl::tract_nnef;
14
15    pub use downcast_rs::Downcast;
16
17    pub use crate::fact::PulsedFact;
18    pub use crate::model::{PulsedModel, PulsedModelExt};
19    pub use crate::ops::{OpPulsifier, PulsedOp};
20}
21
22use std::ops::ControlFlow;
23
24use internal::*;
25use tract_core::optim::TypedPass;
26use tract_core::transform::ModelTransform;
27use tract_pulse_opl::tract_nnef::tract_core;
28
29pub use ops::PulsedOp;
30
31#[derive(Debug, Default, serde::Deserialize)]
32pub struct PulseConfig {
33    pub symbol: Option<String>,
34    pub pulse: String,
35}
36
37#[derive(Debug)]
38struct PulseTransform(PulseConfig);
39
40impl ModelTransform for PulseTransform {
41    fn name(&self) -> std::borrow::Cow<'static, str> {
42        "pulse".into()
43    }
44    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
45        let symbol = self.0.symbol.as_deref().unwrap_or("S");
46        let sym = model.symbols.sym(symbol);
47        let pulse_dim = parse_tdim(&model.symbols, &self.0.pulse)?;
48        ops::diag_gather::detect_diag_gather(model)?;
49        tract_core::optim::propagate_roi::PropagateRoi.run_direct(model)?;
50        model.declutter()?;
51        let pulsed = model::PulsedModel::new(model, sym, &pulse_dim)?;
52        *model = pulsed.into_typed()?;
53        Ok(())
54    }
55}
56
57register_model_transform!("pulse", PulseConfig, |config| Ok(Box::new(PulseTransform(config))));
58
59register_model_transform!("blockify", blockify::BlockifyConfig, |config| Ok(Box::new(
60    blockify::BlockifyTransform(config)
61)));
62
63pub trait WithPulse {
64    fn enable_pulse(&mut self);
65    fn with_pulse(self) -> Self;
66}
67
68impl WithPulse for tract_nnef::framework::Nnef {
69    fn enable_pulse(&mut self) {
70        self.registries.push(tract_nnef_registry());
71    }
72    fn with_pulse(mut self) -> Self {
73        self.enable_pulse();
74        self
75    }
76}
77
78pub fn tract_nnef_registry() -> Registry {
79    let mut reg = tract_pulse_opl::tract_nnef_registry();
80    ops::delay::register(&mut reg);
81    reg.extensions.push(Box::new(decl_stream_symbol));
82    reg
83}
84
85fn decl_stream_symbol(
86    _proto_model: &mut ModelBuilder,
87    name: &Identifier,
88    _rest: &str,
89) -> TractResult<ControlFlow<(), ()>> {
90    if name.0 == "tract_pulse_streaming_symbol" {
91        Ok(ControlFlow::Break(()))
92    } else {
93        Ok(ControlFlow::Continue(()))
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100
101    #[test]
102    fn test_source_must_stream() {
103        let mut model = TypedModel::default();
104        let s = model.symbols.sym("S");
105        let _a = model.add_source("a", f32::fact([1, 2, 3])).unwrap();
106        model.auto_outputs().unwrap();
107        assert!(PulsedModel::new(&model, s.clone(), &4.to_dim()).is_err());
108
109        let mut model = TypedModel::default();
110        let _a = model.add_source("a", f32::fact(dims![1, s, 3].as_ref())).unwrap();
111        model.auto_outputs().unwrap();
112        let pulse = PulsedModel::new(&model, s, &4.to_dim()).unwrap();
113        assert_eq!(
114            *pulse.outlet_fact(OutletId::new(0, 0)).unwrap().to_typed_fact().unwrap(),
115            f32::fact([1usize, 4, 3])
116        );
117    }
118
119    #[test]
120    fn test_immediate() {
121        let mut model = TypedModel::default();
122        let s = model.symbols.sym("S");
123        let _a = model.add_source("a", f32::fact(dims![s, 2, 3].as_ref())).unwrap();
124        model.auto_outputs().unwrap();
125
126        let pulse = PulsedModel::new(&model, s, &4.to_dim()).unwrap();
127
128        assert_eq!(*pulse.input_fact(0).unwrap().to_typed_fact().unwrap(), f32::fact([4, 2, 3]));
129        assert_eq!(*pulse.output_fact(0).unwrap().to_typed_fact().unwrap(), f32::fact([4, 2, 3]));
130    }
131
132    #[test]
133    fn test_reshape_split_streaming_axis() {
134        use tract_core::ops::change_axes::AxisOp;
135        let mut model = TypedModel::default();
136        let s = model.symbols.sym("S");
137        let a = model.add_source("a", f32::fact(dims![s.to_dim() * 2, 4].as_ref())).unwrap();
138        let split = model
139            .wire_node(
140                "split",
141                AxisOp::Reshape(0, tvec!(s.to_dim() * 2), tvec!(s.to_dim(), 2.to_dim())),
142                &[a],
143            )
144            .unwrap();
145        model.select_output_outlets(&split).unwrap();
146        let pulse = PulsedModel::new(&model, s.clone(), &1.to_dim()).unwrap();
147        assert_eq!(*pulse.input_fact(0).unwrap().to_typed_fact().unwrap(), f32::fact([2, 4]));
148        assert_eq!(*pulse.output_fact(0).unwrap().to_typed_fact().unwrap(), f32::fact([1, 2, 4]));
149        let out_stream = pulse.output_fact(0).unwrap().stream.as_ref().unwrap();
150        assert_eq!(out_stream.axis, 0);
151        assert_eq!(out_stream.dim, s.to_dim());
152    }
153
154    #[test]
155    fn test_reshape_merge_streaming_axis() {
156        use tract_core::ops::change_axes::AxisOp;
157        let mut model = TypedModel::default();
158        let s = model.symbols.sym("S");
159        let a = model.add_source("a", f32::fact(dims![s, 2, 4].as_ref())).unwrap();
160        let merged = model
161            .wire_node(
162                "merge",
163                AxisOp::Reshape(0, tvec!(s.to_dim(), 2.to_dim()), tvec!(s.to_dim() * 2)),
164                &[a],
165            )
166            .unwrap();
167        model.select_output_outlets(&merged).unwrap();
168        let pulse = PulsedModel::new(&model, s.clone(), &1.to_dim()).unwrap();
169        assert_eq!(*pulse.input_fact(0).unwrap().to_typed_fact().unwrap(), f32::fact([1, 2, 4]));
170        assert_eq!(*pulse.output_fact(0).unwrap().to_typed_fact().unwrap(), f32::fact([2, 4]));
171        let out_stream = pulse.output_fact(0).unwrap().stream.as_ref().unwrap();
172        assert_eq!(out_stream.axis, 0);
173        assert_eq!(out_stream.dim, s.to_dim() * 2);
174    }
175
176    #[test]
177    fn test_reshape_split_then_run() {
178        use tract_core::ops::change_axes::AxisOp;
179        let mut model = TypedModel::default();
180        let s = model.symbols.sym("S");
181        let a = model.add_source("a", f32::fact(dims![s.to_dim() * 2].as_ref())).unwrap();
182        let split = model
183            .wire_node(
184                "split",
185                AxisOp::Reshape(0, tvec!(s.to_dim() * 2), tvec!(s.to_dim(), 2.to_dim())),
186                &[a],
187            )
188            .unwrap();
189        model.select_output_outlets(&split).unwrap();
190
191        let pulse = PulsedModel::new(&model, s, &1.to_dim()).unwrap();
192        let plan = SimplePlan::new(pulse.into_typed().unwrap()).unwrap();
193        let mut state = SimpleState::new(&plan).unwrap();
194        let chunk1 = tensor1(&[1f32, 2.0]);
195        let out1 = state.run(tvec!(chunk1.into_tvalue())).unwrap();
196        assert_eq!(*out1[0], tensor2(&[[1f32, 2.0]]).into());
197        let chunk2 = tensor1(&[3f32, 4.0]);
198        let out2 = state.run(tvec!(chunk2.into_tvalue())).unwrap();
199        assert_eq!(*out2[0], tensor2(&[[3f32, 4.0]]).into());
200    }
201
202    /// Two parallel pulse paths meeting at an elementwise op produce
203    /// different per-pulse stream-axis sizes when one path goes through a
204    /// ConvTranspose (kernel > stride) and the other doesn't.  Pre-fix
205    /// pulsification bailed at the meet point because the typed
206    /// `output_facts`' `multi_broadcast` returned `Broadcast(K_a, K_b)`
207    /// on the stream axis -- not equal, not 1, doesn't simplify.  After fix
208    /// the merge uses LCM for the stream axis specifically.
209    ///
210    /// Minimal repro of the Pocket-TTS upsample-then-attention pattern:
211    /// a ConvTranspose1d(stride=4, kernel=8) emits steady-state stride=4
212    /// frames per pulse with 4-frame overlap-add; an arange of the same
213    /// post-convtr length produces (after our Range slope-based fix) also
214    /// 4 frames per pulse; an elementwise Add of the two requires the
215    /// meet-point merge to be LCM(4, 4) = 4 (trivial here, but the path
216    /// went through Broadcast(4, 8) before slope+LCM fixes were in place).
217    #[test]
218    fn test_pulse_meet_with_arange_branch_types_through() {
219        use tract_core::ops::array::Range;
220        use tract_core::ops::cnn::{Deconv, KernelFormat, PaddingSpec, PoolSpec};
221        use tract_core::ops::nn::DataFormat;
222
223        let mut model = TypedModel::default();
224        let t = model.symbols.sym("T");
225        let src = model.add_source("x", f32::fact(dims![1, 2, t.to_dim()].as_ref())).unwrap();
226
227        // ConvTranspose1d(C=2, kernel=8, stride=4) → stream-axis dim
228        // becomes 4*T + 4 (post overlap-add tail).
229        let kernel = model
230            .add_const("kernel", tract_core::ndarray::Array3::<f32>::zeros((2, 2, 8)))
231            .unwrap();
232        let bias = model.add_const("bias", tract_core::ndarray::arr1(&[0.0f32, 0.0])).unwrap();
233        let conv_out = model
234            .wire_node(
235                "convtr",
236                Deconv {
237                    pool_spec: PoolSpec {
238                        data_format: DataFormat::NCHW,
239                        kernel_shape: tvec!(8),
240                        padding: PaddingSpec::Valid,
241                        dilations: Some(tvec!(1)),
242                        strides: Some(tvec!(4)),
243                        input_channels: 2,
244                        output_channels: 2,
245                    },
246                    kernel_format: KernelFormat::OIHW,
247                    adjustments: tvec!(0),
248                    group: 1,
249                },
250                &[src, kernel, bias],
251            )
252            .unwrap()[0];
253
254        // arange(0, 4*T + 4) of the same stream-axis length — this is the
255        // branch that surfaced the Broadcast bug pre-fix.
256        let start = model.add_const("range_start", tensor0(TDim::Val(0))).unwrap();
257        let end = model
258            .add_const(
259                "range_end",
260                tract_core::ndarray::arr0(t.to_dim() * 4 + 4).into_dyn().into_tensor(),
261            )
262            .unwrap();
263        let step = model.add_const("range_step", tensor0(TDim::Val(1))).unwrap();
264        let range_out = model
265            .wire_node("range", Range::new(t.to_dim() * 4 + 4), &[start, end, step])
266            .unwrap()[0];
267
268        // Cast range to f32 and broadcast-shape with conv_out so they Add.
269        let range_f32 = model
270            .wire_node("range_cast", tract_core::ops::cast::cast(f32::datum_type()), &[range_out])
271            .unwrap()[0];
272        let range_bc = model
273            .wire_node(
274                "range_unsqueeze",
275                tract_core::ops::change_axes::AxisOp::Add(0),
276                &[range_f32],
277            )
278            .unwrap()[0];
279        let range_bc = model
280            .wire_node(
281                "range_unsqueeze2",
282                tract_core::ops::change_axes::AxisOp::Add(0),
283                &[range_bc],
284            )
285            .unwrap()[0];
286
287        let added =
288            model.wire_node("add", tract_core::ops::math::add(), &[conv_out, range_bc]).unwrap();
289        model.select_output_outlets(&added).unwrap();
290
291        // The point of the test: this used to panic with
292        // `Pulsification requires pulse Broadcast(4, 8) ...` at the
293        // downstream meet point.  Now it should pulsify without error.
294        let _pulse = PulsedModel::new(&model, t, &2.to_dim()).expect("pulsification");
295    }
296
297    /// `MultiBroadcastTo` pulsifier baseline: a target shape that grows
298    /// linearly with the streaming symbol (`1 + S/2` -- the canonical
299    /// `shape_of(stride-2 conv)` pattern) gets the per-pulse increment
300    /// `P/2` after the boundary-subtract trick. Locks in the existing
301    /// linear contract so the non-linear fallback below cannot regress
302    /// it.
303    #[test]
304    fn test_multi_broadcast_to_pulsifier_linear_axis() {
305        use tract_pulse_opl::tract_core::ops::array::MultiBroadcastTo;
306
307        let mut model = TypedModel::default();
308        let s = model.symbols.sym("S");
309        let linear: TDim = 1.to_dim() + s.to_dim() / 2;
310        let target_shape: ShapeFact = tvec![1.to_dim(), linear, 4.to_dim()].into();
311
312        let a = model.add_source("a", f32::fact(dims![1, s.to_dim(), 4].as_ref())).unwrap();
313        let out = model.wire_node("bc", MultiBroadcastTo::new(target_shape), &[a]).unwrap();
314        model.select_output_outlets(&out).unwrap();
315
316        let pulse = PulsedModel::new(&model, s, &4.to_dim()).expect("pulsification");
317        let out_fact = pulse.output_fact(0).unwrap();
318        // `1 + S/2` at S=P=4 is 3, at S=0 it is 1. The trick yields
319        // `3 - 1 = 2` per pulse; the linearity probe at S=8 gives delta
320        // 4, matching `2 * 2`, so we stay on the linear path.
321        assert_eq!(
322            out_fact.shape[1],
323            2.to_dim(),
324            "linear streaming axis must keep the boundary-subtract delta; got fact: {out_fact:?}",
325        );
326    }
327
328    /// Non-linear target shape (`min(2, S + 2)`, which equals 2 for every
329    /// `S >= 0`): the boundary-subtract collapses `full - base` to 0 even
330    /// though the full per-pulse shape is 2. Pre-fix that produced a
331    /// 0-volume PulsedFact that poisoned every downstream consumer (most
332    /// visibly: a Scan body's State input reading the GRU `h_0` tile,
333    /// surfacing as `Clashing resolution for expression. 2=2 != 0` on the
334    /// runtime warmup turn). The fallback keeps the full value when the
335    /// `substitute(S→0) == substitute(S→P) == substitute(S→2P)` probe
336    /// confirms the axis is not actually streaming.
337    #[test]
338    fn test_multi_broadcast_to_pulsifier_non_linear_axis() {
339        use tract_pulse_opl::tract_core::ops::array::MultiBroadcastTo;
340
341        let mut model = TypedModel::default();
342        let s = model.symbols.sym("S");
343        let non_linear: TDim = (s.to_dim() + 2.to_dim()).mini(2.to_dim());
344        let target_shape: ShapeFact = tvec![1.to_dim(), non_linear, 1.to_dim()].into();
345
346        let a = model.add_source("a", f32::fact(dims![1, s.to_dim(), 1].as_ref())).unwrap();
347        let out = model.wire_node("bc", MultiBroadcastTo::new(target_shape), &[a]).unwrap();
348        model.select_output_outlets(&out).unwrap();
349
350        let pulse = PulsedModel::new(&model, s, &4.to_dim()).expect("pulsification");
351        let out_fact = pulse.output_fact(0).unwrap();
352        assert_eq!(
353            out_fact.shape[1],
354            2.to_dim(),
355            "non-linear streaming axis must keep the full value, not the collapsed delta; got fact: {out_fact:?}",
356        );
357    }
358}