edgefirst_client/
dataset.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4use std::{collections::HashMap, fmt::Display};
5
6use crate::{
7    Client, Error,
8    api::{AnnotationSetID, DatasetID, ProjectID, SampleID},
9};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12
13#[cfg(feature = "polars")]
14use polars::prelude::*;
15
16/// File types supported in EdgeFirst Studio datasets.
17///
18/// Represents the different types of sensor data files that can be stored
19/// and processed in a dataset. EdgeFirst Studio supports various modalities
20/// including visual images and different forms of LiDAR and radar data.
21///
22/// # Examples
23///
24/// ```rust
25/// use edgefirst_client::FileType;
26///
27/// // Create file types from strings
28/// let image_type = FileType::from("image");
29/// let lidar_type = FileType::from("lidar.pcd");
30///
31/// // Display file types
32/// println!("Processing {} files", image_type); // "Processing image files"
33///
34/// // Use in dataset operations - example usage
35/// let file_type = FileType::Image;
36/// match file_type {
37///     FileType::Image => println!("Processing image files"),
38///     FileType::LidarPcd => println!("Processing LiDAR point cloud files"),
39///     _ => println!("Processing other sensor data"),
40/// }
41/// ```
42#[derive(Clone, Eq, PartialEq, Debug)]
43pub enum FileType {
44    /// Standard image files (JPEG, PNG, etc.)
45    Image,
46    /// LiDAR point cloud data files (.pcd format)
47    LidarPcd,
48    /// LiDAR depth images (.png format)
49    LidarDepth,
50    /// LiDAR reflectance images (.jpg format)
51    LidarReflect,
52    /// Radar point cloud data files (.pcd format)
53    RadarPcd,
54    /// Radar cube data files (.png format)
55    RadarCube,
56}
57
58impl std::fmt::Display for FileType {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        let value = match self {
61            FileType::Image => "image",
62            FileType::LidarPcd => "lidar.pcd",
63            FileType::LidarDepth => "lidar.png",
64            FileType::LidarReflect => "lidar.jpg",
65            FileType::RadarPcd => "radar.pcd",
66            FileType::RadarCube => "radar.png",
67        };
68        write!(f, "{}", value)
69    }
70}
71
72impl From<&str> for FileType {
73    fn from(s: &str) -> Self {
74        match s {
75            "image" => FileType::Image,
76            "lidar.pcd" => FileType::LidarPcd,
77            "lidar.png" => FileType::LidarDepth,
78            "lidar.jpg" => FileType::LidarReflect,
79            "radar.pcd" => FileType::RadarPcd,
80            "radar.png" => FileType::RadarCube,
81            _ => panic!("Invalid file type"),
82        }
83    }
84}
85
86/// Annotation types supported for labeling data in EdgeFirst Studio.
87///
88/// Represents the different types of annotations that can be applied to
89/// sensor data for machine learning tasks. Each type corresponds to a
90/// different annotation geometry and use case.
91///
92/// # Examples
93///
94/// ```rust
95/// use edgefirst_client::AnnotationType;
96///
97/// // Create annotation types from strings
98/// let box_2d = AnnotationType::from("box2d");
99/// let segmentation = AnnotationType::from("mask");
100///
101/// // Display annotation types
102/// println!("Annotation type: {}", box_2d); // "Annotation type: box2d"
103///
104/// // Use in matching and processing
105/// let annotation_type = AnnotationType::Box2d;
106/// match annotation_type {
107///     AnnotationType::Box2d => println!("Processing 2D bounding boxes"),
108///     AnnotationType::Box3d => println!("Processing 3D bounding boxes"),
109///     AnnotationType::Mask => println!("Processing segmentation masks"),
110/// }
111/// ```
112#[derive(Clone, Eq, PartialEq, Debug)]
113pub enum AnnotationType {
114    /// 2D bounding boxes for object detection in images
115    Box2d,
116    /// 3D bounding boxes for object detection in 3D space (LiDAR, etc.)
117    Box3d,
118    /// Pixel-level segmentation masks for semantic/instance segmentation
119    Mask,
120}
121
122impl From<&str> for AnnotationType {
123    fn from(s: &str) -> Self {
124        match s {
125            "box2d" => AnnotationType::Box2d,
126            "box3d" => AnnotationType::Box3d,
127            "mask" => AnnotationType::Mask,
128            _ => panic!("Invalid annotation type"),
129        }
130    }
131}
132
133impl From<String> for AnnotationType {
134    fn from(s: String) -> Self {
135        s.as_str().into()
136    }
137}
138
139impl From<&String> for AnnotationType {
140    fn from(s: &String) -> Self {
141        s.as_str().into()
142    }
143}
144
145impl std::fmt::Display for AnnotationType {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        let value = match self {
148            AnnotationType::Box2d => "box2d",
149            AnnotationType::Box3d => "box3d",
150            AnnotationType::Mask => "mask",
151        };
152        write!(f, "{}", value)
153    }
154}
155
156/// A dataset in EdgeFirst Studio containing sensor data and annotations.
157///
158/// Datasets are collections of multi-modal sensor data (images, LiDAR, radar)
159/// along with their corresponding annotations (bounding boxes, segmentation
160/// masks, 3D annotations). Datasets belong to projects and can be used for
161/// training and validation of machine learning models.
162///
163/// # Features
164///
165/// - **Multi-modal Data**: Support for images, LiDAR point clouds, radar data
166/// - **Rich Annotations**: 2D/3D bounding boxes, segmentation masks
167/// - **Metadata**: Timestamps, sensor configurations, calibration data
168/// - **Version Control**: Track changes and maintain data lineage
169/// - **Format Conversion**: Export to popular ML frameworks
170///
171/// # Examples
172///
173/// ```no_run
174/// use edgefirst_client::{Client, Dataset, DatasetID};
175/// use std::str::FromStr;
176///
177/// # async fn example() -> Result<(), edgefirst_client::Error> {
178/// # let client = Client::new()?;
179/// // Get dataset information
180/// let dataset_id = DatasetID::from_str("ds-abc123")?;
181/// let dataset = client.dataset(dataset_id).await?;
182/// println!("Dataset: {}", dataset.name());
183///
184/// // Access dataset metadata
185/// println!("Dataset ID: {}", dataset.id());
186/// println!("Description: {}", dataset.description());
187/// println!("Created: {}", dataset.created());
188///
189/// // Work with dataset data would require additional methods
190/// // that are implemented in the full API
191/// # Ok(())
192/// # }
193/// ```
194#[derive(Deserialize, Clone, Debug)]
195pub struct Dataset {
196    id: DatasetID,
197    project_id: ProjectID,
198    name: String,
199    description: String,
200    cloud_key: String,
201    #[serde(rename = "createdAt")]
202    created: DateTime<Utc>,
203}
204
205impl Display for Dataset {
206    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
207        write!(f, "{} {}", self.uid(), self.name)
208    }
209}
210
211impl Dataset {
212    pub fn id(&self) -> DatasetID {
213        self.id
214    }
215
216    pub fn uid(&self) -> String {
217        self.id.to_string()
218    }
219
220    pub fn project_id(&self) -> ProjectID {
221        self.project_id
222    }
223
224    pub fn name(&self) -> &str {
225        &self.name
226    }
227
228    pub fn description(&self) -> &str {
229        &self.description
230    }
231
232    pub fn cloud_key(&self) -> &str {
233        &self.cloud_key
234    }
235
236    pub fn created(&self) -> &DateTime<Utc> {
237        &self.created
238    }
239
240    pub async fn project(&self, client: &Client) -> Result<crate::api::Project, Error> {
241        client.project(self.project_id).await
242    }
243
244    pub async fn annotation_sets(&self, client: &Client) -> Result<Vec<AnnotationSet>, Error> {
245        client.annotation_sets(self.id).await
246    }
247
248    pub async fn labels(&self, client: &Client) -> Result<Vec<Label>, Error> {
249        client.labels(self.id).await
250    }
251
252    pub async fn add_label(&self, client: &Client, name: &str) -> Result<(), Error> {
253        client.add_label(self.id, name).await
254    }
255
256    pub async fn remove_label(&self, client: &Client, name: &str) -> Result<(), Error> {
257        let labels = self.labels(client).await?;
258        let label = labels
259            .iter()
260            .find(|l| l.name() == name)
261            .ok_or_else(|| Error::MissingLabel(name.to_string()))?;
262        client.remove_label(label.id()).await
263    }
264}
265
266/// The AnnotationSet class represents a collection of annotations in a dataset.
267/// A dataset can have multiple annotation sets, each containing annotations for
268/// different tasks or purposes.
269#[derive(Deserialize)]
270pub struct AnnotationSet {
271    id: AnnotationSetID,
272    dataset_id: DatasetID,
273    name: String,
274    description: String,
275    #[serde(rename = "date")]
276    created: DateTime<Utc>,
277}
278
279impl Display for AnnotationSet {
280    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
281        write!(f, "{} {}", self.uid(), self.name)
282    }
283}
284
285impl AnnotationSet {
286    pub fn id(&self) -> AnnotationSetID {
287        self.id
288    }
289
290    pub fn uid(&self) -> String {
291        self.id.to_string()
292    }
293
294    pub fn dataset_id(&self) -> DatasetID {
295        self.dataset_id
296    }
297
298    pub fn name(&self) -> &str {
299        &self.name
300    }
301
302    pub fn description(&self) -> &str {
303        &self.description
304    }
305
306    pub fn created(&self) -> DateTime<Utc> {
307        self.created
308    }
309
310    pub async fn dataset(&self, client: &Client) -> Result<Dataset, Error> {
311        client.dataset(self.dataset_id).await
312    }
313}
314
315/// A sample in a dataset, typically representing a single image with metadata
316/// and optional sensor data.
317///
318/// Each sample has a unique ID, image reference, and can include additional
319/// sensor data like LiDAR, radar, or depth maps. Samples can also have
320/// associated annotations.
321#[derive(Serialize, Deserialize, Clone, Debug)]
322pub struct Sample {
323    #[serde(skip_serializing_if = "Option::is_none")]
324    pub id: Option<SampleID>,
325    #[serde(alias = "group_name", skip_serializing_if = "Option::is_none")]
326    pub group: Option<String>,
327    #[serde(skip_serializing_if = "Option::is_none")]
328    pub sequence_name: Option<String>,
329    #[serde(skip_serializing_if = "Option::is_none")]
330    pub sequence_uuid: Option<String>,
331    #[serde(skip_serializing_if = "Option::is_none")]
332    pub sequence_description: Option<String>,
333    #[serde(skip_serializing_if = "Option::is_none")]
334    pub frame_number: Option<u32>,
335    #[serde(skip_serializing_if = "Option::is_none")]
336    pub uuid: Option<String>,
337    #[serde(skip_serializing_if = "Option::is_none")]
338    pub image_name: Option<String>,
339    #[serde(skip_serializing_if = "Option::is_none")]
340    pub image_url: Option<String>,
341    #[serde(skip_serializing_if = "Option::is_none")]
342    pub width: Option<u32>,
343    #[serde(skip_serializing_if = "Option::is_none")]
344    pub height: Option<u32>,
345    #[serde(skip_serializing_if = "Option::is_none")]
346    pub date: Option<DateTime<Utc>>,
347    #[serde(skip_serializing_if = "Option::is_none")]
348    pub source: Option<String>,
349    /// Camera location and pose (GPS + IMU data).
350    /// Serialized as "sensors" for API compatibility with populate endpoint.
351    #[serde(rename = "sensors", skip_serializing_if = "Option::is_none")]
352    pub location: Option<Location>,
353    /// Additional sensor files (LiDAR, radar, depth maps, etc.).
354    /// When deserializing from samples.list: Vec<SampleFile>
355    /// When serializing for samples.populate: HashMap<String, String>
356    /// (file_type -> filename)
357    #[serde(
358        default,
359        skip_serializing_if = "Vec::is_empty",
360        serialize_with = "serialize_files",
361        deserialize_with = "deserialize_files"
362    )]
363    pub files: Vec<SampleFile>,
364    #[serde(
365        default,
366        skip_serializing_if = "Vec::is_empty",
367        serialize_with = "serialize_annotations",
368        deserialize_with = "deserialize_annotations"
369    )]
370    pub annotations: Vec<Annotation>,
371}
372
373// Custom serializer for files field - converts Vec<SampleFile> to
374// HashMap<String, String>
375fn serialize_files<S>(files: &[SampleFile], serializer: S) -> Result<S::Ok, S::Error>
376where
377    S: serde::Serializer,
378{
379    use serde::Serialize;
380    let map: HashMap<String, String> = files
381        .iter()
382        .filter_map(|f| {
383            f.filename()
384                .map(|filename| (f.file_type().to_string(), filename.to_string()))
385        })
386        .collect();
387    map.serialize(serializer)
388}
389
390// Custom deserializer for files field - converts HashMap or Vec to
391// Vec<SampleFile>
392fn deserialize_files<'de, D>(deserializer: D) -> Result<Vec<SampleFile>, D::Error>
393where
394    D: serde::Deserializer<'de>,
395{
396    use serde::Deserialize;
397
398    #[derive(Deserialize)]
399    #[serde(untagged)]
400    enum FilesFormat {
401        Vec(Vec<SampleFile>),
402        Map(HashMap<String, String>),
403    }
404
405    let value = Option::<FilesFormat>::deserialize(deserializer)?;
406    Ok(value
407        .map(|v| match v {
408            FilesFormat::Vec(files) => files,
409            FilesFormat::Map(map) => map
410                .into_iter()
411                .map(|(file_type, filename)| SampleFile::with_filename(file_type, filename))
412                .collect(),
413        })
414        .unwrap_or_default())
415}
416
417// Custom serializer for annotations field - converts Vec<Annotation> to
418// format expected by server: {"bbox": [...], "box3d": [...], "mask": [...]}
419fn serialize_annotations<S>(annotations: &Vec<Annotation>, serializer: S) -> Result<S::Ok, S::Error>
420where
421    S: serde::Serializer,
422{
423    use serde::ser::SerializeMap;
424
425    // Group annotations by type
426    let mut bbox_annotations = Vec::new();
427    let mut box3d_annotations = Vec::new();
428    let mut mask_annotations = Vec::new();
429
430    for ann in annotations {
431        if ann.box2d().is_some() {
432            bbox_annotations.push(ann);
433        } else if ann.box3d().is_some() {
434            box3d_annotations.push(ann);
435        } else if ann.mask().is_some() {
436            mask_annotations.push(ann);
437        }
438    }
439
440    let mut map = serializer.serialize_map(Some(3))?;
441
442    if !bbox_annotations.is_empty() {
443        map.serialize_entry("bbox", &bbox_annotations)?;
444    }
445    if !box3d_annotations.is_empty() {
446        map.serialize_entry("box3d", &box3d_annotations)?;
447    }
448    if !mask_annotations.is_empty() {
449        map.serialize_entry("mask", &mask_annotations)?;
450    }
451
452    map.end()
453}
454
455// Custom deserializer for annotations field - converts server format back to
456// Vec<Annotation>
457fn deserialize_annotations<'de, D>(deserializer: D) -> Result<Vec<Annotation>, D::Error>
458where
459    D: serde::Deserializer<'de>,
460{
461    use serde::Deserialize;
462
463    #[derive(Deserialize)]
464    #[serde(untagged)]
465    enum AnnotationsFormat {
466        Vec(Vec<Annotation>),
467        Map(HashMap<String, Vec<Annotation>>),
468    }
469
470    let value = Option::<AnnotationsFormat>::deserialize(deserializer)?;
471    Ok(value
472        .map(|v| match v {
473            AnnotationsFormat::Vec(annotations) => annotations,
474            AnnotationsFormat::Map(map) => {
475                let mut all_annotations = Vec::new();
476                if let Some(bbox_anns) = map.get("bbox") {
477                    all_annotations.extend(bbox_anns.clone());
478                }
479                if let Some(box3d_anns) = map.get("box3d") {
480                    all_annotations.extend(box3d_anns.clone());
481                }
482                if let Some(mask_anns) = map.get("mask") {
483                    all_annotations.extend(mask_anns.clone());
484                }
485                all_annotations
486            }
487        })
488        .unwrap_or_default())
489}
490
491impl Display for Sample {
492    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
493        write!(
494            f,
495            "{} {}",
496            self.uid().unwrap_or_else(|| "unknown".to_string()),
497            self.image_name().unwrap_or("unknown")
498        )
499    }
500}
501
502impl Default for Sample {
503    fn default() -> Self {
504        Self::new()
505    }
506}
507
508impl Sample {
509    /// Creates a new empty sample.
510    pub fn new() -> Self {
511        Self {
512            id: None,
513            group: None,
514            sequence_name: None,
515            sequence_uuid: None,
516            sequence_description: None,
517            frame_number: None,
518            uuid: None,
519            image_name: None,
520            image_url: None,
521            width: None,
522            height: None,
523            date: None,
524            source: None,
525            location: None,
526            files: vec![],
527            annotations: vec![],
528        }
529    }
530
531    pub fn id(&self) -> Option<SampleID> {
532        self.id
533    }
534
535    pub fn uid(&self) -> Option<String> {
536        self.id.map(|id| id.to_string())
537    }
538
539    pub fn name(&self) -> Option<String> {
540        self.image_name.as_ref().map(|image_name| {
541            let name = image_name
542                .rsplit_once('.')
543                .map_or_else(|| image_name.clone(), |(name, _)| name.to_string());
544            name.rsplit_once(".camera")
545                .map_or_else(|| name.clone(), |(name, _)| name.to_string())
546        })
547    }
548
549    pub fn group(&self) -> Option<&String> {
550        self.group.as_ref()
551    }
552
553    pub fn sequence_name(&self) -> Option<&String> {
554        self.sequence_name.as_ref()
555    }
556
557    pub fn sequence_uuid(&self) -> Option<&String> {
558        self.sequence_uuid.as_ref()
559    }
560
561    pub fn sequence_description(&self) -> Option<&String> {
562        self.sequence_description.as_ref()
563    }
564
565    pub fn frame_number(&self) -> Option<u32> {
566        self.frame_number
567    }
568
569    pub fn uuid(&self) -> Option<&String> {
570        self.uuid.as_ref()
571    }
572
573    pub fn image_name(&self) -> Option<&str> {
574        self.image_name.as_deref()
575    }
576
577    pub fn image_url(&self) -> Option<&str> {
578        self.image_url.as_deref()
579    }
580
581    pub fn width(&self) -> Option<u32> {
582        self.width
583    }
584
585    pub fn height(&self) -> Option<u32> {
586        self.height
587    }
588
589    pub fn date(&self) -> Option<DateTime<Utc>> {
590        self.date
591    }
592
593    pub fn source(&self) -> Option<&String> {
594        self.source.as_ref()
595    }
596
597    pub fn location(&self) -> Option<&Location> {
598        self.location.as_ref()
599    }
600
601    pub fn files(&self) -> &[SampleFile] {
602        &self.files
603    }
604
605    pub fn annotations(&self) -> &[Annotation] {
606        &self.annotations
607    }
608
609    pub fn with_annotations(mut self, annotations: Vec<Annotation>) -> Self {
610        self.annotations = annotations;
611        self
612    }
613
614    pub async fn download(
615        &self,
616        client: &Client,
617        file_type: FileType,
618    ) -> Result<Option<Vec<u8>>, Error> {
619        let url = match file_type {
620            FileType::Image => self.image_url.as_ref(),
621            file => self
622                .files
623                .iter()
624                .find(|f| f.r#type == file.to_string())
625                .and_then(|f| f.url.as_ref()),
626        };
627
628        Ok(match url {
629            Some(url) => Some(client.download(url).await?),
630            None => None,
631        })
632    }
633}
634
635/// A file associated with a sample (e.g., LiDAR point cloud, radar data).
636///
637/// For samples retrieved from the server, this contains the file type and URL.
638/// For samples being populated to the server, this can be a type and filename.
639#[derive(Serialize, Deserialize, Clone, Debug)]
640pub struct SampleFile {
641    r#type: String,
642    #[serde(skip_serializing_if = "Option::is_none")]
643    url: Option<String>,
644    #[serde(skip_serializing_if = "Option::is_none")]
645    filename: Option<String>,
646}
647
648impl SampleFile {
649    /// Creates a new sample file with type and URL (for downloaded samples).
650    pub fn with_url(file_type: String, url: String) -> Self {
651        Self {
652            r#type: file_type,
653            url: Some(url),
654            filename: None,
655        }
656    }
657
658    /// Creates a new sample file with type and filename (for populate API).
659    pub fn with_filename(file_type: String, filename: String) -> Self {
660        Self {
661            r#type: file_type,
662            url: None,
663            filename: Some(filename),
664        }
665    }
666
667    pub fn file_type(&self) -> &str {
668        &self.r#type
669    }
670
671    pub fn url(&self) -> Option<&str> {
672        self.url.as_deref()
673    }
674
675    pub fn filename(&self) -> Option<&str> {
676        self.filename.as_deref()
677    }
678}
679
680/// Location and pose information for a sample.
681///
682/// Contains GPS coordinates and IMU orientation data describing where and how
683/// the camera was positioned when capturing the sample.
684#[derive(Serialize, Deserialize, Clone, Debug)]
685pub struct Location {
686    #[serde(skip_serializing_if = "Option::is_none")]
687    pub gps: Option<GpsData>,
688    #[serde(skip_serializing_if = "Option::is_none")]
689    pub imu: Option<ImuData>,
690}
691
692/// GPS location data (latitude and longitude).
693#[derive(Serialize, Deserialize, Clone, Debug)]
694pub struct GpsData {
695    pub lat: f64,
696    pub lon: f64,
697}
698
699/// IMU orientation data (roll, pitch, yaw in degrees).
700#[derive(Serialize, Deserialize, Clone, Debug)]
701pub struct ImuData {
702    pub roll: f64,
703    pub pitch: f64,
704    pub yaw: f64,
705}
706
707#[allow(dead_code)]
708pub trait TypeName {
709    fn type_name() -> String;
710}
711
712#[derive(Serialize, Deserialize, Clone, Debug)]
713pub struct Box3d {
714    x: f32,
715    y: f32,
716    z: f32,
717    w: f32,
718    h: f32,
719    l: f32,
720}
721
722impl TypeName for Box3d {
723    fn type_name() -> String {
724        "box3d".to_owned()
725    }
726}
727
728impl Box3d {
729    pub fn new(cx: f32, cy: f32, cz: f32, width: f32, height: f32, length: f32) -> Self {
730        Self {
731            x: cx,
732            y: cy,
733            z: cz,
734            w: width,
735            h: height,
736            l: length,
737        }
738    }
739
740    pub fn width(&self) -> f32 {
741        self.w
742    }
743
744    pub fn height(&self) -> f32 {
745        self.h
746    }
747
748    pub fn length(&self) -> f32 {
749        self.l
750    }
751
752    pub fn cx(&self) -> f32 {
753        self.x
754    }
755
756    pub fn cy(&self) -> f32 {
757        self.y
758    }
759
760    pub fn cz(&self) -> f32 {
761        self.z
762    }
763
764    pub fn left(&self) -> f32 {
765        self.x - self.w / 2.0
766    }
767
768    pub fn top(&self) -> f32 {
769        self.y - self.h / 2.0
770    }
771
772    pub fn front(&self) -> f32 {
773        self.z - self.l / 2.0
774    }
775}
776
777#[derive(Serialize, Deserialize, Clone, Debug)]
778pub struct Box2d {
779    h: f32,
780    w: f32,
781    x: f32,
782    y: f32,
783}
784
785impl TypeName for Box2d {
786    fn type_name() -> String {
787        "box2d".to_owned()
788    }
789}
790
791impl Box2d {
792    pub fn new(left: f32, top: f32, width: f32, height: f32) -> Self {
793        Self {
794            x: left,
795            y: top,
796            w: width,
797            h: height,
798        }
799    }
800
801    pub fn width(&self) -> f32 {
802        self.w
803    }
804
805    pub fn height(&self) -> f32 {
806        self.h
807    }
808
809    pub fn left(&self) -> f32 {
810        self.x
811    }
812
813    pub fn top(&self) -> f32 {
814        self.y
815    }
816
817    pub fn cx(&self) -> f32 {
818        self.x + self.w / 2.0
819    }
820
821    pub fn cy(&self) -> f32 {
822        self.y + self.h / 2.0
823    }
824}
825
826#[derive(Serialize, Deserialize, Clone, Debug)]
827pub struct Mask {
828    pub polygon: Vec<Vec<(f32, f32)>>,
829}
830
831impl TypeName for Mask {
832    fn type_name() -> String {
833        "mask".to_owned()
834    }
835}
836
837impl Mask {
838    pub fn new(polygon: Vec<Vec<(f32, f32)>>) -> Self {
839        Self { polygon }
840    }
841}
842
843#[derive(Deserialize, Clone, Debug)]
844#[serde(from = "AnnotationHelper")]
845pub struct Annotation {
846    #[serde(skip_serializing_if = "Option::is_none")]
847    sample_id: Option<SampleID>,
848    #[serde(skip_serializing_if = "Option::is_none")]
849    name: Option<String>,
850    #[serde(skip_serializing_if = "Option::is_none")]
851    sequence_name: Option<String>,
852    #[serde(skip_serializing_if = "Option::is_none")]
853    group: Option<String>,
854    #[serde(rename = "object_reference", skip_serializing_if = "Option::is_none")]
855    object_id: Option<String>,
856    #[serde(rename = "label_name", skip_serializing_if = "Option::is_none")]
857    label: Option<String>,
858    #[serde(skip_serializing_if = "Option::is_none")]
859    label_index: Option<u64>,
860    #[serde(skip_serializing_if = "Option::is_none")]
861    box2d: Option<Box2d>,
862    #[serde(skip_serializing_if = "Option::is_none")]
863    box3d: Option<Box3d>,
864    #[serde(skip_serializing_if = "Option::is_none")]
865    mask: Option<Mask>,
866}
867
868// Helper struct for deserialization that matches the nested format
869#[derive(Deserialize)]
870struct AnnotationHelper {
871    #[serde(skip_serializing_if = "Option::is_none")]
872    sample_id: Option<SampleID>,
873    #[serde(skip_serializing_if = "Option::is_none")]
874    name: Option<String>,
875    #[serde(skip_serializing_if = "Option::is_none")]
876    sequence_name: Option<String>,
877    #[serde(skip_serializing_if = "Option::is_none")]
878    group: Option<String>,
879    #[serde(rename = "object_reference", skip_serializing_if = "Option::is_none")]
880    object_id: Option<String>,
881    #[serde(rename = "label_name", skip_serializing_if = "Option::is_none")]
882    label: Option<String>,
883    #[serde(skip_serializing_if = "Option::is_none")]
884    label_index: Option<u64>,
885    #[serde(skip_serializing_if = "Option::is_none")]
886    box2d: Option<Box2d>,
887    #[serde(skip_serializing_if = "Option::is_none")]
888    box3d: Option<Box3d>,
889    #[serde(skip_serializing_if = "Option::is_none")]
890    mask: Option<Mask>,
891}
892
893impl From<AnnotationHelper> for Annotation {
894    fn from(helper: AnnotationHelper) -> Self {
895        Self {
896            sample_id: helper.sample_id,
897            name: helper.name,
898            sequence_name: helper.sequence_name,
899            group: helper.group,
900            object_id: helper.object_id,
901            label: helper.label,
902            label_index: helper.label_index,
903            box2d: helper.box2d,
904            box3d: helper.box3d,
905            mask: helper.mask,
906        }
907    }
908}
909
910// Custom serializer that flattens box2d/box3d fields
911impl Serialize for Annotation {
912    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
913    where
914        S: serde::Serializer,
915    {
916        use serde::ser::SerializeMap;
917
918        let mut map = serializer.serialize_map(None)?;
919
920        if let Some(ref sample_id) = self.sample_id {
921            map.serialize_entry("sample_id", sample_id)?;
922        }
923        if let Some(ref name) = self.name {
924            map.serialize_entry("name", name)?;
925        }
926        if let Some(ref sequence_name) = self.sequence_name {
927            map.serialize_entry("sequence_name", sequence_name)?;
928        }
929        if let Some(ref group) = self.group {
930            map.serialize_entry("group", group)?;
931        }
932        if let Some(ref object_id) = self.object_id {
933            map.serialize_entry("object_reference", object_id)?;
934        }
935        if let Some(ref label) = self.label {
936            map.serialize_entry("label_name", label)?;
937        }
938        if let Some(label_index) = self.label_index {
939            map.serialize_entry("label_index", &label_index)?;
940        }
941
942        // Flatten box2d fields
943        if let Some(ref box2d) = self.box2d {
944            map.serialize_entry("x", &box2d.x)?;
945            map.serialize_entry("y", &box2d.y)?;
946            map.serialize_entry("w", &box2d.w)?;
947            map.serialize_entry("h", &box2d.h)?;
948        }
949
950        // Flatten box3d fields
951        if let Some(ref box3d) = self.box3d {
952            map.serialize_entry("x", &box3d.x)?;
953            map.serialize_entry("y", &box3d.y)?;
954            map.serialize_entry("z", &box3d.z)?;
955            map.serialize_entry("w", &box3d.w)?;
956            map.serialize_entry("h", &box3d.h)?;
957            map.serialize_entry("l", &box3d.l)?;
958        }
959
960        if let Some(ref mask) = self.mask {
961            map.serialize_entry("mask", mask)?;
962        }
963
964        map.end()
965    }
966}
967
968impl Default for Annotation {
969    fn default() -> Self {
970        Self::new()
971    }
972}
973
974impl Annotation {
975    pub fn new() -> Self {
976        Self {
977            sample_id: None,
978            name: None,
979            sequence_name: None,
980            group: None,
981            object_id: None,
982            label: None,
983            label_index: None,
984            box2d: None,
985            box3d: None,
986            mask: None,
987        }
988    }
989
990    pub fn set_sample_id(&mut self, sample_id: Option<SampleID>) {
991        self.sample_id = sample_id;
992    }
993
994    pub fn sample_id(&self) -> Option<SampleID> {
995        self.sample_id
996    }
997
998    pub fn set_name(&mut self, name: Option<String>) {
999        self.name = name;
1000    }
1001
1002    pub fn name(&self) -> Option<&String> {
1003        self.name.as_ref()
1004    }
1005
1006    pub fn set_sequence_name(&mut self, sequence_name: Option<String>) {
1007        self.sequence_name = sequence_name;
1008    }
1009
1010    pub fn sequence_name(&self) -> Option<&String> {
1011        self.sequence_name.as_ref()
1012    }
1013
1014    pub fn set_group(&mut self, group: Option<String>) {
1015        self.group = group;
1016    }
1017
1018    pub fn group(&self) -> Option<&String> {
1019        self.group.as_ref()
1020    }
1021
1022    pub fn object_id(&self) -> Option<&String> {
1023        self.object_id.as_ref()
1024    }
1025
1026    pub fn set_object_id(&mut self, object_id: Option<String>) {
1027        self.object_id = object_id;
1028    }
1029
1030    pub fn label(&self) -> Option<&String> {
1031        self.label.as_ref()
1032    }
1033
1034    pub fn set_label(&mut self, label: Option<String>) {
1035        self.label = label;
1036    }
1037
1038    pub fn label_index(&self) -> Option<u64> {
1039        self.label_index
1040    }
1041
1042    pub fn set_label_index(&mut self, label_index: Option<u64>) {
1043        self.label_index = label_index;
1044    }
1045
1046    pub fn box2d(&self) -> Option<&Box2d> {
1047        self.box2d.as_ref()
1048    }
1049
1050    pub fn set_box2d(&mut self, box2d: Option<Box2d>) {
1051        self.box2d = box2d;
1052    }
1053
1054    pub fn box3d(&self) -> Option<&Box3d> {
1055        self.box3d.as_ref()
1056    }
1057
1058    pub fn set_box3d(&mut self, box3d: Option<Box3d>) {
1059        self.box3d = box3d;
1060    }
1061
1062    pub fn mask(&self) -> Option<&Mask> {
1063        self.mask.as_ref()
1064    }
1065
1066    pub fn set_mask(&mut self, mask: Option<Mask>) {
1067        self.mask = mask;
1068    }
1069}
1070
1071#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
1072pub struct Label {
1073    id: u64,
1074    dataset_id: DatasetID,
1075    index: u64,
1076    name: String,
1077}
1078
1079impl Label {
1080    pub fn id(&self) -> u64 {
1081        self.id
1082    }
1083
1084    pub fn dataset_id(&self) -> DatasetID {
1085        self.dataset_id
1086    }
1087
1088    pub fn index(&self) -> u64 {
1089        self.index
1090    }
1091
1092    pub fn name(&self) -> &str {
1093        &self.name
1094    }
1095
1096    pub async fn remove(&self, client: &Client) -> Result<(), Error> {
1097        client.remove_label(self.id()).await
1098    }
1099
1100    pub async fn set_name(&mut self, client: &Client, name: &str) -> Result<(), Error> {
1101        self.name = name.to_string();
1102        client.update_label(self).await
1103    }
1104
1105    pub async fn set_index(&mut self, client: &Client, index: u64) -> Result<(), Error> {
1106        self.index = index;
1107        client.update_label(self).await
1108    }
1109}
1110
1111impl Display for Label {
1112    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1113        write!(f, "{}", self.name())
1114    }
1115}
1116
1117#[derive(Serialize, Clone, Debug)]
1118pub struct NewLabelObject {
1119    pub name: String,
1120}
1121
1122#[derive(Serialize, Clone, Debug)]
1123pub struct NewLabel {
1124    pub dataset_id: DatasetID,
1125    pub labels: Vec<NewLabelObject>,
1126}
1127
1128#[derive(Deserialize, Clone, Debug)]
1129#[allow(dead_code)]
1130pub struct Group {
1131    pub id: u64, // Groups seem to use raw u64, not a specific ID type
1132    pub name: String,
1133}
1134
1135#[cfg(feature = "polars")]
1136pub fn annotations_dataframe(annotations: &[Annotation]) -> DataFrame {
1137    use itertools::Itertools;
1138    use log::warn;
1139    use std::path::Path;
1140
1141    let (names, frames, objects, labels, label_indices, groups, masks, boxes2d, boxes3d) =
1142        annotations
1143            .iter()
1144            .map(|ann| {
1145                let name = match &ann.name {
1146                    Some(name) => name,
1147                    None => {
1148                        warn!("annotation missing image name, skipping");
1149                        return (
1150                            String::new(),
1151                            None,
1152                            None,
1153                            None,
1154                            None,
1155                            None,
1156                            None,
1157                            None,
1158                            None,
1159                        );
1160                    }
1161                };
1162
1163                let name = Path::new(name).file_stem().unwrap().to_str().unwrap();
1164
1165                let (name, frame) = match &ann.sequence_name {
1166                    Some(sequence) => match name.strip_prefix(sequence) {
1167                        Some(frame) => (
1168                            sequence.to_string(),
1169                            Some(frame.trim_start_matches('_').to_string()),
1170                        ),
1171                        None => {
1172                            warn!(
1173                                "image_name {} does not match sequence_name {}",
1174                                name, sequence
1175                            );
1176                            return (
1177                                String::new(),
1178                                None,
1179                                None,
1180                                None,
1181                                None,
1182                                None,
1183                                None,
1184                                None,
1185                                None,
1186                            );
1187                        }
1188                    },
1189                    None => (name.to_string(), None),
1190                };
1191
1192                let masks = match &ann.mask {
1193                    Some(seg) => {
1194                        use polars::series::Series;
1195
1196                        let mut list = Vec::new();
1197                        for polygon in &seg.polygon {
1198                            for &(x, y) in polygon {
1199                                list.push(x);
1200                                list.push(y);
1201                            }
1202                            // Separate polygons with NaN
1203                            list.push(f32::NAN);
1204                        }
1205
1206                        // Remove the last NaN if it exists
1207                        let list = if !list.is_empty() {
1208                            list[..list.len() - 1].to_vec()
1209                        } else {
1210                            vec![]
1211                        };
1212
1213                        Some(Series::new("mask".into(), list))
1214                    }
1215                    None => Option::<Series>::None,
1216                };
1217
1218                let box2d = ann.box2d.as_ref().map(|box2d| {
1219                    Series::new(
1220                        "box2d".into(),
1221                        [box2d.cx(), box2d.cy(), box2d.width(), box2d.height()],
1222                    )
1223                });
1224
1225                let box3d = ann.box3d.as_ref().map(|box3d| {
1226                    Series::new(
1227                        "box3d".into(),
1228                        [box3d.x, box3d.y, box3d.z, box3d.w, box3d.h, box3d.l],
1229                    )
1230                });
1231
1232                (
1233                    name,
1234                    frame,
1235                    ann.object_id.clone(),
1236                    ann.label.clone(),
1237                    ann.label_index,
1238                    ann.group.clone(),
1239                    masks,
1240                    box2d,
1241                    box3d,
1242                )
1243            })
1244            .multiunzip::<(
1245                Vec<_>, // names
1246                Vec<_>, // frames
1247                Vec<_>, // objects
1248                Vec<_>, // labels
1249                Vec<_>, // label_indices
1250                Vec<_>, // groups
1251                Vec<_>, // masks
1252                Vec<_>, // boxes2d
1253                Vec<_>, // boxes3d
1254            )>();
1255    let names = Series::new("name".into(), names).into();
1256    let frames = Series::new("frame".into(), frames).into();
1257    let objects = Series::new("object_id".into(), objects).into();
1258    let labels = Series::new("label".into(), labels)
1259        .cast(&DataType::Categorical(
1260            Categories::new("labels".into(), "labels".into(), CategoricalPhysical::U8),
1261            Arc::new(CategoricalMapping::new(u8::MAX as usize)),
1262        ))
1263        .unwrap()
1264        .into();
1265    let label_indices = Series::new("label_index".into(), label_indices).into();
1266    let groups = Series::new("group".into(), groups)
1267        .cast(&DataType::Categorical(
1268            Categories::new("groups".into(), "groups".into(), CategoricalPhysical::U8),
1269            Arc::new(CategoricalMapping::new(u8::MAX as usize)),
1270        ))
1271        .unwrap()
1272        .into();
1273    let masks = Series::new("mask".into(), masks)
1274        .cast(&DataType::List(Box::new(DataType::Float32)))
1275        .unwrap()
1276        .into();
1277    let boxes2d = Series::new("box2d".into(), boxes2d)
1278        .cast(&DataType::Array(Box::new(DataType::Float32), 4))
1279        .unwrap()
1280        .into();
1281    let boxes3d = Series::new("box3d".into(), boxes3d)
1282        .cast(&DataType::Array(Box::new(DataType::Float32), 6))
1283        .unwrap()
1284        .into();
1285
1286    DataFrame::new(vec![
1287        names,
1288        frames,
1289        objects,
1290        labels,
1291        label_indices,
1292        groups,
1293        masks,
1294        boxes2d,
1295        boxes3d,
1296    ])
1297    .unwrap()
1298}