gstreamer_analytics/
tensor.rs

1// Take a look at the license at the top of the repository in the LICENSE file.
2
3use crate::ffi;
4use crate::*;
5use glib::translate::*;
6
7glib::wrapper! {
8    #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
9    #[doc(alias = "GstTensor")]
10    pub struct Tensor(Boxed<ffi::GstTensor>);
11
12    match fn {
13        copy => |ptr| ffi::gst_tensor_copy(ptr),
14        free => |ptr| ffi::gst_tensor_free(ptr),
15        type_ => || ffi::gst_tensor_get_type(),
16    }
17}
18
19unsafe impl Send for Tensor {}
20unsafe impl Sync for Tensor {}
21
22impl Tensor {
23    #[doc(alias = "gst_tensor_new_simple")]
24    pub fn new_simple(
25        id: glib::Quark,
26        data_type: TensorDataType,
27        data: gst::Buffer,
28        dims_order: TensorDimOrder,
29        dims: &[usize],
30    ) -> Tensor {
31        skip_assert_initialized!();
32        unsafe {
33            from_glib_full(ffi::gst_tensor_new_simple(
34                id.into_glib(),
35                data_type.into_glib(),
36                data.into_glib_ptr(),
37                dims_order.into_glib(),
38                dims.len(),
39                dims.as_ptr() as *mut _,
40            ))
41        }
42    }
43
44    #[doc(alias = "gst_tensor_get_dims")]
45    #[doc(alias = "get_dims")]
46    pub fn dims(&self) -> &[usize] {
47        let mut num_dims: usize = 0;
48        unsafe {
49            let dims = ffi::gst_tensor_get_dims(self.as_ptr(), &mut num_dims);
50            std::slice::from_raw_parts(dims as *const _, num_dims)
51        }
52    }
53
54    #[inline]
55    pub fn id(&self) -> glib::Quark {
56        unsafe { from_glib(self.inner.id) }
57    }
58
59    #[inline]
60    pub fn data_type(&self) -> TensorDataType {
61        unsafe { from_glib(self.inner.data_type) }
62    }
63
64    #[inline]
65    pub fn data(&self) -> &gst::BufferRef {
66        unsafe { gst::BufferRef::from_ptr(self.inner.data) }
67    }
68
69    #[inline]
70    pub fn data_mut(&mut self) -> &mut gst::BufferRef {
71        unsafe {
72            self.inner.data = gst::ffi::gst_mini_object_make_writable(self.inner.data as _) as _;
73            gst::BufferRef::from_mut_ptr(self.inner.data)
74        }
75    }
76
77    #[inline]
78    pub fn dims_order(&self) -> TensorDimOrder {
79        unsafe { from_glib(self.inner.dims_order) }
80    }
81
82    #[cfg(feature = "v1_28")]
83    #[cfg_attr(docsrs, doc(cfg(feature = "v1_28")))]
84    #[doc(alias = "gst_tensor_check_type")]
85    pub fn check_type(
86        &self,
87        data_type: crate::TensorDataType,
88        order: crate::TensorDimOrder,
89        dims: &[usize],
90    ) -> bool {
91        unsafe {
92            from_glib(ffi::gst_tensor_check_type(
93                self.to_glib_none().0,
94                data_type.into_glib(),
95                order.into_glib(),
96                dims.len(),
97                dims.as_ptr(),
98            ))
99        }
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use crate::*;
106
107    #[test]
108    fn create_tensor() {
109        gst::init().unwrap();
110
111        let buf = gst::Buffer::with_size(2 * 3 * 4 * 5).unwrap();
112        assert_eq!(buf.size(), 2 * 3 * 4 * 5);
113
114        let mut tensor = Tensor::new_simple(
115            glib::Quark::from_str("me"),
116            TensorDataType::Int16,
117            buf,
118            TensorDimOrder::RowMajor,
119            &[3, 4, 5],
120        );
121
122        assert_eq!(tensor.id(), glib::Quark::from_str("me"));
123        assert_eq!(tensor.data_type(), TensorDataType::Int16);
124        assert_eq!(tensor.dims_order(), TensorDimOrder::RowMajor);
125        assert_eq!(tensor.dims()[0], 3);
126        assert_eq!(tensor.dims()[1], 4);
127        assert_eq!(tensor.dims()[2], 5);
128        assert_eq!(tensor.data().size(), 2 * 3 * 4 * 5);
129
130        tensor.data();
131        tensor.data_mut();
132    }
133}