Skip to main content

gstreamer_analytics/
group.rs

1// Take a look at the license at the top of the repository in the LICENSE file.
2
3use glib::translate::*;
4use std::marker::PhantomData;
5
6use crate::{AnalyticsKeypointDimensions, AnalyticsKeypointPosition, ffi, relation_meta::*};
7
8#[derive(Debug)]
9pub enum AnalyticsGroupMtd {}
10
11mod sealed {
12    pub trait Sealed {}
13    impl<T: super::AnalyticsRelationMetaGroupExt> Sealed for T {}
14}
15
16pub trait AnalyticsRelationMetaGroupExt: sealed::Sealed {
17    fn add_group_mtd(
18        &mut self,
19        pre_alloc_size: usize,
20    ) -> Result<AnalyticsMtdRef<'_, AnalyticsGroupMtd>, glib::BoolError>;
21
22    fn add_group_mtd_with_size(
23        &mut self,
24        group_size: usize,
25    ) -> Result<AnalyticsMtdRef<'_, AnalyticsGroupMtd>, glib::BoolError>;
26
27    fn add_keypoints_group(
28        &mut self,
29        semantic_tag: &str,
30        dimension: AnalyticsKeypointDimensions,
31        positions: &[i32],
32        confidences: Option<&[f32]>,
33        visibilities: Option<&[u8]>,
34        skeleton_pairs: &[i32],
35    ) -> Result<AnalyticsMtdRef<'_, AnalyticsGroupMtd>, glib::BoolError>;
36
37    fn add_keypoints_group_from_positions(
38        &mut self,
39        semantic_tag: &str,
40        positions: &[AnalyticsKeypointPosition],
41        confidences: Option<&[f32]>,
42        visibilities: Option<&[u8]>,
43        skeleton_pairs: &[i32],
44    ) -> Result<AnalyticsMtdRef<'_, AnalyticsGroupMtd>, glib::BoolError> {
45        if positions.is_empty() {
46            return Err(glib::bool_error!("No keypoint positions provided"));
47        }
48
49        let dimension = positions[0].dimension;
50
51        if positions
52            .iter()
53            .any(|position| position.dimension != dimension)
54        {
55            return Err(glib::bool_error!(
56                "All keypoint positions must use the same dimension"
57            ));
58        }
59
60        let coords_per_keypoint = match dimension {
61            AnalyticsKeypointDimensions::_2d => 2,
62            AnalyticsKeypointDimensions::_3d => 3,
63            _ => {
64                return Err(glib::bool_error!(
65                    "Unsupported keypoint dimension for positions"
66                ));
67            }
68        };
69
70        let mut flattened_positions = Vec::with_capacity(positions.len() * coords_per_keypoint);
71        for position in positions {
72            flattened_positions.push(position.x);
73            flattened_positions.push(position.y);
74            if coords_per_keypoint == 3 {
75                flattened_positions.push(position.z);
76            }
77        }
78
79        self.add_keypoints_group(
80            semantic_tag,
81            dimension,
82            &flattened_positions,
83            confidences,
84            visibilities,
85            skeleton_pairs,
86        )
87    }
88}
89
90impl AnalyticsRelationMetaGroupExt
91    for gst::MetaRefMut<'_, AnalyticsRelationMeta, gst::meta::Standalone>
92{
93    #[doc(alias = "gst_analytics_relation_meta_add_group_mtd")]
94    fn add_group_mtd(
95        &mut self,
96        pre_alloc_size: usize,
97    ) -> Result<AnalyticsMtdRef<'_, AnalyticsGroupMtd>, glib::BoolError> {
98        unsafe {
99            let mut mtd = std::mem::MaybeUninit::uninit();
100            let ret = from_glib(ffi::gst_analytics_relation_meta_add_group_mtd(
101                self.as_mut_ptr(),
102                pre_alloc_size,
103                mtd.as_mut_ptr(),
104            ));
105            let id = mtd.assume_init().id;
106
107            if ret {
108                Ok(AnalyticsMtdRef::from_meta(self.as_ref(), id))
109            } else {
110                Err(glib::bool_error!("Couldn't add group metadata"))
111            }
112        }
113    }
114
115    #[doc(alias = "gst_analytics_relation_meta_add_group_mtd_with_size")]
116    fn add_group_mtd_with_size(
117        &mut self,
118        group_size: usize,
119    ) -> Result<AnalyticsMtdRef<'_, AnalyticsGroupMtd>, glib::BoolError> {
120        unsafe {
121            let mut mtd = std::mem::MaybeUninit::uninit();
122            let ret = from_glib(ffi::gst_analytics_relation_meta_add_group_mtd_with_size(
123                self.as_mut_ptr(),
124                group_size,
125                mtd.as_mut_ptr(),
126            ));
127            let id = mtd.assume_init().id;
128
129            if ret {
130                Ok(AnalyticsMtdRef::from_meta(self.as_ref(), id))
131            } else {
132                Err(glib::bool_error!("Couldn't add group metadata"))
133            }
134        }
135    }
136
137    #[doc(alias = "gst_analytics_relation_meta_add_keypoints_group")]
138    fn add_keypoints_group(
139        &mut self,
140        semantic_tag: &str,
141        dimension: AnalyticsKeypointDimensions,
142        positions: &[i32],
143        confidences: Option<&[f32]>,
144        visibilities: Option<&[u8]>,
145        skeleton_pairs: &[i32],
146    ) -> Result<AnalyticsMtdRef<'_, AnalyticsGroupMtd>, glib::BoolError> {
147        let coords_per_keypoint = match dimension {
148            AnalyticsKeypointDimensions::_2d => 2,
149            AnalyticsKeypointDimensions::_3d => 3,
150            _ => {
151                return Err(glib::bool_error!(
152                    "Unsupported keypoint dimension for positions"
153                ));
154            }
155        };
156
157        if positions.is_empty() {
158            return Err(glib::bool_error!("No keypoint positions provided"));
159        }
160
161        if !positions.len().is_multiple_of(coords_per_keypoint) {
162            return Err(glib::bool_error!(
163                "Positions length must match the keypoint dimension"
164            ));
165        }
166
167        let keypoint_count = positions.len() / coords_per_keypoint;
168
169        if let Some(confidences) = confidences
170            && confidences.len() != keypoint_count
171        {
172            return Err(glib::bool_error!(
173                "Confidences length must match keypoint count"
174            ));
175        }
176
177        if let Some(visibilities) = visibilities
178            && visibilities.len() != keypoint_count
179        {
180            return Err(glib::bool_error!(
181                "Visibilities length must match keypoint count"
182            ));
183        }
184
185        unsafe {
186            let mut mtd = std::mem::MaybeUninit::uninit();
187            let ret = from_glib(ffi::gst_analytics_relation_meta_add_keypoints_group(
188                self.as_mut_ptr(),
189                semantic_tag.to_glib_none().0,
190                dimension.into_glib(),
191                positions.len(),
192                positions.as_ptr(),
193                keypoint_count,
194                confidences.map_or(std::ptr::null(), |confidences| confidences.as_ptr()),
195                visibilities.map_or(std::ptr::null(), |visibilities| visibilities.as_ptr()),
196                skeleton_pairs.len(),
197                skeleton_pairs.as_ptr(),
198                mtd.as_mut_ptr(),
199            ));
200            let id = mtd.assume_init().id;
201
202            if ret {
203                Ok(AnalyticsMtdRef::from_meta(self.as_ref(), id))
204            } else {
205                Err(glib::bool_error!("Couldn't add keypoints group metadata"))
206            }
207        }
208    }
209}
210
211impl AnalyticsMtdRef<'_, AnalyticsGroupMtd> {
212    #[doc(alias = "gst_analytics_group_mtd_get_member_count")]
213    pub fn member_count(&self) -> usize {
214        unsafe {
215            let mtd = ffi::GstAnalyticsMtd::unsafe_from(self);
216            ffi::gst_analytics_group_mtd_get_member_count(
217                &mtd as *const _ as *const ffi::GstAnalyticsGroupMtd,
218            ) as usize
219        }
220    }
221
222    #[doc(alias = "gst_analytics_group_mtd_get_member")]
223    pub fn member(&self, index: usize) -> Option<AnalyticsMtdRef<'_, AnalyticsAnyMtd>> {
224        if index >= self.member_count() {
225            return None;
226        }
227
228        unsafe {
229            let mtd = ffi::GstAnalyticsMtd::unsafe_from(self);
230            let mut member = std::mem::MaybeUninit::uninit();
231            let ret = from_glib(ffi::gst_analytics_group_mtd_get_member(
232                &mtd as *const _ as *const ffi::GstAnalyticsGroupMtd,
233                index,
234                member.as_mut_ptr(),
235            ));
236
237            if ret {
238                let member = member.assume_init();
239                let id = ffi::gst_analytics_mtd_get_id(&member);
240                Some(AnalyticsMtdRef::from_meta(self.meta_ref(), id))
241            } else {
242                None
243            }
244        }
245    }
246
247    pub fn member_typed<T: AnalyticsMtd>(&self, index: usize) -> Option<AnalyticsMtdRef<'_, T>> {
248        self.member(index)
249            .and_then(|member| member.downcast::<T>().ok())
250    }
251
252    #[doc(alias = "gst_analytics_group_mtd_iterate")]
253    pub fn iter<T: AnalyticsMtd>(&self) -> AnalyticsGroupMtdIter<'_, T> {
254        AnalyticsGroupMtdIter::new(self)
255    }
256}
257
258#[must_use = "iterators are lazy and do nothing unless consumed"]
259pub struct AnalyticsGroupMtdIter<'a, T: AnalyticsMtd> {
260    group: &'a AnalyticsMtdRef<'a, AnalyticsGroupMtd>,
261    state: glib::ffi::gpointer,
262    phantom: PhantomData<T>,
263}
264
265impl<'a, T: AnalyticsMtd> AnalyticsGroupMtdIter<'a, T> {
266    fn new(group: &'a AnalyticsMtdRef<'a, AnalyticsGroupMtd>) -> Self {
267        skip_assert_initialized!();
268        AnalyticsGroupMtdIter {
269            group,
270            state: std::ptr::null_mut(),
271            phantom: PhantomData,
272        }
273    }
274}
275
276impl<'a, T: AnalyticsMtd + 'a> Iterator for AnalyticsGroupMtdIter<'a, T> {
277    type Item = AnalyticsMtdRef<'a, T>;
278
279    fn next(&mut self) -> Option<Self::Item> {
280        unsafe {
281            let mtd = ffi::GstAnalyticsMtd::unsafe_from(self.group);
282            let mut member = std::mem::MaybeUninit::uninit();
283            let ret = from_glib(ffi::gst_analytics_group_mtd_iterate(
284                &mtd as *const _ as *const ffi::GstAnalyticsGroupMtd,
285                &mut self.state,
286                T::mtd_type(),
287                member.as_mut_ptr(),
288            ));
289
290            if ret {
291                let member = member.assume_init();
292                let id = ffi::gst_analytics_mtd_get_id(&member);
293                Some(AnalyticsMtdRef::from_meta(self.group.meta_ref(), id))
294            } else {
295                None
296            }
297        }
298    }
299}
300
301unsafe impl AnalyticsMtd for AnalyticsGroupMtd {
302    #[doc(alias = "gst_analytics_group_mtd_get_mtd_type")]
303    fn mtd_type() -> ffi::GstAnalyticsMtdType {
304        unsafe { ffi::gst_analytics_group_mtd_get_mtd_type() }
305    }
306}
307
308impl AnalyticsMtdRefMut<'_, AnalyticsGroupMtd> {
309    #[doc(alias = "gst_analytics_group_mtd_add_member")]
310    pub fn add_member(&mut self, an_meta_id: u32) -> Result<(), glib::BoolError> {
311        let ret = unsafe {
312            let mut mtd = ffi::GstAnalyticsMtd::unsafe_from(self);
313            from_glib(ffi::gst_analytics_group_mtd_add_member(
314                &mut mtd as *mut _ as *mut ffi::GstAnalyticsGroupMtd,
315                an_meta_id,
316            ))
317        };
318
319        if ret {
320            Ok(())
321        } else {
322            Err(glib::bool_error!("Couldn't add group member"))
323        }
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use crate::*;
330
331    #[test]
332    fn group_members() {
333        gst::init().unwrap();
334
335        let type_name = AnalyticsGroupMtd::type_name();
336        assert_eq!(type_name, "grouping-mtd");
337
338        let mut buf = gst::Buffer::new();
339        let mut meta = AnalyticsRelationMeta::add(buf.make_mut());
340
341        let keypoint_id = {
342            let keypoint = meta
343                .add_keypoint_mtd_from_position(
344                    AnalyticsKeypointPosition {
345                        x: 1,
346                        y: 2,
347                        z: 0,
348                        dimension: AnalyticsKeypointDimensions::_2d,
349                    },
350                    AnalyticsKeypointVisibility::VISIBLE,
351                    0.5,
352                )
353                .unwrap();
354            keypoint.id()
355        };
356
357        let group = meta.add_group_mtd_with_size(1).unwrap();
358        let group_id = group.id();
359
360        let mut group_mut = meta.mtd_mut::<AnalyticsGroupMtd>(group_id).unwrap();
361        group_mut.set_semantic_tag("pose").unwrap();
362        group_mut.add_member(keypoint_id).unwrap();
363
364        let group = AnalyticsMtdRef::from(group_mut);
365        assert!(group.has_semantic_tag("pose"));
366        assert!(group.semantic_tag_has_prefix("po"));
367        assert_eq!(group.member_count(), 1);
368
369        let member = group.member_typed::<AnalyticsKeypointMtd>(0).unwrap();
370        let position = member.position().unwrap();
371        assert_eq!(position.x, 1);
372        assert_eq!(position.y, 2);
373    }
374
375    #[test]
376    fn keypoints_group() {
377        gst::init().unwrap();
378
379        let mut buf = gst::Buffer::new();
380        let mut meta = AnalyticsRelationMeta::add(buf.make_mut());
381
382        let positions = [
383            AnalyticsKeypointPosition {
384                x: 10,
385                y: 20,
386                z: 0,
387                dimension: AnalyticsKeypointDimensions::_2d,
388            },
389            AnalyticsKeypointPosition {
390                x: 30,
391                y: 40,
392                z: 0,
393                dimension: AnalyticsKeypointDimensions::_2d,
394            },
395        ];
396        let confidences = [0.9, 0.8];
397        let visibilities = [1, 0];
398
399        let group = meta
400            .add_keypoints_group_from_positions(
401                "pose",
402                &positions,
403                Some(&confidences),
404                Some(&visibilities),
405                &[],
406            )
407            .unwrap();
408
409        assert!(group.has_semantic_tag("pose"));
410        assert!(group.semantic_tag_has_prefix("po"));
411        assert_eq!(group.member_count(), 2);
412    }
413
414    #[test]
415    fn keypoints_group_rejects_mismatched_confidences() {
416        gst::init().unwrap();
417
418        let mut buf = gst::Buffer::new();
419        let mut meta = AnalyticsRelationMeta::add(buf.make_mut());
420
421        let positions = [
422            AnalyticsKeypointPosition {
423                x: 10,
424                y: 20,
425                z: 0,
426                dimension: AnalyticsKeypointDimensions::_2d,
427            },
428            AnalyticsKeypointPosition {
429                x: 30,
430                y: 40,
431                z: 0,
432                dimension: AnalyticsKeypointDimensions::_2d,
433            },
434        ];
435        let confidences = [0.9];
436
437        let result = meta.add_keypoints_group_from_positions(
438            "pose",
439            &positions,
440            Some(&confidences),
441            None,
442            &[],
443        );
444
445        assert!(result.is_err());
446    }
447
448    #[test]
449    fn keypoints_group_rejects_mismatched_visibilities() {
450        gst::init().unwrap();
451
452        let mut buf = gst::Buffer::new();
453        let mut meta = AnalyticsRelationMeta::add(buf.make_mut());
454
455        let positions = [
456            AnalyticsKeypointPosition {
457                x: 10,
458                y: 20,
459                z: 0,
460                dimension: AnalyticsKeypointDimensions::_2d,
461            },
462            AnalyticsKeypointPosition {
463                x: 30,
464                y: 40,
465                z: 0,
466                dimension: AnalyticsKeypointDimensions::_2d,
467            },
468        ];
469        let visibilities = [1];
470
471        let result = meta.add_keypoints_group_from_positions(
472            "pose",
473            &positions,
474            None,
475            Some(&visibilities),
476            &[],
477        );
478
479        assert!(result.is_err());
480    }
481}