1#![allow(clippy::len_zero)]
2#[macro_use]
3pub mod macros;
4
5pub mod blockify;
6pub mod fact;
7pub mod model;
8pub mod ops;
9
10pub mod internal {
11 pub use std::fmt;
12 pub use tract_nnef::internal::*;
13 pub use tract_pulse_opl::tract_nnef;
14
15 pub use downcast_rs::Downcast;
16
17 pub use crate::fact::PulsedFact;
18 pub use crate::model::{PulsedModel, PulsedModelExt};
19 pub use crate::ops::{OpPulsifier, PulsedOp};
20}
21
22use std::ops::ControlFlow;
23
24use internal::*;
25use tract_core::optim::TypedPass;
26use tract_core::transform::ModelTransform;
27use tract_pulse_opl::tract_nnef::tract_core;
28
29pub use ops::PulsedOp;
30
31#[derive(Debug, Default, serde::Deserialize)]
32pub struct PulseConfig {
33 pub symbol: Option<String>,
34 pub pulse: String,
35}
36
37#[derive(Debug)]
38struct PulseTransform(PulseConfig);
39
40impl ModelTransform for PulseTransform {
41 fn name(&self) -> std::borrow::Cow<'static, str> {
42 "pulse".into()
43 }
44 fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
45 let symbol = self.0.symbol.as_deref().unwrap_or("S");
46 let sym = model.symbols.sym(symbol);
47 let pulse_dim = parse_tdim(&model.symbols, &self.0.pulse)?;
48 ops::diag_gather::detect_diag_gather(model)?;
49 tract_core::optim::propagate_roi::PropagateRoi.run_direct(model)?;
50 model.declutter()?;
51 let pulsed = model::PulsedModel::new(model, sym, &pulse_dim)?;
52 *model = pulsed.into_typed()?;
53 Ok(())
54 }
55}
56
57register_model_transform!("pulse", PulseConfig, |config| Ok(Box::new(PulseTransform(config))));
58
59register_model_transform!("blockify", blockify::BlockifyConfig, |config| Ok(Box::new(
60 blockify::BlockifyTransform(config)
61)));
62
63pub trait WithPulse {
64 fn enable_pulse(&mut self);
65 fn with_pulse(self) -> Self;
66}
67
68impl WithPulse for tract_nnef::framework::Nnef {
69 fn enable_pulse(&mut self) {
70 self.registries.push(tract_nnef_registry());
71 }
72 fn with_pulse(mut self) -> Self {
73 self.enable_pulse();
74 self
75 }
76}
77
78pub fn tract_nnef_registry() -> Registry {
79 let mut reg = tract_pulse_opl::tract_nnef_registry();
80 ops::delay::register(&mut reg);
81 reg.extensions.push(Box::new(decl_stream_symbol));
82 reg
83}
84
85fn decl_stream_symbol(
86 _proto_model: &mut ModelBuilder,
87 name: &Identifier,
88 _rest: &str,
89) -> TractResult<ControlFlow<(), ()>> {
90 if name.0 == "tract_pulse_streaming_symbol" {
91 Ok(ControlFlow::Break(()))
92 } else {
93 Ok(ControlFlow::Continue(()))
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100
101 #[test]
102 fn test_source_must_stream() {
103 let mut model = TypedModel::default();
104 let s = model.symbols.sym("S");
105 let _a = model.add_source("a", f32::fact([1, 2, 3])).unwrap();
106 model.auto_outputs().unwrap();
107 assert!(PulsedModel::new(&model, s.clone(), &4.to_dim()).is_err());
108
109 let mut model = TypedModel::default();
110 let _a = model.add_source("a", f32::fact(dims![1, s, 3].as_ref())).unwrap();
111 model.auto_outputs().unwrap();
112 let pulse = PulsedModel::new(&model, s, &4.to_dim()).unwrap();
113 assert_eq!(
114 *pulse.outlet_fact(OutletId::new(0, 0)).unwrap().to_typed_fact().unwrap(),
115 f32::fact([1usize, 4, 3])
116 );
117 }
118
119 #[test]
120 fn test_immediate() {
121 let mut model = TypedModel::default();
122 let s = model.symbols.sym("S");
123 let _a = model.add_source("a", f32::fact(dims![s, 2, 3].as_ref())).unwrap();
124 model.auto_outputs().unwrap();
125
126 let pulse = PulsedModel::new(&model, s, &4.to_dim()).unwrap();
127
128 assert_eq!(*pulse.input_fact(0).unwrap().to_typed_fact().unwrap(), f32::fact([4, 2, 3]));
129 assert_eq!(*pulse.output_fact(0).unwrap().to_typed_fact().unwrap(), f32::fact([4, 2, 3]));
130 }
131
132 #[test]
133 fn test_reshape_split_streaming_axis() {
134 use tract_core::ops::change_axes::AxisOp;
135 let mut model = TypedModel::default();
136 let s = model.symbols.sym("S");
137 let a = model.add_source("a", f32::fact(dims![s.to_dim() * 2, 4].as_ref())).unwrap();
138 let split = model
139 .wire_node(
140 "split",
141 AxisOp::Reshape(0, tvec!(s.to_dim() * 2), tvec!(s.to_dim(), 2.to_dim())),
142 &[a],
143 )
144 .unwrap();
145 model.select_output_outlets(&split).unwrap();
146 let pulse = PulsedModel::new(&model, s.clone(), &1.to_dim()).unwrap();
147 assert_eq!(*pulse.input_fact(0).unwrap().to_typed_fact().unwrap(), f32::fact([2, 4]));
148 assert_eq!(*pulse.output_fact(0).unwrap().to_typed_fact().unwrap(), f32::fact([1, 2, 4]));
149 let out_stream = pulse.output_fact(0).unwrap().stream.as_ref().unwrap();
150 assert_eq!(out_stream.axis, 0);
151 assert_eq!(out_stream.dim, s.to_dim());
152 }
153
154 #[test]
155 fn test_reshape_merge_streaming_axis() {
156 use tract_core::ops::change_axes::AxisOp;
157 let mut model = TypedModel::default();
158 let s = model.symbols.sym("S");
159 let a = model.add_source("a", f32::fact(dims![s, 2, 4].as_ref())).unwrap();
160 let merged = model
161 .wire_node(
162 "merge",
163 AxisOp::Reshape(0, tvec!(s.to_dim(), 2.to_dim()), tvec!(s.to_dim() * 2)),
164 &[a],
165 )
166 .unwrap();
167 model.select_output_outlets(&merged).unwrap();
168 let pulse = PulsedModel::new(&model, s.clone(), &1.to_dim()).unwrap();
169 assert_eq!(*pulse.input_fact(0).unwrap().to_typed_fact().unwrap(), f32::fact([1, 2, 4]));
170 assert_eq!(*pulse.output_fact(0).unwrap().to_typed_fact().unwrap(), f32::fact([2, 4]));
171 let out_stream = pulse.output_fact(0).unwrap().stream.as_ref().unwrap();
172 assert_eq!(out_stream.axis, 0);
173 assert_eq!(out_stream.dim, s.to_dim() * 2);
174 }
175
176 #[test]
177 fn test_reshape_split_then_run() {
178 use tract_core::ops::change_axes::AxisOp;
179 let mut model = TypedModel::default();
180 let s = model.symbols.sym("S");
181 let a = model.add_source("a", f32::fact(dims![s.to_dim() * 2].as_ref())).unwrap();
182 let split = model
183 .wire_node(
184 "split",
185 AxisOp::Reshape(0, tvec!(s.to_dim() * 2), tvec!(s.to_dim(), 2.to_dim())),
186 &[a],
187 )
188 .unwrap();
189 model.select_output_outlets(&split).unwrap();
190
191 let pulse = PulsedModel::new(&model, s, &1.to_dim()).unwrap();
192 let plan = SimplePlan::new(pulse.into_typed().unwrap()).unwrap();
193 let mut state = SimpleState::new(&plan).unwrap();
194 let chunk1 = tensor1(&[1f32, 2.0]);
195 let out1 = state.run(tvec!(chunk1.into_tvalue())).unwrap();
196 assert_eq!(*out1[0], tensor2(&[[1f32, 2.0]]).into());
197 let chunk2 = tensor1(&[3f32, 4.0]);
198 let out2 = state.run(tvec!(chunk2.into_tvalue())).unwrap();
199 assert_eq!(*out2[0], tensor2(&[[3f32, 4.0]]).into());
200 }
201
202 #[test]
218 fn test_pulse_meet_with_arange_branch_types_through() {
219 use tract_core::ops::array::Range;
220 use tract_core::ops::cnn::{Deconv, KernelFormat, PaddingSpec, PoolSpec};
221 use tract_core::ops::nn::DataFormat;
222
223 let mut model = TypedModel::default();
224 let t = model.symbols.sym("T");
225 let src = model.add_source("x", f32::fact(dims![1, 2, t.to_dim()].as_ref())).unwrap();
226
227 let kernel = model
230 .add_const("kernel", tract_core::ndarray::Array3::<f32>::zeros((2, 2, 8)))
231 .unwrap();
232 let bias = model.add_const("bias", tract_core::ndarray::arr1(&[0.0f32, 0.0])).unwrap();
233 let conv_out = model
234 .wire_node(
235 "convtr",
236 Deconv {
237 pool_spec: PoolSpec {
238 data_format: DataFormat::NCHW,
239 kernel_shape: tvec!(8),
240 padding: PaddingSpec::Valid,
241 dilations: Some(tvec!(1)),
242 strides: Some(tvec!(4)),
243 input_channels: 2,
244 output_channels: 2,
245 },
246 kernel_format: KernelFormat::OIHW,
247 adjustments: tvec!(0),
248 group: 1,
249 },
250 &[src, kernel, bias],
251 )
252 .unwrap()[0];
253
254 let start = model.add_const("range_start", tensor0(TDim::Val(0))).unwrap();
257 let end = model
258 .add_const(
259 "range_end",
260 tract_core::ndarray::arr0(t.to_dim() * 4 + 4).into_dyn().into_tensor(),
261 )
262 .unwrap();
263 let step = model.add_const("range_step", tensor0(TDim::Val(1))).unwrap();
264 let range_out = model
265 .wire_node("range", Range::new(t.to_dim() * 4 + 4), &[start, end, step])
266 .unwrap()[0];
267
268 let range_f32 = model
270 .wire_node("range_cast", tract_core::ops::cast::cast(f32::datum_type()), &[range_out])
271 .unwrap()[0];
272 let range_bc = model
273 .wire_node(
274 "range_unsqueeze",
275 tract_core::ops::change_axes::AxisOp::Add(0),
276 &[range_f32],
277 )
278 .unwrap()[0];
279 let range_bc = model
280 .wire_node(
281 "range_unsqueeze2",
282 tract_core::ops::change_axes::AxisOp::Add(0),
283 &[range_bc],
284 )
285 .unwrap()[0];
286
287 let added =
288 model.wire_node("add", tract_core::ops::math::add(), &[conv_out, range_bc]).unwrap();
289 model.select_output_outlets(&added).unwrap();
290
291 let _pulse = PulsedModel::new(&model, t, &2.to_dim()).expect("pulsification");
295 }
296
297 #[test]
304 fn test_multi_broadcast_to_pulsifier_linear_axis() {
305 use tract_pulse_opl::tract_core::ops::array::MultiBroadcastTo;
306
307 let mut model = TypedModel::default();
308 let s = model.symbols.sym("S");
309 let linear: TDim = 1.to_dim() + s.to_dim() / 2;
310 let target_shape: ShapeFact = tvec![1.to_dim(), linear, 4.to_dim()].into();
311
312 let a = model.add_source("a", f32::fact(dims![1, s.to_dim(), 4].as_ref())).unwrap();
313 let out = model.wire_node("bc", MultiBroadcastTo::new(target_shape), &[a]).unwrap();
314 model.select_output_outlets(&out).unwrap();
315
316 let pulse = PulsedModel::new(&model, s, &4.to_dim()).expect("pulsification");
317 let out_fact = pulse.output_fact(0).unwrap();
318 assert_eq!(
322 out_fact.shape[1],
323 2.to_dim(),
324 "linear streaming axis must keep the boundary-subtract delta; got fact: {out_fact:?}",
325 );
326 }
327
328 #[test]
338 fn test_multi_broadcast_to_pulsifier_non_linear_axis() {
339 use tract_pulse_opl::tract_core::ops::array::MultiBroadcastTo;
340
341 let mut model = TypedModel::default();
342 let s = model.symbols.sym("S");
343 let non_linear: TDim = (s.to_dim() + 2.to_dim()).mini(2.to_dim());
344 let target_shape: ShapeFact = tvec![1.to_dim(), non_linear, 1.to_dim()].into();
345
346 let a = model.add_source("a", f32::fact(dims![1, s.to_dim(), 1].as_ref())).unwrap();
347 let out = model.wire_node("bc", MultiBroadcastTo::new(target_shape), &[a]).unwrap();
348 model.select_output_outlets(&out).unwrap();
349
350 let pulse = PulsedModel::new(&model, s, &4.to_dim()).expect("pulsification");
351 let out_fact = pulse.output_fact(0).unwrap();
352 assert_eq!(
353 out_fact.shape[1],
354 2.to_dim(),
355 "non-linear streaming axis must keep the full value, not the collapsed delta; got fact: {out_fact:?}",
356 );
357 }
358}