1use tract_core::internal::*;
2use tract_core::tract_linalg::block_quant::BlockQuant;
3use tract_linalg::block_quant::{BlockQuantFact, Q4_0};
4
5use crate::fact::*;
6use crate::tensor::DeviceTensor;
7
8pub fn facts_to_device_facts(
9 facts: &[&TypedFact],
10 resolve_facts: impl Fn(&[&TypedFact]) -> TractResult<TVec<TypedFact>>,
11) -> TractResult<TVec<TypedFact>> {
12 if facts.iter().all(|it| it.datum_type == DatumType::Opaque) {
13 let device_facts = facts
14 .iter()
15 .map(|it| it.to_device_fact().map(|it| it.as_ref()))
16 .collect::<TractResult<TVec<_>>>()?;
17 let output_facts = (resolve_facts)(device_facts.as_slice())?;
18 Ok(output_facts
19 .into_iter()
20 .map(|it| Ok(DeviceFact::new(DeviceTensorOrigin::FromDevice, it)?.into_opaque_fact()))
21 .collect::<TractResult<_>>()?)
22 } else if facts.iter().all(|it| it.datum_type != DatumType::Opaque) {
23 (resolve_facts)(facts)
24 } else {
25 bail!(
26 "Inconsistent facts datum type: {:?}",
27 facts.iter().map(|it| it.datum_type).collect::<TVec<_>>()
28 );
29 }
30}
31
32pub fn get_device_facts<'a, 'b: 'a, T>(
33 facts: &'a [&'b TypedFact],
34 map_facts: impl Fn(&[&'b TypedFact]) -> TractResult<T>,
35) -> TractResult<T> {
36 if facts.iter().all(|it| it.datum_type == DatumType::Opaque) {
37 let device_facts = facts
38 .iter()
39 .map(|it| it.to_device_fact().map(|it| it.as_ref()))
40 .collect::<TractResult<TVec<_>>>()?;
41 (map_facts)(device_facts.as_slice())
42 } else if facts.iter().all(|it| it.datum_type != DatumType::Opaque) {
43 (map_facts)(facts)
44 } else {
45 bail!(
46 "Inconsistent facts datum type: {:?}",
47 facts.iter().map(|it| it.datum_type).collect::<Vec<_>>()
48 );
49 }
50}
51
52pub fn get_device_fact<'a, T: 'a>(
53 fact: &'a TypedFact,
54 map_fact: impl Fn(&'a TypedFact) -> TractResult<T>,
55) -> TractResult<T> {
56 if fact.datum_type == DatumType::Opaque {
57 (map_fact)(fact.to_device_fact()?)
58 } else {
59 (map_fact)(fact)
60 }
61}
62
63pub fn as_quant_fact<'a>(
64 fact: &'a TypedFact,
65 format: &dyn BlockQuant,
66) -> Option<&'a BlockQuantFact> {
67 fact.opaque_fact
68 .as_ref()
69 .and_then(|of| of.downcast_ref::<BlockQuantFact>())
70 .and_then(|bqf| if bqf.format.same_as(format) { Some(bqf) } else { None })
71 .or_else(|| {
72 fact.konst
73 .as_ref()
74 .and_then(|k| k.try_as_dense().ok()?.to_scalar::<Opaque>().ok())
75 .and_then(|o| o.downcast_ref::<BlobWithFact>())
76 .and_then(|bwf| bwf.fact.downcast_ref::<BlockQuantFact>())
77 .and_then(|bqf| if bqf.format.same_as(format) { Some(bqf) } else { None })
78 })
79}
80
81pub fn as_q40_tensor(a: &Tensor) -> Option<&BlobWithFact> {
82 let od = a.try_as_dense().ok()?.to_scalar::<Opaque>().ok()?;
83 od.downcast_ref::<BlobWithFact>().and_then(|bwf| {
84 if bwf.fact.downcast_ref::<BlockQuantFact>().is_some_and(|bqf| bqf.format.same_as(&Q4_0)) {
85 Some(bwf)
86 } else {
87 None
88 }
89 })
90}
91
92pub fn get_quant_fact(t: &DeviceTensor, format: &dyn BlockQuant) -> Option<BlockQuantFact> {
93 if let DeviceTensor::Owned(t) = t {
94 t.opaque_fact()
95 .and_then(|of| of.downcast_ref::<BlockQuantFact>())
96 .cloned()
97 .filter(|bqf| bqf.format.same_as(format))
98 } else {
99 None
100 }
101}
102
103pub fn check_strides_validity(shape: TVec<usize>, strides: TVec<isize>) -> TractResult<()> {
104 let mut zipped_shape_strides: Vec<_> = shape.into_iter().zip(strides).collect();
105 zipped_shape_strides.sort_by_key(|&(_, stride)| stride);
106
107 let mut prev_stride = 1;
108 for (dim, stride) in zipped_shape_strides {
109 ensure!((stride == prev_stride) || (dim == 1), "Invalid strides");
110 prev_stride *= dim as isize;
111 }
112 Ok(())
113}