1use tract_core::internal::*;
2use tract_core::tract_linalg::block_quant::BlockQuant;
3use tract_linalg::block_quant::{BlockQuantFact, Q4_0};
4
5use crate::fact::*;
6use crate::tensor::DeviceTensor;
7
8pub fn facts_to_device_facts(
9 facts: &[&TypedFact],
10 resolve_facts: impl Fn(&[&TypedFact]) -> TractResult<TVec<TypedFact>>,
11) -> TractResult<TVec<TypedFact>> {
12 if facts.iter().all(|it| it.datum_type == DatumType::Opaque) {
13 let device_facts = facts
14 .iter()
15 .map(|it| it.to_device_fact().map(|it| it.as_ref()))
16 .collect::<TractResult<TVec<_>>>()?;
17 let output_facts = (resolve_facts)(device_facts.as_slice())?;
18 Ok(output_facts
19 .into_iter()
20 .map(|it| Ok(DeviceFact::new(DeviceTensorOrigin::FromDevice, it)?.into_opaque_fact()))
21 .collect::<TractResult<_>>()?)
22 } else if facts.iter().all(|it| it.datum_type != DatumType::Opaque) {
23 (resolve_facts)(facts)
24 } else {
25 bail!(
26 "Inconsistent facts datum type: {:?}",
27 facts.iter().map(|it| it.datum_type).collect::<TVec<_>>()
28 );
29 }
30}
31
32pub fn get_device_facts<'a, 'b: 'a, T>(
33 facts: &'a [&'b TypedFact],
34 map_facts: impl Fn(&[&'b TypedFact]) -> TractResult<T>,
35) -> TractResult<T> {
36 if facts.iter().all(|it| it.datum_type == DatumType::Opaque) {
37 let device_facts = facts
38 .iter()
39 .map(|it| it.to_device_fact().map(|it| it.as_ref()))
40 .collect::<TractResult<TVec<_>>>()?;
41 (map_facts)(device_facts.as_slice())
42 } else if facts.iter().all(|it| it.datum_type != DatumType::Opaque) {
43 (map_facts)(facts)
44 } else {
45 bail!(
46 "Inconsistent facts datum type: {:?}",
47 facts.iter().map(|it| it.datum_type).collect::<Vec<_>>()
48 );
49 }
50}
51
52pub fn get_device_fact<'a, T: 'a>(
53 fact: &'a TypedFact,
54 map_fact: impl Fn(&'a TypedFact) -> TractResult<T>,
55) -> TractResult<T> {
56 if fact.datum_type == DatumType::Opaque {
57 (map_fact)(fact.to_device_fact()?)
58 } else {
59 (map_fact)(fact)
60 }
61}
62
63pub fn as_quant_fact<'a>(
64 fact: &'a TypedFact,
65 format: &dyn BlockQuant,
66) -> Option<&'a BlockQuantFact> {
67 fact.opaque_fact
68 .as_ref()
69 .and_then(|of| of.downcast_ref::<BlockQuantFact>())
70 .and_then(|bqf| if bqf.format.same_as(format) { Some(bqf) } else { None })
71 .or_else(|| {
72 fact.konst
73 .as_ref()
74 .and_then(|k| k.to_scalar::<Opaque>().ok())
75 .and_then(|o| o.downcast_ref::<BlobWithFact>())
76 .and_then(|bwf| bwf.fact.downcast_ref::<BlockQuantFact>())
77 .and_then(|bqf| if bqf.format.same_as(format) { Some(bqf) } else { None })
78 })
79}
80
81pub fn as_q40_tensor(a: &Tensor) -> Option<&BlobWithFact> {
82 a.to_scalar::<Opaque>().ok().and_then(|od| {
83 od.downcast_ref::<BlobWithFact>().and_then(|bwf| {
84 if bwf
85 .fact
86 .downcast_ref::<BlockQuantFact>()
87 .is_some_and(|bqf| bqf.format.same_as(&Q4_0))
88 {
89 Some(bwf)
90 } else {
91 None
92 }
93 })
94 })
95}
96
97pub fn get_quant_fact(t: &DeviceTensor, format: &dyn BlockQuant) -> Option<BlockQuantFact> {
98 if let DeviceTensor::Owned(t) = t {
99 t.opaque_fact()
100 .and_then(|of| of.downcast_ref::<BlockQuantFact>())
101 .cloned()
102 .filter(|bqf| bqf.format.same_as(format))
103 } else {
104 None
105 }
106}
107
108pub fn check_strides_validity(shape: TVec<usize>, strides: TVec<isize>) -> TractResult<()> {
109 let mut zipped_shape_strides: Vec<_> = shape.into_iter().zip(strides).collect();
110 zipped_shape_strides.sort_by_key(|&(_, stride)| stride);
111
112 let mut prev_stride = 1;
113 for (dim, stride) in zipped_shape_strides {
114 ensure!((stride == prev_stride) || (dim == 1), "Invalid strides");
115 prev_stride *= dim as isize;
116 }
117 Ok(())
118}