1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
use crate::model::ParsingContext;
use crate::tfpb::tensorflow::NodeDef;
use tract_hir::internal::*;
use tract_hir::ops::array::Squeeze;

pub fn squeeze(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
    let squeeze_dims = pb.get_attr_opt_list_int("squeeze_dims")?;
    if let Some(mut squeeze_dims) = squeeze_dims {
        if squeeze_dims.len() > 0 {
            squeeze_dims.sort();
            return Ok(Box::new(Squeeze::new(Some(squeeze_dims))));
        }
    }
    Ok(Box::new(Squeeze::default()))
}

#[cfg(test)]
mod tests {
    #![allow(non_snake_case)]
    use super::*;
    use tract_ndarray::Array;

    fn run<I>(op: Squeeze, input: I) -> Tensor
    where
        I: Into<Tensor>,
    {
        op.eval(tvec![input.into().into()]).unwrap().pop().unwrap().into_tensor()
    }

    #[test]
    fn squeeze_1() {
        assert_eq!(
            run(Squeeze::new(None), Array::from_elem([1, 2, 1, 3, 1, 1], 0)).shape(),
            &[2, 3]
        );
    }

    #[test]
    fn squeeze_2() {
        assert_eq!(
            run(Squeeze::new(Some(vec![2, 4])), Array::from_elem([1, 2, 1, 3, 1, 1], 0)).shape(),
            &[1, 2, 3, 1]
        );
    }

    #[test]
    fn squeeze_inference_1() {
        let input = InferenceFact::default()
            .with_datum_type(DatumType::TDim)
            .with_shape(shapefactoid![1, 1, (TDim::stream() - 2), 16]);
        let any = InferenceFact::default();

        let mut op = Squeeze::new(Some(vec![1]));
        let inferred = op.infer_facts(tvec!(&input), tvec!(&any), tvec!()).unwrap();

        let expect: TVec<_> = tvec!(InferenceFact::default()
            .with_datum_type(DatumType::TDim)
            .with_shape(shapefactoid![1, (TDim::stream() - 2), 16]));

        assert_eq!(inferred.1, expect);
    }
}