use std::path::{Path, PathBuf};
use std::sync::Arc;
use ahash::{HashMap, HashMapExt as _, HashSet, HashSetExt as _};
use anyhow::{Context as _, bail};
use arrow::array::{Array as _, Float64Array, ListArray, StringArray, StructArray};
use arrow::buffer::OffsetBuffer;
use arrow::compute::cast;
use arrow::datatypes::{DataType, Field, Fields};
use itertools::Itertools as _;
use re_arrow_util::ArrowArrayDowncastRef as _;
use re_chunk::EntityPath;
use re_log_types::EntityPathPart;
use re_sdk_types::Loggable as _;
use re_sdk_types::archetypes::Transform3D;
use re_sdk_types::datatypes::{Quaternion, Vec3D};
use urdf_rs::{Geometry, Joint, Link, Material, Robot};
use super::joint_transform;
const DEFAULT_TF_STATIC_ENTITY_PATH: &str = "tf_static";
fn geometry_type_name(geometry: &Geometry) -> &'static str {
match geometry {
Geometry::Mesh { .. } => "mesh",
Geometry::Box { .. } => "box",
Geometry::Cylinder { .. } => "cylinder",
Geometry::Capsule { .. } => "capsule",
Geometry::Sphere { .. } => "sphere",
}
}
pub(crate) struct UrdfLogPaths {
pub root: EntityPath,
pub visual_root: EntityPath,
pub collision_root: EntityPath,
pub transforms: EntityPath,
}
impl UrdfLogPaths {
pub fn new(robot_name: &str, entity_path_prefix: Option<EntityPath>) -> Self {
let root = match entity_path_prefix {
Some(prefix) => prefix / EntityPath::from_single_string(robot_name),
None => EntityPath::from_single_string(robot_name),
};
let visual_root = root.clone() / EntityPathPart::new("visual_geometries");
let collision_root = root.clone() / EntityPathPart::new("collision_geometries");
let transforms = EntityPath::from_single_string(DEFAULT_TF_STATIC_ENTITY_PATH);
Self {
root,
visual_root,
collision_root,
transforms,
}
}
}
pub struct UrdfTree {
pub(crate) urdf_dir: Option<PathBuf>,
pub(crate) log_paths: UrdfLogPaths,
name: String,
root: Link,
joints: Vec<Joint>,
links: HashMap<String, Link>,
children: HashMap<String, Vec<Joint>>,
materials: HashMap<String, Material>,
frame_prefix: Option<String>,
}
impl UrdfTree {
pub fn from_file_path<P: AsRef<Path>>(
path: P,
entity_path_prefix: Option<EntityPath>,
) -> anyhow::Result<Self> {
let path = path.as_ref();
let robot = urdf_rs::read_file(path)?;
let urdf_dir = path.parent().map(|p| p.to_path_buf());
Self::new(robot, urdf_dir, entity_path_prefix)
}
pub fn new(
robot: Robot,
urdf_dir: Option<PathBuf>,
entity_path_prefix: Option<EntityPath>,
) -> anyhow::Result<Self> {
let urdf_rs::Robot {
name,
links,
joints,
materials,
} = robot;
let materials = materials
.into_iter()
.map(|material| (material.name.clone(), material))
.collect::<HashMap<_, _>>();
let links: HashMap<String, Link> = links
.into_iter()
.map(|link| (link.name.clone(), link))
.collect();
let mut children = HashMap::<String, Vec<Joint>>::new();
let mut child_links = HashSet::<String>::new();
for joint in &joints {
children
.entry(joint.parent.link.clone())
.or_default()
.push(joint.clone());
child_links.insert(joint.child.link.clone());
}
let roots = links
.iter()
.filter(|(name, _)| !child_links.contains(*name))
.map(|(_, link)| link)
.collect_vec();
let root = match roots.len() {
0 => {
bail!("No root link found in URDF");
}
1 => roots[0].clone(),
_ => {
bail!("Multiple roots in URDF");
}
};
for joint in &joints {
if !links.contains_key(&joint.child.link) {
bail!(
"Joint '{}' references unknown child link '{}'",
joint.name,
joint.child.link
);
}
if !links.contains_key(&joint.parent.link) {
bail!(
"Joint '{}' references unknown parent link '{}'",
joint.name,
joint.parent.link
);
}
}
let log_paths = UrdfLogPaths::new(&name, entity_path_prefix);
Ok(Self {
urdf_dir,
name,
root: root.clone(),
joints,
links,
children,
materials,
log_paths,
frame_prefix: None,
})
}
pub fn with_frame_prefix(mut self, prefix: impl Into<String>) -> Self {
self.frame_prefix = Some(prefix.into());
self
}
pub fn with_static_transform_entity(mut self, entity_path: impl Into<EntityPath>) -> Self {
self.log_paths.transforms = entity_path.into();
self
}
pub fn name(&self) -> &str {
&self.name
}
pub fn frame_prefix(&self) -> Option<&str> {
self.frame_prefix.as_deref()
}
pub fn apply_frame_prefix(&self, name: &str) -> String {
match &self.frame_prefix {
Some(prefix) => format!("{prefix}{name}"),
None => name.to_owned(),
}
}
pub fn compute_joint_transform(
&self,
joint: &Joint,
value: f64,
clamp: bool,
) -> Result<Transform3D, joint_transform::Error> {
let result = joint_transform::internal::compute_joint_transform(joint, value, clamp)?;
if let Some(warning) = &result.warning {
re_log::warn!("{warning}");
}
Ok(Transform3D::update_fields()
.with_translation(result.translation.to_array())
.with_quaternion(result.quaternion.to_array())
.with_parent_frame(self.apply_frame_prefix(&result.parent_frame))
.with_child_frame(self.apply_frame_prefix(&result.child_frame)))
}
pub fn compute_joint_transform_batches(
&self,
names: &ListArray,
values: &ListArray,
clamp: bool,
) -> anyhow::Result<ListArray> {
if names.len() != values.len() {
bail!(
"joint name and value arrays must have the same number of rows, got {} and {}",
names.len(),
values.len()
);
}
let joint_names = names
.values()
.try_downcast_array_ref::<StringArray>()
.with_context(|| {
format!(
"joint names must be list<utf8>, got {:?}",
names.data_type()
)
})?;
let joint_values = cast(values.values().as_ref(), &DataType::Float64)
.context("failed to cast joint values to float64")?
.try_downcast_array::<Float64Array>()?;
let mut offsets = Vec::<i32>::with_capacity(names.len() + 1);
let mut translations = Vec::<Vec3D>::new();
let mut quaternions = Vec::<Quaternion>::new();
let mut parent_frames = Vec::<String>::new();
let mut child_frames = Vec::<String>::new();
offsets.push(0);
for row in 0..names.len() {
if names.is_null(row) || values.is_null(row) {
bail!(
"joint name and value lists must not contain null rows, got null at row {row}"
);
}
let names_start = names.value_offsets()[row] as usize;
let names_end = names.value_offsets()[row + 1] as usize;
let values_start = values.value_offsets()[row] as usize;
let values_end = values.value_offsets()[row + 1] as usize;
let num_names = names_end - names_start;
let num_values = values_end - values_start;
if num_names != num_values {
bail!(
"joint name and value lists must have the same length at row {row}, got {num_names} and {num_values}"
);
}
for (name_index, value_index) in (names_start..names_end).zip(values_start..values_end)
{
if joint_names.is_null(name_index) || joint_values.is_null(value_index) {
bail!(
"joint name and value lists must not contain null values, got null at row {row}"
);
}
let joint_name = joint_names.value(name_index);
let joint = self
.get_joint_by_name(joint_name)
.with_context(|| format!("URDF does not contain joint {joint_name:?}"))?;
let result = joint_transform::internal::compute_joint_transform(
joint,
joint_values.value(value_index),
clamp,
)?;
if let Some(warning) = &result.warning {
re_log::warn!("{warning}");
}
translations.push(Vec3D([
result.translation.x,
result.translation.y,
result.translation.z,
]));
quaternions.push(Quaternion([
result.quaternion.x,
result.quaternion.y,
result.quaternion.z,
result.quaternion.w,
]));
parent_frames.push(self.apply_frame_prefix(&result.parent_frame));
child_frames.push(self.apply_frame_prefix(&result.child_frame));
}
offsets.push(
i32::try_from(parent_frames.len())
.context("too many joint transforms for Arrow list offsets")?,
);
}
let translation_array = Vec3D::to_arrow_opt(translations.iter().map(Some))?;
let quaternion_array = Quaternion::to_arrow_opt(quaternions.iter().map(Some))?;
let struct_fields = Fields::from(vec![
Field::new("translation", translation_array.data_type().clone(), false),
Field::new("quaternion", quaternion_array.data_type().clone(), false),
Field::new("parent_frame", DataType::Utf8, false),
Field::new("child_frame", DataType::Utf8, false),
]);
let struct_array = StructArray::try_new(
struct_fields.clone(),
vec![
translation_array,
quaternion_array,
Arc::new(StringArray::from(parent_frames)),
Arc::new(StringArray::from(child_frames)),
],
None,
)?;
Ok(ListArray::new(
Arc::new(Field::new_list_field(
DataType::Struct(struct_fields),
false,
)),
OffsetBuffer::new(offsets.into()),
Arc::new(struct_array),
None,
))
}
pub fn root(&self) -> &Link {
&self.root
}
pub fn joints(&self) -> impl Iterator<Item = &Joint> {
self.joints.iter()
}
pub fn get_joint_by_name(&self, joint_name: &str) -> Option<&Joint> {
self.joints.iter().find(|j| j.name == joint_name)
}
pub fn get_link(&self, link_name: &str) -> Option<&Link> {
self.links.get(link_name)
}
pub fn get_children(&self, link_name: &str) -> Option<&Vec<Joint>> {
self.children.get(link_name)
}
pub fn get_visual_geometries(
&self,
link: &Link,
) -> Option<Vec<(EntityPath, &urdf_rs::Visual)>> {
let link = self.links.get(&link.name)?;
if link.visual.is_empty() {
return None;
}
let visual_base_path_for_link =
self.log_paths.visual_root.clone() / EntityPathPart::new(&link.name);
link.visual
.iter()
.enumerate()
.map(|(i, visual)| {
let visual_name = visual.name.clone().unwrap_or_else(|| format!("visual_{i}"));
(
visual_base_path_for_link.clone() / EntityPathPart::new(visual_name),
visual,
)
})
.collect::<Vec<_>>()
.into()
}
pub fn get_collision_geometries(
&self,
link: &Link,
) -> Option<Vec<(EntityPath, &urdf_rs::Collision)>> {
let link = self.links.get(&link.name)?;
if link.collision.is_empty() {
return None;
}
link.collision
.iter()
.enumerate()
.map(|(i, collision)| {
let geometry_type = geometry_type_name(&collision.geometry);
let collision_name = collision
.name
.clone()
.unwrap_or_else(|| format!("collision_{i}"));
let path = self.log_paths.collision_root.clone()
/ EntityPathPart::new(geometry_type)
/ EntityPathPart::new(&link.name)
/ EntityPathPart::new(collision_name);
(path, collision)
})
.collect::<Vec<_>>()
.into()
}
pub fn get_joint_child(&self, joint: &Joint) -> &Link {
&self.links[&joint.child.link] }
pub(crate) fn get_material(&self, name: &str) -> Option<&Material> {
self.materials.get(name)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_minimal_link(name: &str) -> urdf_rs::Link {
urdf_rs::Link {
name: name.to_owned(),
inertial: Default::default(),
visual: vec![],
collision: vec![],
}
}
#[test]
fn test_apply_frame_prefix_without_prefix() {
let robot = urdf_rs::Robot {
name: "test".to_owned(),
links: vec![make_minimal_link("base")],
joints: vec![],
materials: vec![],
};
let tree = UrdfTree::new(robot, None, None).unwrap();
assert_eq!(tree.apply_frame_prefix("base"), "base");
assert!(tree.frame_prefix().is_none());
}
#[test]
fn test_apply_frame_prefix_with_prefix() {
let robot = urdf_rs::Robot {
name: "test".to_owned(),
links: vec![make_minimal_link("base")],
joints: vec![],
materials: vec![],
};
let tree = UrdfTree::new(robot, None, None)
.unwrap()
.with_frame_prefix("left_arm/");
assert_eq!(tree.apply_frame_prefix("base"), "left_arm/base");
assert_eq!(tree.frame_prefix(), Some("left_arm/"));
}
#[test]
fn test_with_static_transform_entity_overrides_default_path() {
let robot = urdf_rs::Robot {
name: "test".to_owned(),
links: vec![make_minimal_link("base")],
joints: vec![],
materials: vec![],
};
let tree = UrdfTree::new(robot, None, None)
.unwrap()
.with_static_transform_entity("robot/tf");
assert_eq!(tree.log_paths.transforms.to_string(), "/robot/tf");
}
#[test]
fn test_static_transforms_path_defaults_to_tf_static() {
let paths = UrdfLogPaths::new("test", None);
assert_eq!(paths.transforms.to_string(), "/tf_static");
}
#[test]
fn test_static_transforms_path_is_unaffected_by_prefix() {
let paths = UrdfLogPaths::new("test", Some(EntityPath::parse_forgiving("robots/left_arm")));
assert_eq!(paths.transforms.to_string(), "/tf_static");
}
}