use ahash::HashMap;
use arrow2::{array::TryExtend, datatypes::DataType};
use arrow2_convert::{
deserialize::ArrowDeserialize, field::ArrowField, serialize::ArrowSerialize, ArrowDeserialize,
ArrowField, ArrowSerialize,
};
use re_types::components::{ClassId, KeypointId};
use crate::{LegacyClassId, LegacyColor, LegacyKeypointId, LegacyLabel};
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[derive(Clone, Debug, Default, PartialEq, Eq, ArrowField, ArrowSerialize, ArrowDeserialize)]
pub struct AnnotationInfo {
pub id: u16,
pub label: Option<LegacyLabel>,
pub color: Option<LegacyColor>,
}
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct ClassDescription {
pub info: AnnotationInfo,
pub keypoint_map: HashMap<KeypointId, AnnotationInfo>,
pub keypoint_connections: Vec<(KeypointId, KeypointId)>,
}
#[derive(ArrowField, ArrowSerialize, ArrowDeserialize)]
struct KeypointPairArrow {
keypoint0: LegacyKeypointId,
keypoint1: LegacyKeypointId,
}
#[derive(ArrowField, ArrowSerialize, ArrowDeserialize)]
struct ClassDescriptionArrow {
info: AnnotationInfo,
keypoint_map: Vec<AnnotationInfo>,
keypoint_connections: Vec<KeypointPairArrow>,
}
impl From<&ClassDescription> for ClassDescriptionArrow {
fn from(v: &ClassDescription) -> Self {
ClassDescriptionArrow {
info: v.info.clone(),
keypoint_map: v.keypoint_map.values().cloned().collect(),
keypoint_connections: v
.keypoint_connections
.iter()
.map(|(k0, k1)| KeypointPairArrow {
keypoint0: (*k0).into(),
keypoint1: (*k1).into(),
})
.collect(),
}
}
}
impl From<ClassDescriptionArrow> for ClassDescription {
fn from(v: ClassDescriptionArrow) -> Self {
ClassDescription {
info: v.info,
keypoint_map: v
.keypoint_map
.into_iter()
.map(|elem| (KeypointId(elem.id), elem))
.collect(),
keypoint_connections: v
.keypoint_connections
.into_iter()
.map(|elem| (elem.keypoint0.into(), elem.keypoint1.into()))
.collect(),
}
}
}
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct AnnotationContext {
pub class_map: HashMap<ClassId, ClassDescription>,
}
impl re_log_types::LegacyComponent for AnnotationContext {
#[inline]
fn legacy_name() -> re_log_types::ComponentName {
"rerun.annotation_context".into()
}
}
#[derive(ArrowField, ArrowSerialize, ArrowDeserialize)]
pub struct ClassMapElemArrow {
class_id: LegacyClassId,
class_description: ClassDescriptionArrow,
}
type AnnotationContextArrow = Vec<ClassMapElemArrow>;
impl From<&AnnotationContext> for AnnotationContextArrow {
#[inline]
fn from(v: &AnnotationContext) -> Self {
v.class_map
.iter()
.map(|(class_id, class_description)| ClassMapElemArrow {
class_id: (*class_id).into(),
class_description: class_description.into(),
})
.collect()
}
}
impl From<Vec<ClassMapElemArrow>> for AnnotationContext {
#[inline]
fn from(v: AnnotationContextArrow) -> Self {
AnnotationContext {
class_map: v
.into_iter()
.map(|elem| (elem.class_id.into(), elem.class_description.into()))
.collect(),
}
}
}
impl ArrowField for AnnotationContext {
type Type = Self;
#[inline]
fn data_type() -> DataType {
<AnnotationContextArrow as ArrowField>::data_type()
}
}
impl ArrowSerialize for AnnotationContext {
type MutableArrayType = <AnnotationContextArrow as ArrowSerialize>::MutableArrayType;
#[inline]
fn new_array() -> Self::MutableArrayType {
AnnotationContextArrow::new_array()
}
#[inline]
fn arrow_serialize(v: &Self, array: &mut Self::MutableArrayType) -> arrow2::error::Result<()> {
let v: AnnotationContextArrow = v.into();
array.mut_values().try_extend(v.iter().map(Some))?;
array.try_push_valid()
}
}
impl ArrowDeserialize for AnnotationContext {
type ArrayType = <AnnotationContextArrow as ArrowDeserialize>::ArrayType;
#[inline]
fn arrow_deserialize(
v: <&Self::ArrayType as IntoIterator>::Item,
) -> Option<<Self as ArrowField>::Type> {
let v = <AnnotationContextArrow as ArrowDeserialize>::arrow_deserialize(v);
v.map(|v| v.into())
}
}
re_log_types::component_legacy_shim!(AnnotationContext);
#[test]
fn test_context_roundtrip() {
use arrow2::array::Array;
use arrow2_convert::{deserialize::TryIntoCollection, serialize::TryIntoArrow};
let context = AnnotationContext {
class_map: vec![(
ClassId(13),
ClassDescription {
info: AnnotationInfo {
id: 32,
label: Some(LegacyLabel("hello".to_owned())),
color: Some(LegacyColor(0x123456)),
},
keypoint_map: vec![
(
KeypointId(43),
AnnotationInfo {
id: 43,
label: Some(LegacyLabel("head".to_owned())),
color: None,
},
),
(
KeypointId(94),
AnnotationInfo {
id: 94,
label: Some(LegacyLabel("leg".to_owned())),
color: Some(LegacyColor(0x654321)),
},
),
]
.into_iter()
.collect(),
keypoint_connections: vec![(KeypointId(43), KeypointId(94))].into_iter().collect(),
},
)]
.into_iter()
.collect(),
};
let context_in = vec![context];
let array: Box<dyn Array> = context_in.try_into_arrow().unwrap();
let context_out: Vec<AnnotationContext> =
TryIntoCollection::try_into_collection(array).unwrap();
assert_eq!(context_in, context_out);
}