re_data_loader/
loader_lerobot.rs

1use std::{
2    sync::{Arc, mpsc::Sender},
3    thread,
4};
5
6use anyhow::{Context as _, anyhow};
7use arrow::{
8    array::{
9        ArrayRef, BinaryArray, FixedSizeListArray, Int64Array, RecordBatch, StringArray,
10        StructArray,
11    },
12    compute::cast,
13    datatypes::{DataType, Field},
14};
15use itertools::Either;
16
17use re_arrow_util::ArrowArrayDowncastRef as _;
18use re_chunk::{
19    ArrowArray, Chunk, ChunkId, EntityPath, RowId, TimeColumn, TimeInt, TimePoint, Timeline,
20    TimelineName, external::nohash_hasher::IntMap,
21};
22use re_log_types::{ApplicationId, StoreId};
23use re_types::{
24    archetypes::{
25        self, AssetVideo, DepthImage, EncodedImage, Scalars, TextDocument, VideoFrameReference,
26    },
27    components::VideoTimestamp,
28};
29
30use crate::lerobot::{
31    DType, EpisodeIndex, Feature, LeRobotDataset, TaskIndex, is_lerobot_dataset,
32    is_v1_lerobot_dataset,
33};
34use crate::load_file::prepare_store_info;
35use crate::{DataLoader, DataLoaderError, LoadedData};
36
37/// Columns in the `LeRobot` dataset schema that we do not visualize in the viewer, and thus ignore.
38const LEROBOT_DATASET_IGNORED_COLUMNS: &[&str] =
39    &["episode_index", "index", "frame_index", "timestamp"];
40
41/// Only supports `LeRobot` datasets that are in a supported version format.
42/// Datasets from unsupported versions won't load.
43const LEROBOT_DATASET_SUPPORTED_VERSIONS: &[&str] = &["v2.0", "v2.1"];
44
45/// A [`DataLoader`] for `LeRobot` datasets.
46///
47/// An example dataset which can be loaded can be found on Hugging Face: [lerobot/pusht_image](https://huggingface.co/datasets/lerobot/pusht_image)
48pub struct LeRobotDatasetLoader;
49
50impl DataLoader for LeRobotDatasetLoader {
51    fn name(&self) -> String {
52        "LeRobotDatasetLoader".into()
53    }
54
55    fn load_from_path(
56        &self,
57        settings: &crate::DataLoaderSettings,
58        filepath: std::path::PathBuf,
59        tx: Sender<LoadedData>,
60    ) -> Result<(), DataLoaderError> {
61        if !is_lerobot_dataset(&filepath) {
62            return Err(DataLoaderError::Incompatible(filepath));
63        }
64
65        if is_v1_lerobot_dataset(&filepath) {
66            re_log::error!("LeRobot 'v1.x' dataset format is unsupported.");
67            return Ok(());
68        }
69
70        let dataset = LeRobotDataset::load_from_directory(&filepath)
71            .map_err(|err| anyhow!("Loading LeRobot dataset failed: {err}"))?;
72
73        if !LEROBOT_DATASET_SUPPORTED_VERSIONS
74            .contains(&dataset.metadata.info.codebase_version.as_str())
75        {
76            re_log::error!(
77                "LeRobot '{}' dataset format is unsupported.",
78                dataset.metadata.info.codebase_version
79            );
80            return Ok(());
81        }
82
83        let application_id = settings
84            .application_id
85            .clone()
86            .unwrap_or(filepath.display().to_string().into());
87
88        // NOTE(1): `spawn` is fine, this whole function is native-only.
89        // NOTE(2): this must spawned on a dedicated thread to avoid a deadlock!
90        // `load` will spawn a bunch of loaders on the common rayon thread pool and wait for
91        // their response via channels: we cannot be waiting for these responses on the
92        // common rayon thread pool.
93        thread::Builder::new()
94            .name(format!("load_and_stream({filepath:?}"))
95            .spawn({
96                move || {
97                    re_log::info!(
98                        "Loading LeRobot dataset from {:?}, with {} episode(s)",
99                        dataset.path,
100                        dataset.metadata.episode_count(),
101                    );
102                    load_and_stream(&dataset, &application_id, &tx);
103                }
104            })
105            .with_context(|| {
106                format!("Failed to spawn IO thread to load LeRobot dataset {filepath:?} ")
107            })?;
108
109        Ok(())
110    }
111
112    fn load_from_file_contents(
113        &self,
114        _settings: &crate::DataLoaderSettings,
115        filepath: std::path::PathBuf,
116        _contents: std::borrow::Cow<'_, [u8]>,
117        _tx: Sender<LoadedData>,
118    ) -> Result<(), DataLoaderError> {
119        Err(DataLoaderError::Incompatible(filepath))
120    }
121}
122
123fn load_and_stream(
124    dataset: &LeRobotDataset,
125    application_id: &ApplicationId,
126    tx: &Sender<crate::LoadedData>,
127) {
128    // set up all recordings
129    let episodes = prepare_episode_chunks(dataset, application_id, tx);
130
131    for (episode, store_id) in &episodes {
132        // log episode data to its respective recording
133        match load_episode(dataset, *episode) {
134            Ok(chunks) => {
135                let recording_info = re_types::archetypes::RecordingInfo::new()
136                    .with_name(format!("Episode {}", episode.0));
137
138                let Ok(initial) = Chunk::builder(EntityPath::properties())
139                    .with_archetype(RowId::new(), TimePoint::STATIC, &recording_info)
140                    .build()
141                else {
142                    re_log::error!(
143                        "Failed to build recording properties chunk for episode {}",
144                        episode.0
145                    );
146                    return;
147                };
148
149                for chunk in std::iter::once(initial).chain(chunks.into_iter()) {
150                    let data = LoadedData::Chunk(
151                        LeRobotDatasetLoader::name(&LeRobotDatasetLoader),
152                        store_id.clone(),
153                        chunk,
154                    );
155
156                    if tx.send(data).is_err() {
157                        break; // The other end has decided to hang up, not our problem.
158                    }
159                }
160            }
161            Err(err) => {
162                re_log::warn!(
163                    "Failed to load episode {} from LeRobot dataset: {err}",
164                    episode.0
165                );
166            }
167        }
168    }
169}
170
171/// Prepare the viewer for all episodes, by sending out a [`SetStoreInfo`](`re_log_types::SetStoreInfo`)
172/// [`LogMsg`](`re_log_types::LogMsg`) for each episode.
173fn prepare_episode_chunks(
174    dataset: &LeRobotDataset,
175    application_id: &ApplicationId,
176    tx: &Sender<crate::LoadedData>,
177) -> Vec<(EpisodeIndex, StoreId)> {
178    let mut store_ids = vec![];
179
180    for episode_index in dataset.metadata.episodes.keys() {
181        let episode = *episode_index;
182
183        let store_id = StoreId::recording(application_id.clone(), format!("episode_{}", episode.0));
184        let set_store_info = LoadedData::LogMsg(
185            LeRobotDatasetLoader::name(&LeRobotDatasetLoader),
186            prepare_store_info(&store_id, re_log_types::FileSource::Sdk),
187        );
188
189        if tx.send(set_store_info).is_err() {
190            break;
191        }
192
193        store_ids.push((episode, store_id));
194    }
195
196    store_ids
197}
198
199/// Loads a single episode from a `LeRobot` dataset and converts it into a collection of Rerun chunks.
200///
201/// This function processes an episode from the dataset by extracting the relevant data columns and
202/// converting them into appropriate Rerun data structures. It handles different types of data
203/// (videos, images, scalar values, etc.) based on their data type specifications in the dataset metadata.
204pub fn load_episode(
205    dataset: &LeRobotDataset,
206    episode: EpisodeIndex,
207) -> Result<Vec<Chunk>, DataLoaderError> {
208    let data = dataset
209        .read_episode_data(episode)
210        .map_err(|err| anyhow!("Reading data for episode {} failed: {err}", episode.0))?;
211
212    let frame_indices = data
213        .column_by_name("frame_index")
214        .ok_or_else(|| anyhow!("Failed to get frame index column in LeRobot dataset"))?
215        .clone();
216
217    let timeline = re_log_types::Timeline::new_sequence("frame_index");
218    let times: &arrow::buffer::ScalarBuffer<i64> = frame_indices
219        .downcast_array_ref::<Int64Array>()
220        .ok_or_else(|| anyhow!("LeRobot dataset frame indices are of an unexpected type"))?
221        .values();
222
223    let time_column = re_chunk::TimeColumn::new(None, timeline, times.clone());
224    let timelines = std::iter::once((*timeline.name(), time_column.clone())).collect();
225
226    let mut chunks = Vec::new();
227
228    for (feature_key, feature) in dataset
229        .metadata
230        .info
231        .features
232        .iter()
233        .filter(|(key, _)| !LEROBOT_DATASET_IGNORED_COLUMNS.contains(&key.as_str()))
234    {
235        match feature.dtype {
236            DType::Video => {
237                chunks.extend(load_episode_video(
238                    dataset,
239                    feature_key,
240                    episode,
241                    &timeline,
242                    time_column.clone(),
243                )?);
244            }
245
246            DType::Image => {
247                let num_channels = feature.channel_dim();
248
249                match num_channels {
250                    1 => chunks.extend(load_episode_depth_images(feature_key, &timeline, &data)?),
251                    3 => chunks.extend(load_episode_images(feature_key, &timeline, &data)?),
252                    _ => re_log::warn_once!(
253                        "Unsupported channel count {num_channels} (shape: {:?}) for LeRobot dataset; Only 1- and 3-channel images are supported",
254                        feature.shape
255                    ),
256                }
257            }
258            DType::Int64 if feature_key == "task_index" => {
259                // special case int64 task_index columns
260                // this always refers to the task description in the dataset metadata.
261                chunks.extend(log_episode_task(dataset, &timeline, &data)?);
262            }
263            DType::Int16 | DType::Int64 | DType::Bool | DType::String => {
264                re_log::warn_once!(
265                    "Loading LeRobot feature ({feature_key}) of dtype `{:?}` into Rerun is not yet implemented",
266                    feature.dtype
267                );
268            }
269            DType::Float32 | DType::Float64 => {
270                chunks.extend(load_scalar(feature_key, feature, &timelines, &data)?);
271            }
272        }
273    }
274
275    Ok(chunks)
276}
277
278fn log_episode_task(
279    dataset: &LeRobotDataset,
280    timeline: &Timeline,
281    data: &RecordBatch,
282) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
283    let task_indices = data
284        .column_by_name("task_index")
285        .and_then(|c| c.downcast_array_ref::<Int64Array>())
286        .with_context(|| "Failed to get task_index field from dataset!")?;
287
288    let mut chunk = Chunk::builder("task");
289    let mut row_id = RowId::new();
290    let mut time_int = TimeInt::ZERO;
291
292    for task_index in task_indices {
293        let Some(task) = task_index
294            .and_then(|i| usize::try_from(i).ok())
295            .and_then(|i| dataset.task_by_index(TaskIndex(i)))
296        else {
297            // if there is no valid task for the current frame index, we skip it.
298            time_int = time_int.inc();
299            continue;
300        };
301
302        let timepoint = TimePoint::default().with(*timeline, time_int);
303        let text = TextDocument::new(task.task.clone());
304        chunk = chunk.with_archetype(row_id, timepoint, &text);
305
306        row_id = row_id.next();
307        time_int = time_int.inc();
308    }
309
310    Ok(std::iter::once(chunk.build()?))
311}
312
313fn load_episode_images(
314    observation: &str,
315    timeline: &Timeline,
316    data: &RecordBatch,
317) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
318    let image_bytes = data
319        .column_by_name(observation)
320        .and_then(|c| c.downcast_array_ref::<StructArray>())
321        .and_then(|a| a.column_by_name("bytes"))
322        .and_then(|a| a.downcast_array_ref::<BinaryArray>())
323        .with_context(|| format!("Failed to get binary data from image feature: {observation}"))?;
324
325    let mut chunk = Chunk::builder(observation);
326    let mut row_id = RowId::new();
327
328    for frame_idx in 0..image_bytes.len() {
329        let img_buffer = image_bytes.value(frame_idx);
330        let encoded_image = EncodedImage::from_file_contents(img_buffer.to_owned());
331        let timepoint = TimePoint::default().with(*timeline, frame_idx as i64);
332        chunk = chunk.with_archetype(row_id, timepoint, &encoded_image);
333
334        row_id = row_id.next();
335    }
336
337    Ok(std::iter::once(chunk.build().with_context(|| {
338        format!("Failed to build image chunk for image: {observation}")
339    })?))
340}
341
342fn load_episode_depth_images(
343    observation: &str,
344    timeline: &Timeline,
345    data: &RecordBatch,
346) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
347    let image_bytes = data
348        .column_by_name(observation)
349        .and_then(|c| c.downcast_array_ref::<StructArray>())
350        .and_then(|a| a.column_by_name("bytes"))
351        .and_then(|a| a.downcast_array_ref::<BinaryArray>())
352        .with_context(|| format!("Failed to get binary data from image feature: {observation}"))?;
353
354    let mut chunk = Chunk::builder(observation);
355    let mut row_id = RowId::new();
356
357    for frame_idx in 0..image_bytes.len() {
358        let img_buffer = image_bytes.value(frame_idx);
359        let depth_image = DepthImage::from_file_contents(img_buffer.to_owned())
360            .map_err(|err| anyhow!("Failed to decode image: {err}"))?;
361
362        let timepoint = TimePoint::default().with(*timeline, frame_idx as i64);
363        chunk = chunk.with_archetype(row_id, timepoint, &depth_image);
364
365        row_id = row_id.next();
366    }
367
368    Ok(std::iter::once(chunk.build().with_context(|| {
369        format!("Failed to build image chunk for image: {observation}")
370    })?))
371}
372
373fn load_episode_video(
374    dataset: &LeRobotDataset,
375    observation: &str,
376    episode: EpisodeIndex,
377    timeline: &Timeline,
378    time_column: TimeColumn,
379) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
380    let contents = dataset
381        .read_episode_video_contents(observation, episode)
382        .with_context(|| format!("Reading video contents for episode {episode:?} failed!"))?;
383
384    let video_asset = AssetVideo::new(contents.into_owned());
385    let entity_path = observation;
386
387    let video_frame_reference_chunk = match video_asset.read_frame_timestamps_nanos() {
388        Ok(frame_timestamps_nanos) => {
389            let frame_timestamps_nanos: arrow::buffer::ScalarBuffer<i64> =
390                frame_timestamps_nanos.into();
391
392            let video_timestamps = frame_timestamps_nanos
393                .iter()
394                .take(time_column.num_rows())
395                .copied()
396                .map(VideoTimestamp::from_nanos)
397                .collect::<Vec<_>>();
398
399            let video_frame_reference_column = VideoFrameReference::update_fields()
400                .with_many_timestamp(video_timestamps)
401                .columns_of_unit_batches()
402                .with_context(|| {
403                    format!(
404                        "Failed to create `VideoFrameReference` column for episode {episode:?}."
405                    )
406                })?;
407
408            Some(Chunk::from_auto_row_ids(
409                re_chunk::ChunkId::new(),
410                entity_path.into(),
411                std::iter::once((*timeline.name(), time_column)).collect(),
412                video_frame_reference_column.collect(),
413            )?)
414        }
415        Err(err) => {
416            re_log::warn_once!(
417                "Failed to read frame timestamps from episode {episode:?} video: {err}"
418            );
419            None
420        }
421    };
422
423    // Put video asset into its own (static) chunk since it can be fairly large.
424    let video_asset_chunk = Chunk::builder(entity_path)
425        .with_archetype(RowId::new(), TimePoint::default(), &video_asset)
426        .build()?;
427
428    if let Some(video_frame_reference_chunk) = video_frame_reference_chunk {
429        Ok(Either::Left(
430            [video_asset_chunk, video_frame_reference_chunk].into_iter(),
431        ))
432    } else {
433        // Still log the video asset, but don't include video frames.
434        Ok(Either::Right(std::iter::once(video_asset_chunk)))
435    }
436}
437
438/// Helper type similar to [`Either`], but with 3 variants.
439enum ScalarChunkIterator {
440    Empty(std::iter::Empty<Chunk>),
441    Batch(Box<dyn ExactSizeIterator<Item = Chunk>>),
442
443    // Boxed, because `Chunk` is huge, and by extension so is `std::iter::Once<Chunk>`.
444    Single(Box<std::iter::Once<Chunk>>),
445}
446
447impl Iterator for ScalarChunkIterator {
448    type Item = Chunk;
449
450    fn next(&mut self) -> Option<Self::Item> {
451        match self {
452            Self::Empty(iter) => iter.next(),
453            Self::Batch(iter) => iter.next(),
454            Self::Single(iter) => iter.next(),
455        }
456    }
457}
458
459impl ExactSizeIterator for ScalarChunkIterator {}
460
461fn load_scalar(
462    feature_key: &str,
463    feature: &Feature,
464    timelines: &IntMap<TimelineName, TimeColumn>,
465    data: &RecordBatch,
466) -> Result<ScalarChunkIterator, DataLoaderError> {
467    let field = data
468        .schema_ref()
469        .field_with_name(feature_key)
470        .with_context(|| {
471            format!("Failed to get field for feature {feature_key} from parquet file")
472        })?;
473
474    let entity_path = EntityPath::parse_forgiving(field.name());
475
476    match field.data_type() {
477        DataType::FixedSizeList(_, _) => {
478            let fixed_size_array = data
479                .column_by_name(feature_key)
480                .and_then(|col| col.downcast_array_ref::<FixedSizeListArray>())
481                .ok_or_else(|| {
482                    DataLoaderError::Other(anyhow!(
483                        "Failed to downcast feature to FixedSizeListArray"
484                    ))
485                })?;
486
487            let batch_chunks =
488                make_scalar_batch_entity_chunks(entity_path, feature, timelines, fixed_size_array)?;
489            Ok(ScalarChunkIterator::Batch(Box::new(batch_chunks)))
490        }
491        DataType::List(_field) => {
492            let list_array = data
493                .column_by_name(feature_key)
494                .and_then(|col| col.downcast_array_ref::<arrow::array::ListArray>())
495                .ok_or_else(|| {
496                    DataLoaderError::Other(anyhow!("Failed to downcast feature to ListArray"))
497                })?;
498
499            let sliced = extract_list_array_elements_as_f64(list_array).with_context(|| {
500                format!("Failed to cast scalar feature {entity_path} to Float64")
501            })?;
502
503            Ok(ScalarChunkIterator::Single(Box::new(std::iter::once(
504                make_scalar_entity_chunk(entity_path, timelines, &sliced)?,
505            ))))
506        }
507        DataType::Float32 | DataType::Float64 => {
508            let feature_data = data.column_by_name(feature_key).ok_or_else(|| {
509                DataLoaderError::Other(anyhow!(
510                    "Failed to get LeRobot dataset column data for: {:?}",
511                    field.name()
512                ))
513            })?;
514
515            let sliced = extract_scalar_slices_as_f64(feature_data).with_context(|| {
516                format!("Failed to cast scalar feature {entity_path} to Float64")
517            })?;
518
519            Ok(ScalarChunkIterator::Single(Box::new(std::iter::once(
520                make_scalar_entity_chunk(entity_path, timelines, &sliced)?,
521            ))))
522        }
523        _ => {
524            re_log::warn_once!(
525                "Tried logging scalar {} with unsupported dtype: {}",
526                field.name(),
527                field.data_type()
528            );
529            Ok(ScalarChunkIterator::Empty(std::iter::empty()))
530        }
531    }
532}
533
534fn make_scalar_batch_entity_chunks(
535    entity_path: EntityPath,
536    feature: &Feature,
537    timelines: &IntMap<TimelineName, TimeColumn>,
538    data: &FixedSizeListArray,
539) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
540    let num_elements = data.value_length() as usize;
541
542    let mut chunks = Vec::with_capacity(num_elements);
543
544    let sliced = extract_fixed_size_list_array_elements_as_f64(data)
545        .with_context(|| format!("Failed to cast scalar feature {entity_path} to Float64"))?;
546
547    chunks.push(make_scalar_entity_chunk(
548        entity_path.clone(),
549        timelines,
550        &sliced,
551    )?);
552
553    // If we have names for this feature, we insert a single static chunk containing the names.
554    if let Some(names) = feature.names.clone() {
555        let names: Vec<_> = (0..data.value_length() as usize)
556            .map(|idx| names.name_for_index(idx))
557            .collect();
558
559        chunks.push(
560            Chunk::builder(entity_path)
561                .with_row(
562                    RowId::new(),
563                    TimePoint::default(),
564                    std::iter::once((
565                        archetypes::SeriesLines::descriptor_names(),
566                        Arc::new(StringArray::from_iter(names)) as Arc<dyn ArrowArray>,
567                    )),
568                )
569                .build()?,
570        );
571    }
572
573    Ok(chunks.into_iter())
574}
575
576fn make_scalar_entity_chunk(
577    entity_path: EntityPath,
578    timelines: &IntMap<TimelineName, TimeColumn>,
579    sliced_data: &[ArrayRef],
580) -> Result<Chunk, DataLoaderError> {
581    let data_arrays = sliced_data
582        .iter()
583        .map(|e| Some(e.as_ref()))
584        .collect::<Vec<_>>();
585
586    let data_field_inner = Field::new("item", DataType::Float64, true /* nullable */);
587    #[allow(clippy::unwrap_used)] // we know we've given the right field type
588    let data_field_array: arrow::array::ListArray =
589        re_arrow_util::arrays_to_list_array(data_field_inner.data_type().clone(), &data_arrays)
590            .unwrap();
591
592    Ok(Chunk::from_auto_row_ids(
593        ChunkId::new(),
594        entity_path,
595        timelines.clone(),
596        std::iter::once((Scalars::descriptor_scalars().clone(), data_field_array)).collect(),
597    )?)
598}
599
600fn extract_scalar_slices_as_f64(data: &ArrayRef) -> anyhow::Result<Vec<ArrayRef>> {
601    // cast the slice to f64 first, as scalars need an f64
602    let scalar_values = cast(&data, &DataType::Float64)
603        .with_context(|| format!("Failed to cast {} to Float64", data.data_type()))?;
604
605    Ok((0..data.len())
606        .map(|idx| scalar_values.slice(idx, 1))
607        .collect::<Vec<_>>())
608}
609
610fn extract_fixed_size_list_array_elements_as_f64(
611    data: &FixedSizeListArray,
612) -> anyhow::Result<Vec<ArrayRef>> {
613    (0..data.len())
614        .map(|idx| {
615            cast(&data.value(idx), &DataType::Float64)
616                .with_context(|| format!("Failed to cast {} to Float64", data.data_type()))
617        })
618        .collect::<Result<Vec<_>, _>>()
619}
620
621fn extract_list_array_elements_as_f64(
622    data: &arrow::array::ListArray,
623) -> anyhow::Result<Vec<ArrayRef>> {
624    (0..data.len())
625        .map(|idx| {
626            cast(&data.value(idx), &DataType::Float64)
627                .with_context(|| format!("Failed to cast {} to Float64", data.data_type()))
628        })
629        .collect::<Result<Vec<_>, _>>()
630}