edgefirst_client/
dataset.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4use std::{collections::HashMap, fmt::Display};
5
6use crate::{
7    Client, Error,
8    api::{AnnotationSetID, DatasetID, ProjectID, SampleID},
9};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12
13#[cfg(feature = "polars")]
14use polars::prelude::*;
15
16/// File types supported in EdgeFirst Studio datasets.
17///
18/// Represents the different types of sensor data files that can be stored
19/// and processed in a dataset. EdgeFirst Studio supports various modalities
20/// including visual images and different forms of LiDAR and radar data.
21///
22/// # Examples
23///
24/// ```rust
25/// use edgefirst_client::FileType;
26///
27/// // Create file types from strings
28/// let image_type = FileType::from("image");
29/// let lidar_type = FileType::from("lidar.pcd");
30///
31/// // Display file types
32/// println!("Processing {} files", image_type); // "Processing image files"
33///
34/// // Use in dataset operations - example usage
35/// let file_type = FileType::Image;
36/// match file_type {
37///     FileType::Image => println!("Processing image files"),
38///     FileType::LidarPcd => println!("Processing LiDAR point cloud files"),
39///     _ => println!("Processing other sensor data"),
40/// }
41/// ```
42#[derive(Clone, Eq, PartialEq, Debug)]
43pub enum FileType {
44    /// Standard image files (JPEG, PNG, etc.)
45    Image,
46    /// LiDAR point cloud data files (.pcd format)
47    LidarPcd,
48    /// LiDAR depth images (.png format)
49    LidarDepth,
50    /// LiDAR reflectance images (.jpg format)
51    LidarReflect,
52    /// Radar point cloud data files (.pcd format)
53    RadarPcd,
54    /// Radar cube data files (.png format)
55    RadarCube,
56}
57
58impl std::fmt::Display for FileType {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        let value = match self {
61            FileType::Image => "image",
62            FileType::LidarPcd => "lidar.pcd",
63            FileType::LidarDepth => "lidar.png",
64            FileType::LidarReflect => "lidar.jpg",
65            FileType::RadarPcd => "radar.pcd",
66            FileType::RadarCube => "radar.png",
67        };
68        write!(f, "{}", value)
69    }
70}
71
72impl From<&str> for FileType {
73    fn from(s: &str) -> Self {
74        match s {
75            "image" => FileType::Image,
76            "lidar.pcd" => FileType::LidarPcd,
77            "lidar.png" => FileType::LidarDepth,
78            "lidar.jpg" => FileType::LidarReflect,
79            "radar.pcd" => FileType::RadarPcd,
80            "radar.png" => FileType::RadarCube,
81            _ => panic!("Invalid file type"),
82        }
83    }
84}
85
86/// Annotation types supported for labeling data in EdgeFirst Studio.
87///
88/// Represents the different types of annotations that can be applied to
89/// sensor data for machine learning tasks. Each type corresponds to a
90/// different annotation geometry and use case.
91///
92/// # Examples
93///
94/// ```rust
95/// use edgefirst_client::AnnotationType;
96///
97/// // Create annotation types from strings
98/// let box_2d = AnnotationType::from("box2d");
99/// let segmentation = AnnotationType::from("mask");
100///
101/// // Display annotation types
102/// println!("Annotation type: {}", box_2d); // "Annotation type: box2d"
103///
104/// // Use in matching and processing
105/// let annotation_type = AnnotationType::Box2d;
106/// match annotation_type {
107///     AnnotationType::Box2d => println!("Processing 2D bounding boxes"),
108///     AnnotationType::Box3d => println!("Processing 3D bounding boxes"),
109///     AnnotationType::Mask => println!("Processing segmentation masks"),
110/// }
111/// ```
112#[derive(Clone, Eq, PartialEq, Debug)]
113pub enum AnnotationType {
114    /// 2D bounding boxes for object detection in images
115    Box2d,
116    /// 3D bounding boxes for object detection in 3D space (LiDAR, etc.)
117    Box3d,
118    /// Pixel-level segmentation masks for semantic/instance segmentation
119    Mask,
120}
121
122impl From<&str> for AnnotationType {
123    fn from(s: &str) -> Self {
124        match s {
125            "box2d" => AnnotationType::Box2d,
126            "box3d" => AnnotationType::Box3d,
127            "mask" => AnnotationType::Mask,
128            _ => panic!("Invalid annotation type"),
129        }
130    }
131}
132
133impl From<String> for AnnotationType {
134    fn from(s: String) -> Self {
135        s.as_str().into()
136    }
137}
138
139impl From<&String> for AnnotationType {
140    fn from(s: &String) -> Self {
141        s.as_str().into()
142    }
143}
144
145impl std::fmt::Display for AnnotationType {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        let value = match self {
148            AnnotationType::Box2d => "box2d",
149            AnnotationType::Box3d => "box3d",
150            AnnotationType::Mask => "mask",
151        };
152        write!(f, "{}", value)
153    }
154}
155
156/// A dataset in EdgeFirst Studio containing sensor data and annotations.
157///
158/// Datasets are collections of multi-modal sensor data (images, LiDAR, radar)
159/// along with their corresponding annotations (bounding boxes, segmentation
160/// masks, 3D annotations). Datasets belong to projects and can be used for
161/// training and validation of machine learning models.
162///
163/// # Features
164///
165/// - **Multi-modal Data**: Support for images, LiDAR point clouds, radar data
166/// - **Rich Annotations**: 2D/3D bounding boxes, segmentation masks
167/// - **Metadata**: Timestamps, sensor configurations, calibration data
168/// - **Version Control**: Track changes and maintain data lineage
169/// - **Format Conversion**: Export to popular ML frameworks
170///
171/// # Examples
172///
173/// ```no_run
174/// use edgefirst_client::{Client, Dataset, DatasetID};
175/// use std::str::FromStr;
176///
177/// # async fn example() -> Result<(), edgefirst_client::Error> {
178/// # let client = Client::new()?;
179/// // Get dataset information
180/// let dataset_id = DatasetID::from_str("ds-abc123")?;
181/// let dataset = client.dataset(dataset_id).await?;
182/// println!("Dataset: {}", dataset.name());
183///
184/// // Access dataset metadata
185/// println!("Dataset ID: {}", dataset.id());
186/// println!("Description: {}", dataset.description());
187/// println!("Created: {}", dataset.created());
188///
189/// // Work with dataset data would require additional methods
190/// // that are implemented in the full API
191/// # Ok(())
192/// # }
193/// ```
194#[derive(Deserialize, Clone, Debug)]
195pub struct Dataset {
196    id: DatasetID,
197    project_id: ProjectID,
198    name: String,
199    description: String,
200    cloud_key: String,
201    #[serde(rename = "createdAt")]
202    created: DateTime<Utc>,
203}
204
205impl Display for Dataset {
206    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
207        write!(f, "{} {}", self.uid(), self.name)
208    }
209}
210
211impl Dataset {
212    pub fn id(&self) -> DatasetID {
213        self.id
214    }
215
216    pub fn uid(&self) -> String {
217        self.id.to_string()
218    }
219
220    pub fn project_id(&self) -> ProjectID {
221        self.project_id
222    }
223
224    pub fn name(&self) -> &str {
225        &self.name
226    }
227
228    pub fn description(&self) -> &str {
229        &self.description
230    }
231
232    pub fn cloud_key(&self) -> &str {
233        &self.cloud_key
234    }
235
236    pub fn created(&self) -> &DateTime<Utc> {
237        &self.created
238    }
239
240    pub async fn project(&self, client: &Client) -> Result<crate::api::Project, Error> {
241        client.project(self.project_id).await
242    }
243
244    pub async fn annotation_sets(&self, client: &Client) -> Result<Vec<AnnotationSet>, Error> {
245        client.annotation_sets(self.id).await
246    }
247
248    pub async fn labels(&self, client: &Client) -> Result<Vec<Label>, Error> {
249        client.labels(self.id).await
250    }
251
252    pub async fn add_label(&self, client: &Client, name: &str) -> Result<(), Error> {
253        client.add_label(self.id, name).await
254    }
255
256    pub async fn remove_label(&self, client: &Client, name: &str) -> Result<(), Error> {
257        let labels = self.labels(client).await?;
258        let label = labels
259            .iter()
260            .find(|l| l.name() == name)
261            .ok_or_else(|| Error::MissingLabel(name.to_string()))?;
262        client.remove_label(label.id()).await
263    }
264}
265
266/// The AnnotationSet class represents a collection of annotations in a dataset.
267/// A dataset can have multiple annotation sets, each containing annotations for
268/// different tasks or purposes.
269#[derive(Deserialize)]
270pub struct AnnotationSet {
271    id: AnnotationSetID,
272    dataset_id: DatasetID,
273    name: String,
274    description: String,
275    #[serde(rename = "date")]
276    created: DateTime<Utc>,
277}
278
279impl Display for AnnotationSet {
280    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
281        write!(f, "{} {}", self.uid(), self.name)
282    }
283}
284
285impl AnnotationSet {
286    pub fn id(&self) -> AnnotationSetID {
287        self.id
288    }
289
290    pub fn uid(&self) -> String {
291        self.id.to_string()
292    }
293
294    pub fn dataset_id(&self) -> DatasetID {
295        self.dataset_id
296    }
297
298    pub fn name(&self) -> &str {
299        &self.name
300    }
301
302    pub fn description(&self) -> &str {
303        &self.description
304    }
305
306    pub fn created(&self) -> DateTime<Utc> {
307        self.created
308    }
309
310    pub async fn dataset(&self, client: &Client) -> Result<Dataset, Error> {
311        client.dataset(self.dataset_id).await
312    }
313}
314
315/// A sample in a dataset, typically representing a single image with metadata
316/// and optional sensor data.
317///
318/// Each sample has a unique ID, image reference, and can include additional
319/// sensor data like LiDAR, radar, or depth maps. Samples can also have
320/// associated annotations.
321#[derive(Serialize, Deserialize, Clone, Debug)]
322pub struct Sample {
323    #[serde(skip_serializing_if = "Option::is_none")]
324    pub id: Option<SampleID>,
325    #[serde(alias = "group_name", skip_serializing_if = "Option::is_none")]
326    pub group: Option<String>,
327    #[serde(skip_serializing_if = "Option::is_none")]
328    pub sequence_name: Option<String>,
329    #[serde(skip_serializing_if = "Option::is_none")]
330    pub sequence_uuid: Option<String>,
331    #[serde(skip_serializing_if = "Option::is_none")]
332    pub sequence_description: Option<String>,
333    #[serde(skip_serializing_if = "Option::is_none")]
334    pub frame_number: Option<u32>,
335    #[serde(skip_serializing_if = "Option::is_none")]
336    pub uuid: Option<String>,
337    #[serde(skip_serializing_if = "Option::is_none")]
338    pub image_name: Option<String>,
339    #[serde(skip_serializing_if = "Option::is_none")]
340    pub image_url: Option<String>,
341    #[serde(skip_serializing_if = "Option::is_none")]
342    pub width: Option<u32>,
343    #[serde(skip_serializing_if = "Option::is_none")]
344    pub height: Option<u32>,
345    #[serde(skip_serializing_if = "Option::is_none")]
346    pub date: Option<DateTime<Utc>>,
347    #[serde(skip_serializing_if = "Option::is_none")]
348    pub source: Option<String>,
349    /// Camera location and pose (GPS + IMU data).
350    /// Serialized as "sensors" for API compatibility with populate endpoint.
351    #[serde(rename = "sensors", skip_serializing_if = "Option::is_none")]
352    pub location: Option<Location>,
353    /// Additional sensor files (LiDAR, radar, depth maps, etc.).
354    /// When deserializing from samples.list: Vec<SampleFile>
355    /// When serializing for samples.populate: HashMap<String, String>
356    /// (file_type -> filename)
357    #[serde(
358        default,
359        skip_serializing_if = "Vec::is_empty",
360        serialize_with = "serialize_files",
361        deserialize_with = "deserialize_files"
362    )]
363    pub files: Vec<SampleFile>,
364    #[serde(
365        default,
366        skip_serializing_if = "Vec::is_empty",
367        serialize_with = "serialize_annotations",
368        deserialize_with = "deserialize_annotations"
369    )]
370    pub annotations: Vec<Annotation>,
371}
372
373// Custom serializer for files field - converts Vec<SampleFile> to
374// HashMap<String, String>
375fn serialize_files<S>(files: &[SampleFile], serializer: S) -> Result<S::Ok, S::Error>
376where
377    S: serde::Serializer,
378{
379    use serde::Serialize;
380    let map: HashMap<String, String> = files
381        .iter()
382        .filter_map(|f| {
383            f.filename()
384                .map(|filename| (f.file_type().to_string(), filename.to_string()))
385        })
386        .collect();
387    map.serialize(serializer)
388}
389
390// Custom deserializer for files field - converts HashMap or Vec to
391// Vec<SampleFile>
392fn deserialize_files<'de, D>(deserializer: D) -> Result<Vec<SampleFile>, D::Error>
393where
394    D: serde::Deserializer<'de>,
395{
396    use serde::Deserialize;
397
398    #[derive(Deserialize)]
399    #[serde(untagged)]
400    enum FilesFormat {
401        Vec(Vec<SampleFile>),
402        Map(HashMap<String, String>),
403    }
404
405    let value = Option::<FilesFormat>::deserialize(deserializer)?;
406    Ok(value
407        .map(|v| match v {
408            FilesFormat::Vec(files) => files,
409            FilesFormat::Map(map) => map
410                .into_iter()
411                .map(|(file_type, filename)| SampleFile::with_filename(file_type, filename))
412                .collect(),
413        })
414        .unwrap_or_default())
415}
416
417// Custom serializer for annotations field - converts Vec<Annotation> to
418// format expected by server: {"bbox": [...], "box3d": [...], "mask": [...]}
419fn serialize_annotations<S>(annotations: &Vec<Annotation>, serializer: S) -> Result<S::Ok, S::Error>
420where
421    S: serde::Serializer,
422{
423    use serde::ser::SerializeMap;
424
425    // Group annotations by type
426    let mut bbox_annotations = Vec::new();
427    let mut box3d_annotations = Vec::new();
428    let mut mask_annotations = Vec::new();
429
430    for ann in annotations {
431        if ann.box2d().is_some() {
432            bbox_annotations.push(ann);
433        } else if ann.box3d().is_some() {
434            box3d_annotations.push(ann);
435        } else if ann.mask().is_some() {
436            mask_annotations.push(ann);
437        }
438    }
439
440    let mut map = serializer.serialize_map(Some(3))?;
441
442    if !bbox_annotations.is_empty() {
443        map.serialize_entry("bbox", &bbox_annotations)?;
444    }
445    if !box3d_annotations.is_empty() {
446        map.serialize_entry("box3d", &box3d_annotations)?;
447    }
448    if !mask_annotations.is_empty() {
449        map.serialize_entry("mask", &mask_annotations)?;
450    }
451
452    map.end()
453}
454
455// Custom deserializer for annotations field - converts server format back to
456// Vec<Annotation>
457fn deserialize_annotations<'de, D>(deserializer: D) -> Result<Vec<Annotation>, D::Error>
458where
459    D: serde::Deserializer<'de>,
460{
461    use serde::Deserialize;
462
463    #[derive(Deserialize)]
464    #[serde(untagged)]
465    enum AnnotationsFormat {
466        Vec(Vec<Annotation>),
467        Map(HashMap<String, Vec<Annotation>>),
468    }
469
470    let value = Option::<AnnotationsFormat>::deserialize(deserializer)?;
471    Ok(value
472        .map(|v| match v {
473            AnnotationsFormat::Vec(annotations) => annotations,
474            AnnotationsFormat::Map(map) => {
475                let mut all_annotations = Vec::new();
476                if let Some(bbox_anns) = map.get("bbox") {
477                    all_annotations.extend(bbox_anns.clone());
478                }
479                if let Some(box3d_anns) = map.get("box3d") {
480                    all_annotations.extend(box3d_anns.clone());
481                }
482                if let Some(mask_anns) = map.get("mask") {
483                    all_annotations.extend(mask_anns.clone());
484                }
485                all_annotations
486            }
487        })
488        .unwrap_or_default())
489}
490
491impl Display for Sample {
492    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
493        write!(
494            f,
495            "{} {}",
496            self.uid().unwrap_or_else(|| "unknown".to_string()),
497            self.image_name().unwrap_or("unknown")
498        )
499    }
500}
501
502impl Default for Sample {
503    fn default() -> Self {
504        Self::new()
505    }
506}
507
508impl Sample {
509    /// Creates a new empty sample.
510    pub fn new() -> Self {
511        Self {
512            id: None,
513            group: None,
514            sequence_name: None,
515            sequence_uuid: None,
516            sequence_description: None,
517            frame_number: None,
518            uuid: None,
519            image_name: None,
520            image_url: None,
521            width: None,
522            height: None,
523            date: None,
524            source: None,
525            location: None,
526            files: vec![],
527            annotations: vec![],
528        }
529    }
530
531    pub fn id(&self) -> Option<SampleID> {
532        self.id
533    }
534
535    pub fn uid(&self) -> Option<String> {
536        self.id.map(|id| id.to_string())
537    }
538
539    pub fn name(&self) -> Option<String> {
540        self.image_name.as_ref().map(|image_name| {
541            let name = image_name
542                .rsplit_once('.')
543                .map_or_else(|| image_name.clone(), |(name, _)| name.to_string());
544            name.rsplit_once(".camera")
545                .map_or_else(|| name.clone(), |(name, _)| name.to_string())
546        })
547    }
548
549    pub fn group(&self) -> Option<&String> {
550        self.group.as_ref()
551    }
552
553    pub fn sequence_name(&self) -> Option<&String> {
554        self.sequence_name.as_ref()
555    }
556
557    pub fn image_name(&self) -> Option<&str> {
558        self.image_name.as_deref()
559    }
560
561    pub fn image_url(&self) -> Option<&str> {
562        self.image_url.as_deref()
563    }
564
565    pub fn files(&self) -> &[SampleFile] {
566        &self.files
567    }
568
569    pub fn annotations(&self) -> &[Annotation] {
570        &self.annotations
571    }
572
573    pub fn with_annotations(mut self, annotations: Vec<Annotation>) -> Self {
574        self.annotations = annotations;
575        self
576    }
577
578    pub async fn download(
579        &self,
580        client: &Client,
581        file_type: FileType,
582    ) -> Result<Option<Vec<u8>>, Error> {
583        let url = match file_type {
584            FileType::Image => self.image_url.as_ref(),
585            file => self
586                .files
587                .iter()
588                .find(|f| f.r#type == file.to_string())
589                .and_then(|f| f.url.as_ref()),
590        };
591
592        Ok(match url {
593            Some(url) => Some(client.download(url).await?),
594            None => None,
595        })
596    }
597}
598
599/// A file associated with a sample (e.g., LiDAR point cloud, radar data).
600///
601/// For samples retrieved from the server, this contains the file type and URL.
602/// For samples being populated to the server, this can be a type and filename.
603#[derive(Serialize, Deserialize, Clone, Debug)]
604pub struct SampleFile {
605    r#type: String,
606    #[serde(skip_serializing_if = "Option::is_none")]
607    url: Option<String>,
608    #[serde(skip_serializing_if = "Option::is_none")]
609    filename: Option<String>,
610}
611
612impl SampleFile {
613    /// Creates a new sample file with type and URL (for downloaded samples).
614    pub fn with_url(file_type: String, url: String) -> Self {
615        Self {
616            r#type: file_type,
617            url: Some(url),
618            filename: None,
619        }
620    }
621
622    /// Creates a new sample file with type and filename (for populate API).
623    pub fn with_filename(file_type: String, filename: String) -> Self {
624        Self {
625            r#type: file_type,
626            url: None,
627            filename: Some(filename),
628        }
629    }
630
631    pub fn file_type(&self) -> &str {
632        &self.r#type
633    }
634
635    pub fn url(&self) -> Option<&str> {
636        self.url.as_deref()
637    }
638
639    pub fn filename(&self) -> Option<&str> {
640        self.filename.as_deref()
641    }
642}
643
644/// Location and pose information for a sample.
645///
646/// Contains GPS coordinates and IMU orientation data describing where and how
647/// the camera was positioned when capturing the sample.
648#[derive(Serialize, Deserialize, Clone, Debug)]
649pub struct Location {
650    #[serde(skip_serializing_if = "Option::is_none")]
651    pub gps: Option<GpsData>,
652    #[serde(skip_serializing_if = "Option::is_none")]
653    pub imu: Option<ImuData>,
654}
655
656/// GPS location data (latitude and longitude).
657#[derive(Serialize, Deserialize, Clone, Debug)]
658pub struct GpsData {
659    pub lat: f64,
660    pub lon: f64,
661}
662
663/// IMU orientation data (roll, pitch, yaw in degrees).
664#[derive(Serialize, Deserialize, Clone, Debug)]
665pub struct ImuData {
666    pub roll: f64,
667    pub pitch: f64,
668    pub yaw: f64,
669}
670
671pub trait TypeName {
672    fn type_name() -> String;
673}
674
675#[derive(Serialize, Deserialize, Clone, Debug)]
676pub struct Box3d {
677    x: f32,
678    y: f32,
679    z: f32,
680    w: f32,
681    h: f32,
682    l: f32,
683}
684
685impl TypeName for Box3d {
686    fn type_name() -> String {
687        "box3d".to_owned()
688    }
689}
690
691impl Box3d {
692    pub fn new(cx: f32, cy: f32, cz: f32, width: f32, height: f32, length: f32) -> Self {
693        Self {
694            x: cx,
695            y: cy,
696            z: cz,
697            w: width,
698            h: height,
699            l: length,
700        }
701    }
702
703    pub fn width(&self) -> f32 {
704        self.w
705    }
706
707    pub fn height(&self) -> f32 {
708        self.h
709    }
710
711    pub fn length(&self) -> f32 {
712        self.l
713    }
714
715    pub fn cx(&self) -> f32 {
716        self.x
717    }
718
719    pub fn cy(&self) -> f32 {
720        self.y
721    }
722
723    pub fn cz(&self) -> f32 {
724        self.z
725    }
726
727    pub fn left(&self) -> f32 {
728        self.x - self.w / 2.0
729    }
730
731    pub fn top(&self) -> f32 {
732        self.y - self.h / 2.0
733    }
734
735    pub fn front(&self) -> f32 {
736        self.z - self.l / 2.0
737    }
738}
739
740#[derive(Serialize, Deserialize, Clone, Debug)]
741pub struct Box2d {
742    h: f32,
743    w: f32,
744    x: f32,
745    y: f32,
746}
747
748impl TypeName for Box2d {
749    fn type_name() -> String {
750        "box2d".to_owned()
751    }
752}
753
754impl Box2d {
755    pub fn new(left: f32, top: f32, width: f32, height: f32) -> Self {
756        Self {
757            x: left,
758            y: top,
759            w: width,
760            h: height,
761        }
762    }
763
764    pub fn width(&self) -> f32 {
765        self.w
766    }
767
768    pub fn height(&self) -> f32 {
769        self.h
770    }
771
772    pub fn left(&self) -> f32 {
773        self.x
774    }
775
776    pub fn top(&self) -> f32 {
777        self.y
778    }
779
780    pub fn cx(&self) -> f32 {
781        self.x + self.w / 2.0
782    }
783
784    pub fn cy(&self) -> f32 {
785        self.y + self.h / 2.0
786    }
787}
788
789#[derive(Serialize, Deserialize, Clone, Debug)]
790pub struct Mask {
791    pub polygon: Vec<Vec<(f32, f32)>>,
792}
793
794impl TypeName for Mask {
795    fn type_name() -> String {
796        "mask".to_owned()
797    }
798}
799
800impl Mask {
801    pub fn new(polygon: Vec<Vec<(f32, f32)>>) -> Self {
802        Self { polygon }
803    }
804}
805
806#[derive(Deserialize, Clone, Debug)]
807#[serde(from = "AnnotationHelper")]
808pub struct Annotation {
809    #[serde(skip_serializing_if = "Option::is_none")]
810    sample_id: Option<SampleID>,
811    #[serde(skip_serializing_if = "Option::is_none")]
812    name: Option<String>,
813    #[serde(skip_serializing_if = "Option::is_none")]
814    sequence_name: Option<String>,
815    #[serde(skip_serializing_if = "Option::is_none")]
816    group: Option<String>,
817    #[serde(rename = "object_reference", skip_serializing_if = "Option::is_none")]
818    object_id: Option<String>,
819    #[serde(rename = "label_name", skip_serializing_if = "Option::is_none")]
820    label: Option<String>,
821    #[serde(skip_serializing_if = "Option::is_none")]
822    label_index: Option<u64>,
823    #[serde(skip_serializing_if = "Option::is_none")]
824    box2d: Option<Box2d>,
825    #[serde(skip_serializing_if = "Option::is_none")]
826    box3d: Option<Box3d>,
827    #[serde(skip_serializing_if = "Option::is_none")]
828    mask: Option<Mask>,
829}
830
831// Helper struct for deserialization that matches the nested format
832#[derive(Deserialize)]
833struct AnnotationHelper {
834    #[serde(skip_serializing_if = "Option::is_none")]
835    sample_id: Option<SampleID>,
836    #[serde(skip_serializing_if = "Option::is_none")]
837    name: Option<String>,
838    #[serde(skip_serializing_if = "Option::is_none")]
839    sequence_name: Option<String>,
840    #[serde(skip_serializing_if = "Option::is_none")]
841    group: Option<String>,
842    #[serde(rename = "object_reference", skip_serializing_if = "Option::is_none")]
843    object_id: Option<String>,
844    #[serde(rename = "label_name", skip_serializing_if = "Option::is_none")]
845    label: Option<String>,
846    #[serde(skip_serializing_if = "Option::is_none")]
847    label_index: Option<u64>,
848    #[serde(skip_serializing_if = "Option::is_none")]
849    box2d: Option<Box2d>,
850    #[serde(skip_serializing_if = "Option::is_none")]
851    box3d: Option<Box3d>,
852    #[serde(skip_serializing_if = "Option::is_none")]
853    mask: Option<Mask>,
854}
855
856impl From<AnnotationHelper> for Annotation {
857    fn from(helper: AnnotationHelper) -> Self {
858        Self {
859            sample_id: helper.sample_id,
860            name: helper.name,
861            sequence_name: helper.sequence_name,
862            group: helper.group,
863            object_id: helper.object_id,
864            label: helper.label,
865            label_index: helper.label_index,
866            box2d: helper.box2d,
867            box3d: helper.box3d,
868            mask: helper.mask,
869        }
870    }
871}
872
873// Custom serializer that flattens box2d/box3d fields
874impl Serialize for Annotation {
875    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
876    where
877        S: serde::Serializer,
878    {
879        use serde::ser::SerializeMap;
880
881        let mut map = serializer.serialize_map(None)?;
882
883        if let Some(ref sample_id) = self.sample_id {
884            map.serialize_entry("sample_id", sample_id)?;
885        }
886        if let Some(ref name) = self.name {
887            map.serialize_entry("name", name)?;
888        }
889        if let Some(ref sequence_name) = self.sequence_name {
890            map.serialize_entry("sequence_name", sequence_name)?;
891        }
892        if let Some(ref group) = self.group {
893            map.serialize_entry("group", group)?;
894        }
895        if let Some(ref object_id) = self.object_id {
896            map.serialize_entry("object_reference", object_id)?;
897        }
898        if let Some(ref label) = self.label {
899            map.serialize_entry("label_name", label)?;
900        }
901        if let Some(label_index) = self.label_index {
902            map.serialize_entry("label_index", &label_index)?;
903        }
904
905        // Flatten box2d fields
906        if let Some(ref box2d) = self.box2d {
907            map.serialize_entry("x", &box2d.x)?;
908            map.serialize_entry("y", &box2d.y)?;
909            map.serialize_entry("w", &box2d.w)?;
910            map.serialize_entry("h", &box2d.h)?;
911        }
912
913        // Flatten box3d fields
914        if let Some(ref box3d) = self.box3d {
915            map.serialize_entry("x", &box3d.x)?;
916            map.serialize_entry("y", &box3d.y)?;
917            map.serialize_entry("z", &box3d.z)?;
918            map.serialize_entry("w", &box3d.w)?;
919            map.serialize_entry("h", &box3d.h)?;
920            map.serialize_entry("l", &box3d.l)?;
921        }
922
923        if let Some(ref mask) = self.mask {
924            map.serialize_entry("mask", mask)?;
925        }
926
927        map.end()
928    }
929}
930
931impl Default for Annotation {
932    fn default() -> Self {
933        Self::new()
934    }
935}
936
937impl Annotation {
938    pub fn new() -> Self {
939        Self {
940            sample_id: None,
941            name: None,
942            sequence_name: None,
943            group: None,
944            object_id: None,
945            label: None,
946            label_index: None,
947            box2d: None,
948            box3d: None,
949            mask: None,
950        }
951    }
952
953    pub fn set_sample_id(&mut self, sample_id: Option<SampleID>) {
954        self.sample_id = sample_id;
955    }
956
957    pub fn sample_id(&self) -> Option<SampleID> {
958        self.sample_id
959    }
960
961    pub fn set_name(&mut self, name: Option<String>) {
962        self.name = name;
963    }
964
965    pub fn name(&self) -> Option<&String> {
966        self.name.as_ref()
967    }
968
969    pub fn set_sequence_name(&mut self, sequence_name: Option<String>) {
970        self.sequence_name = sequence_name;
971    }
972
973    pub fn sequence_name(&self) -> Option<&String> {
974        self.sequence_name.as_ref()
975    }
976
977    pub fn set_group(&mut self, group: Option<String>) {
978        self.group = group;
979    }
980
981    pub fn group(&self) -> Option<&String> {
982        self.group.as_ref()
983    }
984
985    pub fn object_id(&self) -> Option<&String> {
986        self.object_id.as_ref()
987    }
988
989    pub fn set_object_id(&mut self, object_id: Option<String>) {
990        self.object_id = object_id;
991    }
992
993    pub fn label(&self) -> Option<&String> {
994        self.label.as_ref()
995    }
996
997    pub fn set_label(&mut self, label: Option<String>) {
998        self.label = label;
999    }
1000
1001    pub fn label_index(&self) -> Option<u64> {
1002        self.label_index
1003    }
1004
1005    pub fn set_label_index(&mut self, label_index: Option<u64>) {
1006        self.label_index = label_index;
1007    }
1008
1009    pub fn box2d(&self) -> Option<&Box2d> {
1010        self.box2d.as_ref()
1011    }
1012
1013    pub fn set_box2d(&mut self, box2d: Option<Box2d>) {
1014        self.box2d = box2d;
1015    }
1016
1017    pub fn box3d(&self) -> Option<&Box3d> {
1018        self.box3d.as_ref()
1019    }
1020
1021    pub fn set_box3d(&mut self, box3d: Option<Box3d>) {
1022        self.box3d = box3d;
1023    }
1024
1025    pub fn mask(&self) -> Option<&Mask> {
1026        self.mask.as_ref()
1027    }
1028
1029    pub fn set_mask(&mut self, mask: Option<Mask>) {
1030        self.mask = mask;
1031    }
1032}
1033
1034#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
1035pub struct Label {
1036    id: u64,
1037    dataset_id: DatasetID,
1038    index: u64,
1039    name: String,
1040}
1041
1042impl Label {
1043    pub fn id(&self) -> u64 {
1044        self.id
1045    }
1046
1047    pub fn dataset_id(&self) -> DatasetID {
1048        self.dataset_id
1049    }
1050
1051    pub fn index(&self) -> u64 {
1052        self.index
1053    }
1054
1055    pub fn name(&self) -> &str {
1056        &self.name
1057    }
1058
1059    pub async fn remove(&self, client: &Client) -> Result<(), Error> {
1060        client.remove_label(self.id()).await
1061    }
1062
1063    pub async fn set_name(&mut self, client: &Client, name: &str) -> Result<(), Error> {
1064        self.name = name.to_string();
1065        client.update_label(self).await
1066    }
1067
1068    pub async fn set_index(&mut self, client: &Client, index: u64) -> Result<(), Error> {
1069        self.index = index;
1070        client.update_label(self).await
1071    }
1072}
1073
1074impl Display for Label {
1075    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1076        write!(f, "{}", self.name())
1077    }
1078}
1079
1080#[derive(Serialize, Clone, Debug)]
1081pub struct NewLabelObject {
1082    pub name: String,
1083}
1084
1085#[derive(Serialize, Clone, Debug)]
1086pub struct NewLabel {
1087    pub dataset_id: DatasetID,
1088    pub labels: Vec<NewLabelObject>,
1089}
1090
1091#[derive(Deserialize, Clone, Debug)]
1092pub struct Group {
1093    pub id: u64, // Groups seem to use raw u64, not a specific ID type
1094    pub name: String,
1095}
1096
1097#[cfg(feature = "polars")]
1098pub fn annotations_dataframe(annotations: &[Annotation]) -> DataFrame {
1099    use itertools::Itertools;
1100    use log::warn;
1101    use std::path::Path;
1102
1103    let (names, frames, objects, labels, label_indices, groups, masks, boxes2d, boxes3d) =
1104        annotations
1105            .iter()
1106            .map(|ann| {
1107                let name = match &ann.name {
1108                    Some(name) => name,
1109                    None => {
1110                        warn!("annotation missing image name, skipping");
1111                        return (
1112                            String::new(),
1113                            None,
1114                            None,
1115                            None,
1116                            None,
1117                            None,
1118                            None,
1119                            None,
1120                            None,
1121                        );
1122                    }
1123                };
1124
1125                let name = Path::new(name).file_stem().unwrap().to_str().unwrap();
1126
1127                let (name, frame) = match &ann.sequence_name {
1128                    Some(sequence) => match name.strip_prefix(sequence) {
1129                        Some(frame) => (
1130                            sequence.to_string(),
1131                            Some(frame.trim_start_matches('_').to_string()),
1132                        ),
1133                        None => {
1134                            warn!(
1135                                "image_name {} does not match sequence_name {}",
1136                                name, sequence
1137                            );
1138                            return (
1139                                String::new(),
1140                                None,
1141                                None,
1142                                None,
1143                                None,
1144                                None,
1145                                None,
1146                                None,
1147                                None,
1148                            );
1149                        }
1150                    },
1151                    None => (name.to_string(), None),
1152                };
1153
1154                let masks = match &ann.mask {
1155                    Some(seg) => {
1156                        use polars::series::Series;
1157
1158                        let mut list = Vec::new();
1159                        for polygon in &seg.polygon {
1160                            for &(x, y) in polygon {
1161                                list.push(x);
1162                                list.push(y);
1163                            }
1164                            // Separate polygons with NaN
1165                            list.push(f32::NAN);
1166                        }
1167
1168                        // Remove the last NaN if it exists
1169                        let list = if !list.is_empty() {
1170                            list[..list.len() - 1].to_vec()
1171                        } else {
1172                            vec![]
1173                        };
1174
1175                        Some(Series::new("mask".into(), list))
1176                    }
1177                    None => Option::<Series>::None,
1178                };
1179
1180                let box2d = ann.box2d.as_ref().map(|box2d| {
1181                    Series::new(
1182                        "box2d".into(),
1183                        [box2d.cx(), box2d.cy(), box2d.width(), box2d.height()],
1184                    )
1185                });
1186
1187                let box3d = ann.box3d.as_ref().map(|box3d| {
1188                    Series::new(
1189                        "box3d".into(),
1190                        [box3d.x, box3d.y, box3d.z, box3d.w, box3d.h, box3d.l],
1191                    )
1192                });
1193
1194                (
1195                    name,
1196                    frame,
1197                    ann.object_id.clone(),
1198                    ann.label.clone(),
1199                    ann.label_index,
1200                    ann.group.clone(),
1201                    masks,
1202                    box2d,
1203                    box3d,
1204                )
1205            })
1206            .multiunzip::<(
1207                Vec<_>, // names
1208                Vec<_>, // frames
1209                Vec<_>, // objects
1210                Vec<_>, // labels
1211                Vec<_>, // label_indices
1212                Vec<_>, // groups
1213                Vec<_>, // masks
1214                Vec<_>, // boxes2d
1215                Vec<_>, // boxes3d
1216            )>();
1217    let names = Series::new("name".into(), names).into();
1218    let frames = Series::new("frame".into(), frames).into();
1219    let objects = Series::new("object_id".into(), objects).into();
1220    let labels = Series::new("label".into(), labels)
1221        .cast(&DataType::Categorical(
1222            Categories::new("labels".into(), "labels".into(), CategoricalPhysical::U8),
1223            Arc::new(CategoricalMapping::new(u8::MAX as usize)),
1224        ))
1225        .unwrap()
1226        .into();
1227    let label_indices = Series::new("label_index".into(), label_indices).into();
1228    let groups = Series::new("group".into(), groups)
1229        .cast(&DataType::Categorical(
1230            Categories::new("groups".into(), "groups".into(), CategoricalPhysical::U8),
1231            Arc::new(CategoricalMapping::new(u8::MAX as usize)),
1232        ))
1233        .unwrap()
1234        .into();
1235    let masks = Series::new("mask".into(), masks)
1236        .cast(&DataType::List(Box::new(DataType::Float32)))
1237        .unwrap()
1238        .into();
1239    let boxes2d = Series::new("box2d".into(), boxes2d)
1240        .cast(&DataType::Array(Box::new(DataType::Float32), 4))
1241        .unwrap()
1242        .into();
1243    let boxes3d = Series::new("box3d".into(), boxes3d)
1244        .cast(&DataType::Array(Box::new(DataType::Float32), 6))
1245        .unwrap()
1246        .into();
1247
1248    DataFrame::new(vec![
1249        names,
1250        frames,
1251        objects,
1252        labels,
1253        label_indices,
1254        groups,
1255        masks,
1256        boxes2d,
1257        boxes3d,
1258    ])
1259    .unwrap()
1260}