from collections import defaultdict
from pathlib import Path
from typing import Any
from mcap.reader import make_reader
from edgefirst.schemas import (
edgefirst_msgs,
foxglove_msgs,
geometry_msgs,
sensor_msgs,
)
SCHEMA_MAP: dict[str, type] = {
"sensor_msgs/msg/CameraInfo": sensor_msgs.CameraInfo,
"sensor_msgs/msg/CompressedImage": sensor_msgs.CompressedImage,
"sensor_msgs/msg/Image": sensor_msgs.Image,
"sensor_msgs/msg/Imu": sensor_msgs.Imu,
"sensor_msgs/msg/NavSatFix": sensor_msgs.NavSatFix,
"sensor_msgs/msg/PointCloud2": sensor_msgs.PointCloud2,
"sensor_msgs/msg/PointField": sensor_msgs.PointField,
"geometry_msgs/msg/Transform": geometry_msgs.Transform,
"geometry_msgs/msg/TransformStamped": geometry_msgs.TransformStamped,
"geometry_msgs/msg/Vector3": geometry_msgs.Vector3,
"geometry_msgs/msg/Quaternion": geometry_msgs.Quaternion,
"geometry_msgs/msg/Pose": geometry_msgs.Pose,
"geometry_msgs/msg/PoseStamped": geometry_msgs.PoseStamped, "geometry_msgs/msg/Point": geometry_msgs.Point,
"geometry_msgs/msg/Twist": geometry_msgs.Twist,
"geometry_msgs/msg/TwistStamped": geometry_msgs.TwistStamped,
"foxglove_msgs/msg/CompressedVideo": foxglove_msgs.CompressedVideo,
"foxglove_msgs/msg/CompressedImage": foxglove_msgs.CompressedImage, "foxglove_msgs/msg/FrameTransform": foxglove_msgs.FrameTransform, "foxglove_msgs/msg/LocationFix": foxglove_msgs.LocationFix, "foxglove_msgs/msg/Log": foxglove_msgs.Log, "foxglove_msgs/msg/PointCloud": foxglove_msgs.PointCloud, "foxglove_msgs/msg/RawImage": foxglove_msgs.RawImage, "edgefirst_msgs/msg/Box": edgefirst_msgs.Box,
"edgefirst_msgs/msg/Detect": edgefirst_msgs.Detect,
"edgefirst_msgs/msg/DmaBuffer": edgefirst_msgs.DmaBuffer,
"edgefirst_msgs/msg/Mask": edgefirst_msgs.Mask,
"edgefirst_msgs/msg/ModelInfo": edgefirst_msgs.ModelInfo,
"edgefirst_msgs/msg/RadarCube": edgefirst_msgs.RadarCube,
"edgefirst_msgs/msg/RadarInfo": edgefirst_msgs.RadarInfo,
"edgefirst_msgs/msg/Track": edgefirst_msgs.Track,
}
def get_mcap_summary(mcap_path: Path) -> dict[str, Any]:
with open(mcap_path, "rb") as f:
reader = make_reader(f)
summary = reader.get_summary()
if not summary:
return {"schemas": {}, "channels": {}, "statistics": None}
return {
"schemas": {s.name: s for s in summary.schemas.values()},
"channels": {c.topic: c for c in summary.channels.values()},
"statistics": summary.statistics,
}
def iter_mcap_messages(mcap_path: Path):
with open(mcap_path, "rb") as f:
reader = make_reader(f)
for schema, channel, message in reader.iter_messages():
yield schema, channel, message
class TestMcapSchemaSupport:
def test_all_schemas_supported(self, mcap_file: Path):
summary = get_mcap_summary(mcap_file)
unsupported = []
for schema_name in summary["schemas"]:
if schema_name not in SCHEMA_MAP:
unsupported.append(schema_name)
assert not unsupported, (
f"Unsupported schema types in {mcap_file.name}: {unsupported}\n"
f"Add these to SCHEMA_MAP in test_mcap.py"
)
class TestMcapDeserialization:
def test_deserialize_all_messages(self, mcap_file: Path):
errors = []
message_counts = defaultdict(int)
for schema, channel, message in iter_mcap_messages(mcap_file):
schema_name = schema.name if schema else "unknown"
message_counts[schema_name] += 1
if schema_name not in SCHEMA_MAP:
errors.append(
f"Unsupported schema: {schema_name} (topic: {channel.topic})"
)
continue
cls = SCHEMA_MAP[schema_name]
try:
msg = cls.deserialize(message.data)
assert msg is not None, (
f"Deserialization returned None for {schema_name}"
)
except Exception as e:
errors.append(
f"Failed to deserialize {schema_name} (topic: {channel.topic}): {e}"
)
total = sum(message_counts.values())
print(f"\nDeserialized {total} messages from {mcap_file.name}:")
for schema_name, count in sorted(message_counts.items()):
print(f" {schema_name}: {count}")
assert not errors, (
f"Deserialization errors in {mcap_file.name}:\n"
+ "\n".join(f" - {e}" for e in errors[:10])
+ (f"\n ... and {len(errors) - 10} more" if len(errors) > 10 else "")
)
class TestMcapRoundTrip:
def test_roundtrip_all_messages(self, mcap_file: Path):
errors = []
success_count = 0
for schema, channel, message in iter_mcap_messages(mcap_file):
schema_name = schema.name if schema else "unknown"
if schema_name not in SCHEMA_MAP:
continue
cls = SCHEMA_MAP[schema_name]
try:
msg1 = cls.deserialize(message.data)
cdr_bytes = msg1.serialize()
if cdr_bytes != message.data:
errors.append(
f"Byte mismatch for {schema_name} "
f"(topic: {channel.topic}): "
f"original {len(message.data)} bytes, "
f"reserialized {len(cdr_bytes)} bytes"
)
continue
msg2 = cls.deserialize(cdr_bytes)
self._compare_messages(msg1, msg2, schema_name, channel.topic)
success_count += 1
except Exception as e:
errors.append(
f"Round-trip failed for {schema_name} (topic: {channel.topic}): {e}"
)
print(f"\nRound-trip validated {success_count} messages")
assert not errors, (
f"Round-trip errors in {mcap_file.name}:\n"
+ "\n".join(f" - {e}" for e in errors[:10])
+ (f"\n ... and {len(errors) - 10} more" if len(errors) > 10 else "")
)
def _compare_messages(
self,
msg1: Any,
msg2: Any,
schema_name: str,
topic: str,
):
if hasattr(msg1, "header") and msg1.header is not None:
assert msg1.header.stamp.sec == msg2.header.stamp.sec, (
f"Header stamp.sec mismatch: {msg1.header.stamp.sec} "
f"!= {msg2.header.stamp.sec}"
)
assert msg1.header.stamp.nanosec == msg2.header.stamp.nanosec
assert msg1.header.frame_id == msg2.header.frame_id
if hasattr(msg1, "width") and hasattr(msg1, "height"):
assert msg1.width == msg2.width
assert msg1.height == msg2.height
if hasattr(msg1, "data"):
if isinstance(msg1.data, (bytes, list)):
assert len(msg1.data) == len(msg2.data), (
f"Data length mismatch: {len(msg1.data)} != {len(msg2.data)}"
)
if "RadarCube" in schema_name:
assert list(msg1.shape) == list(msg2.shape)
assert list(msg1.layout) == list(msg2.layout)
assert msg1.is_complex == msg2.is_complex
if "NavSatFix" in schema_name:
assert msg1.latitude == msg2.latitude
assert msg1.longitude == msg2.longitude
assert msg1.altitude == msg2.altitude
if "Imu" in schema_name:
assert msg1.angular_velocity.x == msg2.angular_velocity.x
assert msg1.linear_acceleration.x == msg2.linear_acceleration.x
class TestMcapFieldValidation:
def test_validate_timestamps(self, mcap_file: Path):
errors = []
for schema, channel, message in iter_mcap_messages(mcap_file):
schema_name = schema.name if schema else "unknown"
if schema_name not in SCHEMA_MAP:
continue
cls = SCHEMA_MAP[schema_name]
try:
msg = cls.deserialize(message.data)
if hasattr(msg, "header") and msg.header is not None:
stamp = msg.header.stamp
if stamp.sec > 0 and stamp.sec < 946684800:
pass
elif stamp.sec < 0:
errors.append(
f"Negative timestamp in {schema_name} "
f"(topic: {channel.topic}): sec={stamp.sec}"
)
if stamp.nanosec >= 1_000_000_000:
errors.append(
f"Invalid nanosec in {schema_name} "
f"(topic: {channel.topic}): "
f"nanosec={stamp.nanosec}"
)
except Exception as e:
errors.append(f"Validation error for {schema_name}: {e}")
assert not errors, (
f"Timestamp validation errors in {mcap_file.name}:\n"
+ "\n".join(f" - {e}" for e in errors)
)
def test_validate_dimensions(self, mcap_file: Path):
errors = []
for schema, channel, message in iter_mcap_messages(mcap_file):
schema_name = schema.name if schema else "unknown"
if schema_name not in SCHEMA_MAP:
continue
cls = SCHEMA_MAP[schema_name]
try:
msg = cls.deserialize(message.data)
if hasattr(msg, "width"):
if msg.width < 0:
errors.append(f"Negative width in {schema_name}: {msg.width}")
if hasattr(msg, "height"):
if msg.height < 0:
errors.append(f"Negative height in {schema_name}: {msg.height}")
if hasattr(msg, "point_step"):
if msg.point_step <= 0:
errors.append(
f"Invalid point_step in {schema_name}: {msg.point_step}"
)
except Exception as e:
errors.append(f"Validation error for {schema_name}: {e}")
assert not errors, (
f"Dimension validation errors in {mcap_file.name}:\n"
+ "\n".join(f" - {e}" for e in errors)
)