Skip to main content

tract_gpu/
utils.rs

1use tract_core::internal::*;
2use tract_linalg::block_quant::{BlockQuantFact, BlockQuantValue, Q4_0};
3
4use crate::fact::*;
5
6pub fn facts_to_device_facts(
7    facts: &[&TypedFact],
8    resolve_facts: impl Fn(&[&TypedFact]) -> TractResult<TVec<TypedFact>>,
9) -> TractResult<TVec<TypedFact>> {
10    if facts.iter().all(|it| it.datum_type == DatumType::Opaque) {
11        let device_facts = facts
12            .iter()
13            .map(|it| it.to_device_fact().map(|it| it.as_ref()))
14            .collect::<TractResult<TVec<_>>>()?;
15        let output_facts = (resolve_facts)(device_facts.as_slice())?;
16        Ok(output_facts
17            .into_iter()
18            .map(|it| Ok(DeviceFact::new(DeviceTensorOrigin::FromDevice, it)?.into_opaque_fact()))
19            .collect::<TractResult<_>>()?)
20    } else if facts.iter().all(|it| it.datum_type != DatumType::Opaque) {
21        (resolve_facts)(facts)
22    } else {
23        bail!(
24            "Inconsistent facts datum type: {:?}",
25            facts.iter().map(|it| it.datum_type).collect::<TVec<_>>()
26        );
27    }
28}
29
30pub fn get_device_facts<'a, 'b: 'a, T>(
31    facts: &'a [&'b TypedFact],
32    map_facts: impl Fn(&[&'b TypedFact]) -> TractResult<T>,
33) -> TractResult<T> {
34    if facts.iter().all(|it| it.datum_type == DatumType::Opaque) {
35        let device_facts = facts
36            .iter()
37            .map(|it| it.to_device_fact().map(|it| it.as_ref()))
38            .collect::<TractResult<TVec<_>>>()?;
39        (map_facts)(device_facts.as_slice())
40    } else if facts.iter().all(|it| it.datum_type != DatumType::Opaque) {
41        (map_facts)(facts)
42    } else {
43        bail!(
44            "Inconsistent facts datum type: {:?}",
45            facts.iter().map(|it| it.datum_type).collect::<Vec<_>>()
46        );
47    }
48}
49
50pub fn get_device_fact<'a, T: 'a>(
51    fact: &'a TypedFact,
52    map_fact: impl Fn(&'a TypedFact) -> TractResult<T>,
53) -> TractResult<T> {
54    if fact.datum_type == DatumType::Opaque {
55        (map_fact)(fact.to_device_fact()?)
56    } else {
57        (map_fact)(fact)
58    }
59}
60
61pub fn as_q40_fact(fact: &TypedFact) -> Option<&BlockQuantFact> {
62    fact.opaque_fact
63        .as_ref()
64        .and_then(|of| of.downcast_ref::<BlockQuantFact>())
65        .and_then(|bqf| if bqf.format.same_as(&Q4_0) { Some(bqf) } else { None })
66        .or_else(|| {
67            fact.konst
68                .as_ref()
69                .and_then(|k| k.to_scalar::<Opaque>().ok())
70                .and_then(|o| o.downcast_ref::<BlockQuantValue>())
71                .map(|v| &v.fact)
72                .and_then(|bqf| if bqf.format.same_as(&Q4_0) { Some(bqf) } else { None })
73        })
74}
75
76pub fn as_q40_tensor(a: &Tensor) -> Option<&BlockQuantValue> {
77    a.to_scalar::<Opaque>().ok().and_then(|od| {
78        od.downcast_ref::<BlockQuantValue>()
79            .and_then(|bqv| if bqv.fact.format.same_as(&Q4_0) { Some(bqv) } else { None })
80    })
81}
82
83pub fn check_strides_validity(shape: TVec<usize>, strides: TVec<isize>) -> TractResult<()> {
84    let mut zipped_shape_strides: Vec<_> = shape.into_iter().zip(strides).collect();
85    zipped_shape_strides.sort_by_key(|&(_, stride)| stride);
86
87    let mut prev_stride = 1;
88    for (dim, stride) in zipped_shape_strides {
89        ensure!((stride == prev_stride) || (dim == 1), "Invalid strides");
90        prev_stride *= dim as isize;
91    }
92    Ok(())
93}