use std::{collections::HashMap, fmt::Display};
use crate::{
Client, Error,
api::{AnnotationSetID, DatasetID, ProjectID, SampleID},
mask::MaskData,
};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
#[cfg(feature = "polars")]
use polars::prelude::*;
#[derive(Clone, Eq, PartialEq, Debug)]
pub enum FileType {
Image,
LidarPcd,
LidarDepth,
LidarReflect,
RadarPcd,
RadarCube,
All,
}
impl std::fmt::Display for FileType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let value = match self {
FileType::Image => "image",
FileType::LidarPcd => "lidar.pcd",
FileType::LidarDepth => "lidar.depth",
FileType::LidarReflect => "lidar.reflect",
FileType::RadarPcd => "radar.pcd",
FileType::RadarCube => "radar.png",
FileType::All => "all",
};
write!(f, "{}", value)
}
}
impl FileType {
pub fn file_extension(&self) -> &'static str {
match self {
FileType::Image => "jpg", FileType::LidarPcd => "lidar.pcd",
FileType::LidarDepth => "lidar.png",
FileType::LidarReflect => "lidar.jpg",
FileType::RadarPcd => "radar.pcd",
FileType::RadarCube => "radar.png",
FileType::All => "",
}
}
}
impl TryFrom<&str> for FileType {
type Error = crate::Error;
fn try_from(s: &str) -> Result<Self, Self::Error> {
match s {
"image" => Ok(FileType::Image),
"lidar.pcd" => Ok(FileType::LidarPcd),
"lidar.png" | "lidar.depth" | "depth.png" | "depthmap" => Ok(FileType::LidarDepth),
"lidar.jpg" | "lidar.jpeg" | "lidar.reflect" => Ok(FileType::LidarReflect),
"radar.pcd" | "pcd" => Ok(FileType::RadarPcd),
"radar.png" | "cube" => Ok(FileType::RadarCube),
"all" => Ok(FileType::All),
_ => Err(crate::Error::InvalidFileType(s.to_string())),
}
}
}
impl std::str::FromStr for FileType {
type Err = crate::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
s.try_into()
}
}
impl FileType {
pub fn all_sensor_types() -> Vec<FileType> {
vec![
FileType::Image,
FileType::LidarPcd,
FileType::LidarDepth,
FileType::LidarReflect,
FileType::RadarPcd,
FileType::RadarCube,
]
}
pub fn type_names() -> Vec<&'static str> {
vec![
"image",
"lidar.pcd",
"lidar.png",
"lidar.jpg",
"radar.pcd",
"radar.png",
"all",
]
}
pub fn expand_types(types: &[FileType]) -> Vec<FileType> {
if types.contains(&FileType::All) {
FileType::all_sensor_types()
} else {
types.to_vec()
}
}
}
#[derive(Clone, Eq, PartialEq, Debug)]
pub enum AnnotationType {
Box2d,
Box3d,
Polygon,
Mask,
}
impl TryFrom<&str> for AnnotationType {
type Error = crate::Error;
fn try_from(s: &str) -> Result<Self, Self::Error> {
match s {
"box2d" => Ok(AnnotationType::Box2d),
"box3d" => Ok(AnnotationType::Box3d),
"polygon" => Ok(AnnotationType::Polygon),
"seg" => Ok(AnnotationType::Polygon),
"mask" => Ok(AnnotationType::Polygon), "raster" => Ok(AnnotationType::Mask),
_ => Err(crate::Error::InvalidAnnotationType(s.to_string())),
}
}
}
impl From<String> for AnnotationType {
fn from(s: String) -> Self {
s.as_str().try_into().unwrap_or(AnnotationType::Box2d)
}
}
impl From<&String> for AnnotationType {
fn from(s: &String) -> Self {
s.as_str().try_into().unwrap_or(AnnotationType::Box2d)
}
}
impl AnnotationType {
pub fn as_server_type(&self) -> &'static str {
match self {
AnnotationType::Box2d => "box",
AnnotationType::Box3d => "box3d",
AnnotationType::Polygon => "seg",
AnnotationType::Mask => "seg",
}
}
}
impl std::fmt::Display for AnnotationType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let value = match self {
AnnotationType::Box2d => "box2d",
AnnotationType::Box3d => "box3d",
AnnotationType::Polygon => "polygon",
AnnotationType::Mask => "mask",
};
write!(f, "{}", value)
}
}
#[derive(Deserialize, Clone, Debug)]
pub struct Dataset {
id: DatasetID,
project_id: ProjectID,
name: String,
description: String,
cloud_key: String,
#[serde(rename = "createdAt")]
created: DateTime<Utc>,
}
impl Display for Dataset {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{} {}", self.id, self.name)
}
}
impl Dataset {
pub fn id(&self) -> DatasetID {
self.id
}
pub fn project_id(&self) -> ProjectID {
self.project_id
}
pub fn name(&self) -> &str {
&self.name
}
pub fn description(&self) -> &str {
&self.description
}
pub fn cloud_key(&self) -> &str {
&self.cloud_key
}
pub fn created(&self) -> &DateTime<Utc> {
&self.created
}
pub async fn project(&self, client: &Client) -> Result<crate::api::Project, Error> {
client.project(self.project_id).await
}
pub async fn annotation_sets(&self, client: &Client) -> Result<Vec<AnnotationSet>, Error> {
client.annotation_sets(self.id).await
}
pub async fn labels(&self, client: &Client) -> Result<Vec<Label>, Error> {
client.labels(self.id).await
}
pub async fn add_label(&self, client: &Client, name: &str) -> Result<(), Error> {
client.add_label(self.id, name).await
}
pub async fn remove_label(&self, client: &Client, name: &str) -> Result<(), Error> {
let labels = self.labels(client).await?;
let label = labels
.iter()
.find(|l| l.name() == name)
.ok_or_else(|| Error::MissingLabel(name.to_string()))?;
client.remove_label(label.id()).await
}
}
#[derive(Deserialize)]
pub struct AnnotationSet {
id: AnnotationSetID,
dataset_id: DatasetID,
name: String,
description: String,
#[serde(rename = "date")]
created: DateTime<Utc>,
}
impl Display for AnnotationSet {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{} {}", self.id, self.name)
}
}
impl AnnotationSet {
pub fn id(&self) -> AnnotationSetID {
self.id
}
pub fn dataset_id(&self) -> DatasetID {
self.dataset_id
}
pub fn name(&self) -> &str {
&self.name
}
pub fn description(&self) -> &str {
&self.description
}
pub fn created(&self) -> DateTime<Utc> {
self.created
}
pub async fn dataset(&self, client: &Client) -> Result<Dataset, Error> {
client.dataset(self.dataset_id).await
}
}
#[derive(Clone, Debug, Default, PartialEq)]
pub struct Timing {
pub load: Option<i64>,
pub preprocess: Option<i64>,
pub inference: Option<i64>,
pub decode: Option<i64>,
}
#[derive(Serialize, Clone, Debug)]
pub struct Sample {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<SampleID>,
#[serde(
alias = "group_name",
rename(serialize = "group", deserialize = "group_name"),
skip_serializing_if = "Option::is_none"
)]
pub group: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sequence_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sequence_uuid: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sequence_description: Option<String>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deserialize_frame_number"
)]
pub frame_number: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub uuid: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub width: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub height: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub date: Option<DateTime<Utc>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub source: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "sensors"))]
pub location: Option<Location>,
#[serde(skip_serializing_if = "Option::is_none")]
pub degradation: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub neg_label_indices: Option<Vec<u32>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub not_exhaustive_label_indices: Option<Vec<u32>>,
#[serde(
default,
skip_serializing_if = "Vec::is_empty",
serialize_with = "serialize_files"
)]
pub files: Vec<SampleFile>,
#[serde(
default,
skip_serializing_if = "Vec::is_empty",
serialize_with = "serialize_annotations"
)]
pub annotations: Vec<Annotation>,
#[serde(skip)]
pub timing: Option<Timing>,
}
fn deserialize_frame_number<'de, D>(deserializer: D) -> Result<Option<u32>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
let value = Option::<i32>::deserialize(deserializer)?;
Ok(value.and_then(|v| if v < 0 { None } else { Some(v as u32) }))
}
fn is_valid_url(s: &str) -> bool {
s.starts_with("http://") || s.starts_with("https://")
}
fn serialize_files<S>(files: &[SampleFile], serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::Serialize;
let map: HashMap<String, String> = files
.iter()
.filter_map(|f| {
f.filename()
.map(|filename| (f.file_type().to_string(), filename.to_string()))
})
.collect();
map.serialize(serializer)
}
fn serialize_annotations<S>(annotations: &Vec<Annotation>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serde::Serialize::serialize(annotations, serializer)
}
fn deserialize_annotations<'de, D>(deserializer: D) -> Result<Vec<Annotation>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
#[derive(Deserialize)]
#[serde(untagged)]
enum AnnotationsFormat {
Vec(Vec<Annotation>),
Map(HashMap<String, Vec<Annotation>>),
}
let value = Option::<AnnotationsFormat>::deserialize(deserializer)?;
Ok(value
.map(|v| match v {
AnnotationsFormat::Vec(annotations) => annotations,
AnnotationsFormat::Map(map) => convert_annotations_map_to_vec(map),
})
.unwrap_or_default())
}
#[derive(Debug, Default)]
struct SensorsData {
files: Vec<SampleFile>,
location: Option<Location>,
}
fn deserialize_sensors_data(value: Option<serde_json::Value>) -> SensorsData {
use serde_json::Value;
fn create_sample_file(file_type: String, value: String) -> SampleFile {
if is_valid_url(&value) {
SampleFile::with_url(file_type, value)
} else {
SampleFile::with_data(file_type, value)
}
}
fn create_sample_file_from_value(file_type: String, value: Value) -> Option<SampleFile> {
match value {
Value::String(s) => Some(create_sample_file(file_type, s)),
Value::Object(_) | Value::Array(_) => {
serde_json::to_string(&value)
.ok()
.map(|data| SampleFile::with_data(file_type, data))
}
_ => None,
}
}
fn extract_location(map: &serde_json::Map<String, Value>) -> Option<Location> {
let gps = map
.get("gps")
.and_then(|v| serde_json::from_value::<GpsData>(v.clone()).ok());
let imu = map
.get("imu")
.and_then(|v| serde_json::from_value::<ImuData>(v.clone()).ok());
if gps.is_some() || imu.is_some() {
Some(Location { gps, imu })
} else {
None
}
}
let mut result = SensorsData::default();
match value {
None => result,
Some(Value::Array(arr)) => {
for item in arr {
if let Value::Object(map) = item {
if map.contains_key("type") {
if let Ok(file) =
serde_json::from_value::<SampleFile>(Value::Object(map.clone()))
{
result.files.push(file);
}
} else {
if let Some(loc) = extract_location(&map) {
if let Some(ref mut existing) = result.location {
if loc.gps.is_some() {
existing.gps = loc.gps;
}
if loc.imu.is_some() {
existing.imu = loc.imu;
}
} else {
result.location = Some(loc);
}
} else {
for (file_type, value) in map {
if let Some(file) = create_sample_file_from_value(file_type, value)
{
result.files.push(file);
}
}
}
}
}
}
result
}
Some(Value::Object(map)) => {
if let Some(loc) = extract_location(&map) {
result.location = Some(loc);
}
for (key, value) in map {
if key != "gps"
&& key != "imu"
&& let Some(file) = create_sample_file_from_value(key, value)
{
result.files.push(file);
}
}
result
}
Some(_) => result,
}
}
#[derive(Deserialize)]
struct SampleRaw {
#[serde(default)]
id: Option<SampleID>,
#[serde(alias = "group_name")]
group: Option<String>,
sequence_name: Option<String>,
sequence_uuid: Option<String>,
sequence_description: Option<String>,
#[serde(default, deserialize_with = "deserialize_frame_number")]
frame_number: Option<u32>,
uuid: Option<String>,
image_name: Option<String>,
image_url: Option<String>,
width: Option<u32>,
height: Option<u32>,
date: Option<DateTime<Utc>>,
source: Option<String>,
degradation: Option<String>,
#[serde(default)]
neg_label_indices: Option<Vec<u32>>,
#[serde(default)]
not_exhaustive_label_indices: Option<Vec<u32>>,
#[serde(default, alias = "sensors")]
sensors: Option<serde_json::Value>,
#[serde(default, deserialize_with = "deserialize_annotations")]
annotations: Vec<Annotation>,
}
impl From<SampleRaw> for Sample {
fn from(raw: SampleRaw) -> Self {
let sensors_data = deserialize_sensors_data(raw.sensors);
Sample {
id: raw.id,
group: raw.group,
sequence_name: raw.sequence_name,
sequence_uuid: raw.sequence_uuid,
sequence_description: raw.sequence_description,
frame_number: raw.frame_number,
uuid: raw.uuid,
image_name: raw.image_name,
image_url: raw.image_url,
width: raw.width,
height: raw.height,
date: raw.date,
source: raw.source,
location: sensors_data.location,
degradation: raw.degradation,
neg_label_indices: raw.neg_label_indices,
not_exhaustive_label_indices: raw.not_exhaustive_label_indices,
files: sensors_data.files,
annotations: raw.annotations,
timing: None,
}
}
}
impl<'de> serde::Deserialize<'de> for Sample {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let raw = SampleRaw::deserialize(deserializer)?;
Ok(Sample::from(raw))
}
}
impl Display for Sample {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"{} {}",
self.id
.map(|id| id.to_string())
.unwrap_or_else(|| "unknown".to_string()),
self.image_name().unwrap_or("unknown")
)
}
}
impl Default for Sample {
fn default() -> Self {
Self::new()
}
}
impl Sample {
pub fn new() -> Self {
Self {
id: None,
group: None,
sequence_name: None,
sequence_uuid: None,
sequence_description: None,
frame_number: None,
uuid: None,
image_name: None,
image_url: None,
width: None,
height: None,
date: None,
source: None,
location: None,
degradation: None,
neg_label_indices: None,
not_exhaustive_label_indices: None,
files: vec![],
annotations: vec![],
timing: None,
}
}
pub fn id(&self) -> Option<SampleID> {
self.id
}
pub fn name(&self) -> Option<String> {
self.image_name.as_ref().map(|n| extract_sample_name(n))
}
pub fn group(&self) -> Option<&String> {
self.group.as_ref()
}
pub fn sequence_name(&self) -> Option<&String> {
self.sequence_name.as_ref()
}
pub fn sequence_uuid(&self) -> Option<&String> {
self.sequence_uuid.as_ref()
}
pub fn sequence_description(&self) -> Option<&String> {
self.sequence_description.as_ref()
}
pub fn frame_number(&self) -> Option<u32> {
self.frame_number
}
pub fn uuid(&self) -> Option<&String> {
self.uuid.as_ref()
}
pub fn image_name(&self) -> Option<&str> {
self.image_name.as_deref()
}
pub fn image_url(&self) -> Option<&str> {
self.image_url.as_deref()
}
pub fn width(&self) -> Option<u32> {
self.width
}
pub fn height(&self) -> Option<u32> {
self.height
}
pub fn date(&self) -> Option<DateTime<Utc>> {
self.date
}
pub fn source(&self) -> Option<&String> {
self.source.as_ref()
}
pub fn location(&self) -> Option<&Location> {
self.location.as_ref()
}
pub fn files(&self) -> &[SampleFile] {
&self.files
}
pub fn annotations(&self) -> &[Annotation] {
&self.annotations
}
pub fn with_annotations(mut self, annotations: Vec<Annotation>) -> Self {
self.annotations = annotations;
self
}
pub fn with_frame_number(mut self, frame_number: Option<u32>) -> Self {
self.frame_number = frame_number;
self
}
pub async fn download(
&self,
client: &Client,
file_type: FileType,
) -> Result<Option<Vec<u8>>, Error> {
use base64::{Engine, engine::general_purpose::STANDARD};
if file_type == FileType::Image {
if let Some(url) = self.image_url.as_deref()
&& is_valid_url(url)
{
return Ok(Some(client.download(url).await?));
}
return Ok(None);
}
let file = resolve_file(&file_type, &self.files);
match file {
Some(f) => {
if let Some(url) = f.url() {
return Ok(Some(client.download(url).await?));
}
if let Some(data) = f.data() {
let decoded = if let Ok(bytes) = STANDARD.decode(data) {
if let Ok(text) = String::from_utf8(bytes.clone()) {
if text.starts_with('{') {
text
} else {
return Ok(Some(bytes));
}
} else {
return Ok(Some(bytes));
}
} else {
data.to_string()
};
let content = if decoded.starts_with('{') {
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&decoded) {
if let Some(obj) = json.as_object() {
obj.values()
.next()
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.unwrap_or(decoded)
} else {
decoded
}
} else {
decoded
}
} else {
decoded
};
return Ok(Some(content.as_bytes().to_vec()));
}
Ok(None)
}
None => Ok(None),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct SampleFile {
r#type: String,
#[serde(skip_serializing_if = "Option::is_none")]
url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
filename: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", skip_deserializing)]
data: Option<String>,
#[serde(skip)]
bytes: Option<Vec<u8>>,
}
impl SampleFile {
pub fn with_url(file_type: String, url: String) -> Self {
Self {
r#type: file_type,
url: Some(url),
filename: None,
data: None,
bytes: None,
}
}
pub fn with_filename(file_type: String, filename: String) -> Self {
Self {
r#type: file_type,
url: None,
filename: Some(filename),
data: None,
bytes: None,
}
}
pub fn with_data(file_type: String, data: String) -> Self {
Self {
r#type: file_type,
url: None,
filename: None,
data: Some(data),
bytes: None,
}
}
pub fn with_bytes(file_type: String, filename: String, bytes: Vec<u8>) -> Self {
Self {
r#type: file_type,
url: None,
filename: Some(filename),
data: None,
bytes: Some(bytes),
}
}
pub fn file_type(&self) -> &str {
&self.r#type
}
pub fn url(&self) -> Option<&str> {
self.url.as_deref()
}
pub fn filename(&self) -> Option<&str> {
self.filename.as_deref()
}
pub fn data(&self) -> Option<&str> {
self.data.as_deref()
}
pub fn bytes(&self) -> Option<&[u8]> {
self.bytes.as_deref()
}
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Location {
#[serde(skip_serializing_if = "Option::is_none")]
pub gps: Option<GpsData>,
#[serde(skip_serializing_if = "Option::is_none")]
pub imu: Option<ImuData>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct GpsData {
pub lat: f64,
pub lon: f64,
}
impl GpsData {
pub fn validate(&self) -> Result<(), String> {
validate_gps_coordinates(self.lat, self.lon)
}
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct ImuData {
pub roll: f64,
pub pitch: f64,
pub yaw: f64,
}
impl ImuData {
pub fn validate(&self) -> Result<(), String> {
validate_imu_orientation(self.roll, self.pitch, self.yaw)
}
}
#[allow(dead_code)]
pub trait TypeName {
fn type_name() -> String;
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct Box3d {
x: f32,
y: f32,
z: f32,
w: f32,
h: f32,
l: f32,
}
impl TypeName for Box3d {
fn type_name() -> String {
"box3d".to_owned()
}
}
impl Box3d {
pub fn new(cx: f32, cy: f32, cz: f32, width: f32, height: f32, length: f32) -> Self {
Self {
x: cx,
y: cy,
z: cz,
w: width,
h: height,
l: length,
}
}
pub fn width(&self) -> f32 {
self.w
}
pub fn height(&self) -> f32 {
self.h
}
pub fn length(&self) -> f32 {
self.l
}
pub fn cx(&self) -> f32 {
self.x
}
pub fn cy(&self) -> f32 {
self.y
}
pub fn cz(&self) -> f32 {
self.z
}
pub fn left(&self) -> f32 {
self.x - self.w / 2.0
}
pub fn top(&self) -> f32 {
self.y - self.h / 2.0
}
pub fn front(&self) -> f32 {
self.z - self.l / 2.0
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct Box2d {
h: f32,
w: f32,
x: f32,
y: f32,
}
impl TypeName for Box2d {
fn type_name() -> String {
"box2d".to_owned()
}
}
impl Box2d {
pub fn new(left: f32, top: f32, width: f32, height: f32) -> Self {
Self {
x: left,
y: top,
w: width,
h: height,
}
}
pub fn width(&self) -> f32 {
self.w
}
pub fn height(&self) -> f32 {
self.h
}
pub fn left(&self) -> f32 {
self.x
}
pub fn top(&self) -> f32 {
self.y
}
pub fn cx(&self) -> f32 {
self.x + self.w / 2.0
}
pub fn cy(&self) -> f32 {
self.y + self.h / 2.0
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct Polygon {
pub rings: Vec<Vec<(f32, f32)>>,
}
impl TypeName for Polygon {
fn type_name() -> String {
"polygon".to_owned()
}
}
impl Polygon {
pub fn new(rings: Vec<Vec<(f32, f32)>>) -> Self {
Self { rings }
}
}
impl serde::Serialize for Polygon {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serde::Serialize::serialize(&self.rings, serializer)
}
}
impl<'de> serde::Deserialize<'de> for Polygon {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = serde_json::Value::deserialize(deserializer)?;
let polygon_value = if let Some(obj) = value.as_object() {
obj.get("rings")
.or_else(|| obj.get("polygon"))
.cloned()
.unwrap_or(serde_json::Value::Null)
} else {
value
};
let rings = parse_polygon_value(&polygon_value);
Ok(Self { rings })
}
}
fn parse_polygon_value(value: &serde_json::Value) -> Vec<Vec<(f32, f32)>> {
let Some(outer_array) = value.as_array() else {
return vec![];
};
let mut result = Vec::new();
for ring in outer_array {
let Some(ring_array) = ring.as_array() else {
continue;
};
let is_3d = ring_array
.first()
.map(|first| first.is_array())
.unwrap_or(false);
let points: Vec<(f32, f32)> = if is_3d {
ring_array
.iter()
.filter_map(|point| {
let arr = point.as_array()?;
if arr.len() >= 2 {
let x = arr[0].as_f64()? as f32;
let y = arr[1].as_f64()? as f32;
if x.is_finite() && y.is_finite() {
Some((x, y))
} else {
None
}
} else {
None
}
})
.collect()
} else {
ring_array
.chunks(2)
.filter_map(|chunk| {
if chunk.len() >= 2 {
let x = chunk[0].as_f64()? as f32;
let y = chunk[1].as_f64()? as f32;
if x.is_finite() && y.is_finite() {
Some((x, y))
} else {
None
}
} else {
None
}
})
.collect()
};
if points.len() >= 3 {
result.push(points);
}
}
result
}
#[derive(Deserialize)]
struct AnnotationRaw {
#[serde(default)]
sample_id: Option<SampleID>,
#[serde(default)]
name: Option<String>,
#[serde(default)]
sequence_name: Option<String>,
#[serde(default)]
frame_number: Option<u32>,
#[serde(rename = "group_name", default)]
group: Option<String>,
#[serde(rename = "object_reference", alias = "object_id", default)]
object_id: Option<String>,
#[serde(default)]
label_name: Option<String>,
#[serde(default)]
label_index: Option<u64>,
#[serde(default)]
iscrowd: Option<bool>,
#[serde(default)]
category_frequency: Option<String>,
#[serde(default)]
box2d: Option<Box2d>,
#[serde(default)]
box3d: Option<Box3d>,
#[serde(default, alias = "mask")]
polygon: Option<Polygon>,
#[serde(default)]
x: Option<f64>,
#[serde(default)]
y: Option<f64>,
#[serde(default)]
w: Option<f64>,
#[serde(default)]
h: Option<f64>,
}
#[derive(Serialize, Clone, Debug)]
pub struct Annotation {
#[serde(skip_serializing_if = "Option::is_none")]
sample_id: Option<SampleID>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
sequence_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
frame_number: Option<u32>,
#[serde(rename = "group_name", skip_serializing_if = "Option::is_none")]
group: Option<String>,
#[serde(
rename = "object_reference",
alias = "object_id",
skip_serializing_if = "Option::is_none"
)]
object_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
label_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
label_index: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
iscrowd: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
category_frequency: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
box2d: Option<Box2d>,
#[serde(skip_serializing_if = "Option::is_none")]
box3d: Option<Box3d>,
#[serde(rename(serialize = "mask"), skip_serializing_if = "Option::is_none")]
polygon: Option<Polygon>,
#[serde(skip)]
mask: Option<MaskData>,
#[serde(skip_serializing_if = "Option::is_none")]
box2d_score: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
box3d_score: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
polygon_score: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
mask_score: Option<f32>,
}
impl<'de> serde::Deserialize<'de> for Annotation {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let raw: AnnotationRaw = serde::Deserialize::deserialize(deserializer)?;
let box2d = raw.box2d.or_else(|| match (raw.x, raw.y, raw.w, raw.h) {
(Some(x), Some(y), Some(w), Some(h)) if w > 0.0 && h > 0.0 => {
Some(Box2d::new(x as f32, y as f32, w as f32, h as f32))
}
_ => None,
});
Ok(Annotation {
sample_id: raw.sample_id,
name: raw.name,
sequence_name: raw.sequence_name,
frame_number: raw.frame_number,
group: raw.group,
object_id: raw.object_id,
label_name: raw.label_name,
label_index: raw.label_index,
iscrowd: raw.iscrowd,
category_frequency: raw.category_frequency,
box2d,
box3d: raw.box3d,
polygon: raw.polygon,
mask: None,
box2d_score: None,
box3d_score: None,
polygon_score: None,
mask_score: None,
})
}
}
impl Default for Annotation {
fn default() -> Self {
Self::new()
}
}
impl Annotation {
pub fn new() -> Self {
Self {
sample_id: None,
name: None,
sequence_name: None,
frame_number: None,
group: None,
object_id: None,
label_name: None,
label_index: None,
iscrowd: None,
category_frequency: None,
box2d: None,
box3d: None,
polygon: None,
mask: None,
box2d_score: None,
box3d_score: None,
polygon_score: None,
mask_score: None,
}
}
pub fn set_sample_id(&mut self, sample_id: Option<SampleID>) {
self.sample_id = sample_id;
}
pub fn sample_id(&self) -> Option<SampleID> {
self.sample_id
}
pub fn set_name(&mut self, name: Option<String>) {
self.name = name;
}
pub fn name(&self) -> Option<&String> {
self.name.as_ref()
}
pub fn set_sequence_name(&mut self, sequence_name: Option<String>) {
self.sequence_name = sequence_name;
}
pub fn sequence_name(&self) -> Option<&String> {
self.sequence_name.as_ref()
}
pub fn set_frame_number(&mut self, frame_number: Option<u32>) {
self.frame_number = frame_number;
}
pub fn frame_number(&self) -> Option<u32> {
self.frame_number
}
pub fn set_group(&mut self, group: Option<String>) {
self.group = group;
}
pub fn group(&self) -> Option<&String> {
self.group.as_ref()
}
pub fn object_id(&self) -> Option<&String> {
self.object_id.as_ref()
}
pub fn set_object_id(&mut self, object_id: Option<String>) {
self.object_id = object_id;
}
pub fn label(&self) -> Option<&String> {
self.label_name.as_ref()
}
pub fn set_label(&mut self, label_name: Option<String>) {
self.label_name = label_name;
}
pub fn label_index(&self) -> Option<u64> {
self.label_index
}
pub fn set_label_index(&mut self, label_index: Option<u64>) {
self.label_index = label_index;
}
pub fn iscrowd(&self) -> Option<bool> {
self.iscrowd
}
pub fn set_iscrowd(&mut self, iscrowd: Option<bool>) {
self.iscrowd = iscrowd;
}
pub fn category_frequency(&self) -> Option<&String> {
self.category_frequency.as_ref()
}
pub fn set_category_frequency(&mut self, category_frequency: Option<String>) {
self.category_frequency = category_frequency;
}
pub fn box2d(&self) -> Option<&Box2d> {
self.box2d.as_ref()
}
pub fn set_box2d(&mut self, box2d: Option<Box2d>) {
self.box2d = box2d;
}
pub fn box3d(&self) -> Option<&Box3d> {
self.box3d.as_ref()
}
pub fn set_box3d(&mut self, box3d: Option<Box3d>) {
self.box3d = box3d;
}
pub fn polygon(&self) -> Option<&Polygon> {
self.polygon.as_ref()
}
pub fn set_polygon(&mut self, polygon: Option<Polygon>) {
self.polygon = polygon;
}
pub fn mask(&self) -> Option<&MaskData> {
self.mask.as_ref()
}
pub fn set_mask(&mut self, mask: Option<MaskData>) {
self.mask = mask;
}
pub fn box2d_score(&self) -> Option<f32> {
self.box2d_score
}
pub fn set_box2d_score(&mut self, score: Option<f32>) {
self.box2d_score = score;
}
pub fn box3d_score(&self) -> Option<f32> {
self.box3d_score
}
pub fn set_box3d_score(&mut self, score: Option<f32>) {
self.box3d_score = score;
}
pub fn polygon_score(&self) -> Option<f32> {
self.polygon_score
}
pub fn set_polygon_score(&mut self, score: Option<f32>) {
self.polygon_score = score;
}
pub fn mask_score(&self) -> Option<f32> {
self.mask_score
}
pub fn set_mask_score(&mut self, score: Option<f32>) {
self.mask_score = score;
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
pub struct Label {
id: u64,
dataset_id: DatasetID,
index: u64,
name: String,
}
impl Label {
pub fn id(&self) -> u64 {
self.id
}
pub fn dataset_id(&self) -> DatasetID {
self.dataset_id
}
pub fn index(&self) -> u64 {
self.index
}
pub fn name(&self) -> &str {
&self.name
}
pub async fn remove(&self, client: &Client) -> Result<(), Error> {
client.remove_label(self.id()).await
}
pub async fn set_name(&mut self, client: &Client, name: &str) -> Result<(), Error> {
self.name = name.to_string();
client.update_label(self).await
}
pub async fn set_index(&mut self, client: &Client, index: u64) -> Result<(), Error> {
self.index = index;
client.update_label(self).await
}
}
impl Display for Label {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Serialize, Clone, Debug)]
pub struct NewLabelObject {
pub name: String,
}
#[derive(Serialize, Clone, Debug)]
pub struct NewLabel {
pub dataset_id: DatasetID,
pub labels: Vec<NewLabelObject>,
}
#[derive(Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
pub struct Group {
pub id: u64,
pub name: String,
}
#[cfg(feature = "polars")]
fn extract_annotation_name(ann: &Annotation) -> Option<(String, Option<u32>)> {
use std::path::Path;
let name = ann.name.as_ref()?;
let name = Path::new(name).file_stem()?.to_str()?;
match &ann.sequence_name {
Some(sequence) => Some((sequence.clone(), ann.frame_number)),
None => Some((name.to_string(), None)),
}
}
#[cfg(feature = "polars")]
fn convert_polygon_to_nested_series(polygon: &Polygon) -> Series {
let ring_series: Vec<Option<Series>> = polygon
.rings
.iter()
.map(|ring| {
let coords: Vec<f32> = ring.iter().flat_map(|&(x, y)| [x, y]).collect();
Some(Series::new("".into(), coords))
})
.collect();
Series::new("".into(), ring_series)
}
#[cfg(feature = "polars")]
pub fn samples_dataframe(samples: &[Sample]) -> Result<DataFrame, Error> {
let mut names: Vec<String> = Vec::new();
let mut frames: Vec<Option<u32>> = Vec::new();
let mut objects: Vec<Option<String>> = Vec::new();
let mut labels: Vec<Option<String>> = Vec::new();
let mut label_indices: Vec<Option<u64>> = Vec::new();
let mut groups: Vec<Option<String>> = Vec::new();
let mut polygons: Vec<Option<Series>> = Vec::new();
let mut boxes2d: Vec<Option<Series>> = Vec::new();
let mut boxes3d: Vec<Option<Series>> = Vec::new();
let mut mask_bytes: Vec<Option<Vec<u8>>> = Vec::new();
let mut box2d_scores: Vec<Option<f32>> = Vec::new();
let mut box3d_scores: Vec<Option<f32>> = Vec::new();
let mut polygon_scores: Vec<Option<f32>> = Vec::new();
let mut mask_scores: Vec<Option<f32>> = Vec::new();
let mut sizes: Vec<Option<Vec<u32>>> = Vec::new();
let mut locations: Vec<Option<Vec<f32>>> = Vec::new();
let mut poses: Vec<Option<Vec<f32>>> = Vec::new();
let mut degradations: Vec<Option<String>> = Vec::new();
let mut iscrowds: Vec<Option<bool>> = Vec::new();
let mut category_frequencies: Vec<Option<String>> = Vec::new();
let mut neg_label_indices_vec: Vec<Option<Vec<u32>>> = Vec::new();
let mut not_exhaustive_label_indices_vec: Vec<Option<Vec<u32>>> = Vec::new();
let mut timing_load: Vec<Option<i64>> = Vec::new();
let mut timing_preprocess: Vec<Option<i64>> = Vec::new();
let mut timing_inference: Vec<Option<i64>> = Vec::new();
let mut timing_decode: Vec<Option<i64>> = Vec::new();
for sample in samples {
let size = match (sample.width, sample.height) {
(Some(w), Some(h)) => Some(vec![w, h]),
_ => None,
};
let location = sample.location.as_ref().and_then(|loc| {
loc.gps
.as_ref()
.map(|gps| vec![gps.lat as f32, gps.lon as f32])
});
let pose = sample.location.as_ref().and_then(|loc| {
loc.imu
.as_ref()
.map(|imu| vec![imu.yaw as f32, imu.pitch as f32, imu.roll as f32])
});
let degradation = sample.degradation.clone();
let t_load = sample.timing.as_ref().and_then(|t| t.load);
let t_preprocess = sample.timing.as_ref().and_then(|t| t.preprocess);
let t_inference = sample.timing.as_ref().and_then(|t| t.inference);
let t_decode = sample.timing.as_ref().and_then(|t| t.decode);
macro_rules! push_sample_fields {
() => {
sizes.push(size.clone());
locations.push(location.clone());
poses.push(pose.clone());
degradations.push(degradation.clone());
neg_label_indices_vec.push(sample.neg_label_indices.clone());
not_exhaustive_label_indices_vec.push(sample.not_exhaustive_label_indices.clone());
timing_load.push(t_load);
timing_preprocess.push(t_preprocess);
timing_inference.push(t_inference);
timing_decode.push(t_decode);
};
}
if sample.annotations.is_empty() {
let (name, frame) = match extract_annotation_name_from_sample(sample) {
Some(nf) => nf,
None => continue,
};
names.push(name);
frames.push(frame);
objects.push(None);
labels.push(None);
label_indices.push(None);
groups.push(sample.group.clone());
polygons.push(None);
boxes2d.push(None);
boxes3d.push(None);
mask_bytes.push(None);
box2d_scores.push(None);
box3d_scores.push(None);
polygon_scores.push(None);
mask_scores.push(None);
iscrowds.push(None);
category_frequencies.push(None);
push_sample_fields!();
} else {
for ann in &sample.annotations {
let (name, frame) = match extract_annotation_name(ann) {
Some(nf) => nf,
None => continue,
};
let polygon = ann.polygon.as_ref().map(convert_polygon_to_nested_series);
let box2d = ann
.box2d
.as_ref()
.map(|b| Series::new("box2d".into(), [b.cx(), b.cy(), b.width(), b.height()]));
let box3d = ann
.box3d
.as_ref()
.map(|b| Series::new("box3d".into(), [b.x, b.y, b.z, b.w, b.h, b.l]));
names.push(name);
frames.push(frame);
objects.push(ann.object_id().cloned());
labels.push(ann.label_name.clone());
label_indices.push(ann.label_index);
groups.push(sample.group.clone());
polygons.push(polygon);
boxes2d.push(box2d);
boxes3d.push(box3d);
mask_bytes.push(ann.mask.as_ref().map(|m| m.as_bytes().to_vec()));
box2d_scores.push(ann.box2d_score());
box3d_scores.push(ann.box3d_score());
polygon_scores.push(ann.polygon_score());
mask_scores.push(ann.mask_score());
iscrowds.push(ann.iscrowd);
category_frequencies.push(ann.category_frequency.clone());
push_sample_fields!();
}
}
}
let names_col: Column = Series::new("name".into(), names).into();
let frames_col: Column = Series::new("frame".into(), frames).into();
let objects_col: Column = Series::new("object_id".into(), objects).into();
let labels_col: Column = Series::new("label".into(), labels)
.cast(&DataType::Categorical(
Categories::new("labels".into(), "labels".into(), CategoricalPhysical::U16),
Arc::new(CategoricalMapping::with_hasher(
u16::MAX as usize,
Default::default(),
)),
))?
.into();
let label_indices_col: Column = Series::new("label_index".into(), label_indices).into();
let groups_col: Column = Series::new("group".into(), groups)
.cast(&DataType::Categorical(
Categories::new("groups".into(), "groups".into(), CategoricalPhysical::U8),
Arc::new(CategoricalMapping::with_hasher(
u8::MAX as usize,
Default::default(),
)),
))?
.into();
let polygons_col: Column = if polygons.iter().all(|p| p.is_none()) {
Series::new_null("polygon".into(), polygons.len()).into()
} else {
let typed_polygons: Vec<Option<Series>> = polygons
.into_iter()
.map(|opt| {
opt.map(|s| {
s.cast(&DataType::List(Box::new(DataType::Float32)))
.unwrap_or(s)
})
})
.collect();
Series::new("polygon".into(), &typed_polygons)
.cast(&DataType::List(Box::new(DataType::List(Box::new(
DataType::Float32,
)))))?
.into()
};
let boxes2d_col: Column = Series::new("box2d".into(), boxes2d)
.cast(&DataType::Array(Box::new(DataType::Float32), 4))?
.into();
let boxes3d_col: Column = Series::new("box3d".into(), boxes3d)
.cast(&DataType::Array(Box::new(DataType::Float32), 6))?
.into();
let mask_col: Column = Series::new("mask".into(), mask_bytes).into();
let box2d_score_col: Column = Series::new("box2d_score".into(), box2d_scores).into();
let box3d_score_col: Column = Series::new("box3d_score".into(), box3d_scores).into();
let polygon_score_col: Column = Series::new("polygon_score".into(), polygon_scores).into();
let mask_score_col: Column = Series::new("mask_score".into(), mask_scores).into();
let size_series: Vec<Option<Series>> = sizes
.into_iter()
.map(|opt_vec| opt_vec.map(|vec| Series::new("size".into(), vec)))
.collect();
let sizes_col: Column = Series::new("size".into(), size_series)
.cast(&DataType::Array(Box::new(DataType::UInt32), 2))?
.into();
let location_series: Vec<Option<Series>> = locations
.into_iter()
.map(|opt_vec| opt_vec.map(|vec| Series::new("location".into(), vec)))
.collect();
let locations_col: Column = Series::new("location".into(), location_series)
.cast(&DataType::Array(Box::new(DataType::Float32), 2))?
.into();
let pose_series: Vec<Option<Series>> = poses
.into_iter()
.map(|opt_vec| opt_vec.map(|vec| Series::new("pose".into(), vec)))
.collect();
let poses_col: Column = Series::new("pose".into(), pose_series)
.cast(&DataType::Array(Box::new(DataType::Float32), 3))?
.into();
let degradations_col: Column = Series::new("degradation".into(), degradations).into();
let iscrowds_col: Column = Series::new("iscrowd".into(), iscrowds).into();
let category_frequencies_col: Column =
Series::new("category_frequency".into(), category_frequencies)
.cast(&DataType::Categorical(
Categories::new(
"cat_freq".into(),
"cat_freq".into(),
CategoricalPhysical::U8,
),
Arc::new(CategoricalMapping::with_hasher(
u8::MAX as usize,
Default::default(),
)),
))?
.into();
let neg_label_indices_series: Vec<Option<Series>> = neg_label_indices_vec
.into_iter()
.map(|opt_vec| opt_vec.map(|vec| Series::new("neg_label_indices".into(), vec)))
.collect();
let neg_label_indices_col: Column =
Series::new("neg_label_indices".into(), neg_label_indices_series)
.cast(&DataType::List(Box::new(DataType::UInt32)))?
.into();
let not_exhaustive_label_indices_series: Vec<Option<Series>> = not_exhaustive_label_indices_vec
.into_iter()
.map(|opt_vec| opt_vec.map(|vec| Series::new("not_exhaustive_label_indices".into(), vec)))
.collect();
let not_exhaustive_label_indices_col: Column = Series::new(
"not_exhaustive_label_indices".into(),
not_exhaustive_label_indices_series,
)
.cast(&DataType::List(Box::new(DataType::UInt32)))?
.into();
let timing_col: Column = StructChunked::from_series(
"timing".into(),
frames_col.len(),
[
Series::new("load".into(), &timing_load),
Series::new("preprocess".into(), &timing_preprocess),
Series::new("inference".into(), &timing_inference),
Series::new("decode".into(), &timing_decode),
]
.iter(),
)?
.into_series()
.into();
let all_columns: Vec<Column> = vec![
names_col,
frames_col,
objects_col,
labels_col,
label_indices_col,
groups_col,
polygons_col,
boxes2d_col,
boxes3d_col,
mask_col,
box2d_score_col,
box3d_score_col,
polygon_score_col,
mask_score_col,
sizes_col,
locations_col,
poses_col,
degradations_col,
iscrowds_col,
category_frequencies_col,
neg_label_indices_col,
not_exhaustive_label_indices_col,
timing_col,
];
let height = all_columns.first().map(|c| c.len()).unwrap_or(0);
let non_empty_columns: Vec<Column> = all_columns
.into_iter()
.filter(|col| col.name() == "name" || !is_all_null_column(col))
.collect();
Ok(DataFrame::new(height, non_empty_columns)?)
}
#[cfg(feature = "polars")]
fn is_all_null_column(col: &Column) -> bool {
if col.is_empty() {
return true;
}
if col.null_count() == col.len() {
return true;
}
if let DataType::Struct(..) = col.dtype()
&& let Ok(s) = col.as_materialized_series().struct_()
{
return s
.fields_as_series()
.iter()
.all(|field| field.null_count() == field.len());
}
false
}
#[cfg(feature = "polars")]
fn extract_annotation_name_from_sample(sample: &Sample) -> Option<(String, Option<u32>)> {
use std::path::Path;
let name = sample.image_name.as_ref()?;
let name = Path::new(name).file_stem()?.to_str()?;
match &sample.sequence_name {
Some(sequence) => Some((sequence.clone(), sample.frame_number)),
None => Some((name.to_string(), None)),
}
}
fn extract_sample_name(image_name: &str) -> String {
let name = image_name
.rsplit_once('.')
.and_then(|(name, _)| {
if name.is_empty() {
None
} else {
Some(name.to_string())
}
})
.unwrap_or_else(|| image_name.to_string());
name.rsplit_once(".camera")
.and_then(|(name, _)| {
if name.is_empty() {
None
} else {
Some(name.to_string())
}
})
.unwrap_or_else(|| name.clone())
}
fn resolve_file<'a>(file_type: &FileType, files: &'a [SampleFile]) -> Option<&'a SampleFile> {
match file_type {
FileType::Image => None, FileType::All => None, file => {
let type_names = file_type_names(file);
files
.iter()
.find(|f| type_names.contains(&f.r#type.as_str()))
}
}
}
fn file_type_names(file_type: &FileType) -> Vec<&'static str> {
match file_type {
FileType::Image => vec!["image"],
FileType::LidarPcd => vec!["lidar.pcd"],
FileType::LidarDepth => vec!["lidar.depth", "depth.png", "depthmap"],
FileType::LidarReflect => vec!["lidar.reflect"],
FileType::RadarPcd => vec!["radar.pcd", "pcd"],
FileType::RadarCube => vec!["radar.png", "cube"],
FileType::All => vec![],
}
}
fn convert_annotations_map_to_vec(map: HashMap<String, Vec<Annotation>>) -> Vec<Annotation> {
let mut all_annotations = Vec::new();
if let Some(bbox_anns) = map.get("bbox") {
all_annotations.extend(bbox_anns.clone());
}
if let Some(box3d_anns) = map.get("box3d") {
all_annotations.extend(box3d_anns.clone());
}
if let Some(mask_anns) = map.get("mask") {
all_annotations.extend(mask_anns.clone());
}
all_annotations
}
fn validate_gps_coordinates(lat: f64, lon: f64) -> Result<(), String> {
if !lat.is_finite() {
return Err(format!("GPS latitude is not finite: {}", lat));
}
if !lon.is_finite() {
return Err(format!("GPS longitude is not finite: {}", lon));
}
if !(-90.0..=90.0).contains(&lat) {
return Err(format!("GPS latitude out of range [-90, 90]: {}", lat));
}
if !(-180.0..=180.0).contains(&lon) {
return Err(format!("GPS longitude out of range [-180, 180]: {}", lon));
}
Ok(())
}
fn validate_imu_orientation(roll: f64, pitch: f64, yaw: f64) -> Result<(), String> {
if !roll.is_finite() {
return Err(format!("IMU roll is not finite: {}", roll));
}
if !pitch.is_finite() {
return Err(format!("IMU pitch is not finite: {}", pitch));
}
if !yaw.is_finite() {
return Err(format!("IMU yaw is not finite: {}", yaw));
}
if !(-180.0..=180.0).contains(&roll) {
return Err(format!("IMU roll out of range [-180, 180]: {}", roll));
}
if !(-90.0..=90.0).contains(&pitch) {
return Err(format!("IMU pitch out of range [-90, 90]: {}", pitch));
}
if !(-180.0..=180.0).contains(&yaw) {
return Err(format!("IMU yaw out of range [-180, 180]: {}", yaw));
}
Ok(())
}
#[cfg(feature = "polars")]
pub fn unflatten_polygon_coordinates(coords: &[f32]) -> Vec<Vec<(f32, f32)>> {
let mut polygons = Vec::new();
let mut current_polygon = Vec::new();
let mut i = 0;
while i < coords.len() {
if coords[i].is_nan() {
if !current_polygon.is_empty() {
polygons.push(std::mem::take(&mut current_polygon));
}
i += 1;
} else if i + 1 < coords.len() && !coords[i + 1].is_nan() {
current_polygon.push((coords[i], coords[i + 1]));
i += 2;
} else if i + 1 < coords.len() && coords[i + 1].is_nan() {
i += 1;
} else {
i += 1;
}
}
if !current_polygon.is_empty() {
polygons.push(current_polygon);
}
polygons
}
#[cfg(test)]
mod tests {
use super::*;
fn flatten_annotation_map(
map: std::collections::HashMap<String, Vec<Annotation>>,
) -> Vec<Annotation> {
let mut all_annotations = Vec::new();
for key in ["bbox", "box3d", "mask"] {
if let Some(mut anns) = map.get(key).cloned() {
all_annotations.append(&mut anns);
}
}
all_annotations
}
fn annotation_group_field_name() -> &'static str {
"group_name"
}
fn annotation_object_id_field_name() -> &'static str {
"object_reference"
}
fn annotation_object_id_alias() -> &'static str {
"object_id"
}
fn validate_annotation_field_names(
json_str: &str,
expected_group: bool,
expected_object_ref: bool,
) -> Result<(), String> {
if expected_group && !json_str.contains("\"group_name\"") {
return Err("Missing expected field: group_name".to_string());
}
if expected_object_ref && !json_str.contains("\"object_reference\"") {
return Err("Missing expected field: object_reference".to_string());
}
Ok(())
}
#[test]
fn test_file_type_conversions() {
let api_cases = vec![
(FileType::Image, "image"),
(FileType::LidarPcd, "lidar.pcd"),
(FileType::LidarDepth, "lidar.depth"),
(FileType::LidarReflect, "lidar.reflect"),
(FileType::RadarPcd, "radar.pcd"),
(FileType::RadarCube, "radar.png"),
];
let ext_cases = vec![
(FileType::Image, "jpg"),
(FileType::LidarPcd, "lidar.pcd"),
(FileType::LidarDepth, "lidar.png"),
(FileType::LidarReflect, "lidar.jpg"),
(FileType::RadarPcd, "radar.pcd"),
(FileType::RadarCube, "radar.png"),
];
for (file_type, expected_str) in &api_cases {
assert_eq!(file_type.to_string(), *expected_str);
}
for (file_type, expected_ext) in &ext_cases {
assert_eq!(file_type.file_extension(), *expected_ext);
}
assert_eq!(
FileType::try_from("lidar.depth").unwrap(),
FileType::LidarDepth
);
assert_eq!(
FileType::try_from("lidar.png").unwrap(),
FileType::LidarDepth
);
assert_eq!(
FileType::try_from("depth.png").unwrap(),
FileType::LidarDepth
);
assert_eq!(
FileType::try_from("lidar.reflect").unwrap(),
FileType::LidarReflect
);
assert_eq!(
FileType::try_from("lidar.jpg").unwrap(),
FileType::LidarReflect
);
assert_eq!(
FileType::try_from("lidar.jpeg").unwrap(),
FileType::LidarReflect
);
assert!(FileType::try_from("invalid").is_err());
for (file_type, _) in &api_cases {
let s = file_type.to_string();
let parsed = FileType::try_from(s.as_str()).unwrap();
assert_eq!(parsed, *file_type);
}
}
#[test]
fn test_annotation_type_conversions() {
let cases = vec![
(AnnotationType::Box2d, "box2d"),
(AnnotationType::Box3d, "box3d"),
(AnnotationType::Polygon, "polygon"),
(AnnotationType::Mask, "mask"),
];
for (ann_type, expected_str) in &cases {
assert_eq!(ann_type.to_string(), *expected_str);
}
assert_eq!(
AnnotationType::try_from("box2d").unwrap(),
AnnotationType::Box2d
);
assert_eq!(
AnnotationType::try_from("box3d").unwrap(),
AnnotationType::Box3d
);
assert_eq!(
AnnotationType::try_from("polygon").unwrap(),
AnnotationType::Polygon
);
assert_eq!(
AnnotationType::try_from("mask").unwrap(),
AnnotationType::Polygon
);
assert_eq!(
AnnotationType::try_from("raster").unwrap(),
AnnotationType::Mask
);
assert_eq!(
AnnotationType::from("box2d".to_string()),
AnnotationType::Box2d
);
assert_eq!(
AnnotationType::from("box3d".to_string()),
AnnotationType::Box3d
);
assert_eq!(
AnnotationType::from("polygon".to_string()),
AnnotationType::Polygon
);
assert_eq!(
AnnotationType::from("mask".to_string()),
AnnotationType::Polygon
);
assert_eq!(
AnnotationType::from("invalid".to_string()),
AnnotationType::Box2d
);
assert!(AnnotationType::try_from("invalid").is_err());
assert_eq!(
AnnotationType::try_from(AnnotationType::Box2d.to_string().as_str()).unwrap(),
AnnotationType::Box2d
);
assert_eq!(
AnnotationType::try_from(AnnotationType::Box3d.to_string().as_str()).unwrap(),
AnnotationType::Box3d
);
assert_eq!(
AnnotationType::try_from(AnnotationType::Polygon.to_string().as_str()).unwrap(),
AnnotationType::Polygon
);
}
#[test]
fn test_extract_sample_name_with_extension_and_camera() {
assert_eq!(extract_sample_name("scene_001.camera.jpg"), "scene_001");
}
#[test]
fn test_extract_sample_name_multiple_dots() {
assert_eq!(extract_sample_name("image.v2.camera.png"), "image.v2");
}
#[test]
fn test_extract_sample_name_extension_only() {
assert_eq!(extract_sample_name("test.jpg"), "test");
}
#[test]
fn test_extract_sample_name_no_extension() {
assert_eq!(extract_sample_name("test"), "test");
}
#[test]
fn test_extract_sample_name_edge_case_dot_prefix() {
assert_eq!(extract_sample_name(".jpg"), ".jpg");
}
#[test]
fn test_resolve_file_image_type_returns_none() {
let files = vec![];
let result = resolve_file(&FileType::Image, &files);
assert!(result.is_none());
}
#[test]
fn test_resolve_file_lidar_pcd() {
let files = vec![
SampleFile::with_url(
"lidar.pcd".to_string(),
"https://example.com/file.pcd".to_string(),
),
SampleFile::with_url(
"radar.pcd".to_string(),
"https://example.com/radar.pcd".to_string(),
),
];
let result = resolve_file(&FileType::LidarPcd, &files);
assert!(result.is_some());
assert_eq!(result.unwrap().url(), Some("https://example.com/file.pcd"));
}
#[test]
fn test_resolve_file_not_found() {
let files = vec![SampleFile::with_url(
"lidar.pcd".to_string(),
"https://example.com/file.pcd".to_string(),
)];
let result = resolve_file(&FileType::RadarPcd, &files);
assert!(result.is_none());
}
#[test]
fn test_resolve_file_lidar_depth() {
let files = vec![SampleFile::with_url(
"lidar.depth".to_string(),
"https://example.com/depth.png".to_string(),
)];
let result = resolve_file(&FileType::LidarDepth, &files);
assert!(result.is_some());
assert_eq!(result.unwrap().url(), Some("https://example.com/depth.png"));
}
#[test]
fn test_resolve_file_lidar_reflect() {
let files = vec![SampleFile::with_url(
"lidar.reflect".to_string(),
"https://example.com/reflect.png".to_string(),
)];
let result = resolve_file(&FileType::LidarReflect, &files);
assert!(result.is_some());
assert_eq!(
result.unwrap().url(),
Some("https://example.com/reflect.png")
);
}
#[test]
fn test_resolve_file_radar_cube() {
let files = vec![SampleFile::with_url(
"radar.png".to_string(),
"https://example.com/radar.png".to_string(),
)];
let result = resolve_file(&FileType::RadarCube, &files);
assert!(result.is_some());
assert_eq!(result.unwrap().url(), Some("https://example.com/radar.png"));
}
#[test]
fn test_resolve_file_with_inline_data() {
let files = vec![SampleFile::with_data(
"radar.pcd".to_string(),
"SGVsbG8gV29ybGQ=".to_string(), )];
let result = resolve_file(&FileType::RadarPcd, &files);
assert!(result.is_some());
let file = result.unwrap();
assert!(file.url().is_none());
assert_eq!(file.data(), Some("SGVsbG8gV29ybGQ="));
}
#[test]
fn test_convert_annotations_map_to_vec_with_bbox() {
let mut map = HashMap::new();
let bbox_ann = Annotation::new();
map.insert("bbox".to_string(), vec![bbox_ann.clone()]);
let annotations = convert_annotations_map_to_vec(map);
assert_eq!(annotations.len(), 1);
}
#[test]
fn test_convert_annotations_map_to_vec_all_types() {
let mut map = HashMap::new();
map.insert("bbox".to_string(), vec![Annotation::new()]);
map.insert("box3d".to_string(), vec![Annotation::new()]);
map.insert("mask".to_string(), vec![Annotation::new()]);
let annotations = convert_annotations_map_to_vec(map);
assert_eq!(annotations.len(), 3);
}
#[test]
fn test_convert_annotations_map_to_vec_empty() {
let map = HashMap::new();
let annotations = convert_annotations_map_to_vec(map);
assert_eq!(annotations.len(), 0);
}
#[test]
fn test_convert_annotations_map_to_vec_unknown_type_ignored() {
let mut map = HashMap::new();
map.insert("unknown".to_string(), vec![Annotation::new()]);
let annotations = convert_annotations_map_to_vec(map);
assert_eq!(annotations.len(), 0);
}
#[test]
fn test_annotation_group_field_name() {
assert_eq!(annotation_group_field_name(), "group_name");
}
#[test]
fn test_annotation_object_id_field_name() {
assert_eq!(annotation_object_id_field_name(), "object_reference");
}
#[test]
fn test_annotation_object_id_alias() {
assert_eq!(annotation_object_id_alias(), "object_id");
}
#[test]
fn test_validate_annotation_field_names_success() {
let json = r#"{"group_name":"train","object_reference":"obj1"}"#;
assert!(validate_annotation_field_names(json, true, true).is_ok());
}
#[test]
fn test_validate_annotation_field_names_missing_group() {
let json = r#"{"object_reference":"obj1"}"#;
let result = validate_annotation_field_names(json, true, false);
assert!(result.is_err());
assert!(result.unwrap_err().contains("group_name"));
}
#[test]
fn test_validate_annotation_field_names_missing_object_ref() {
let json = r#"{"group_name":"train"}"#;
let result = validate_annotation_field_names(json, false, true);
assert!(result.is_err());
assert!(result.unwrap_err().contains("object_reference"));
}
#[test]
fn test_annotation_serialization_field_names() {
let mut ann = Annotation::new();
ann.set_group(Some("train".to_string()));
ann.set_object_id(Some("obj1".to_string()));
let json = serde_json::to_string(&ann).unwrap();
assert!(validate_annotation_field_names(&json, true, true).is_ok());
}
#[test]
fn test_validate_gps_coordinates_valid() {
assert!(validate_gps_coordinates(37.7749, -122.4194).is_ok()); assert!(validate_gps_coordinates(0.0, 0.0).is_ok()); assert!(validate_gps_coordinates(90.0, 180.0).is_ok()); assert!(validate_gps_coordinates(-90.0, -180.0).is_ok()); }
#[test]
fn test_validate_gps_coordinates_invalid_latitude() {
let result = validate_gps_coordinates(91.0, 0.0);
assert!(result.is_err());
assert!(result.unwrap_err().contains("latitude out of range"));
let result = validate_gps_coordinates(-91.0, 0.0);
assert!(result.is_err());
assert!(result.unwrap_err().contains("latitude out of range"));
}
#[test]
fn test_validate_gps_coordinates_invalid_longitude() {
let result = validate_gps_coordinates(0.0, 181.0);
assert!(result.is_err());
assert!(result.unwrap_err().contains("longitude out of range"));
let result = validate_gps_coordinates(0.0, -181.0);
assert!(result.is_err());
assert!(result.unwrap_err().contains("longitude out of range"));
}
#[test]
fn test_validate_gps_coordinates_non_finite() {
let result = validate_gps_coordinates(f64::NAN, 0.0);
assert!(result.is_err());
assert!(result.unwrap_err().contains("not finite"));
let result = validate_gps_coordinates(0.0, f64::INFINITY);
assert!(result.is_err());
assert!(result.unwrap_err().contains("not finite"));
}
#[test]
fn test_validate_imu_orientation_valid() {
assert!(validate_imu_orientation(0.0, 0.0, 0.0).is_ok());
assert!(validate_imu_orientation(45.0, 30.0, 90.0).is_ok());
assert!(validate_imu_orientation(180.0, 90.0, -180.0).is_ok()); assert!(validate_imu_orientation(-180.0, -90.0, 180.0).is_ok()); }
#[test]
fn test_validate_imu_orientation_invalid_roll() {
let result = validate_imu_orientation(181.0, 0.0, 0.0);
assert!(result.is_err());
assert!(result.unwrap_err().contains("roll out of range"));
let result = validate_imu_orientation(-181.0, 0.0, 0.0);
assert!(result.is_err());
}
#[test]
fn test_validate_imu_orientation_invalid_pitch() {
let result = validate_imu_orientation(0.0, 91.0, 0.0);
assert!(result.is_err());
assert!(result.unwrap_err().contains("pitch out of range"));
let result = validate_imu_orientation(0.0, -91.0, 0.0);
assert!(result.is_err());
}
#[test]
fn test_validate_imu_orientation_non_finite() {
let result = validate_imu_orientation(f64::NAN, 0.0, 0.0);
assert!(result.is_err());
assert!(result.unwrap_err().contains("not finite"));
let result = validate_imu_orientation(0.0, f64::INFINITY, 0.0);
assert!(result.is_err());
let result = validate_imu_orientation(0.0, 0.0, f64::NEG_INFINITY);
assert!(result.is_err());
}
#[test]
#[cfg(feature = "polars")]
fn test_unflatten_polygon_coordinates_single_polygon() {
let coords = vec![1.0, 2.0, 3.0, 4.0];
let result = unflatten_polygon_coordinates(&coords);
assert_eq!(result.len(), 1);
assert_eq!(result[0].len(), 2);
assert_eq!(result[0][0], (1.0, 2.0));
assert_eq!(result[0][1], (3.0, 4.0));
}
#[test]
#[cfg(feature = "polars")]
fn test_unflatten_polygon_coordinates_multiple_polygons() {
let coords = vec![1.0, 2.0, 3.0, 4.0, f32::NAN, 5.0, 6.0, 7.0, 8.0];
let result = unflatten_polygon_coordinates(&coords);
assert_eq!(result.len(), 2);
assert_eq!(result[0].len(), 2);
assert_eq!(result[0][0], (1.0, 2.0));
assert_eq!(result[0][1], (3.0, 4.0));
assert_eq!(result[1].len(), 2);
assert_eq!(result[1][0], (5.0, 6.0));
assert_eq!(result[1][1], (7.0, 8.0));
}
#[test]
#[cfg(feature = "polars")]
fn test_unflatten_polygon_coordinates_roundtrip() {
let flat = vec![1.0, 2.0, 3.0, 4.0, f32::NAN, 5.0, 6.0, 7.0, 8.0];
let result = unflatten_polygon_coordinates(&flat);
let expected = vec![vec![(1.0, 2.0), (3.0, 4.0)], vec![(5.0, 6.0), (7.0, 8.0)]];
assert_eq!(result, expected);
}
#[test]
fn test_flatten_annotation_map_all_types() {
use std::collections::HashMap;
let mut map = HashMap::new();
let mut bbox_ann = Annotation::new();
bbox_ann.set_label(Some("bbox_label".to_string()));
let mut box3d_ann = Annotation::new();
box3d_ann.set_label(Some("box3d_label".to_string()));
let mut mask_ann = Annotation::new();
mask_ann.set_label(Some("mask_label".to_string()));
map.insert("bbox".to_string(), vec![bbox_ann.clone()]);
map.insert("box3d".to_string(), vec![box3d_ann.clone()]);
map.insert("mask".to_string(), vec![mask_ann.clone()]);
let result = flatten_annotation_map(map);
assert_eq!(result.len(), 3);
assert_eq!(result[0].label(), Some(&"bbox_label".to_string()));
assert_eq!(result[1].label(), Some(&"box3d_label".to_string()));
assert_eq!(result[2].label(), Some(&"mask_label".to_string()));
}
#[test]
fn test_flatten_annotation_map_single_type() {
use std::collections::HashMap;
let mut map = HashMap::new();
let mut bbox_ann = Annotation::new();
bbox_ann.set_label(Some("test".to_string()));
map.insert("bbox".to_string(), vec![bbox_ann]);
let result = flatten_annotation_map(map);
assert_eq!(result.len(), 1);
assert_eq!(result[0].label(), Some(&"test".to_string()));
}
#[test]
fn test_flatten_annotation_map_empty() {
use std::collections::HashMap;
let map = HashMap::new();
let result = flatten_annotation_map(map);
assert_eq!(result.len(), 0);
}
#[test]
fn test_flatten_annotation_map_deterministic_order() {
use std::collections::HashMap;
let mut map = HashMap::new();
let mut bbox_ann = Annotation::new();
bbox_ann.set_label(Some("bbox".to_string()));
let mut box3d_ann = Annotation::new();
box3d_ann.set_label(Some("box3d".to_string()));
let mut mask_ann = Annotation::new();
mask_ann.set_label(Some("mask".to_string()));
map.insert("mask".to_string(), vec![mask_ann]);
map.insert("box3d".to_string(), vec![box3d_ann]);
map.insert("bbox".to_string(), vec![bbox_ann]);
let result = flatten_annotation_map(map);
assert_eq!(result.len(), 3);
assert_eq!(result[0].label(), Some(&"bbox".to_string()));
assert_eq!(result[1].label(), Some(&"box3d".to_string()));
assert_eq!(result[2].label(), Some(&"mask".to_string()));
}
#[test]
fn test_box2d_construction_and_accessors() {
let bbox = Box2d::new(10.0, 20.0, 100.0, 50.0);
assert_eq!(
(bbox.left(), bbox.top(), bbox.width(), bbox.height()),
(10.0, 20.0, 100.0, 50.0)
);
assert_eq!((bbox.cx(), bbox.cy()), (60.0, 45.0));
let bbox = Box2d::new(0.0, 0.0, 640.0, 480.0);
assert_eq!(
(bbox.left(), bbox.top(), bbox.width(), bbox.height()),
(0.0, 0.0, 640.0, 480.0)
);
assert_eq!((bbox.cx(), bbox.cy()), (320.0, 240.0));
}
#[test]
fn test_box2d_center_calculation() {
let bbox = Box2d::new(10.0, 20.0, 100.0, 50.0);
assert_eq!(bbox.cx(), 60.0); assert_eq!(bbox.cy(), 45.0); }
#[test]
fn test_box2d_zero_dimensions() {
let bbox = Box2d::new(10.0, 20.0, 0.0, 0.0);
assert_eq!(bbox.cx(), 10.0);
assert_eq!(bbox.cy(), 20.0);
}
#[test]
fn test_box2d_negative_dimensions() {
let bbox = Box2d::new(100.0, 100.0, -50.0, -50.0);
assert_eq!(bbox.width(), -50.0);
assert_eq!(bbox.height(), -50.0);
assert_eq!(bbox.cx(), 75.0); assert_eq!(bbox.cy(), 75.0); }
#[test]
fn test_box3d_construction_and_accessors() {
let bbox = Box3d::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0);
assert_eq!((bbox.cx(), bbox.cy(), bbox.cz()), (1.0, 2.0, 3.0));
assert_eq!(
(bbox.width(), bbox.height(), bbox.length()),
(4.0, 5.0, 6.0)
);
let bbox = Box3d::new(10.0, 20.0, 30.0, 4.0, 6.0, 8.0);
assert_eq!((bbox.left(), bbox.top(), bbox.front()), (8.0, 17.0, 26.0));
let bbox = Box3d::new(0.0, 0.0, 0.0, 2.0, 3.0, 4.0);
assert_eq!((bbox.cx(), bbox.cy(), bbox.cz()), (0.0, 0.0, 0.0));
assert_eq!(
(bbox.width(), bbox.height(), bbox.length()),
(2.0, 3.0, 4.0)
);
assert_eq!((bbox.left(), bbox.top(), bbox.front()), (-1.0, -1.5, -2.0));
}
#[test]
fn test_box3d_center_calculation() {
let bbox = Box3d::new(10.0, 20.0, 30.0, 100.0, 50.0, 40.0);
assert_eq!(bbox.cx(), 10.0);
assert_eq!(bbox.cy(), 20.0);
assert_eq!(bbox.cz(), 30.0);
}
#[test]
fn test_box3d_zero_dimensions() {
let bbox = Box3d::new(5.0, 10.0, 15.0, 0.0, 0.0, 0.0);
assert_eq!(bbox.cx(), 5.0);
assert_eq!(bbox.cy(), 10.0);
assert_eq!(bbox.cz(), 15.0);
assert_eq!((bbox.left(), bbox.top(), bbox.front()), (5.0, 10.0, 15.0));
}
#[test]
fn test_box3d_negative_dimensions() {
let bbox = Box3d::new(100.0, 100.0, 100.0, -50.0, -50.0, -50.0);
assert_eq!(bbox.width(), -50.0);
assert_eq!(bbox.height(), -50.0);
assert_eq!(bbox.length(), -50.0);
assert_eq!(
(bbox.left(), bbox.top(), bbox.front()),
(125.0, 125.0, 125.0)
);
}
#[test]
fn test_polygon_creation_and_deserialization() {
let rings = vec![vec![(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)]];
let polygon = Polygon::new(rings.clone());
assert_eq!(polygon.rings, rings);
let legacy = serde_json::json!({
"polygon": {
"polygon": [[
[0.0_f32, 0.0_f32],
[1.0_f32, 0.0_f32],
[1.0_f32, 1.0_f32]
]]
}
});
#[derive(serde::Deserialize)]
struct Wrapper {
polygon: Polygon,
}
let parsed: Wrapper = serde_json::from_value(legacy).unwrap();
assert_eq!(parsed.polygon.rings.len(), 1);
assert_eq!(parsed.polygon.rings[0].len(), 3);
}
#[test]
fn test_sample_construction_and_accessors() {
let sample = Sample::new();
assert_eq!(sample.id(), None);
assert_eq!(sample.image_name(), None);
assert_eq!(sample.width(), None);
assert_eq!(sample.height(), None);
let mut sample = Sample::new();
sample.image_name = Some("test.jpg".to_string());
sample.width = Some(1920);
sample.height = Some(1080);
sample.group = Some("group1".to_string());
assert_eq!(sample.image_name(), Some("test.jpg"));
assert_eq!(sample.width(), Some(1920));
assert_eq!(sample.height(), Some(1080));
assert_eq!(sample.group(), Some(&"group1".to_string()));
}
#[test]
fn test_sample_name_extraction_from_image_name() {
let mut sample = Sample::new();
sample.image_name = Some("test_image.jpg".to_string());
assert_eq!(sample.name(), Some("test_image".to_string()));
sample.image_name = Some("test_image.camera.jpg".to_string());
assert_eq!(sample.name(), Some("test_image".to_string()));
sample.image_name = Some("test_image".to_string());
assert_eq!(sample.name(), Some("test_image".to_string()));
}
#[test]
fn test_annotation_construction_and_setters() {
let ann = Annotation::new();
assert_eq!(ann.sample_id(), None);
assert_eq!(ann.label(), None);
assert_eq!(ann.box2d(), None);
assert_eq!(ann.box3d(), None);
assert_eq!(ann.polygon(), None);
let mut ann = Annotation::new();
ann.set_label(Some("car".to_string()));
assert_eq!(ann.label(), Some(&"car".to_string()));
ann.set_label_index(Some(42));
assert_eq!(ann.label_index(), Some(42));
let bbox = Box2d::new(10.0, 20.0, 100.0, 50.0);
ann.set_box2d(Some(bbox.clone()));
assert!(ann.box2d().is_some());
assert_eq!(ann.box2d().unwrap().left(), 10.0);
}
#[test]
fn test_sample_file_with_url_and_filename() {
let file = SampleFile::with_url(
"lidar.pcd".to_string(),
"https://example.com/file.pcd".to_string(),
);
assert_eq!(file.file_type(), "lidar.pcd");
assert_eq!(file.url(), Some("https://example.com/file.pcd"));
assert_eq!(file.filename(), None);
let file = SampleFile::with_filename("image".to_string(), "test.jpg".to_string());
assert_eq!(file.file_type(), "image");
assert_eq!(file.filename(), Some("test.jpg"));
assert_eq!(file.url(), None);
}
#[test]
fn test_sample_deserializes_gps_imu_from_sensors() {
use serde_json::json;
let sample_json = json!({
"id": 123,
"image_name": "test.jpg",
"sensors": [
{"gps": {"lat": 37.7749, "lon": -122.4194}},
{"imu": {"roll": 1.5, "pitch": 2.5, "yaw": 3.5}},
{"radar.pcd": "https://example.com/radar.pcd"}
]
});
let sample: Sample = serde_json::from_value(sample_json).unwrap();
assert!(sample.location.is_some());
let location = sample.location.as_ref().unwrap();
assert!(location.gps.is_some());
let gps = location.gps.as_ref().unwrap();
assert!((gps.lat - 37.7749).abs() < 0.0001);
assert!((gps.lon - (-122.4194)).abs() < 0.0001);
assert!(location.imu.is_some());
let imu = location.imu.as_ref().unwrap();
assert!((imu.roll - 1.5).abs() < 0.0001);
assert!((imu.pitch - 2.5).abs() < 0.0001);
assert!((imu.yaw - 3.5).abs() < 0.0001);
assert_eq!(sample.files.len(), 1);
assert_eq!(sample.files[0].file_type(), "radar.pcd");
assert_eq!(sample.files[0].url(), Some("https://example.com/radar.pcd"));
}
#[test]
fn test_sample_deserializes_gps_only() {
use serde_json::json;
let sample_json = json!({
"id": 456,
"sensors": [
{"gps": {"lat": 40.7128, "lon": -74.0060}}
]
});
let sample: Sample = serde_json::from_value(sample_json).unwrap();
assert!(sample.location.is_some());
let location = sample.location.as_ref().unwrap();
assert!(location.gps.is_some());
assert!(location.imu.is_none());
let gps = location.gps.as_ref().unwrap();
assert!((gps.lat - 40.7128).abs() < 0.0001);
assert!((gps.lon - (-74.0060)).abs() < 0.0001);
}
#[test]
fn test_sample_deserializes_without_location() {
use serde_json::json;
let sample_json = json!({
"id": 789,
"sensors": [
{"radar.pcd": "https://example.com/radar.pcd"},
{"lidar.pcd": "https://example.com/lidar.pcd"}
]
});
let sample: Sample = serde_json::from_value(sample_json).unwrap();
assert!(sample.location.is_none());
assert_eq!(sample.files.len(), 2);
}
#[test]
fn test_label_deserialization_and_accessors() {
use serde_json::json;
let label_json = json!({
"id": 123,
"dataset_id": 456,
"index": 5,
"name": "car"
});
let label: Label = serde_json::from_value(label_json).unwrap();
assert_eq!(label.id(), 123);
assert_eq!(label.index(), 5);
assert_eq!(label.name(), "car");
assert_eq!(label.to_string(), "car");
assert_eq!(format!("{}", label), "car");
let label_json = json!({
"id": 1,
"dataset_id": 100,
"index": 0,
"name": "person"
});
let label: Label = serde_json::from_value(label_json).unwrap();
assert_eq!(format!("{}", label), "person");
}
#[test]
fn test_annotation_serialization_with_mask_and_box() {
let polygon = vec![vec![
(0.0_f32, 0.0_f32),
(1.0_f32, 0.0_f32),
(1.0_f32, 1.0_f32),
]];
let mut annotation = Annotation::new();
annotation.set_label(Some("test".to_string()));
annotation.set_box2d(Some(Box2d::new(10.0, 20.0, 30.0, 40.0)));
annotation.set_polygon(Some(Polygon::new(polygon)));
let mut sample = Sample::new();
sample.annotations.push(annotation);
let json = serde_json::to_value(&sample).unwrap();
let annotations = json
.get("annotations")
.and_then(|value| value.as_array())
.expect("annotations serialized as array");
assert_eq!(annotations.len(), 1);
let annotation_json = annotations[0].as_object().expect("annotation object");
assert!(annotation_json.contains_key("box2d"));
assert!(
annotation_json.contains_key("mask"),
"Annotation must serialise polygon under 'mask' key for samples.populate2; got keys: {:?}",
annotation_json.keys().collect::<Vec<_>>()
);
assert!(!annotation_json.contains_key("polygon"));
assert!(!annotation_json.contains_key("x"));
assert!(
annotation_json
.get("mask")
.and_then(|value| value.as_array())
.is_some()
);
}
#[test]
fn test_frame_number_negative_one_deserializes_as_none() {
let json = r#"{
"uuid": "test-uuid",
"frame_number": -1
}"#;
let sample: Sample = serde_json::from_str(json).unwrap();
assert_eq!(sample.frame_number, None);
}
#[test]
fn test_frame_number_positive_value_deserializes_correctly() {
let json = r#"{
"uuid": "test-uuid",
"frame_number": 5
}"#;
let sample: Sample = serde_json::from_str(json).unwrap();
assert_eq!(sample.frame_number, Some(5));
}
#[test]
fn test_frame_number_null_deserializes_as_none() {
let json = r#"{
"uuid": "test-uuid",
"frame_number": null
}"#;
let sample: Sample = serde_json::from_str(json).unwrap();
assert_eq!(sample.frame_number, None);
}
#[test]
fn test_frame_number_missing_deserializes_as_none() {
let json = r#"{
"uuid": "test-uuid"
}"#;
let sample: Sample = serde_json::from_str(json).unwrap();
assert_eq!(sample.frame_number, None);
}
#[cfg(feature = "polars")]
#[test]
fn test_samples_dataframe_preserves_group_for_samples_without_annotations() {
use polars::prelude::*;
let mut sample_with_ann = Sample::new();
sample_with_ann.image_name = Some("annotated.jpg".to_string());
sample_with_ann.group = Some("train".to_string());
let mut annotation = Annotation::new();
annotation.set_label(Some("car".to_string()));
annotation.set_box2d(Some(Box2d::new(0.1, 0.2, 0.3, 0.4)));
annotation.set_name(Some("annotated".to_string()));
sample_with_ann.annotations = vec![annotation];
let mut sample_no_ann = Sample::new();
sample_no_ann.image_name = Some("unannotated.jpg".to_string());
sample_no_ann.group = Some("val".to_string()); sample_no_ann.annotations = vec![];
let samples = vec![sample_with_ann, sample_no_ann];
let df = samples_dataframe(&samples).expect("Failed to create DataFrame");
assert_eq!(df.height(), 2, "Expected 2 rows (one per sample)");
let groups_col = df.column("group").expect("group column should exist");
let groups_cast = groups_col.cast(&DataType::String).expect("cast to string");
let groups = groups_cast.str().expect("as str");
let names_col = df.column("name").expect("name column should exist");
let names_cast = names_col.cast(&DataType::String).expect("cast to string");
let names = names_cast.str().expect("as str");
let mut found_unannotated = false;
for idx in 0..df.height() {
if let Some(name) = names.get(idx)
&& name == "unannotated"
{
found_unannotated = true;
let group = groups.get(idx);
assert_eq!(
group,
Some("val"),
"CRITICAL: Sample 'unannotated' without annotations must have group 'val'"
);
}
}
assert!(
found_unannotated,
"Did not find 'unannotated' sample in DataFrame - \
this means samples without annotations are not being included"
);
}
#[cfg(feature = "polars")]
#[test]
fn test_samples_dataframe_includes_all_samples_even_without_annotations() {
let mut sample1 = Sample::new();
sample1.image_name = Some("with_ann.jpg".to_string());
sample1.group = Some("train".to_string());
let mut ann = Annotation::new();
ann.set_label(Some("person".to_string()));
ann.set_box2d(Some(Box2d::new(0.0, 0.0, 0.5, 0.5)));
ann.set_name(Some("with_ann".to_string()));
sample1.annotations = vec![ann];
let mut sample2 = Sample::new();
sample2.image_name = Some("no_ann_train.jpg".to_string());
sample2.group = Some("train".to_string());
sample2.annotations = vec![];
let mut sample3 = Sample::new();
sample3.image_name = Some("no_ann_val.jpg".to_string());
sample3.group = Some("val".to_string());
sample3.annotations = vec![];
let samples = vec![sample1, sample2, sample3];
let df = samples_dataframe(&samples).expect("Failed to create DataFrame");
assert_eq!(
df.height(),
3,
"Expected 3 rows (samples without annotations should create one row each)"
);
let groups_col = df.column("group").expect("group column");
let groups_cast = groups_col.cast(&polars::prelude::DataType::String).unwrap();
let groups = groups_cast.str().unwrap();
let mut train_count = 0;
let mut val_count = 0;
for idx in 0..df.height() {
match groups.get(idx) {
Some("train") => train_count += 1,
Some("val") => val_count += 1,
other => panic!(
"Unexpected group value at row {}: {:?}. \
All samples should have their group preserved.",
idx, other
),
}
}
assert_eq!(train_count, 2, "Expected 2 samples in 'train' group");
assert_eq!(val_count, 1, "Expected 1 sample in 'val' group");
}
#[cfg(feature = "polars")]
#[test]
fn test_samples_dataframe_group_is_not_null_for_samples_with_group() {
let mut sample = Sample::new();
sample.image_name = Some("test.jpg".to_string());
sample.group = Some("test_group".to_string());
sample.annotations = vec![];
let df = samples_dataframe(&[sample]).expect("Failed to create DataFrame");
let groups_col = df.column("group").expect("group column");
assert_eq!(
groups_col.null_count(),
0,
"Sample with group='test_group' but no annotations has NULL group in DataFrame. \
This is a bug in samples_dataframe - group must be preserved!"
);
}
#[cfg(feature = "polars")]
#[test]
fn test_samples_dataframe_group_consistent_across_all_rows_for_same_image() {
use polars::prelude::*;
let mut sample = Sample::new();
sample.image_name = Some("multi_ann.jpg".to_string());
sample.group = Some("train".to_string());
let mut ann1 = Annotation::new();
ann1.set_label(Some("car".to_string()));
ann1.set_box2d(Some(Box2d::new(0.1, 0.2, 0.3, 0.4)));
ann1.set_name(Some("multi_ann".to_string()));
let mut ann2 = Annotation::new();
ann2.set_label(Some("truck".to_string()));
ann2.set_box2d(Some(Box2d::new(0.5, 0.6, 0.2, 0.2)));
ann2.set_name(Some("multi_ann".to_string()));
let mut ann3 = Annotation::new();
ann3.set_label(Some("bus".to_string()));
ann3.set_box2d(Some(Box2d::new(0.7, 0.8, 0.1, 0.1)));
ann3.set_name(Some("multi_ann".to_string()));
sample.annotations = vec![ann1, ann2, ann3];
let df = samples_dataframe(&[sample]).expect("Failed to create DataFrame");
assert_eq!(df.height(), 3, "Expected 3 rows (one per annotation)");
let groups_col = df.column("group").expect("group column");
let groups_cast = groups_col.cast(&DataType::String).expect("cast to string");
let groups = groups_cast.str().expect("as str");
assert_eq!(groups_col.null_count(), 0, "No rows should have null group");
for idx in 0..df.height() {
let group = groups.get(idx);
assert_eq!(
group,
Some("train"),
"Row {} should have group 'train', got {:?}. \
All rows for the same image must have identical group values.",
idx,
group
);
}
}
#[cfg(feature = "polars")]
#[test]
fn test_samples_dataframe_lvis_columns() {
let mut ann = Annotation::new();
ann.set_name(Some("test".to_string()));
ann.set_label(Some("person".to_string()));
ann.set_label_index(Some(1));
ann.set_iscrowd(Some(false));
ann.set_category_frequency(Some("f".to_string()));
let sample = Sample {
image_name: Some("test.jpg".to_string()),
width: Some(640),
height: Some(480),
annotations: vec![ann],
neg_label_indices: Some(vec![5, 12]),
not_exhaustive_label_indices: Some(vec![3]),
..Default::default()
};
let df = samples_dataframe(&[sample]).unwrap();
assert!(df.column("iscrowd").is_ok(), "iscrowd column missing");
assert!(
df.column("category_frequency").is_ok(),
"category_frequency column missing"
);
assert!(
df.column("neg_label_indices").is_ok(),
"neg_label_indices column missing"
);
assert!(
df.column("not_exhaustive_label_indices").is_ok(),
"not_exhaustive_label_indices column missing"
);
assert!(
df.column("polygon").is_err(),
"polygon column should be dropped (all null)"
);
assert!(
df.column("box2d").is_err(),
"box2d column should be dropped (all null)"
);
}
#[test]
fn test_annotation_serialization_skips_lvis_fields() {
let ann = Annotation::new();
let json = serde_json::to_string(&ann).unwrap();
assert!(
!json.contains("iscrowd"),
"iscrowd should be omitted when None"
);
assert!(
!json.contains("category_frequency"),
"category_frequency should be omitted when None"
);
}
#[test]
fn test_sample_serialization_skips_lvis_fields() {
let sample = Sample::new();
let json = serde_json::to_string(&sample).unwrap();
assert!(
!json.contains("neg_label_indices"),
"neg_label_indices should be omitted when None"
);
assert!(
!json.contains("not_exhaustive_label_indices"),
"not_exhaustive_label_indices should be omitted when None"
);
}
#[test]
fn test_annotation_score_fields() {
let mut ann = Annotation::default();
assert!(ann.box2d_score.is_none());
assert!(ann.polygon_score.is_none());
assert!(ann.mask_score.is_none());
ann.box2d_score = Some(0.95);
ann.polygon_score = Some(0.87);
ann.mask_score = Some(0.42);
assert_eq!(ann.box2d_score, Some(0.95));
assert_eq!(ann.polygon_score, Some(0.87));
assert_eq!(ann.mask_score, Some(0.42));
}
#[test]
fn test_timing_struct() {
let timing = Timing {
load: Some(1_000_000),
preprocess: Some(2_000_000),
inference: Some(50_000_000),
decode: Some(3_000_000),
};
assert_eq!(timing.inference, Some(50_000_000));
let default = Timing::default();
assert!(default.load.is_none());
}
#[test]
fn test_sample_timing() {
let mut sample = Sample::default();
assert!(sample.timing.is_none());
sample.timing = Some(Timing {
load: Some(1_000_000),
..Default::default()
});
assert!(sample.timing.is_some());
}
#[cfg(feature = "polars")]
#[test]
fn test_samples_dataframe_polygon_column() {
let mut ann = Annotation::new();
ann.set_name(Some("test".to_string()));
ann.set_polygon(Some(Polygon::new(vec![vec![
(0.1, 0.2),
(0.3, 0.4),
(0.5, 0.6),
]])));
let sample = Sample {
image_name: Some("test.jpg".to_string()),
annotations: vec![ann],
..Default::default()
};
let df = samples_dataframe(&[sample]).unwrap();
assert!(df.column("polygon").is_ok(), "Should have polygon column");
if let Ok(mask_col) = df.column("mask") {
assert_eq!(
mask_col.dtype(),
&polars::prelude::DataType::Binary,
"mask column must be Binary type (PNG bytes), not float list"
);
}
}
#[cfg(feature = "polars")]
#[test]
fn test_samples_dataframe_column_presence_drops_all_null() {
let sample = Sample {
image_name: Some("test.jpg".to_string()),
..Default::default()
};
let df = samples_dataframe(&[sample]).unwrap();
assert!(df.column("name").is_ok(), "name column must always exist");
assert!(
df.column("polygon").is_err(),
"All-null polygon should be dropped"
);
assert!(
df.column("box2d").is_err(),
"All-null box2d should be dropped"
);
assert!(
df.column("box3d").is_err(),
"All-null box3d should be dropped"
);
assert!(
df.column("mask").is_err(),
"All-null mask should be dropped"
);
assert!(
df.column("box2d_score").is_err(),
"All-null score columns should be dropped"
);
assert!(
df.column("timing").is_err(),
"All-null timing should be dropped"
);
}
#[cfg(feature = "polars")]
#[test]
fn test_samples_dataframe_score_columns() {
let mut ann = Annotation::new();
ann.set_name(Some("test".to_string()));
ann.set_box2d(Some(Box2d::new(0.1, 0.2, 0.3, 0.4)));
ann.set_box2d_score(Some(0.95));
ann.set_polygon(Some(Polygon::new(vec![vec![
(0.0, 0.0),
(1.0, 0.0),
(1.0, 1.0),
]])));
ann.set_polygon_score(Some(0.87));
let sample = Sample {
image_name: Some("test.jpg".to_string()),
annotations: vec![ann],
..Default::default()
};
let df = samples_dataframe(&[sample]).unwrap();
assert!(
df.column("box2d_score").is_ok(),
"box2d_score column missing"
);
assert!(
df.column("polygon_score").is_ok(),
"polygon_score column missing"
);
assert!(
df.column("box3d_score").is_err(),
"box3d_score should be dropped (all null)"
);
assert!(
df.column("mask_score").is_err(),
"mask_score should be dropped (all null)"
);
let box2d_scores = df.column("box2d_score").unwrap();
let val = box2d_scores.f32().unwrap().get(0);
assert_eq!(val, Some(0.95));
}
#[cfg(feature = "polars")]
#[test]
fn test_samples_dataframe_timing_column() {
let mut ann = Annotation::new();
ann.set_name(Some("test".to_string()));
ann.set_label(Some("person".to_string()));
let sample = Sample {
image_name: Some("test.jpg".to_string()),
annotations: vec![ann],
timing: Some(Timing {
load: Some(1_000_000),
preprocess: Some(2_000_000),
inference: Some(50_000_000),
decode: Some(3_000_000),
}),
..Default::default()
};
let df = samples_dataframe(&[sample]).unwrap();
assert!(df.column("timing").is_ok(), "timing column missing");
let timing_col = df.column("timing").unwrap();
assert!(
matches!(timing_col.dtype(), polars::prelude::DataType::Struct(..)),
"timing column should be Struct type, got {:?}",
timing_col.dtype()
);
}
#[cfg(feature = "polars")]
#[test]
fn test_samples_dataframe_mask_binary_column() {
let mut ann = Annotation::new();
ann.set_name(Some("test".to_string()));
let pixels = vec![0u8, 255, 128, 64];
let mask_data = MaskData::encode(&pixels, 2, 2, 8).unwrap();
ann.set_mask(Some(mask_data));
let sample = Sample {
image_name: Some("test.jpg".to_string()),
annotations: vec![ann],
..Default::default()
};
let df = samples_dataframe(&[sample]).unwrap();
let mask_col = df.column("mask").unwrap();
assert_eq!(
mask_col.dtype(),
&polars::prelude::DataType::Binary,
"mask column should be Binary"
);
assert_eq!(mask_col.null_count(), 0, "mask value should not be null");
}
#[test]
fn test_annotation_type_seg_alias() {
assert_eq!(
AnnotationType::try_from("seg").unwrap(),
AnnotationType::Polygon,
"\"seg\" should map to Polygon for server round-trip"
);
}
#[cfg(feature = "polars")]
#[test]
fn test_samples_dataframe_timing_partial() {
let mut ann = Annotation::new();
ann.set_name(Some("test".to_string()));
ann.set_label(Some("person".to_string()));
let sample = Sample {
image_name: Some("test.jpg".to_string()),
annotations: vec![ann],
timing: Some(Timing {
load: Some(1000),
..Default::default()
}),
..Default::default()
};
let df = samples_dataframe(&[sample]).unwrap();
assert!(
df.column("timing").is_ok(),
"timing column should be present when partial data exists"
);
}
#[cfg(feature = "polars")]
#[test]
fn test_samples_dataframe_timing_all_none_omitted() {
let mut ann = Annotation::new();
ann.set_name(Some("test".to_string()));
ann.set_label(Some("person".to_string()));
let sample = Sample {
image_name: Some("test.jpg".to_string()),
annotations: vec![ann],
timing: None,
..Default::default()
};
let df = samples_dataframe(&[sample]).unwrap();
assert!(
df.column("timing").is_err(),
"timing column should be omitted when all samples have timing: None"
);
}
#[cfg(feature = "polars")]
#[test]
fn test_samples_dataframe_score_zero_survives() {
let mut ann = Annotation::new();
ann.set_name(Some("test".to_string()));
ann.set_box2d(Some(Box2d::new(0.1, 0.2, 0.3, 0.4)));
ann.set_box2d_score(Some(0.0));
let sample = Sample {
image_name: Some("test.jpg".to_string()),
annotations: vec![ann],
..Default::default()
};
let df = samples_dataframe(&[sample]).unwrap();
let scores = df.column("box2d_score").unwrap();
let val = scores.f32().unwrap().get(0);
assert_eq!(val, Some(0.0), "score of 0.0 should survive as non-null");
}
#[cfg(feature = "polars")]
#[test]
fn test_samples_dataframe_score_one_survives() {
let mut ann = Annotation::new();
ann.set_name(Some("test".to_string()));
ann.set_box2d(Some(Box2d::new(0.1, 0.2, 0.3, 0.4)));
ann.set_box2d_score(Some(1.0));
let sample = Sample {
image_name: Some("test.jpg".to_string()),
annotations: vec![ann],
..Default::default()
};
let df = samples_dataframe(&[sample]).unwrap();
let scores = df.column("box2d_score").unwrap();
let val = scores.f32().unwrap().get(0);
assert_eq!(val, Some(1.0), "score of 1.0 should survive as non-null");
}
}