Skip to main content

tract_gpu/
utils.rs

1use tract_core::internal::*;
2use tract_core::tract_linalg::block_quant::BlockQuant;
3use tract_linalg::block_quant::{BlockQuantFact, Q4_0};
4
5use crate::fact::*;
6use crate::tensor::DeviceTensor;
7
8pub fn facts_to_device_facts(
9    facts: &[&TypedFact],
10    resolve_facts: impl Fn(&[&TypedFact]) -> TractResult<TVec<TypedFact>>,
11) -> TractResult<TVec<TypedFact>> {
12    if facts.iter().all(|it| it.datum_type == DatumType::Opaque) {
13        let device_facts = facts
14            .iter()
15            .map(|it| it.to_device_fact().map(|it| it.as_ref()))
16            .collect::<TractResult<TVec<_>>>()?;
17        let output_facts = (resolve_facts)(device_facts.as_slice())?;
18        Ok(output_facts
19            .into_iter()
20            .map(|it| Ok(DeviceFact::new(DeviceTensorOrigin::FromDevice, it)?.into_opaque_fact()))
21            .collect::<TractResult<_>>()?)
22    } else if facts.iter().all(|it| it.datum_type != DatumType::Opaque) {
23        (resolve_facts)(facts)
24    } else {
25        bail!(
26            "Inconsistent facts datum type: {:?}",
27            facts.iter().map(|it| it.datum_type).collect::<TVec<_>>()
28        );
29    }
30}
31
32pub fn get_device_facts<'a, 'b: 'a, T>(
33    facts: &'a [&'b TypedFact],
34    map_facts: impl Fn(&[&'b TypedFact]) -> TractResult<T>,
35) -> TractResult<T> {
36    if facts.iter().all(|it| it.datum_type == DatumType::Opaque) {
37        let device_facts = facts
38            .iter()
39            .map(|it| it.to_device_fact().map(|it| it.as_ref()))
40            .collect::<TractResult<TVec<_>>>()?;
41        (map_facts)(device_facts.as_slice())
42    } else if facts.iter().all(|it| it.datum_type != DatumType::Opaque) {
43        (map_facts)(facts)
44    } else {
45        bail!(
46            "Inconsistent facts datum type: {:?}",
47            facts.iter().map(|it| it.datum_type).collect::<Vec<_>>()
48        );
49    }
50}
51
52pub fn get_device_fact<'a, T: 'a>(
53    fact: &'a TypedFact,
54    map_fact: impl Fn(&'a TypedFact) -> TractResult<T>,
55) -> TractResult<T> {
56    if fact.datum_type == DatumType::Opaque {
57        (map_fact)(fact.to_device_fact()?)
58    } else {
59        (map_fact)(fact)
60    }
61}
62
63pub fn as_quant_fact<'a>(
64    fact: &'a TypedFact,
65    format: &dyn BlockQuant,
66) -> Option<&'a BlockQuantFact> {
67    fact.opaque_fact
68        .as_ref()
69        .and_then(|of| of.downcast_ref::<BlockQuantFact>())
70        .and_then(|bqf| if bqf.format.same_as(format) { Some(bqf) } else { None })
71        .or_else(|| {
72            fact.konst
73                .as_ref()
74                .and_then(|k| k.to_scalar::<Opaque>().ok())
75                .and_then(|o| o.downcast_ref::<BlobWithFact>())
76                .and_then(|bwf| bwf.fact.downcast_ref::<BlockQuantFact>())
77                .and_then(|bqf| if bqf.format.same_as(format) { Some(bqf) } else { None })
78        })
79}
80
81pub fn as_q40_tensor(a: &Tensor) -> Option<&BlobWithFact> {
82    a.to_scalar::<Opaque>().ok().and_then(|od| {
83        od.downcast_ref::<BlobWithFact>().and_then(|bwf| {
84            if bwf
85                .fact
86                .downcast_ref::<BlockQuantFact>()
87                .is_some_and(|bqf| bqf.format.same_as(&Q4_0))
88            {
89                Some(bwf)
90            } else {
91                None
92            }
93        })
94    })
95}
96
97pub fn get_quant_fact(t: &DeviceTensor, format: &dyn BlockQuant) -> Option<BlockQuantFact> {
98    if let DeviceTensor::Owned(t) = t {
99        t.opaque_fact()
100            .and_then(|of| of.downcast_ref::<BlockQuantFact>())
101            .cloned()
102            .filter(|bqf| bqf.format.same_as(format))
103    } else {
104        None
105    }
106}
107
108pub fn check_strides_validity(shape: TVec<usize>, strides: TVec<isize>) -> TractResult<()> {
109    let mut zipped_shape_strides: Vec<_> = shape.into_iter().zip(strides).collect();
110    zipped_shape_strides.sort_by_key(|&(_, stride)| stride);
111
112    let mut prev_stride = 1;
113    for (dim, stride) in zipped_shape_strides {
114        ensure!((stride == prev_stride) || (dim == 1), "Invalid strides");
115        prev_stride *= dim as isize;
116    }
117    Ok(())
118}