tract_tensorflow/ops/array/
squeeze.rs

1use crate::model::ParsingContext;
2use crate::tfpb::tensorflow::NodeDef;
3use tract_hir::internal::*;
4use tract_hir::ops::array::Squeeze;
5
6pub fn squeeze(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
7    let squeeze_dims = pb.get_attr_opt_list_int("squeeze_dims")?;
8    if let Some(mut squeeze_dims) = squeeze_dims {
9        if squeeze_dims.len() > 0 {
10            squeeze_dims.sort();
11            return Ok(expand(Squeeze::new(Some(squeeze_dims))));
12        }
13    }
14    Ok(expand(Squeeze::default()))
15}
16
17#[cfg(test)]
18mod tests {
19    #![allow(non_snake_case)]
20    use super::*;
21    use tract_ndarray::Array;
22
23    fn run<I>(op: Squeeze, input: I) -> Tensor
24    where
25        I: Into<Tensor>,
26    {
27        expand(op).eval(tvec![input.into().into()]).unwrap().pop().unwrap().into_tensor()
28    }
29
30    #[test]
31    fn squeeze_1() {
32        assert_eq!(
33            run(Squeeze::new(None), Array::from_elem([1, 2, 1, 3, 1, 1], 0)).shape(),
34            &[2, 3]
35        );
36    }
37
38    #[test]
39    fn squeeze_2() {
40        assert_eq!(
41            run(Squeeze::new(Some(vec![2, 4])), Array::from_elem([1, 2, 1, 3, 1, 1], 0)).shape(),
42            &[1, 2, 3, 1]
43        );
44    }
45}