use crate::lerobot::common::{
LEROBOT_DATASET_IGNORED_COLUMNS, LeRobotDataset, load_and_stream_versioned,
load_episode_depth_images, load_episode_images, load_scalar,
};
use crate::lerobot::{
DType, EpisodeIndex, Feature, LeRobotDatasetSubtask, LeRobotDatasetTask, LeRobotError,
SubtaskIndex, TaskIndex,
};
use std::fs::File;
use std::io::BufReader;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use ahash::HashMap;
use anyhow::{Context as _, anyhow};
use arrow::array::{Float64Array, Int64Array, RecordBatch, StringArray};
use arrow::buffer::ScalarBuffer;
use arrow::compute::concat_batches;
use crossbeam::channel::Sender;
use parking_lot::RwLock;
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use re_chunk::{ArrowArray as _, ChunkId};
use re_video::VideoDataDescription;
use serde::{Deserialize, Serialize};
use re_arrow_util::ArrowArrayDowncastRef as _;
use re_chunk::{Chunk, RowId, TimeColumn, TimePoint, Timeline};
use re_log_types::ApplicationId;
use re_sdk_types::archetypes::{TextDocument, VideoStream};
use crate::{DataLoaderError, LoadedData};
pub struct LeRobotDatasetV3 {
pub path: PathBuf,
pub metadata: LeRobotDatasetMetadataV3,
video_cache: RwLock<VideoBlobCache>,
episode_data_cache: RwLock<HashMap<EpisodeIndex, Arc<RecordBatch>>>,
}
#[derive(Default)]
struct VideoBlobCache {
blobs: HashMap<PathBuf, Arc<[u8]>>,
remaining_refs: HashMap<PathBuf, usize>,
}
#[derive(Debug, Clone)]
struct EpisodeRowRange {
start_row: usize,
end_row: usize,
}
impl LeRobotDatasetV3 {
pub fn load_from_directory(path: impl AsRef<Path>) -> Result<Self, LeRobotError> {
let path = path.as_ref();
let metadatapath = path.join("meta");
let metadata = LeRobotDatasetMetadataV3::load_from_directory(&metadatapath)?;
let dataset = Self {
path: path.to_path_buf(),
metadata,
video_cache: RwLock::new(VideoBlobCache::default()),
episode_data_cache: RwLock::new(HashMap::default()),
};
dataset.load_all_episode_data_files()?;
dataset.init_video_ref_counts();
Ok(dataset)
}
fn load_all_episode_data_files(&self) -> Result<(), LeRobotError> {
re_tracing::profile_scope!("load_all_episode_data_files");
let mut files_to_episodes: HashMap<(usize, usize), Vec<EpisodeIndex>> = HashMap::default();
for episode in self.metadata.episodes.values() {
files_to_episodes
.entry((episode.data_chunk_index, episode.data_file_index))
.or_default()
.push(episode.episode_index);
}
for episodes_in_file in files_to_episodes.into_values() {
if let Some(first_episode) = episodes_in_file.first() {
let episode_data = self
.metadata
.get_episode_data(*first_episode)
.ok_or(LeRobotError::InvalidEpisodeIndex(*first_episode))?;
self.cache_episode_file(episode_data, &episodes_in_file)?;
}
}
Ok(())
}
fn init_video_ref_counts(&self) {
let video_features: Vec<&str> = self
.metadata
.info
.features
.iter()
.filter(|(_, feature)| feature.dtype == DType::Video)
.map(|(key, _)| key.as_str())
.collect();
if video_features.is_empty() {
return;
}
let mut cache = self.video_cache.write();
for episode_data in self.metadata.episodes.values() {
for feature_key in &video_features {
if let Ok(video_file) = self.metadata.info.video_path(feature_key, episode_data) {
let video_path = self.path.join(video_file);
*cache.remaining_refs.entry(video_path).or_insert(0) += 1;
}
}
}
re_log::debug!(
"Initialized video cache with {} unique video files across {} episodes",
cache.remaining_refs.len(),
self.metadata.episodes.len()
);
}
fn release_episode_videos(&self, episode: EpisodeIndex) {
let Some(episode_data) = self.metadata.get_episode_data(episode) else {
return;
};
let mut cache = self.video_cache.write();
for (feature_key, feature) in &self.metadata.info.features {
if feature.dtype != DType::Video {
continue;
}
if let Ok(video_file) = self.metadata.info.video_path(feature_key, episode_data) {
let video_path = self.path.join(video_file);
if let Some(count) = cache.remaining_refs.get_mut(&video_path) {
*count = count.saturating_sub(1);
if *count == 0 {
cache.blobs.remove(&video_path);
cache.remaining_refs.remove(&video_path);
}
}
}
}
}
fn cache_episode_file(
&self,
file_metadata: &LeRobotEpisodeData,
episodes_in_file: &[EpisodeIndex],
) -> Result<(), LeRobotError> {
if episodes_in_file.is_empty() {
return Ok(());
}
{
let cache = self.episode_data_cache.read();
if episodes_in_file.iter().all(|ep| cache.contains_key(ep)) {
return Ok(());
}
}
let episode_data_path = self.metadata.info.episode_data_path(file_metadata);
let episode_parquet_file = self.path.join(&episode_data_path);
let file = File::open(&episode_parquet_file)
.map_err(|err| LeRobotError::io(err, episode_parquet_file.clone()))?;
let reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?;
let batches: Vec<RecordBatch> = reader
.collect::<Result<_, _>>()
.map_err(LeRobotError::Arrow)?;
if batches.is_empty() {
return Ok(());
}
let schema = batches[0].schema();
let full_data = concat_batches(&schema, &batches).map_err(LeRobotError::Arrow)?;
let episode_indices = full_data
.column_by_name("episode_index")
.and_then(|c| c.downcast_array_ref::<Int64Array>())
.ok_or_else(|| {
LeRobotError::MissingDatasetInfo(
"`episode_index` column missing or wrong type".into(),
)
})?;
let row_ranges = Self::build_episode_row_index(episode_indices);
let mut cache = self.episode_data_cache.write();
for &ep_idx in episodes_in_file {
if cache.contains_key(&ep_idx) {
continue;
}
if let Some(range) = row_ranges.get(&ep_idx) {
let sliced = full_data.slice(range.start_row, range.end_row - range.start_row);
cache.insert(ep_idx, Arc::new(sliced));
}
}
Ok(())
}
fn build_episode_row_index(
episode_indices: &Int64Array,
) -> HashMap<EpisodeIndex, EpisodeRowRange> {
let mut ranges: HashMap<EpisodeIndex, EpisodeRowRange> = HashMap::default();
let mut current_episode: Option<i64> = None;
let mut current_start = 0;
for (i, ep_idx) in episode_indices.iter().enumerate() {
let ep_idx = ep_idx.unwrap_or(-1);
if Some(ep_idx) != current_episode {
if let Some(prev_ep) = current_episode
&& prev_ep >= 0
{
ranges.insert(
EpisodeIndex(prev_ep as usize),
EpisodeRowRange {
start_row: current_start,
end_row: i,
},
);
}
current_episode = Some(ep_idx);
current_start = i;
}
}
if let Some(ep_idx) = current_episode
&& ep_idx >= 0
{
ranges.insert(
EpisodeIndex(ep_idx as usize),
EpisodeRowRange {
start_row: current_start,
end_row: episode_indices.len(),
},
);
}
ranges
}
pub fn read_episode_data(&self, episode: EpisodeIndex) -> Result<RecordBatch, LeRobotError> {
let cache = self.episode_data_cache.read();
if let Some(cached_data) = cache.get(&episode) {
return Ok((**cached_data).clone());
}
Err(LeRobotError::EmptyEpisode(episode))
}
pub fn read_episode_video_contents(
&self,
observation_key: &str,
episode: EpisodeIndex,
) -> Result<Arc<[u8]>, LeRobotError> {
let episode_data = self
.metadata
.get_episode_data(episode)
.ok_or(LeRobotError::InvalidEpisodeIndex(episode))?;
let video_file = self
.metadata
.info
.video_path(observation_key, episode_data)?;
let videopath = self.path.join(video_file);
{
let cache = self.video_cache.read();
if let Some(cached_contents) = cache.blobs.get(&videopath) {
return Ok(Arc::clone(cached_contents));
}
}
let contents = {
re_tracing::profile_scope!("fs::read");
std::fs::read(&videopath).map_err(|err| LeRobotError::io(err, videopath.clone()))?
};
let mut cache = self.video_cache.write();
if let Some(cached_contents) = cache.blobs.get(&videopath) {
return Ok(Arc::clone(cached_contents));
}
let contents: Arc<[u8]> = Arc::from(contents.into_boxed_slice());
cache.blobs.insert(videopath, contents.clone());
Ok(contents)
}
pub fn task_by_index(&self, task: TaskIndex) -> Option<&LeRobotDatasetTask> {
self.metadata.tasks.tasks.get(&task)
}
pub fn subtask_by_index(&self, subtask: SubtaskIndex) -> Option<&LeRobotDatasetSubtask> {
self.metadata.subtasks.as_ref()?.subtasks.get(&subtask)
}
fn load_episode(&self, episode: EpisodeIndex) -> Result<Vec<Chunk>, DataLoaderError> {
let data = self
.read_episode_data(episode)
.map_err(|err| anyhow!("Reading data for episode {} failed: {err}", episode.0))?;
let frame_indices = data
.column_by_name("frame_index")
.ok_or_else(|| anyhow!("Failed to get frame index column in LeRobot dataset"))?
.clone();
let timeline = re_log_types::Timeline::new_sequence("frame_index");
let times: &arrow::buffer::ScalarBuffer<i64> = frame_indices
.downcast_array_ref::<Int64Array>()
.ok_or_else(|| anyhow!("LeRobot dataset frame indices are of an unexpected type"))?
.values();
let time_column = re_chunk::TimeColumn::new(None, timeline, times.clone());
let timelines = std::iter::once((*timeline.name(), time_column.clone())).collect();
let mut chunks = Vec::new();
for (feature_key, feature) in self
.metadata
.info
.features
.iter()
.filter(|(key, _)| !LEROBOT_DATASET_IGNORED_COLUMNS.contains(&key.as_str()))
{
match feature.dtype {
DType::Video => {
chunks.extend(self.load_episode_video(
feature_key,
episode,
&timeline,
&time_column,
)?);
}
DType::Image => {
let num_channels = feature.channel_dim();
match num_channels {
1 => {
chunks.extend(load_episode_depth_images(
feature_key,
&timeline,
&data,
)?);
}
3 => chunks.extend(load_episode_images(feature_key, &timeline, &data)?),
_ => re_log::warn_once!(
"Unsupported channel count {num_channels} (shape: {:?}) for LeRobot dataset; Only 1- and 3-channel images are supported",
feature.shape
),
}
}
DType::Int64 if feature_key == "task_index" => {
chunks.extend(self.log_episode_task(&timeline, &data)?);
}
DType::Int64 if feature_key == "subtask_index" => {
chunks.extend(self.log_episode_subtask(&timeline, &data)?);
}
DType::Int16 | DType::Int64 | DType::Bool | DType::String => {
re_log::warn_once!(
"Loading LeRobot feature ({feature_key}) of dtype `{:?}` into Rerun is not yet implemented",
feature.dtype
);
}
DType::Float32 | DType::Float64 => {
chunks.extend(load_scalar(feature_key, feature, &timelines, &data)?);
}
}
}
Ok(chunks)
}
fn log_episode_task(
&self,
timeline: &Timeline,
data: &RecordBatch,
) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
let task_indices = data
.column_by_name("task_index")
.and_then(|c| c.downcast_array_ref::<Int64Array>())
.with_context(|| "Failed to get task_index field from dataset!")?;
let mut chunk = Chunk::builder("task");
let mut row_id = RowId::new();
for (frame_idx, task_index_opt) in task_indices.iter().enumerate() {
let Some(task_idx) = task_index_opt
.and_then(|i| usize::try_from(i).ok())
.map(TaskIndex)
else {
continue;
};
if let Some(task) = self.task_by_index(task_idx) {
let frame_idx = i64::try_from(frame_idx)
.map_err(|err| anyhow!("Frame index exceeds max value: {err}"))?;
let timepoint = TimePoint::default().with(*timeline, frame_idx);
let text = TextDocument::new(task.task.clone());
chunk = chunk.with_archetype(row_id, timepoint, &text);
row_id = row_id.next();
}
}
Ok(std::iter::once(chunk.build()?))
}
fn log_episode_subtask(
&self,
timeline: &Timeline,
data: &RecordBatch,
) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
let subtask_indices = data
.column_by_name("subtask_index")
.and_then(|c| c.downcast_array_ref::<Int64Array>())
.with_context(|| "Failed to get subtask_index field from dataset!")?;
let mut chunk = Chunk::builder("subtask");
let mut row_id = RowId::new();
for (frame_idx, subtask_index_opt) in subtask_indices.iter().enumerate() {
let Some(subtask_idx) = subtask_index_opt
.and_then(|i| usize::try_from(i).ok())
.map(SubtaskIndex)
else {
continue;
};
if let Some(subtask) = self.subtask_by_index(subtask_idx) {
let frame_idx = i64::try_from(frame_idx)
.map_err(|err| anyhow!("Frame index exceeds max value: {err}"))?;
let timepoint = TimePoint::default().with(*timeline, frame_idx);
let text = TextDocument::new(subtask.subtask.clone());
chunk = chunk.with_archetype(row_id, timepoint, &text);
row_id = row_id.next();
}
}
Ok(std::iter::once(chunk.build()?))
}
fn get_feature_timestamps(&self, episode: EpisodeIndex, observation: &str) -> (f64, f64) {
self.metadata
.get_episode_data(episode)
.and_then(|ep_data| ep_data.feature_files.get(observation))
.map(|file_meta| {
(
file_meta.from_timestamp.unwrap_or(0.0),
file_meta.to_timestamp.unwrap_or(0.0),
)
})
.unwrap_or((0.0, 0.0))
}
fn load_episode_video(
&self,
observation: &str,
episode: EpisodeIndex,
timeline: &Timeline,
time_column: &TimeColumn,
) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
let contents = self
.read_episode_video_contents(observation, episode)
.with_context(|| format!("Reading video contents for episode {episode:?} failed!"))?;
let entity_path = observation;
let video_bytes: &[u8] = &contents;
let video = VideoDataDescription::load_from_bytes(
video_bytes,
"video/mp4",
observation,
re_log_types::external::re_tuid::Tuid::new(),
)
.map_err(|err| {
anyhow!("Failed to read video data description for feature '{observation}': {err}")
})?;
let (start_time, end_time) = self.get_feature_timestamps(episode, observation);
if video.samples.is_empty() {
return Err(DataLoaderError::Other(anyhow!(
"Video feature '{observation}' for episode {episode:?} did not contain any samples"
)));
}
let timescale = video.timescale.ok_or_else(|| {
anyhow!("Video feature '{observation}' is missing timescale information")
})?;
let start_video_time = re_video::Time::from_secs(start_time, timescale);
let end_video_time = re_video::Time::from_secs(end_time, timescale);
let start_keyframe = video
.presentation_time_keyframe_index(start_video_time)
.unwrap_or(0);
let end_keyframe = video
.presentation_time_keyframe_index(end_video_time)
.or_else(|| video.keyframe_indices.len().checked_sub(1))
.ok_or(DataLoaderError::Other(anyhow!("No keyframes in the video")))?;
let start_sample = video
.gop_sample_range_for_keyframe(start_keyframe)
.ok_or(DataLoaderError::Other(anyhow!("Bad video data")))?
.start;
let end_sample = video
.gop_sample_range_for_keyframe(end_keyframe)
.ok_or(DataLoaderError::Other(anyhow!("Bad video data")))?
.end;
let sample_range = start_sample..end_sample;
let mut samples = Vec::with_capacity(sample_range.len());
for (sample_idx, sample_meta) in video.samples.iter_index_range_clamped(&sample_range) {
let Some(sample_meta) = sample_meta.sample() else {
continue;
};
if sample_meta.presentation_timestamp < start_video_time
|| sample_meta.presentation_timestamp >= end_video_time
{
continue;
}
let chunk = sample_meta
.get(&|_| video_bytes, sample_idx)
.ok_or_else(|| {
anyhow!("Sample {sample_idx} out of bounds for feature '{observation}'")
})?;
let sample_bytes = video
.sample_data_in_stream_format(&chunk)
.with_context(|| {
format!(
"Failed to convert sample {sample_idx} for feature '{observation}' to the expected codec stream format"
)
})?;
samples.push((sample_meta.clone(), sample_bytes));
}
let (samples_meta, samples): (Vec<_>, Vec<_>) = samples.into_iter().unzip();
let samples_column = VideoStream::update_fields()
.with_many_sample(samples)
.columns_of_unit_batches()
.with_context(|| "Failed to create VideoStream")?;
let num_samples = samples_meta.len();
let frame_count = time_column.num_rows();
let uniform_times: Vec<i64> = (0..num_samples)
.map(|i| i64::try_from((i * frame_count) / num_samples).unwrap_or_default())
.collect();
let uniform_time_column = TimeColumn::new(
Some(true), *timeline,
ScalarBuffer::from(uniform_times),
);
let codec = re_sdk_types::components::VideoCodec::try_from(video.codec).map_err(|err| {
anyhow!(
"Unsupported video codec {:?} for feature: '{observation}': {err}",
video.codec
)
})?;
let codec_chunk = Chunk::builder(entity_path)
.with_archetype(
RowId::new(),
TimePoint::default(),
&VideoStream::update_fields().with_codec(codec),
)
.build()?;
let samples_chunk = Chunk::from_auto_row_ids(
ChunkId::new(),
entity_path.into(),
std::iter::once((timeline.name().to_owned(), uniform_time_column)).collect(),
samples_column.collect(),
)?;
Ok([samples_chunk, codec_chunk].into_iter())
}
}
impl LeRobotDataset for LeRobotDatasetV3 {
fn iter_episode_indices(&self) -> impl std::iter::Iterator<Item = EpisodeIndex> {
self.metadata.iter_episode_indices()
}
fn load_episode_chunks(&self, episode: EpisodeIndex) -> Result<Vec<Chunk>, DataLoaderError> {
let result = self.load_episode(episode);
self.release_episode_videos(episode);
result
}
}
pub struct LeRobotDatasetMetadataV3 {
pub info: LeRobotDatasetInfoV3,
pub tasks: LeRobotDatasetV3Tasks,
pub subtasks: Option<LeRobotDatasetV3Subtasks>,
pub episodes: HashMap<EpisodeIndex, LeRobotEpisodeData>,
}
impl LeRobotDatasetMetadataV3 {
pub fn episode_count(&self) -> usize {
self.episodes.len()
}
pub fn get_episode_data(&self, episode: EpisodeIndex) -> Option<&LeRobotEpisodeData> {
self.episodes.get(&episode)
}
pub fn iter_episode_indices(&self) -> impl Iterator<Item = EpisodeIndex> + '_ {
self.episodes.values().map(|episode| episode.episode_index)
}
pub fn load_from_directory(metadir: impl AsRef<Path>) -> Result<Self, LeRobotError> {
let metadir = metadir.as_ref();
let episode_data = LeRobotEpisodeData::load_from_directory(metadir.join("episodes"))?;
let info = LeRobotDatasetInfoV3::load_from_json_file(metadir.join("info.json"))?;
let tasks = LeRobotDatasetV3Tasks::load_from_parquet_file(metadir.join("tasks.parquet"))?;
let subtasks_path = metadir.join("subtasks.parquet");
let subtasks = if subtasks_path.is_file() {
Some(LeRobotDatasetV3Subtasks::load_from_parquet_file(
subtasks_path,
)?)
} else {
None
};
let episodes = episode_data
.into_iter()
.map(|ep| (ep.episode_index, ep))
.collect();
Ok(Self {
info,
tasks,
subtasks,
episodes,
})
}
}
#[derive(Debug, Clone)]
pub struct FeatureFileMetadata {
pub chunk_index: usize,
pub file_index: usize,
pub from_timestamp: Option<f64>,
pub to_timestamp: Option<f64>,
}
#[derive(Debug, Clone)]
pub struct LeRobotEpisodeData {
pub episode_index: EpisodeIndex,
pub data_chunk_index: usize,
pub data_file_index: usize,
pub feature_files: HashMap<String, FeatureFileMetadata>,
}
impl LeRobotEpisodeData {
fn load_from_directory(metadir: impl AsRef<Path>) -> Result<Vec<Self>, LeRobotError> {
let metadir = metadir.as_ref();
let mut all_episodes = vec![];
for entry in std::fs::read_dir(metadir).map_err(|err| LeRobotError::io(err, metadir))? {
let entry = entry.map_err(|err| LeRobotError::io(err, metadir))?;
let path = entry.path();
let path = path.as_path();
re_log::trace!("Loading episode metadata from: {path:?}");
if path.is_dir() {
for chunk_entry in
std::fs::read_dir(path).map_err(|err| LeRobotError::io(err, path))?
{
let chunk_entry = chunk_entry.map_err(|err| LeRobotError::io(err, path))?;
let chunk_path = chunk_entry.path();
if chunk_path.is_file() {
let chunk_parquet = ParquetRecordBatchReaderBuilder::try_new(
File::open(&chunk_path)
.map_err(|err| LeRobotError::io(err, chunk_path.clone()))?,
)?
.build()?;
let episode_data: Vec<_> = chunk_parquet
.filter_map(|batch| {
let batch = batch.ok()?;
let episode_index = batch
.column_by_name("episode_index")?
.as_any()
.downcast_ref::<Int64Array>()?;
let data_chunk_index = batch
.column_by_name("data/chunk_index")?
.as_any()
.downcast_ref::<Int64Array>()?;
let data_file_index = batch
.column_by_name("data/file_index")?
.as_any()
.downcast_ref::<Int64Array>()?;
Some(Self::collect_episode_data(
&batch,
episode_index,
data_chunk_index,
data_file_index,
))
})
.flatten()
.collect();
all_episodes.extend(episode_data);
}
}
}
}
Ok(all_episodes)
}
fn collect_episode_data(
batch: &RecordBatch,
episode_index: &Int64Array,
data_chunk_index: &Int64Array,
data_file_index: &Int64Array,
) -> Vec<Self> {
let feature_metadata = Self::parse_feature_metadata(batch);
let mut episodes = Vec::with_capacity(batch.num_rows());
for i in 0..batch.num_rows() {
let feature_files = feature_metadata
.iter()
.filter_map(|(feature_name, metadata)| {
let chunk_index = metadata.chunk_index.as_ref()?;
let file_index = metadata.file_index.as_ref()?;
Some((
feature_name.to_string(),
FeatureFileMetadata {
chunk_index: chunk_index.value(i) as usize,
file_index: file_index.value(i) as usize,
from_timestamp: metadata.from_timestamp.as_ref().and_then(
|timestamps| timestamps.is_valid(i).then(|| timestamps.value(i)),
),
to_timestamp: metadata.to_timestamp.as_ref().and_then(|timestamps| {
timestamps.is_valid(i).then(|| timestamps.value(i))
}),
},
))
})
.collect();
episodes.push(Self {
episode_index: EpisodeIndex(episode_index.value(i) as usize),
data_chunk_index: data_chunk_index.value(i) as usize,
data_file_index: data_file_index.value(i) as usize,
feature_files,
});
}
episodes
}
fn parse_feature_metadata(batch: &RecordBatch) -> HashMap<Arc<str>, FeatureMetadataColumns> {
let mut features: HashMap<Arc<str>, FeatureMetadataColumns> = HashMap::default();
let schema = batch.schema();
for field in schema.fields() {
let column_name = field.name();
if let Some(rest) = column_name.strip_prefix("videos/")
&& let Some((feature_name, field_name)) = rest.rsplit_once('/')
{
let entry = features.entry(Arc::from(feature_name)).or_default();
match field_name {
"chunk_index" => {
if let Some(col) = batch
.column_by_name(column_name)
.and_then(|c| c.downcast_array_ref::<Int64Array>())
{
entry.chunk_index = Some(col.clone());
}
}
"file_index" => {
if let Some(col) = batch
.column_by_name(column_name)
.and_then(|c| c.downcast_array_ref::<Int64Array>())
{
entry.file_index = Some(col.clone());
}
}
"from_timestamp" => {
if let Some(col) = batch
.column_by_name(column_name)
.and_then(|c| c.downcast_array_ref::<Float64Array>())
{
entry.from_timestamp = Some(col.clone());
}
}
"to_timestamp" => {
if let Some(col) = batch
.column_by_name(column_name)
.and_then(|c| c.downcast_array_ref::<Float64Array>())
{
entry.to_timestamp = Some(col.clone());
}
}
_ => {} }
}
}
features
}
}
#[derive(Default)]
struct FeatureMetadataColumns {
chunk_index: Option<Int64Array>,
file_index: Option<Int64Array>,
from_timestamp: Option<Float64Array>,
to_timestamp: Option<Float64Array>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct LeRobotDatasetInfoV3 {
pub robot_type: Option<String>,
pub codebase_version: String,
pub total_episodes: usize,
pub total_frames: usize,
pub total_tasks: usize,
pub chunks_size: usize,
pub data_path: String,
pub video_path: Option<String>,
pub image_path: Option<String>,
pub fps: usize,
pub features: HashMap<String, Feature>,
}
impl LeRobotDatasetInfoV3 {
pub fn load_from_json_file(filepath: impl AsRef<Path>) -> Result<Self, LeRobotError> {
let info_file = File::open(filepath.as_ref())
.map_err(|err| LeRobotError::io(err, filepath.as_ref()))?;
let reader = BufReader::new(info_file);
serde_json::from_reader(reader).map_err(|err| err.into())
}
pub fn feature(&self, feature_key: &str) -> Option<&Feature> {
self.features.get(feature_key)
}
pub fn episode_data_path(&self, episode_data: &LeRobotEpisodeData) -> PathBuf {
self.data_path
.replace(
"{chunk_index:03d}",
&format!("{:03}", episode_data.data_chunk_index),
)
.replace(
"{file_index:03d}",
&format!("{:03}", episode_data.data_file_index),
)
.into()
}
pub fn video_path(
&self,
feature_key: &str,
episode_data: &LeRobotEpisodeData,
) -> Result<PathBuf, LeRobotError> {
let feature = self
.feature(feature_key)
.ok_or(LeRobotError::InvalidFeatureKey(feature_key.to_owned()))?;
if feature.dtype != DType::Video {
return Err(LeRobotError::InvalidFeatureDtype {
key: feature_key.to_owned(),
expected: DType::Video,
actual: feature.dtype,
});
}
let video_path_template = self
.video_path
.as_ref()
.ok_or_else(|| LeRobotError::MissingDatasetInfo("video_path".to_owned()))?;
if let Some(file_metadata) = episode_data.feature_files.get(feature_key) {
Ok(video_path_template
.replace("{video_key}", feature_key)
.replace(
"{chunk_index:03d}",
&format!("{:03}", file_metadata.chunk_index),
)
.replace(
"{file_index:03d}",
&format!("{:03}", file_metadata.file_index),
)
.into())
} else {
Ok(video_path_template
.replace(
"{episode_chunk:03d}",
&format!("{:03}", episode_data.data_chunk_index),
)
.replace(
"{episode_index:06d}",
&format!("{:06}", episode_data.episode_index.0),
)
.replace("{video_key}", feature_key)
.into())
}
}
}
pub struct LeRobotDatasetV3Tasks {
pub tasks: HashMap<TaskIndex, LeRobotDatasetTask>,
}
impl LeRobotDatasetV3Tasks {
pub fn load_from_parquet_file(filepath: impl AsRef<Path>) -> Result<Self, LeRobotError> {
let filepath = filepath.as_ref().to_owned();
let parquet_data =
File::open(&filepath).map_err(|err| LeRobotError::io(err, filepath.clone()))?;
let reader = ParquetRecordBatchReaderBuilder::try_new(parquet_data)?.build()?;
let tasks = reader
.filter_map(|record_batch| {
let b = record_batch.ok()?;
let task_index_col = b.column_by_name("task_index")?;
let task_col = b.column_by_name("__index_level_0__")?;
let task_index = task_index_col.as_any().downcast_ref::<Int64Array>()?;
let task = task_col.as_any().downcast_ref::<StringArray>()?;
let num_rows = b.num_rows();
Some(
(0..num_rows)
.map(move |i| {
(
TaskIndex(task_index.value(i) as usize),
LeRobotDatasetTask {
index: TaskIndex(task_index.value(i) as usize),
task: task.value(i).to_owned(),
},
)
})
.collect(),
)
})
.flat_map(|e: Vec<(TaskIndex, LeRobotDatasetTask)>| e)
.collect::<HashMap<_, _>>();
Ok(Self { tasks })
}
}
pub struct LeRobotDatasetV3Subtasks {
pub subtasks: HashMap<SubtaskIndex, LeRobotDatasetSubtask>,
}
impl LeRobotDatasetV3Subtasks {
pub fn load_from_parquet_file(filepath: impl AsRef<Path>) -> Result<Self, LeRobotError> {
let filepath = filepath.as_ref().to_owned();
let parquet_data =
File::open(&filepath).map_err(|err| LeRobotError::io(err, filepath.clone()))?;
let reader = ParquetRecordBatchReaderBuilder::try_new(parquet_data)?.build()?;
let subtasks = reader
.filter_map(|record_batch| {
let b = record_batch.ok()?;
let subtask_index_col = b.column_by_name("subtask_index")?;
let subtask_col = b.column_by_name("subtask")?;
let subtask_index = subtask_index_col.as_any().downcast_ref::<Int64Array>()?;
let subtask = subtask_col.as_any().downcast_ref::<StringArray>()?;
let num_rows = b.num_rows();
Some(
(0..num_rows)
.map(move |i| {
(
SubtaskIndex(subtask_index.value(i) as usize),
LeRobotDatasetSubtask {
index: SubtaskIndex(subtask_index.value(i) as usize),
subtask: subtask.value(i).to_owned(),
},
)
})
.collect(),
)
})
.flat_map(|e: Vec<(SubtaskIndex, LeRobotDatasetSubtask)>| e)
.collect::<HashMap<_, _>>();
Ok(Self { subtasks })
}
}
pub fn load_and_stream(
dataset: &LeRobotDatasetV3,
application_id: &ApplicationId,
tx: &Sender<LoadedData>,
loader_name: &str,
) {
load_and_stream_versioned(dataset, application_id, tx, loader_name);
}