use std::collections::HashSet;
use std::path::PathBuf;
use std::time::Duration;
use base64::{Engine, engine::general_purpose::STANDARD as BASE64_ENGINE};
use core_types::{BufferId, Timestamp, TransportDomain};
use data_model::{
ControlEnvelope, DataEnvelope, DataPayload, ExternalBufferRef, Packet, PacketHeader,
SchemaId, SchemaVersion,
};
use replay_core::{FilePacketRecorder, FileReplaySession, RecordedEntry, ReplaySession};
use serde_json::Value;
use crate::constants::DEFAULT_BAG_PATH;
use crate::helpers::{
has_flag, option_value, parse_domain_option, parse_replay_speed, parse_u64_option,
parse_usize_option, replay_speed_label, resolve_runtime_endpoint,
};
use crate::gateway::{
STATUS_OP_TOPIC_POLL, STATUS_OP_TOPIC_PUBLISH, STATUS_OP_TOPIC_SUBSCRIBE,
STATUS_OP_TOPIC_UNSUBSCRIBE, STATUS_SERVICE_NAME, StatusServiceResponse,
build_op_payload_request, make_udp_service_client, next_request_id, validate_response,
};
const DEFAULT_DAEMON_ENDPOINT: &str = "127.0.0.1:7588";
fn subscribe_topic_stream(
endpoint: &str,
timeout_ms: u64,
topic_name: &str,
max_batch: usize,
) -> Result<(core_api::UdpServiceClient, String), String> {
let client = make_udp_service_client(endpoint.to_string(), timeout_ms)?;
let request_id = next_request_id();
let request = build_op_payload_request(
STATUS_OP_TOPIC_SUBSCRIBE,
serde_json::json!({
"topic": topic_name,
"max_batch": max_batch,
}),
);
let response: StatusServiceResponse = client
.call_json(STATUS_SERVICE_NAME, request_id, &request)
.map_err(|err| format!("topic subscribe to {endpoint} failed: {err}"))?;
validate_response(&response, request_id, STATUS_OP_TOPIC_SUBSCRIBE)?;
let stream_id = response
.op_result
.and_then(|item| {
item.get("stream_id")
.and_then(Value::as_str)
.map(|value| value.to_string())
})
.ok_or_else(|| String::from("topic subscribe response missing stream_id"))?;
Ok((client, stream_id))
}
fn poll_topic_frames(
client: &core_api::UdpServiceClient,
endpoint: &str,
stream_id: &str,
max_items: usize,
) -> Result<Vec<Value>, String> {
let request_id = next_request_id();
let request = build_op_payload_request(
STATUS_OP_TOPIC_POLL,
serde_json::json!({
"stream_id": stream_id,
"max_items": max_items,
}),
);
let response: StatusServiceResponse = client
.call_json(STATUS_SERVICE_NAME, request_id, &request)
.map_err(|err| format!("topic poll to {endpoint} failed: {err}"))?;
validate_response(&response, request_id, STATUS_OP_TOPIC_POLL)?;
response
.op_result
.and_then(|result| result.get("frames").and_then(Value::as_array).cloned())
.ok_or_else(|| String::from("topic poll response missing frames"))
}
fn unsubscribe_topic_stream(
client: &core_api::UdpServiceClient,
endpoint: &str,
stream_id: &str,
) -> Result<(), String> {
let request_id = next_request_id();
let request = build_op_payload_request(
STATUS_OP_TOPIC_UNSUBSCRIBE,
serde_json::json!({
"stream_id": stream_id,
}),
);
let response: StatusServiceResponse = client
.call_json(STATUS_SERVICE_NAME, request_id, &request)
.map_err(|err| format!("topic unsubscribe to {endpoint} failed: {err}"))?;
validate_response(&response, request_id, STATUS_OP_TOPIC_UNSUBSCRIBE)
}
fn parse_transport_domain(value: Option<&str>, fallback: TransportDomain) -> TransportDomain {
match value {
Some("network") => TransportDomain::Network,
Some("local") => TransportDomain::Local,
_ => fallback,
}
}
fn payload_to_bytes(frame: &Value) -> Result<Vec<u8>, String> {
if let Some(payload_base64) = frame.get("payload_base64").and_then(Value::as_str) {
return BASE64_ENGINE
.decode(payload_base64)
.map_err(|err| format!("decode payload_base64 failed: {err}"));
}
let Some(payload) = frame.get("payload") else {
return Ok(Vec::new());
};
if let Some(text) = payload.as_str() {
return Ok(text.as_bytes().to_vec());
}
serde_json::to_string(payload)
.map(|text| text.into_bytes())
.map_err(|err| format!("serialize payload to bytes failed: {err}"))
}
fn external_ref_from_frame(frame: &Value) -> Option<ExternalBufferRef> {
let external = frame.get("external_ref")?;
let buffer_id = external.get("buffer_id")?.as_u64()?;
let offset = usize::try_from(external.get("offset")?.as_u64()?).ok()?;
let len = usize::try_from(external.get("len")?.as_u64()?).ok()?;
Some(ExternalBufferRef {
buffer_id: BufferId(buffer_id),
offset,
len,
})
}
fn packet_from_frame(
frame: &Value,
topic: &str,
fallback_domain: TransportDomain,
fallback_sequence: u64,
) -> Result<Packet, String> {
let domain = parse_transport_domain(
frame.get("transport").and_then(Value::as_str),
fallback_domain,
);
let sequence = frame
.get("sequence")
.and_then(Value::as_u64)
.unwrap_or(fallback_sequence);
let captured = frame
.get("captured_at_unix_nanos")
.and_then(Value::as_u64)
.map(u128::from)
.unwrap_or_else(|| Timestamp::now().0);
let schema_id = frame
.get("schema_id")
.and_then(Value::as_str)
.map(ToString::to_string)
.unwrap_or_else(|| format!("robotrt.topic.{topic}"));
let schema_version = frame
.get("schema_version")
.and_then(Value::as_u64)
.and_then(|value| u16::try_from(value).ok())
.unwrap_or(1);
let header = PacketHeader {
version: 1,
domain,
session_id: None,
stream_id: None,
sequence,
ack: None,
timestamp: Timestamp(captured),
schema_id: SchemaId::new(schema_id),
schema_version: SchemaVersion(schema_version),
};
let packet_kind = frame
.get("packet_kind")
.and_then(Value::as_str)
.unwrap_or("data");
if packet_kind == "control" {
let label = frame
.get("control_label")
.and_then(Value::as_str)
.unwrap_or("control")
.to_string();
let payload = payload_to_bytes(frame)?;
return Ok(Packet::Control(ControlEnvelope {
header,
label,
payload,
}));
}
let payload = if let Some(external_ref) = external_ref_from_frame(frame) {
DataPayload::External(external_ref)
} else {
DataPayload::Inline(payload_to_bytes(frame)?)
};
Ok(Packet::Data(DataEnvelope { header, payload }))
}
fn packet_transport(packet: &Packet) -> &'static str {
let domain = match packet {
Packet::Control(control) => control.header.domain,
Packet::Data(data) => data.header.domain,
};
match domain {
TransportDomain::Network => "network",
TransportDomain::Local => "local",
}
}
fn packet_sequence(packet: &Packet) -> u64 {
match packet {
Packet::Control(control) => control.header.sequence,
Packet::Data(data) => data.header.sequence,
}
}
fn publish_recorded_entry(
client: &core_api::UdpServiceClient,
endpoint: &str,
entry: &RecordedEntry,
) -> Result<(), String> {
let topic = entry.topic.clone();
let transport = packet_transport(&entry.packet);
let sequence = packet_sequence(&entry.packet);
let (packet_kind, control_label, payload, payload_base64, external_ref, schema_id, schema_version) =
match &entry.packet {
Packet::Control(control) => (
"control",
Some(control.label.clone()),
Value::Null,
Some(BASE64_ENGINE.encode(&control.payload)),
Value::Null,
Some(control.header.schema_id.0.clone()),
Some(control.header.schema_version.0),
),
Packet::Data(data) => {
let (payload_value, payload_base64, external_ref_value) = match &data.payload {
DataPayload::Inline(bytes) => (
Value::Null,
Some(BASE64_ENGINE.encode(bytes)),
Value::Null,
),
DataPayload::External(buffer) => (
Value::Null,
None,
serde_json::json!({
"buffer_id": buffer.buffer_id.0,
"offset": buffer.offset,
"len": buffer.len,
}),
),
};
(
"data",
None,
payload_value,
payload_base64,
external_ref_value,
Some(data.header.schema_id.0.clone()),
Some(data.header.schema_version.0),
)
}
};
let request_id = next_request_id();
let request = build_op_payload_request(
STATUS_OP_TOPIC_PUBLISH,
serde_json::json!({
"topic": topic,
"payload": payload,
"payload_base64": payload_base64,
"packet_kind": packet_kind,
"control_label": control_label,
"external_ref": external_ref,
"schema_id": schema_id,
"schema_version": schema_version,
"transport": transport,
"sequence": sequence,
"captured_at_unix_nanos": entry.captured_at.0,
}),
);
let response: StatusServiceResponse = client
.call_json(STATUS_SERVICE_NAME, request_id, &request)
.map_err(|err| format!("topic publish to {endpoint} failed: {err}"))?;
validate_response(&response, request_id, STATUS_OP_TOPIC_PUBLISH)
}
pub fn bag_record(args: &[String]) -> Result<(), String> {
let output = option_value(args, "--output")
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from(DEFAULT_BAG_PATH));
let count = parse_usize_option(args, "--count", 20)?;
let topic = option_value(args, "--topic").unwrap_or_else(|| String::from("/robotrt/default"));
let timeout_ms = parse_u64_option(args, "--timeout-ms", 1000)?;
let max_items = parse_usize_option(args, "--max-items", 32)?.clamp(1, 256);
let poll_interval_ms = parse_u64_option(args, "--interval-ms", 50)?;
let fallback_domain = parse_domain_option(option_value(args, "--domain"))?;
let endpoint = resolve_runtime_endpoint(
args,
DEFAULT_DAEMON_ENDPOINT,
"bag record requires --endpoint in embedded mode",
)?;
if let Some(parent) = output.parent() {
std::fs::create_dir_all(parent)
.map_err(|err| format!("create output directory {} failed: {err}", parent.display()))?;
}
let mut recorder = FilePacketRecorder::create(&output)
.map_err(|err| format!("create bag file {} failed: {err}", output.display()))?;
let (client, stream_id) = subscribe_topic_stream(&endpoint, timeout_ms, &topic, max_items)?;
let mut recorded = 0usize;
let mut fallback_sequence = 1u64;
while recorded < count {
let frames = poll_topic_frames(&client, &endpoint, &stream_id, max_items)?;
if frames.is_empty() {
std::thread::sleep(Duration::from_millis(poll_interval_ms));
continue;
}
for frame in frames {
let packet = packet_from_frame(&frame, &topic, fallback_domain, fallback_sequence)?;
fallback_sequence = fallback_sequence.saturating_add(1);
let captured = frame
.get("captured_at_unix_nanos")
.and_then(Value::as_u64)
.map(u128::from)
.unwrap_or_else(|| Timestamp::now().0);
let entry = RecordedEntry::new(Timestamp(captured), topic.clone(), packet);
recorder
.append(&entry)
.map_err(|err| format!("append bag entry failed: {err}"))?;
recorded += 1;
if recorded >= count {
break;
}
}
}
let _ = unsubscribe_topic_stream(&client, &endpoint, &stream_id);
recorder
.flush()
.map_err(|err| format!("flush bag file {} failed: {err}", output.display()))?;
println!("bag recorded: {}", output.display());
println!("entries: {}", recorded);
println!("topic: {}", topic);
println!("endpoint: {}", endpoint);
Ok(())
}
pub fn bag_play(args: &[String]) -> Result<(), String> {
let input = option_value(args, "--input")
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from(DEFAULT_BAG_PATH));
let speed = parse_replay_speed(option_value(args, "--speed"))?;
let json = has_flag(args, "--json");
let topic_override = option_value(args, "--topic");
let timeout_ms = parse_u64_option(args, "--timeout-ms", 1000)?;
let endpoint = resolve_runtime_endpoint(
args,
DEFAULT_DAEMON_ENDPOINT,
"bag play requires --endpoint in embedded mode",
)?;
let replay_file = FileReplaySession::open(&input)
.map_err(|err| format!("open bag file {} failed: {err}", input.display()))?;
let entries = replay_file.entries();
let mut session = ReplaySession::new(entries, speed.clone());
let started_at = std::time::Instant::now();
let client = make_udp_service_client(endpoint.clone(), timeout_ms)?;
let mut emitted = 0usize;
let mut publish_error: Option<String> = None;
while !session.is_done() && publish_error.is_none() {
let count = session.pump(|entry| {
if publish_error.is_some() {
return;
}
let publish_entry = if let Some(topic) = topic_override.as_ref() {
RecordedEntry::new(entry.captured_at, topic.clone(), entry.packet.clone())
} else {
entry.clone()
};
if let Err(err) = publish_recorded_entry(&client, &endpoint, &publish_entry) {
publish_error = Some(err);
} else {
emitted += 1;
}
});
if count == 0 {
std::thread::sleep(Duration::from_millis(1));
}
}
if let Some(err) = publish_error {
return Err(err);
}
let elapsed_ns = started_at.elapsed().as_nanos() as u64;
let topic_count = entries
.iter()
.map(|entry| entry.topic.as_str())
.collect::<HashSet<_>>()
.len();
let capture_span_ns = match (entries.first(), entries.last()) {
(Some(first), Some(last)) => {
(last.captured_at.0.saturating_sub(first.captured_at.0)) as u64
}
_ => 0,
};
if json {
let payload = serde_json::json!({
"input": input,
"endpoint": endpoint,
"speed": replay_speed_label(&speed),
"total_entries": entries.len(),
"emitted_entries": emitted,
"topic_count": topic_count,
"capture_span_ns": capture_span_ns,
"replay_elapsed_ns": elapsed_ns,
});
let out = serde_json::to_string_pretty(&payload)
.map_err(|err| format!("serialize bag replay summary failed: {err}"))?;
println!("{out}");
} else {
println!("RobotRT Bag Replay");
println!("input: {}", input.display());
println!("endpoint: {}", endpoint);
println!("speed: {}", replay_speed_label(&speed));
println!("total_entries: {}", entries.len());
println!("emitted_entries: {}", emitted);
println!("topic_count: {}", topic_count);
println!("capture_span_ns: {}", capture_span_ns);
println!("replay_elapsed_ns: {}", elapsed_ns);
}
Ok(())
}