use crate::lerobot::common::{
LEROBOT_DATASET_IGNORED_COLUMNS, LeRobotDataset, load_and_stream_versioned,
load_episode_depth_images, load_episode_images, load_scalar,
};
use crate::lerobot::{DType, EpisodeIndex, Feature, LeRobotDatasetTask, LeRobotError, TaskIndex};
use std::borrow::Cow;
use std::collections::BTreeMap;
use std::fs::File;
use std::io::BufReader;
use std::path::{Path, PathBuf};
use ahash::HashMap;
use anyhow::{Context as _, anyhow};
use arrow::array::{Int64Array, RecordBatch};
use crossbeam::channel::Sender;
use itertools::Either;
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use re_arrow_util::ArrowArrayDowncastRef as _;
use re_chunk::{Chunk, RowId, TimeColumn, TimeInt, TimePoint, Timeline};
use re_log_types::ApplicationId;
use re_sdk_types::{
archetypes::{AssetVideo, TextDocument, VideoFrameReference},
components::VideoTimestamp,
};
use crate::{DataLoaderError, LoadedData};
#[derive(Debug, Clone)]
pub struct LeRobotDatasetV2 {
pub path: PathBuf,
pub metadata: LeRobotDatasetMetadata,
}
impl LeRobotDatasetV2 {
pub fn load_from_directory(path: impl AsRef<Path>) -> Result<Self, LeRobotError> {
let path = path.as_ref();
let metadatapath = path.join("meta");
let metadata = LeRobotDatasetMetadata::load_from_directory(&metadatapath)?;
Ok(Self {
path: path.to_path_buf(),
metadata,
})
}
pub fn read_episode_data(&self, episode: EpisodeIndex) -> Result<RecordBatch, LeRobotError> {
if !self.metadata.episodes.contains_key(&episode) {
return Err(LeRobotError::InvalidEpisodeIndex(episode));
}
let episode_data_path = self.metadata.info.episode_data_path(episode)?;
let episode_parquet_file = self.path.join(episode_data_path);
let file = File::open(&episode_parquet_file)
.map_err(|err| LeRobotError::io(err, episode_parquet_file))?;
let mut reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?;
reader
.next()
.transpose()
.map(|batch| batch.ok_or(LeRobotError::EmptyEpisode(episode)))
.map_err(LeRobotError::Arrow)?
}
pub fn read_episode_video_contents(
&self,
observation_key: &str,
episode: EpisodeIndex,
) -> Result<Cow<'_, [u8]>, LeRobotError> {
let video_file = self.metadata.info.video_path(observation_key, episode)?;
let videopath = self.path.join(video_file);
let contents = {
re_tracing::profile_scope!("fs::read");
std::fs::read(&videopath).map_err(|err| LeRobotError::io(err, videopath))?
};
Ok(Cow::Owned(contents))
}
pub fn task_by_index(&self, task: TaskIndex) -> Option<&LeRobotDatasetTask> {
self.metadata.tasks.get(task.0)
}
}
#[derive(Debug, Clone)]
pub struct LeRobotDatasetMetadata {
pub info: LeRobotDatasetInfo,
pub episodes: BTreeMap<EpisodeIndex, LeRobotDatasetEpisode>,
pub tasks: Vec<LeRobotDatasetTask>,
}
impl LeRobotDatasetMetadata {
pub fn episode_count(&self) -> usize {
self.episodes.len()
}
pub fn get_episode(&self, episode: EpisodeIndex) -> Option<&LeRobotDatasetEpisode> {
self.episodes.get(&episode)
}
pub fn iter_episode_indices(&self) -> impl Iterator<Item = EpisodeIndex> {
self.episodes.keys().copied()
}
pub fn load_from_directory(metadir: impl AsRef<Path>) -> Result<Self, LeRobotError> {
let metadir = metadir.as_ref();
let info = LeRobotDatasetInfo::load_from_json_file(metadir.join("info.json"))?;
let mut episodes_vec: Vec<LeRobotDatasetEpisode> =
load_jsonl_file(metadir.join("episodes.jsonl"))?;
let mut tasks = load_jsonl_file(metadir.join("tasks.jsonl"))?;
episodes_vec.sort_by_key(|e: &LeRobotDatasetEpisode| e.index);
let episodes = episodes_vec
.into_iter()
.map(|episode| (episode.index, episode))
.collect::<BTreeMap<EpisodeIndex, LeRobotDatasetEpisode>>();
tasks.sort_by_key(|e: &LeRobotDatasetTask| e.index);
Ok(Self {
info,
episodes,
tasks,
})
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct LeRobotDatasetInfo {
pub robot_type: Option<String>,
pub codebase_version: String,
pub total_episodes: usize,
pub total_frames: usize,
pub total_tasks: usize,
pub total_videos: usize,
pub total_chunks: usize,
pub chunks_size: usize,
pub data_path: String,
pub video_path: Option<String>,
pub image_path: Option<String>,
pub fps: usize,
pub features: HashMap<String, Feature>,
}
impl LeRobotDatasetInfo {
pub fn load_from_json_file(filepath: impl AsRef<Path>) -> Result<Self, LeRobotError> {
let info_file = File::open(filepath.as_ref())
.map_err(|err| LeRobotError::io(err, filepath.as_ref()))?;
let reader = BufReader::new(info_file);
serde_json::from_reader(reader).map_err(|err| err.into())
}
pub fn feature(&self, feature_key: &str) -> Option<&Feature> {
self.features.get(feature_key)
}
pub fn chunk_index(&self, episode: EpisodeIndex) -> Result<usize, LeRobotError> {
if episode.0 > self.total_episodes {
return Err(LeRobotError::InvalidEpisodeIndex(episode));
}
let chunk_idx = episode.0 / self.chunks_size;
if chunk_idx < self.total_chunks {
Ok(chunk_idx)
} else {
Err(LeRobotError::InvalidChunkIndex(chunk_idx))
}
}
pub fn episode_data_path(&self, episode: EpisodeIndex) -> Result<PathBuf, LeRobotError> {
let chunk = self.chunk_index(episode)?;
Ok(self
.data_path
.replace("{episode_chunk:03d}", &format!("{chunk:03}"))
.replace("{episode_index:06d}", &format!("{:06}", episode.0))
.into())
}
pub fn video_path(
&self,
feature_key: &str,
episode: EpisodeIndex,
) -> Result<PathBuf, LeRobotError> {
let chunk = self.chunk_index(episode)?;
let feature = self
.feature(feature_key)
.ok_or(LeRobotError::InvalidFeatureKey(feature_key.to_owned()))?;
if feature.dtype != DType::Video {
return Err(LeRobotError::InvalidFeatureDtype {
key: feature_key.to_owned(),
expected: DType::Video,
actual: feature.dtype,
});
}
self.video_path
.as_ref()
.ok_or_else(|| LeRobotError::MissingDatasetInfo("video_path".to_owned()))
.map(|path| {
path.replace("{episode_chunk:03d}", &format!("{chunk:03}"))
.replace("{episode_index:06d}", &format!("{:06}", episode.0))
.replace("{video_key}", feature_key)
.into()
})
}
}
#[cfg(not(target_arch = "wasm32"))]
fn load_jsonl_file<D>(filepath: impl AsRef<Path>) -> Result<Vec<D>, LeRobotError>
where
D: DeserializeOwned,
{
use crate::lerobot::LeRobotError;
let entries = std::fs::read_to_string(filepath.as_ref())
.map_err(|err| LeRobotError::io(err, filepath.as_ref()))?
.lines()
.map(|line| serde_json::from_str(line))
.collect::<Result<Vec<D>, _>>()?;
Ok(entries)
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LeRobotDatasetEpisode {
#[serde(rename = "episode_index")]
pub index: EpisodeIndex,
pub tasks: Vec<String>,
pub length: u32,
}
pub fn load_and_stream(
dataset: &LeRobotDatasetV2,
application_id: &ApplicationId,
tx: &Sender<LoadedData>,
loader_name: &str,
) {
load_and_stream_versioned(dataset, application_id, tx, loader_name);
}
fn load_episode(
dataset: &LeRobotDatasetV2,
episode: EpisodeIndex,
) -> Result<Vec<Chunk>, DataLoaderError> {
let data = dataset
.read_episode_data(episode)
.map_err(|err| anyhow!("Reading data for episode {} failed: {err}", episode.0))?;
let frame_indices = data
.column_by_name("frame_index")
.ok_or_else(|| anyhow!("Failed to get frame index column in LeRobot dataset"))?
.clone();
let timeline = re_log_types::Timeline::new_sequence("frame_index");
let times: &arrow::buffer::ScalarBuffer<i64> = frame_indices
.downcast_array_ref::<Int64Array>()
.ok_or_else(|| anyhow!("LeRobot dataset frame indices are of an unexpected type"))?
.values();
let time_column = re_chunk::TimeColumn::new(None, timeline, times.clone());
let timelines = std::iter::once((*timeline.name(), time_column.clone())).collect();
let mut chunks = Vec::new();
for (feature_key, feature) in dataset
.metadata
.info
.features
.iter()
.filter(|(key, _)| !LEROBOT_DATASET_IGNORED_COLUMNS.contains(&key.as_str()))
{
match feature.dtype {
DType::Video => {
chunks.extend(load_episode_video(
dataset,
feature_key,
episode,
&timeline,
time_column.clone(),
)?);
}
DType::Image => {
let num_channels = feature.channel_dim();
match num_channels {
1 => chunks.extend(load_episode_depth_images(feature_key, &timeline, &data)?),
3 => chunks.extend(load_episode_images(feature_key, &timeline, &data)?),
_ => re_log::warn_once!(
"Unsupported channel count {num_channels} (shape: {:?}) for LeRobot dataset; Only 1- and 3-channel images are supported",
feature.shape
),
}
}
DType::Int64 if feature_key == "task_index" => {
chunks.extend(log_episode_task(dataset, &timeline, &data)?);
}
DType::Int16 | DType::Int64 | DType::Bool | DType::String => {
re_log::warn_once!(
"Loading LeRobot feature ({feature_key}) of dtype `{:?}` into Rerun is not yet implemented",
feature.dtype
);
}
DType::Float32 | DType::Float64 => {
chunks.extend(load_scalar(feature_key, feature, &timelines, &data)?);
}
}
}
Ok(chunks)
}
impl LeRobotDataset for LeRobotDatasetV2 {
fn iter_episode_indices(&self) -> impl std::iter::Iterator<Item = EpisodeIndex> {
self.metadata.iter_episode_indices()
}
fn load_episode_chunks(&self, episode: EpisodeIndex) -> Result<Vec<Chunk>, DataLoaderError> {
load_episode(self, episode)
}
}
fn log_episode_task(
dataset: &LeRobotDatasetV2,
timeline: &Timeline,
data: &RecordBatch,
) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
let task_indices = data
.column_by_name("task_index")
.and_then(|c| c.downcast_array_ref::<Int64Array>())
.with_context(|| "Failed to get task_index field from dataset!")?;
let mut chunk = Chunk::builder("task");
let mut row_id = RowId::new();
let mut time_int = TimeInt::ZERO;
for task_index in task_indices {
let Some(task) = task_index
.and_then(|i| usize::try_from(i).ok())
.and_then(|i| dataset.task_by_index(TaskIndex(i)))
else {
time_int = time_int.inc();
continue;
};
let timepoint = TimePoint::default().with(*timeline, time_int);
let text = TextDocument::new(task.task.clone());
chunk = chunk.with_archetype(row_id, timepoint, &text);
row_id = row_id.next();
time_int = time_int.inc();
}
Ok(std::iter::once(chunk.build()?))
}
fn load_episode_video(
dataset: &LeRobotDatasetV2,
observation: &str,
episode: EpisodeIndex,
timeline: &Timeline,
time_column: TimeColumn,
) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
let contents = dataset
.read_episode_video_contents(observation, episode)
.with_context(|| format!("Reading video contents for episode {episode:?} failed!"))?;
let video_asset = AssetVideo::new(contents.into_owned());
let entity_path = observation;
let video_frame_reference_chunk = match video_asset.read_frame_timestamps_nanos() {
Ok(frame_timestamps_nanos) => {
let frame_timestamps_nanos: arrow::buffer::ScalarBuffer<i64> =
frame_timestamps_nanos.into();
let video_timestamps = frame_timestamps_nanos
.iter()
.take(time_column.num_rows())
.copied()
.map(VideoTimestamp::from_nanos)
.collect::<Vec<_>>();
let video_frame_reference_column = VideoFrameReference::update_fields()
.with_many_timestamp(video_timestamps)
.columns_of_unit_batches()
.with_context(|| {
format!(
"Failed to create `VideoFrameReference` column for episode {episode:?}."
)
})?;
Some(Chunk::from_auto_row_ids(
re_chunk::ChunkId::new(),
entity_path.into(),
std::iter::once((*timeline.name(), time_column)).collect(),
video_frame_reference_column.collect(),
)?)
}
Err(err) => {
re_log::warn_once!(
"Failed to read frame timestamps from episode {episode:?} video: {err}"
);
None
}
};
let video_asset_chunk = Chunk::builder(entity_path)
.with_archetype(RowId::new(), TimePoint::default(), &video_asset)
.build()?;
if let Some(video_frame_reference_chunk) = video_frame_reference_chunk {
Ok(Either::Left(
[video_asset_chunk, video_frame_reference_chunk].into_iter(),
))
} else {
Ok(Either::Right(std::iter::once(video_asset_chunk)))
}
}