1use std::collections::HashSet;
2use std::io::{Read, Seek};
3use std::ops::Range;
4use std::str::FromStr;
5use std::sync::Mutex;
6
7use crate::model::Model;
8use tract_hir::internal::*;
9use tract_num_traits::Zero;
10
11#[derive(Debug, Default, Clone)]
12pub struct TensorsValues(pub Vec<TensorValues>);
13
14impl TensorsValues {
15 pub fn by_name(&self, name: &str) -> Option<&TensorValues> {
16 self.0.iter().find(|t| t.name.as_deref() == Some(name))
17 }
18 pub fn by_name_mut(&mut self, name: &str) -> Option<&mut TensorValues> {
19 self.0.iter_mut().find(|t| t.name.as_deref() == Some(name))
20 }
21 pub fn by_name_mut_with_default(&mut self, name: &str) -> &mut TensorValues {
22 if self.by_name_mut(name).is_none() {
23 self.add(TensorValues { name: Some(name.to_string()), ..TensorValues::default() });
24 }
25 self.by_name_mut(name).unwrap()
26 }
27
28 pub fn by_input_ix(&self, ix: usize) -> Option<&TensorValues> {
29 self.0.iter().find(|t| t.input_index == Some(ix))
30 }
31 pub fn by_input_ix_mut(&mut self, ix: usize) -> Option<&mut TensorValues> {
32 self.0.iter_mut().find(|t| t.input_index == Some(ix))
33 }
34 pub fn by_input_ix_mut_with_default(&mut self, ix: usize) -> &mut TensorValues {
35 if self.by_input_ix_mut(ix).is_none() {
36 self.add(TensorValues { input_index: Some(ix), ..TensorValues::default() });
37 }
38 self.by_input_ix_mut(ix).unwrap()
39 }
40
41 pub fn add(&mut self, other: TensorValues) {
42 let mut tensor = other.input_index.and_then(|ix| self.by_input_ix_mut(ix));
43
44 if tensor.is_none() {
45 tensor = other.name.as_deref().and_then(|ix| self.by_name_mut(ix))
46 }
47
48 if let Some(tensor) = tensor {
49 if tensor.fact.is_none() {
50 tensor.fact = other.fact;
51 }
52 if tensor.values.is_none() {
53 tensor.values = other.values;
54 }
55 } else {
56 self.0.push(other.clone());
57 };
58 }
59}
60
61#[derive(Debug, PartialEq, Clone, Default)]
62pub struct TensorValues {
63 pub input_index: Option<usize>,
64 pub output_index: Option<usize>,
65 pub name: Option<String>,
66 pub fact: Option<InferenceFact>,
67 pub values: Option<Vec<TValue>>,
68 pub random_range: Option<Range<f32>>,
69}
70
71fn parse_dt(dt: &str) -> TractResult<DatumType> {
72 Ok(match dt.to_lowercase().as_ref() {
73 "bool" => DatumType::Bool,
74 "f16" => DatumType::F16,
75 "f32" => DatumType::F32,
76 "f64" => DatumType::F64,
77 "i8" => DatumType::I8,
78 "i16" => DatumType::I16,
79 "i32" => DatumType::I32,
80 "i64" => DatumType::I64,
81 "u8" => DatumType::U8,
82 "u16" => DatumType::U16,
83 "u32" => DatumType::U32,
84 "u64" => DatumType::U64,
85 "tdim" => DatumType::TDim,
86 _ => bail!(
87 "Type of the input should be f16, f32, f64, i8, i16, i16, i32, u8, u16, u32, u64, TDim."
88 ),
89 })
90}
91
92pub fn parse_spec(symbol_table: &SymbolScope, size: &str) -> TractResult<InferenceFact> {
93 if size.is_empty() {
94 return Ok(InferenceFact::default());
95 }
96 parse_coma_spec(symbol_table, size)
97}
98
99pub fn parse_coma_spec(symbol_table: &SymbolScope, size: &str) -> TractResult<InferenceFact> {
100 let splits = size.split(',').collect::<Vec<_>>();
101
102 #[allow(clippy::literal_string_with_formatting_args)]
103 if splits.is_empty() {
104 bail!("The <size> argument should be formatted as {{size}},{{...}},{{type}}.");
105 }
106
107 let last = splits.last().unwrap();
108 let (datum_type, shape) = if let Ok(dt) = parse_dt(last) {
109 (Some(dt), &splits[0..splits.len() - 1])
110 } else {
111 (None, &*splits)
112 };
113
114 let shape = ShapeFactoid::closed(
115 shape
116 .iter()
117 .map(|&s| {
118 Ok(if s == "_" {
119 GenericFactoid::Any
120 } else {
121 GenericFactoid::Only(parse_tdim(symbol_table, s)?)
122 })
123 })
124 .collect::<TractResult<TVec<DimFact>>>()?,
125 );
126
127 if let Some(dt) = datum_type {
128 Ok(InferenceFact::dt_shape(dt, shape))
129 } else {
130 Ok(InferenceFact::shape(shape))
131 }
132}
133
134fn parse_values<T: Datum + FromStr>(shape: &[usize], it: Vec<&str>) -> TractResult<Tensor> {
135 let values = it
136 .into_iter()
137 .map(|v| v.parse::<T>().map_err(|_| format_err!("Failed to parse {}", v)))
138 .collect::<TractResult<Vec<T>>>()?;
139 Ok(tract_ndarray::Array::from_shape_vec(shape, values)?.into())
140}
141
142fn tensor_for_text_data(
143 symbol_table: &SymbolScope,
144 _filename: &str,
145 mut reader: impl Read,
146) -> TractResult<Tensor> {
147 let mut data = String::new();
148 reader.read_to_string(&mut data)?;
149
150 let mut lines = data.lines();
151 let proto = parse_spec(symbol_table, lines.next().context("Empty data file")?)?;
152 let shape = proto.shape.concretize().unwrap();
153
154 let values = lines.flat_map(|l| l.split_whitespace()).collect::<Vec<&str>>();
155
156 let product: usize = shape.iter().map(|o| o.to_usize().unwrap_or(1)).product();
159 let missing = values.len() / product;
160
161 let shape: Vec<_> = shape.iter().map(|d| d.to_usize().unwrap_or(missing)).collect();
162 dispatch_numbers!(parse_values(proto.datum_type.concretize().unwrap())(&*shape, values))
163}
164
165pub fn for_data(
167 symbol_table: &SymbolScope,
168 filename: &str,
169 reader: impl Read + std::io::Seek,
170) -> TractResult<(Option<String>, InferenceFact)> {
171 #[allow(unused_imports)]
172 use std::convert::TryFrom;
173 if filename.ends_with(".pb") {
174 #[cfg(feature = "onnx")]
175 {
176 use tract_onnx::data_resolver::FopenDataResolver;
177 use tract_onnx::tensor::load_tensor;
178 let proto = ::tract_onnx::tensor::proto_from_reader(reader)?;
179 let tensor = load_tensor(&FopenDataResolver, &proto, None)?;
180 Ok((Some(proto.name.to_string()).filter(|s| !s.is_empty()), tensor.into()))
181 }
182 #[cfg(not(feature = "onnx"))]
183 {
184 panic!("Loading tensor from protobuf requires onnx features");
185 }
186 } else if filename.contains(".npz:") {
187 let mut tokens = filename.split(':');
188 let (_filename, inner) = (tokens.next().unwrap(), tokens.next().unwrap());
189 let mut npz = ndarray_npy::NpzReader::new(reader)?;
190 Ok((None, for_npz(&mut npz, inner)?.into()))
191 } else {
192 Ok((None, tensor_for_text_data(symbol_table, filename, reader)?.into()))
193 }
194}
195
196pub fn for_npz(
197 npz: &mut ndarray_npy::NpzReader<impl Read + Seek>,
198 name: &str,
199) -> TractResult<Tensor> {
200 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<f32>, tract_ndarray::IxDyn>(name) {
201 return Ok(t.into_tensor());
202 }
203 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<f64>, tract_ndarray::IxDyn>(name) {
204 return Ok(t.into_tensor());
205 }
206 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<i8>, tract_ndarray::IxDyn>(name) {
207 return Ok(t.into_tensor());
208 }
209 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<i16>, tract_ndarray::IxDyn>(name) {
210 return Ok(t.into_tensor());
211 }
212 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<i32>, tract_ndarray::IxDyn>(name) {
213 return Ok(t.into_tensor());
214 }
215 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<i64>, tract_ndarray::IxDyn>(name) {
216 return Ok(t.into_tensor());
217 }
218 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<u8>, tract_ndarray::IxDyn>(name) {
219 return Ok(t.into_tensor());
220 }
221 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<u16>, tract_ndarray::IxDyn>(name) {
222 return Ok(t.into_tensor());
223 }
224 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<u32>, tract_ndarray::IxDyn>(name) {
225 return Ok(t.into_tensor());
226 }
227 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<u64>, tract_ndarray::IxDyn>(name) {
228 return Ok(t.into_tensor());
229 }
230 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<bool>, tract_ndarray::IxDyn>(name) {
231 return Ok(t.into_tensor());
232 }
233 bail!("Can not extract tensor from {}", name);
234}
235
236pub fn for_string(
237 symbol_table: &SymbolScope,
238 value: &str,
239) -> TractResult<(Option<String>, InferenceFact)> {
240 let (name, value) = if value.contains(':') {
241 let mut splits = value.split(':');
242 (Some(splits.next().unwrap().to_string()), splits.next().unwrap())
243 } else {
244 (None, value)
245 };
246 if value.contains('=') {
247 let mut split = value.split('=');
248 let spec = parse_spec(symbol_table, split.next().unwrap())?;
249 let value = split.next().unwrap().split(',');
250 let dt =
251 spec.datum_type.concretize().context("Must specify type when giving tensor value")?;
252 let shape = spec
253 .shape
254 .as_concrete_finite()?
255 .context("Must specify concrete shape when giving tensor value")?;
256 let tensor = if dt == TDim::datum_type() {
257 let mut tensor = Tensor::zero::<TDim>(&shape)?;
258 let values =
259 value.map(|v| parse_tdim(symbol_table, v)).collect::<TractResult<Vec<_>>>()?;
260 tensor.as_slice_mut::<TDim>()?.iter_mut().zip(values).for_each(|(t, v)| *t = v);
261 tensor
262 } else {
263 dispatch_numbers!(parse_values(dt)(&*shape, value.collect()))?
264 };
265 Ok((name, tensor.into()))
266 } else {
267 Ok((name, parse_spec(symbol_table, value)?))
268 }
269}
270
271lazy_static::lazy_static! {
272 static ref MESSAGE_ONCE: Mutex<HashSet<String>> = Mutex::new(HashSet::new());
273}
274
275fn info_once(msg: String) {
276 if MESSAGE_ONCE.lock().unwrap().insert(msg.clone()) {
277 info!("{}", msg);
278 }
279}
280
281pub struct RunParams {
282 pub tensors_values: TensorsValues,
283 pub allow_random_input: bool,
284 pub allow_float_casts: bool,
285 pub symbols: SymbolValues,
286}
287
288pub fn retrieve_or_make_inputs(
289 tract: &dyn Model,
290 params: &RunParams,
291) -> TractResult<Vec<TVec<TValue>>> {
292 let mut tmp: TVec<Vec<TValue>> = tvec![];
293 for (ix, input) in tract.input_outlets().iter().enumerate() {
294 let name = tract.node_name(input.node);
295 let fact = tract.outlet_typedfact(*input)?;
296 if let Some(mut value) = params
297 .tensors_values
298 .by_name(name)
299 .or_else(|| params.tensors_values.by_input_ix(ix))
300 .and_then(|t| t.values.clone())
301 {
302 if !value[0].datum_type().is_quantized()
303 && fact.datum_type.is_quantized()
304 && value[0].datum_type() == fact.datum_type.unquantized()
305 {
306 value = value
307 .iter()
308 .map(|v| {
309 let mut v = v.clone().into_tensor();
310 unsafe { v.set_datum_type(fact.datum_type) };
311 v.into()
312 })
313 .collect();
314 }
315 if TypedFact::shape_and_dt_of(&value[0]).compatible_with(&fact) {
316 info!("Using fixed input for input called {} ({} turn(s))", name, value.len());
317 tmp.push(value.iter().map(|t| t.clone().into_tensor().into()).collect())
318 } else if fact.datum_type == f16::datum_type()
319 && value[0].datum_type() == f32::datum_type()
320 && params.allow_float_casts
321 {
322 tmp.push(
323 value.iter().map(|t| t.cast_to::<f16>().unwrap().into_owned().into()).collect(),
324 )
325 } else if value.len() == 1 && tract.properties().contains_key("pulse.delay") {
326 let value = &value[0];
327 let input_pulse_axis = tract
328 .properties()
329 .get("pulse.input_axes")
330 .context("Expect pulse.input_axes property")?
331 .cast_to::<i64>()?
332 .as_slice::<i64>()?[ix] as usize;
333 let input_pulse = fact.shape.get(input_pulse_axis).unwrap().to_usize().unwrap();
334 let input_len = value.shape()[input_pulse_axis];
335
336 let output_pulse_axis = tract
339 .properties()
340 .get("pulse.output_axes")
341 .context("Expect pulse.output_axes property")?
342 .cast_to::<i64>()?
343 .as_slice::<i64>()?[0] as usize;
344 let output_fact = tract.outlet_typedfact(tract.output_outlets()[0])?;
345 let output_pulse =
346 output_fact.shape.get(output_pulse_axis).unwrap().to_usize().unwrap();
347 let output_len = input_len * output_pulse / input_pulse;
348 let output_delay = tract.properties()["pulse.delay"].as_slice::<i64>()?[0] as usize;
349 let last_frame = output_len + output_delay;
350 let needed_pulses = last_frame.divceil(output_pulse);
351 let mut values = vec![];
352 for ix in 0..needed_pulses {
353 let mut t =
354 Tensor::zero_dt(fact.datum_type, fact.shape.as_concrete().unwrap())?;
355 let start = ix * input_pulse;
356 let end = (start + input_pulse).min(input_len);
357 if end > start {
358 t.assign_slice(0..end - start, value, start..end, input_pulse_axis)?;
359 }
360 values.push(t.into());
361 }
362 info!(
363 "Generated {} pulses of shape {:?} for input {}.",
364 needed_pulses, fact.shape, ix
365 );
366 tmp.push(values);
367 } else {
368 bail!("For input {}, can not reconcile model input fact {:?} with provided input {:?}", name, fact, value[0]);
369 };
370 } else if fact.shape.is_concrete() && fact.shape.volume() == TDim::zero() {
371 let shape = fact.shape.as_concrete().unwrap();
372 let tensor = Tensor::zero_dt(fact.datum_type, shape)?;
373 tmp.push(vec![tensor.into()]);
374 } else if params.allow_random_input {
375 let mut fact: TypedFact = tract.outlet_typedfact(*input)?.clone();
376 info_once(format!("Using random input for input called {name:?}: {fact:?}"));
377 let tv = params
378 .tensors_values
379 .by_name(name)
380 .or_else(|| params.tensors_values.by_input_ix(ix));
381 fact.shape = fact.shape.iter().map(|dim| dim.eval(¶ms.symbols)).collect();
382 tmp.push(vec![crate::tensor::tensor_for_fact(&fact, None, tv)?.into()]);
383 } else {
384 bail!("Unmatched tensor {}. Fix the input or use \"--allow-random-input\" if this was intended", name);
385 }
386 }
387 Ok((0..tmp[0].len()).map(|turn| tmp.iter().map(|t| t[turn].clone()).collect()).collect())
388}
389
390fn make_inputs(values: &[impl std::borrow::Borrow<TypedFact>]) -> TractResult<TVec<TValue>> {
391 values.iter().map(|v| tensor_for_fact(v.borrow(), None, None).map(|t| t.into())).collect()
392}
393
394pub fn make_inputs_for_model(model: &dyn Model) -> TractResult<TVec<TValue>> {
395 make_inputs(
396 &model
397 .input_outlets()
398 .iter()
399 .map(|&t| model.outlet_typedfact(t))
400 .collect::<TractResult<Vec<TypedFact>>>()?,
401 )
402}
403
404#[allow(unused_variables)]
405pub fn tensor_for_fact(
406 fact: &TypedFact,
407 streaming_dim: Option<usize>,
408 tv: Option<&TensorValues>,
409) -> TractResult<Tensor> {
410 if let Some(value) = &fact.konst {
411 return Ok(value.clone().into_tensor());
412 }
413 Ok(random(
414 fact.shape
415 .as_concrete()
416 .with_context(|| format!("Expected concrete shape, found: {fact:?}"))?,
417 fact.datum_type,
418 tv,
419 ))
420}
421
422pub fn random(sizes: &[usize], datum_type: DatumType, tv: Option<&TensorValues>) -> Tensor {
424 use rand::{Rng, SeedableRng};
425 let mut rng = rand::rngs::StdRng::seed_from_u64(21242);
426 let mut tensor = Tensor::zero::<f32>(sizes).unwrap();
427 let slice = tensor.as_slice_mut::<f32>().unwrap();
428 if let Some(range) = tv.and_then(|tv| tv.random_range.as_ref()) {
429 slice.iter_mut().for_each(|x| *x = rng.gen_range(range.clone()))
430 } else {
431 slice.iter_mut().for_each(|x| *x = rng.gen())
432 };
433 tensor.cast_to_dt(datum_type).unwrap().into_owned()
434}