1use std::collections::HashSet;
2use std::io::{Read, Seek};
3use std::ops::Range;
4use std::str::FromStr;
5use std::sync::Mutex;
6
7use crate::model::Model;
8use tract_hir::internal::*;
9use tract_num_traits::Zero;
10
11#[cfg(feature = "transformers")]
12use tract_transformers::figure_out_causal_llm_b_s_p;
13
14#[derive(Debug, Default, Clone)]
15pub struct TensorsValues(pub Vec<TensorValues>);
16
17impl TensorsValues {
18 pub fn by_name(&self, name: &str) -> Option<&TensorValues> {
19 self.0.iter().find(|t| t.name.as_deref() == Some(name))
20 }
21 pub fn by_name_mut(&mut self, name: &str) -> Option<&mut TensorValues> {
22 self.0.iter_mut().find(|t| t.name.as_deref() == Some(name))
23 }
24 pub fn by_name_mut_with_default(&mut self, name: &str) -> &mut TensorValues {
25 if self.by_name_mut(name).is_none() {
26 self.add(TensorValues { name: Some(name.to_string()), ..TensorValues::default() });
27 }
28 self.by_name_mut(name).unwrap()
29 }
30
31 pub fn by_input_ix(&self, ix: usize) -> Option<&TensorValues> {
32 self.0.iter().find(|t| t.input_index == Some(ix))
33 }
34 pub fn by_input_ix_mut(&mut self, ix: usize) -> Option<&mut TensorValues> {
35 self.0.iter_mut().find(|t| t.input_index == Some(ix))
36 }
37 pub fn by_input_ix_mut_with_default(&mut self, ix: usize) -> &mut TensorValues {
38 if self.by_input_ix_mut(ix).is_none() {
39 self.add(TensorValues { input_index: Some(ix), ..TensorValues::default() });
40 }
41 self.by_input_ix_mut(ix).unwrap()
42 }
43
44 pub fn add(&mut self, other: TensorValues) {
45 let mut tensor = other.input_index.and_then(|ix| self.by_input_ix_mut(ix));
46
47 if tensor.is_none() {
48 tensor = other.name.as_deref().and_then(|ix| self.by_name_mut(ix))
49 }
50
51 if let Some(tensor) = tensor {
52 if tensor.fact.is_none() {
53 tensor.fact = other.fact;
54 }
55 if tensor.values.is_none() {
56 tensor.values = other.values;
57 }
58 } else {
59 self.0.push(other.clone());
60 };
61 }
62
63 pub fn input_by_name(&self, name: &str) -> Option<&TensorValues> {
64 self.0
65 .iter()
66 .filter(|tv| tv.output_index.is_none() && !tv.only_output)
67 .find(|t| t.name.as_deref() == Some(name))
68 }
69}
70
71#[derive(Debug, PartialEq, Clone, Default)]
72pub struct TensorValues {
73 pub input_index: Option<usize>,
74 pub output_index: Option<usize>,
75 pub name: Option<String>,
76 pub fact: Option<InferenceFact>,
77 pub values: Option<Vec<TValue>>,
78 pub random_range: Option<Range<f32>>,
79 pub only_input: bool,
80 pub only_output: bool,
81}
82
83fn parse_dt(dt: &str) -> TractResult<DatumType> {
84 Ok(match dt.to_lowercase().as_ref() {
85 "bool" => DatumType::Bool,
86 "f16" => DatumType::F16,
87 "f32" => DatumType::F32,
88 "f64" => DatumType::F64,
89 "i8" => DatumType::I8,
90 "i16" => DatumType::I16,
91 "i32" => DatumType::I32,
92 "i64" => DatumType::I64,
93 "u8" => DatumType::U8,
94 "u16" => DatumType::U16,
95 "u32" => DatumType::U32,
96 "u64" => DatumType::U64,
97 "tdim" => DatumType::TDim,
98 _ => bail!(
99 "Type of the input should be f16, f32, f64, i8, i16, i16, i32, u8, u16, u32, u64, TDim."
100 ),
101 })
102}
103
104pub fn parse_spec(symbol_table: &SymbolScope, size: &str) -> TractResult<InferenceFact> {
105 if size.is_empty() {
106 return Ok(InferenceFact::default());
107 }
108 parse_coma_spec(symbol_table, size)
109}
110
111pub fn parse_coma_spec(symbol_table: &SymbolScope, size: &str) -> TractResult<InferenceFact> {
112 let splits = size.split(',').collect::<Vec<_>>();
113
114 #[allow(clippy::literal_string_with_formatting_args)]
115 if splits.is_empty() {
116 bail!("The <size> argument should be formatted as {{size}},{{...}},{{type}}.");
117 }
118
119 let last = splits.last().unwrap();
120 let (datum_type, shape) = if let Ok(dt) = parse_dt(last) {
121 (Some(dt), &splits[0..splits.len() - 1])
122 } else {
123 (None, &*splits)
124 };
125
126 let shape = ShapeFactoid::closed(
127 shape
128 .iter()
129 .map(|&s| {
130 Ok(if s == "_" {
131 GenericFactoid::Any
132 } else {
133 GenericFactoid::Only(parse_tdim(symbol_table, s)?)
134 })
135 })
136 .collect::<TractResult<TVec<DimFact>>>()?,
137 );
138
139 if let Some(dt) = datum_type {
140 Ok(InferenceFact::dt_shape(dt, shape))
141 } else {
142 Ok(InferenceFact::shape(shape))
143 }
144}
145
146fn parse_values<T: Datum + FromStr>(shape: &[usize], it: Vec<&str>) -> TractResult<Tensor> {
147 let values = it
148 .into_iter()
149 .map(|v| v.parse::<T>().map_err(|_| format_err!("Failed to parse {}", v)))
150 .collect::<TractResult<Vec<T>>>()?;
151 Ok(tract_ndarray::Array::from_shape_vec(shape, values)?.into())
152}
153
154fn tensor_for_text_data(
155 symbol_table: &SymbolScope,
156 _filename: &str,
157 mut reader: impl Read,
158) -> TractResult<Tensor> {
159 let mut data = String::new();
160 reader.read_to_string(&mut data)?;
161
162 let mut lines = data.lines();
163 let proto = parse_spec(symbol_table, lines.next().context("Empty data file")?)?;
164 let shape = proto.shape.concretize().unwrap();
165
166 let values = lines.flat_map(|l| l.split_whitespace()).collect::<Vec<&str>>();
167
168 let product: usize = shape.iter().map(|o| o.to_usize().unwrap_or(1)).product();
171 let missing = values.len() / product;
172
173 let shape: Vec<_> = shape.iter().map(|d| d.to_usize().unwrap_or(missing)).collect();
174 dispatch_numbers!(parse_values(proto.datum_type.concretize().unwrap())(&*shape, values))
175}
176
177pub fn for_data(
179 symbol_table: &SymbolScope,
180 filename: &str,
181 reader: impl Read + std::io::Seek,
182) -> TractResult<(Option<String>, InferenceFact)> {
183 #[allow(unused_imports)]
184 use std::convert::TryFrom;
185 if filename.ends_with(".pb") {
186 #[cfg(feature = "onnx")]
187 {
188 use tract_onnx::data_resolver::FopenDataResolver;
189 use tract_onnx::tensor::load_tensor;
190 let proto = ::tract_onnx::tensor::proto_from_reader(reader)?;
191 let tensor = load_tensor(&FopenDataResolver, &proto, None)?;
192 Ok((Some(proto.name.to_string()).filter(|s| !s.is_empty()), tensor.into()))
193 }
194 #[cfg(not(feature = "onnx"))]
195 {
196 panic!("Loading tensor from protobuf requires onnx features");
197 }
198 } else if filename.contains(".npz:") {
199 let mut tokens = filename.split(':');
200 let (_filename, inner) = (tokens.next().unwrap(), tokens.next().unwrap());
201 let mut npz = ndarray_npy::NpzReader::new(reader)?;
202 Ok((None, for_npz(&mut npz, inner)?.into()))
203 } else {
204 Ok((None, tensor_for_text_data(symbol_table, filename, reader)?.into()))
205 }
206}
207
208pub fn for_npz(
209 npz: &mut ndarray_npy::NpzReader<impl Read + Seek>,
210 name: &str,
211) -> TractResult<Tensor> {
212 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<f32>, tract_ndarray::IxDyn>(name) {
213 return Ok(t.into_tensor());
214 }
215 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<f64>, tract_ndarray::IxDyn>(name) {
216 return Ok(t.into_tensor());
217 }
218 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<i8>, tract_ndarray::IxDyn>(name) {
219 return Ok(t.into_tensor());
220 }
221 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<i16>, tract_ndarray::IxDyn>(name) {
222 return Ok(t.into_tensor());
223 }
224 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<i32>, tract_ndarray::IxDyn>(name) {
225 return Ok(t.into_tensor());
226 }
227 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<i64>, tract_ndarray::IxDyn>(name) {
228 return Ok(t.into_tensor());
229 }
230 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<u8>, tract_ndarray::IxDyn>(name) {
231 return Ok(t.into_tensor());
232 }
233 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<u16>, tract_ndarray::IxDyn>(name) {
234 return Ok(t.into_tensor());
235 }
236 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<u32>, tract_ndarray::IxDyn>(name) {
237 return Ok(t.into_tensor());
238 }
239 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<u64>, tract_ndarray::IxDyn>(name) {
240 return Ok(t.into_tensor());
241 }
242 if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<bool>, tract_ndarray::IxDyn>(name) {
243 return Ok(t.into_tensor());
244 }
245 bail!("Can not extract tensor from {}", name);
246}
247
248pub fn for_string(
249 symbol_table: &SymbolScope,
250 value: &str,
251) -> TractResult<(Option<String>, InferenceFact)> {
252 let (name, value) = if value.contains(':') {
253 let mut splits = value.split(':');
254 (Some(splits.next().unwrap().to_string()), splits.next().unwrap())
255 } else {
256 (None, value)
257 };
258 if value.contains('=') {
259 let mut split = value.split('=');
260 let spec = parse_spec(symbol_table, split.next().unwrap())?;
261 let value = split.next().unwrap().split(',');
262 let dt =
263 spec.datum_type.concretize().context("Must specify type when giving tensor value")?;
264 let shape = spec
265 .shape
266 .as_concrete_finite()?
267 .context("Must specify concrete shape when giving tensor value")?;
268 let tensor = if dt == TDim::datum_type() {
269 let mut tensor = Tensor::zero::<TDim>(&shape)?;
270 let values =
271 value.map(|v| parse_tdim(symbol_table, v)).collect::<TractResult<Vec<_>>>()?;
272 tensor
273 .try_as_plain_mut()?
274 .as_slice_mut::<TDim>()?
275 .iter_mut()
276 .zip(values)
277 .for_each(|(t, v)| *t = v);
278 tensor
279 } else {
280 dispatch_numbers!(parse_values(dt)(&*shape, value.collect()))?
281 };
282 Ok((name, tensor.into()))
283 } else {
284 Ok((name, parse_spec(symbol_table, value)?))
285 }
286}
287
288lazy_static::lazy_static! {
289 static ref MESSAGE_ONCE: Mutex<HashSet<String>> = Mutex::new(HashSet::new());
290}
291
292fn info_once(msg: String) {
293 if MESSAGE_ONCE.lock().unwrap().insert(msg.clone()) {
294 info!("{msg}");
295 }
296}
297
298pub struct RunParams {
299 pub tensors_values: TensorsValues,
300 pub allow_random_input: bool,
301 pub allow_float_casts: bool,
302 pub symbols: SymbolValues,
303 pub prompt_chunk_size: Option<usize>,
304 pub drop_partial_pulse: bool,
305}
306
307pub struct RunTensors {
308 pub sources: Vec<TVec<TValue>>,
309}
310
311#[cfg(feature = "transformers")]
312fn chunk_fact(
313 fact: &TypedFact,
314 params: &RunParams,
315 model: &Arc<dyn Model>,
316) -> TractResult<Vec<TypedFact>> {
317 let Some(chunk_size) = params.prompt_chunk_size else {
318 return Ok(vec![fact.clone()]);
319 };
320 let Some(model) = model.downcast_ref::<TypedModel>() else {
321 return Ok(vec![fact.clone()]);
322 };
323 let (_, s, _) = figure_out_causal_llm_b_s_p(model)?;
324 let Some(s) = s else {
325 return Ok(vec![fact.clone()]);
326 };
327
328 let dims = fact.shape.dims();
329 let Some(sym_idx) = dims.iter().position(|d| *d == TDim::Sym(s.clone())) else {
330 return Ok(vec![fact.clone()]);
331 };
332
333 let resolved_sym = dims[sym_idx].eval_to_i64(¶ms.symbols)? as usize;
334 if resolved_sym <= chunk_size {
335 return Ok(vec![fact.clone()]);
336 }
337
338 let num_chunks = resolved_sym.div_ceil(chunk_size);
339 let mut out = Vec::with_capacity(num_chunks);
340
341 for start in (0..resolved_sym).step_by(chunk_size) {
342 let this = chunk_size.min(resolved_sym - start) as i64;
343
344 let mut new_fact = fact.clone();
345 new_fact.shape = new_fact
346 .shape
347 .iter()
348 .enumerate()
349 .map(|(i, d)| if i == sym_idx { TDim::Val(this) } else { d.eval(¶ms.symbols) })
350 .collect();
351
352 out.push(new_fact);
353 }
354
355 Ok(out)
356}
357
358#[cfg(feature = "transformers")]
359fn chunk_tensor(
360 tensor: Tensor,
361 fact: &TypedFact,
362 params: &RunParams,
363 model: &Arc<dyn Model>,
364) -> TractResult<Vec<TValue>> {
365 let Some(chunk_size) = params.prompt_chunk_size else {
366 return Ok(vec![tensor.into_tvalue()]);
367 };
368
369 let Some(model) = model.downcast_ref::<TypedModel>() else {
370 return Ok(vec![tensor.into_tvalue()]);
371 };
372 let (_, s, _) = figure_out_causal_llm_b_s_p(model)?;
373 let Some(s) = s else {
374 return Ok(vec![tensor.into_tvalue()]);
375 };
376
377 let dims = fact.shape.dims();
378 let Some(symb_axis) = dims.iter().position(|d| *d == TDim::Sym(s.clone())) else {
379 return Ok(vec![tensor.into_tvalue()]);
380 };
381
382 let resolved_sym = tensor.shape()[symb_axis];
383 if resolved_sym <= chunk_size {
384 return Ok(vec![tensor.into_tvalue()]);
385 }
386
387 let num_chunks = resolved_sym.div_ceil(chunk_size);
388 let mut out = Vec::with_capacity(num_chunks);
389
390 for start in (0..resolved_sym).step_by(chunk_size) {
391 let this = chunk_size.min(resolved_sym - start);
392 out.push(tensor.slice(symb_axis, start, start + this)?.into_tvalue());
393 }
394
395 Ok(out)
396}
397
398fn get_or_make_tensors(
399 model: &Arc<dyn Model>,
400 params: &RunParams,
401 fact: TypedFact,
402 name: &str,
403 input_idx: usize,
404 target: &mut TVec<Vec<TValue>>,
405) -> TractResult<()> {
406 if let Some(mut value) = params
407 .tensors_values
408 .by_name(name)
409 .or_else(|| params.tensors_values.by_input_ix(input_idx))
410 .and_then(|t| t.values.clone())
411 {
412 if !value[0].datum_type().is_quantized()
413 && fact.datum_type.is_quantized()
414 && value[0].datum_type() == fact.datum_type.unquantized()
415 {
416 value = value
417 .iter()
418 .map(|v| {
419 let mut v = v.clone().into_tensor();
420 unsafe { v.set_datum_type(fact.datum_type) };
421 v.into()
422 })
423 .collect();
424 }
425 let mut chunked_tensors: Vec<TValue> = vec![];
426 for t in &value {
427 let tensor = if TypedFact::shape_and_dt_of(&value[0]).compatible_with(&fact) {
428 info_once(format!(
429 "Using fixed input for input called {} ({} turn(s))",
430 name,
431 value.len()
432 ));
433 t.clone().into_tensor()
434 } else if fact.datum_type == f16::datum_type()
435 && value[0].datum_type() == f32::datum_type()
436 && params.allow_float_casts
437 {
438 debug!("Casting input to F16 for input called {} ({} turn(s))", name, value.len());
439 t.cast_to::<f16>()?.into_owned()
440 } else {
441 break;
442 };
443
444 chunked_tensors.extend(chunk_tensor(tensor, &fact, params, model)?);
445 }
446 if !chunked_tensors.is_empty() {
447 target.push(chunked_tensors);
448 return Ok(());
449 }
450
451 if value.len() == 1 && model.properties().contains_key("pulse.delay") {
452 let value = &value[0];
453 let input_pulse_axis = model
454 .properties()
455 .get("pulse.input_axes")
456 .context("Expect pulse.input_axes property")?
457 .cast_to::<i64>()?
458 .try_as_plain()?
459 .as_slice::<i64>()?[input_idx] as usize;
460 let input_pulse = fact.shape.get(input_pulse_axis).unwrap().to_usize().unwrap();
461 let mut input_len = value.shape()[input_pulse_axis];
462 if params.drop_partial_pulse && input_len % input_pulse != 0 {
463 input_len = (input_len / input_pulse) * input_pulse;
464 info!(
465 "Dropping partial trailing pulse: truncating input from {} to {} on axis {}.",
466 value.shape()[input_pulse_axis],
467 input_len,
468 input_pulse_axis
469 );
470 }
471
472 let output_pulse_axis = model
475 .properties()
476 .get("pulse.output_axes")
477 .context("Expect pulse.output_axes property")?
478 .cast_to::<i64>()?
479 .try_as_plain()?
480 .as_slice::<i64>()?[0] as usize;
481 let output_fact = model.outlet_typedfact(model.output_outlets()[0])?;
482 let output_pulse =
483 output_fact.shape.get(output_pulse_axis).unwrap().to_usize().unwrap();
484 let output_len = input_len * output_pulse / input_pulse;
485 let output_delay =
486 model.properties()["pulse.delay"].try_as_plain()?.as_slice::<i64>()?[0] as usize;
487 let last_frame = output_len + output_delay;
488 let needed_pulses = last_frame.divceil(output_pulse);
489 let mut values = vec![];
490 for ix in 0..needed_pulses {
491 let mut t = Tensor::zero_dt(fact.datum_type, fact.shape.as_concrete().unwrap())?;
492 let start = ix * input_pulse;
493 let end = (start + input_pulse).min(input_len);
494 if end > start {
495 t.assign_slice(0..end - start, value, start..end, input_pulse_axis)?;
496 }
497 values.push(t.into());
498 }
499 info!(
500 "Generated {} pulses of shape {:?} for input {}.",
501 needed_pulses, fact.shape, input_idx
502 );
503 target.push(values);
504 } else {
505 bail!(
506 "For input {}, can not reconcile model input fact {:?} with provided input {:?}",
507 name,
508 fact,
509 value[0]
510 );
511 };
512 } else if fact.shape.is_concrete() && fact.shape.volume() == TDim::zero() {
513 let shape = fact.shape.as_concrete().unwrap();
514 let tensor = Tensor::zero_dt(fact.datum_type, shape)?;
515 target.push(vec![tensor.into()]);
516 } else if params.allow_random_input {
517 info_once(format!("Using random input for input called {name:?}: {fact:?}"));
518 let tv = params
519 .tensors_values
520 .by_name(name)
521 .or_else(|| params.tensors_values.by_input_ix(input_idx));
522
523 let mut chunked_facts = chunk_fact(&fact, params, model)?;
524
525 let mut chunked_tensors = Vec::with_capacity(chunked_facts.len());
526 for fact in &mut chunked_facts {
527 fact.shape = fact.shape.iter().map(|dim| dim.eval(¶ms.symbols)).collect();
528 chunked_tensors.push(tensor_for_fact(fact, None, tv)?.into());
529 }
530 target.push(chunked_tensors);
531 } else {
532 bail!(
533 "Unmatched tensor {}. Fix the input or use \"--allow-random-input\" if this was intended",
534 name
535 );
536 }
537 Ok(())
538}
539
540pub fn get_or_make_inputs(tract: &Arc<dyn Model>, params: &RunParams) -> TractResult<RunTensors> {
541 let mut tmp_inputs = tvec![];
543 for (ix, input) in tract.input_outlets().iter().enumerate() {
544 let fact = tract.outlet_typedfact(*input)?;
545 let name = tract.node_name(input.node);
546 get_or_make_tensors(tract, params, fact, name, ix, &mut tmp_inputs)?;
547 }
548
549 let n_turns = tmp_inputs.iter().map(|t| t.len()).max().unwrap_or(0);
550 let sources = (0..n_turns)
551 .map(|i| {
552 tmp_inputs
553 .iter()
554 .map(|t| if i < t.len() { t[i].clone() } else { t[t.len() - 1].clone() })
555 .collect::<TVec<_>>()
556 })
557 .collect::<Vec<_>>();
558
559 Ok(RunTensors { sources })
560}
561
562fn make_inputs(values: &[impl std::borrow::Borrow<TypedFact>]) -> TractResult<TVec<TValue>> {
563 values.iter().map(|v| tensor_for_fact(v.borrow(), None, None).map(|t| t.into())).collect()
564}
565
566pub fn make_inputs_for_model(model: &dyn Model) -> TractResult<TVec<TValue>> {
567 make_inputs(
568 &model
569 .input_outlets()
570 .iter()
571 .map(|&t| model.outlet_typedfact(t))
572 .collect::<TractResult<Vec<TypedFact>>>()?,
573 )
574}
575
576#[allow(unused_variables)]
577pub fn tensor_for_fact(
578 fact: &TypedFact,
579 streaming_dim: Option<usize>,
580 tv: Option<&TensorValues>,
581) -> TractResult<Tensor> {
582 if let Some(value) = &fact.konst {
583 return Ok(value.clone().into_tensor());
584 }
585 Ok(random(
586 fact.shape
587 .as_concrete()
588 .with_context(|| format!("Expected concrete shape, found: {fact:?}"))?,
589 fact.datum_type,
590 tv,
591 ))
592}
593
594pub fn random(sizes: &[usize], datum_type: DatumType, tv: Option<&TensorValues>) -> Tensor {
596 use rand::{RngExt, SeedableRng};
597 let mut rng = rand::rngs::StdRng::seed_from_u64(21242);
598 let mut tensor = Tensor::zero::<f32>(sizes).unwrap();
599 let mut tensor_plain = tensor.try_as_plain_mut().unwrap();
600 let slice = tensor_plain.as_slice_mut::<f32>().unwrap();
601 if let Some(range) = tv.and_then(|tv| tv.random_range.as_ref()) {
602 slice.iter_mut().for_each(|x| *x = rng.random_range(range.clone()))
603 } else {
604 slice.iter_mut().for_each(|x| *x = rng.random())
605 };
606 tensor.cast_to_dt(datum_type).unwrap().into_owned()
607}