re_data_loader/
loader_lerobot.rs

1use std::sync::mpsc::Sender;
2use std::sync::Arc;
3use std::thread;
4
5use anyhow::{anyhow, Context as _};
6use arrow::array::{
7    ArrayRef, BinaryArray, FixedSizeListArray, Int64Array, RecordBatch, StringArray, StructArray,
8};
9use arrow::compute::cast;
10use arrow::datatypes::{DataType, Field};
11use itertools::Either;
12use re_arrow_util::ArrowArrayDowncastRef as _;
13use re_chunk::{external::nohash_hasher::IntMap, TimelineName};
14use re_chunk::{
15    ArrowArray, Chunk, ChunkId, EntityPath, RowId, TimeColumn, TimeInt, TimePoint, Timeline,
16};
17
18use re_log_types::{ApplicationId, StoreId};
19use re_types::archetypes::{
20    AssetVideo, DepthImage, EncodedImage, TextDocument, VideoFrameReference,
21};
22use re_types::components::{Name, Scalar, VideoTimestamp};
23use re_types::{Archetype, Component, ComponentBatch};
24
25use crate::lerobot::{
26    is_lerobot_dataset, is_v1_lerobot_dataset, DType, EpisodeIndex, Feature, LeRobotDataset,
27    TaskIndex,
28};
29use crate::load_file::prepare_store_info;
30use crate::{DataLoader, DataLoaderError, LoadedData};
31
32/// Columns in the `LeRobot` dataset schema that we do not visualize in the viewer, and thus ignore.
33const LEROBOT_DATASET_IGNORED_COLUMNS: &[&str] =
34    &["episode_index", "index", "frame_index", "timestamp"];
35
36/// Only supports `LeRobot` datasets that are in a supported version format.
37/// Datasets from unsupported versions won't load.
38const LEROBOT_DATASET_SUPPORTED_VERSIONS: &[&str] = &["v2.0", "v2.1"];
39
40/// A [`DataLoader`] for `LeRobot` datasets.
41///
42/// An example dataset which can be loaded can be found on Hugging Face: [lerobot/pusht_image](https://huggingface.co/datasets/lerobot/pusht_image)
43pub struct LeRobotDatasetLoader;
44
45impl DataLoader for LeRobotDatasetLoader {
46    fn name(&self) -> String {
47        "LeRobotDatasetLoader".into()
48    }
49
50    fn load_from_path(
51        &self,
52        settings: &crate::DataLoaderSettings,
53        filepath: std::path::PathBuf,
54        tx: Sender<LoadedData>,
55    ) -> Result<(), DataLoaderError> {
56        if !is_lerobot_dataset(&filepath) {
57            return Err(DataLoaderError::Incompatible(filepath));
58        }
59
60        if is_v1_lerobot_dataset(&filepath) {
61            re_log::error!("LeRobot 'v1.x' dataset format is unsupported.");
62            return Ok(());
63        }
64
65        let dataset = LeRobotDataset::load_from_directory(&filepath)
66            .map_err(|err| anyhow!("Loading LeRobot dataset failed: {err}"))?;
67
68        if !LEROBOT_DATASET_SUPPORTED_VERSIONS
69            .contains(&dataset.metadata.info.codebase_version.as_str())
70        {
71            re_log::error!(
72                "LeRobot '{}' dataset format is unsupported.",
73                dataset.metadata.info.codebase_version
74            );
75            return Ok(());
76        }
77
78        let application_id = settings
79            .application_id
80            .clone()
81            .unwrap_or(ApplicationId(filepath.display().to_string()));
82
83        // NOTE(1): `spawn` is fine, this whole function is native-only.
84        // NOTE(2): this must spawned on a dedicated thread to avoid a deadlock!
85        // `load` will spawn a bunch of loaders on the common rayon thread pool and wait for
86        // their response via channels: we cannot be waiting for these responses on the
87        // common rayon thread pool.
88        thread::Builder::new()
89            .name(format!("load_and_stream({filepath:?}"))
90            .spawn({
91                move || {
92                    re_log::info!(
93                        "Loading LeRobot dataset from {:?}, with {} episode(s)",
94                        dataset.path,
95                        dataset.metadata.episodes.len(),
96                    );
97                    load_and_stream(&dataset, &application_id, &tx);
98                }
99            })
100            .with_context(|| {
101                format!("Failed to spawn IO thread to load LeRobot dataset {filepath:?} ")
102            })?;
103
104        Ok(())
105    }
106
107    fn load_from_file_contents(
108        &self,
109        _settings: &crate::DataLoaderSettings,
110        filepath: std::path::PathBuf,
111        _contents: std::borrow::Cow<'_, [u8]>,
112        _tx: Sender<LoadedData>,
113    ) -> Result<(), DataLoaderError> {
114        Err(DataLoaderError::Incompatible(filepath))
115    }
116}
117
118fn load_and_stream(
119    dataset: &LeRobotDataset,
120    application_id: &ApplicationId,
121    tx: &Sender<crate::LoadedData>,
122) {
123    // set up all recordings
124    let episodes = prepare_episode_chunks(dataset, application_id, tx);
125
126    for (episode, store_id) in &episodes {
127        // log episode data to its respective recording
128        match load_episode(dataset, *episode) {
129            Ok(chunks) => {
130                let properties = re_types::archetypes::RecordingProperties::new()
131                    .with_name(format!("Episode {}", episode.0));
132
133                debug_assert!(TimePoint::default().is_static());
134                let Ok(initial) = Chunk::builder(EntityPath::recording_properties())
135                    .with_archetype(RowId::new(), TimePoint::default(), &properties)
136                    .build()
137                else {
138                    re_log::error!(
139                        "Failed to build recording properties chunk for episode {}",
140                        episode.0
141                    );
142                    return;
143                };
144
145                for chunk in std::iter::once(initial).chain(chunks.into_iter()) {
146                    let data = LoadedData::Chunk(
147                        LeRobotDatasetLoader::name(&LeRobotDatasetLoader),
148                        store_id.clone(),
149                        chunk,
150                    );
151
152                    if tx.send(data).is_err() {
153                        break; // The other end has decided to hang up, not our problem.
154                    }
155                }
156            }
157            Err(err) => {
158                re_log::warn!(
159                    "Failed to load episode {} from LeRobot dataset: {err}",
160                    episode.0
161                );
162            }
163        }
164    }
165}
166
167/// Prepare the viewer for all episodes, by sending out a [`SetStoreInfo`](`re_log_types::SetStoreInfo`)
168/// [`LogMsg`](`re_log_types::LogMsg`) for each episode.
169fn prepare_episode_chunks(
170    dataset: &LeRobotDataset,
171    application_id: &ApplicationId,
172    tx: &Sender<crate::LoadedData>,
173) -> Vec<(EpisodeIndex, StoreId)> {
174    let mut store_ids = vec![];
175
176    for episode in &dataset.metadata.episodes {
177        let episode = episode.index;
178
179        let store_id = StoreId::from_string(
180            re_log_types::StoreKind::Recording,
181            format!("episode_{}", episode.0),
182        );
183        let set_store_info = LoadedData::LogMsg(
184            LeRobotDatasetLoader::name(&LeRobotDatasetLoader),
185            prepare_store_info(
186                application_id.clone(),
187                &store_id,
188                re_log_types::FileSource::Sdk,
189            ),
190        );
191
192        if tx.send(set_store_info).is_err() {
193            break;
194        }
195
196        store_ids.push((episode, store_id.clone()));
197    }
198
199    store_ids
200}
201
202/// Loads a single episode from a `LeRobot` dataset and converts it into a collection of Rerun chunks.
203///
204/// This function processes an episode from the dataset by extracting the relevant data columns and
205/// converting them into appropriate Rerun data structures. It handles different types of data
206/// (videos, images, scalar values, etc.) based on their data type specifications in the dataset metadata.
207pub fn load_episode(
208    dataset: &LeRobotDataset,
209    episode: EpisodeIndex,
210) -> Result<Vec<Chunk>, DataLoaderError> {
211    let data = dataset
212        .read_episode_data(episode)
213        .map_err(|err| anyhow!("Reading data for episode {} failed: {err}", episode.0))?;
214
215    let frame_indices = data
216        .column_by_name("frame_index")
217        .ok_or_else(|| anyhow!("Failed to get frame index column in LeRobot dataset"))?
218        .clone();
219
220    let timeline = re_log_types::Timeline::new_sequence("frame_index");
221    let times: &arrow::buffer::ScalarBuffer<i64> = frame_indices
222        .downcast_array_ref::<Int64Array>()
223        .ok_or_else(|| anyhow!("LeRobot dataset frame indices are of an unexpected type"))?
224        .values();
225
226    let time_column = re_chunk::TimeColumn::new(None, timeline, times.clone());
227    let timelines = std::iter::once((*timeline.name(), time_column.clone())).collect();
228
229    let mut chunks = Vec::new();
230
231    for (feature_key, feature) in dataset
232        .metadata
233        .info
234        .features
235        .iter()
236        .filter(|(key, _)| !LEROBOT_DATASET_IGNORED_COLUMNS.contains(&key.as_str()))
237    {
238        match feature.dtype {
239            DType::Video => {
240                chunks.extend(load_episode_video(
241                    dataset,
242                    feature_key,
243                    episode,
244                    &timeline,
245                    time_column.clone(),
246                )?);
247            }
248
249            DType::Image => {
250                let num_channels = feature.channel_dim();
251
252                match num_channels {
253                    1 => chunks.extend(load_episode_depth_images(feature_key, &timeline, &data)?),
254                    3 => chunks.extend(load_episode_images(feature_key, &timeline, &data)?),
255                    _ => re_log::warn_once!(
256                        "Unsupported channel count {num_channels} (shape: {:?}) for LeRobot dataset; Only 1- and 3-channel images are supported",
257                        feature.shape
258                    ),
259                };
260            }
261            DType::Int64 if feature_key == "task_index" => {
262                // special case int64 task_index columns
263                // this always refers to the task description in the dataset metadata.
264                chunks.extend(log_episode_task(dataset, &timeline, &data)?);
265            }
266            DType::Int16 | DType::Int64 | DType::Bool | DType::String => {
267                re_log::warn_once!(
268                    "Loading LeRobot feature ({feature_key}) of dtype `{:?}` into Rerun is not yet implemented",
269                    feature.dtype
270                );
271            }
272            DType::Float32 | DType::Float64 => {
273                chunks.extend(load_scalar(feature_key, feature, &timelines, &data)?);
274            }
275        }
276    }
277
278    Ok(chunks)
279}
280
281fn log_episode_task(
282    dataset: &LeRobotDataset,
283    timeline: &Timeline,
284    data: &RecordBatch,
285) -> Result<impl ExactSizeIterator<Item = Chunk>, DataLoaderError> {
286    let task_indices = data
287        .column_by_name("task_index")
288        .and_then(|c| c.downcast_array_ref::<Int64Array>())
289        .with_context(|| "Failed to get task_index field from dataset!")?;
290
291    let mut chunk = Chunk::builder("task".into());
292    let mut row_id = RowId::new();
293    let mut time_int = TimeInt::ZERO;
294
295    for task_index in task_indices {
296        let Some(task) = task_index
297            .and_then(|i| usize::try_from(i).ok())
298            .and_then(|i| dataset.task_by_index(TaskIndex(i)))
299        else {
300            // if there is no valid task for the current frame index, we skip it.
301            time_int = time_int.inc();
302            continue;
303        };
304
305        let timepoint = TimePoint::default().with(*timeline, time_int);
306        let text = TextDocument::new(task.task.clone());
307        chunk = chunk.with_archetype(row_id, timepoint, &text);
308
309        row_id = row_id.next();
310        time_int = time_int.inc();
311    }
312
313    Ok(std::iter::once(chunk.build()?))
314}
315
316fn load_episode_images(
317    observation: &str,
318    timeline: &Timeline,
319    data: &RecordBatch,
320) -> Result<impl ExactSizeIterator<Item = Chunk>, DataLoaderError> {
321    let image_bytes = data
322        .column_by_name(observation)
323        .and_then(|c| c.downcast_array_ref::<StructArray>())
324        .and_then(|a| a.column_by_name("bytes"))
325        .and_then(|a| a.downcast_array_ref::<BinaryArray>())
326        .with_context(|| format!("Failed to get binary data from image feature: {observation}"))?;
327
328    let mut chunk = Chunk::builder(observation.into());
329    let mut row_id = RowId::new();
330
331    for frame_idx in 0..image_bytes.len() {
332        let img_buffer = image_bytes.value(frame_idx);
333        let encoded_image = EncodedImage::from_file_contents(img_buffer.to_owned());
334        let timepoint = TimePoint::default().with(*timeline, frame_idx as i64);
335        chunk = chunk.with_archetype(row_id, timepoint, &encoded_image);
336
337        row_id = row_id.next();
338    }
339
340    Ok(std::iter::once(chunk.build().with_context(|| {
341        format!("Failed to build image chunk for image: {observation}")
342    })?))
343}
344
345fn load_episode_depth_images(
346    observation: &str,
347    timeline: &Timeline,
348    data: &RecordBatch,
349) -> Result<impl ExactSizeIterator<Item = Chunk>, DataLoaderError> {
350    let image_bytes = data
351        .column_by_name(observation)
352        .and_then(|c| c.downcast_array_ref::<StructArray>())
353        .and_then(|a| a.column_by_name("bytes"))
354        .and_then(|a| a.downcast_array_ref::<BinaryArray>())
355        .with_context(|| format!("Failed to get binary data from image feature: {observation}"))?;
356
357    let mut chunk = Chunk::builder(observation.into());
358    let mut row_id = RowId::new();
359
360    for frame_idx in 0..image_bytes.len() {
361        let img_buffer = image_bytes.value(frame_idx);
362        let depth_image = DepthImage::from_file_contents(img_buffer.to_owned())
363            .map_err(|err| anyhow!("Failed to decode image: {err}"))?;
364
365        let timepoint = TimePoint::default().with(*timeline, frame_idx as i64);
366        chunk = chunk.with_archetype(row_id, timepoint, &depth_image);
367
368        row_id = row_id.next();
369    }
370
371    Ok(std::iter::once(chunk.build().with_context(|| {
372        format!("Failed to build image chunk for image: {observation}")
373    })?))
374}
375
376fn load_episode_video(
377    dataset: &LeRobotDataset,
378    observation: &str,
379    episode: EpisodeIndex,
380    timeline: &Timeline,
381    time_column: TimeColumn,
382) -> Result<impl ExactSizeIterator<Item = Chunk>, DataLoaderError> {
383    let contents = dataset
384        .read_episode_video_contents(observation, episode)
385        .with_context(|| format!("Reading video contents for episode {episode:?} failed!"))?;
386
387    let video_asset = AssetVideo::new(contents.into_owned());
388    let entity_path = observation;
389
390    let video_frame_reference_chunk = match video_asset.read_frame_timestamps_nanos() {
391        Ok(frame_timestamps_nanos) => {
392            let frame_timestamps_nanos: arrow::buffer::ScalarBuffer<i64> =
393                frame_timestamps_nanos.into();
394
395            let video_timestamps = frame_timestamps_nanos
396                .iter()
397                .take(time_column.num_rows())
398                .copied()
399                .map(VideoTimestamp::from_nanos)
400                .collect::<Vec<_>>();
401
402            let video_timestamp_batch = &video_timestamps as &dyn ComponentBatch;
403            let video_timestamp_list_array = video_timestamp_batch
404                .to_arrow_list_array()
405                .map_err(re_chunk::ChunkError::from)?;
406
407            // Indicator column.
408            let video_frame_reference_indicators =
409                <VideoFrameReference as Archetype>::Indicator::new_array(video_timestamps.len());
410            let video_frame_reference_indicators_list_array = video_frame_reference_indicators
411                .to_arrow_list_array()
412                .map_err(re_chunk::ChunkError::from)?;
413
414            Some(Chunk::from_auto_row_ids(
415                re_chunk::ChunkId::new(),
416                entity_path.into(),
417                std::iter::once((*timeline.name(), time_column)).collect(),
418                [
419                    (
420                        VideoFrameReference::indicator().descriptor.clone(),
421                        video_frame_reference_indicators_list_array,
422                    ),
423                    (
424                        video_timestamp_batch.descriptor().into_owned(),
425                        video_timestamp_list_array,
426                    ),
427                ]
428                .into_iter()
429                .collect(),
430            )?)
431        }
432        Err(err) => {
433            re_log::warn_once!(
434                "Failed to read frame timestamps from episode {episode:?} video: {err}"
435            );
436            None
437        }
438    };
439
440    // Put video asset into its own (static) chunk since it can be fairly large.
441    let video_asset_chunk = Chunk::builder(entity_path.into())
442        .with_archetype(RowId::new(), TimePoint::default(), &video_asset)
443        .build()?;
444
445    if let Some(video_frame_reference_chunk) = video_frame_reference_chunk {
446        Ok(Either::Left(
447            [video_asset_chunk, video_frame_reference_chunk].into_iter(),
448        ))
449    } else {
450        // Still log the video asset, but don't include video frames.
451        Ok(Either::Right(std::iter::once(video_asset_chunk)))
452    }
453}
454
455/// Helper type similar to [`Either`], but with 3 variants.
456enum ScalarChunkIterator {
457    Empty(std::iter::Empty<Chunk>),
458    Batch(Box<dyn ExactSizeIterator<Item = Chunk>>),
459    Single(std::iter::Once<Chunk>),
460}
461
462impl Iterator for ScalarChunkIterator {
463    type Item = Chunk;
464
465    fn next(&mut self) -> Option<Self::Item> {
466        match self {
467            Self::Empty(iter) => iter.next(),
468            Self::Batch(iter) => iter.next(),
469            Self::Single(iter) => iter.next(),
470        }
471    }
472}
473
474impl ExactSizeIterator for ScalarChunkIterator {}
475
476fn load_scalar(
477    feature_key: &str,
478    feature: &Feature,
479    timelines: &IntMap<TimelineName, TimeColumn>,
480    data: &RecordBatch,
481) -> Result<ScalarChunkIterator, DataLoaderError> {
482    let field = data
483        .schema_ref()
484        .field_with_name(feature_key)
485        .with_context(|| {
486            format!("Failed to get field for feature {feature_key} from parquet file")
487        })?;
488
489    let entity_path = EntityPath::parse_forgiving(field.name());
490
491    match field.data_type() {
492        DataType::FixedSizeList(_, _) => {
493            let fixed_size_array = data
494                .column_by_name(feature_key)
495                .and_then(|col| col.downcast_array_ref::<FixedSizeListArray>())
496                .ok_or_else(|| {
497                    DataLoaderError::Other(anyhow!(
498                        "Failed to downcast feature to FixedSizeListArray"
499                    ))
500                })?;
501
502            let batch_chunks =
503                make_scalar_batch_entity_chunks(entity_path, feature, timelines, fixed_size_array)?;
504            Ok(ScalarChunkIterator::Batch(Box::new(batch_chunks)))
505        }
506        DataType::List(_field) => {
507            let list_array = data
508                .column_by_name(feature_key)
509                .and_then(|col| col.downcast_array_ref::<arrow::array::ListArray>())
510                .ok_or_else(|| {
511                    DataLoaderError::Other(anyhow!("Failed to downcast feature to ListArray"))
512                })?;
513
514            let sliced = extract_list_array_elements_as_f64(list_array).with_context(|| {
515                format!("Failed to cast scalar feature {entity_path} to Float64")
516            })?;
517
518            Ok(ScalarChunkIterator::Single(std::iter::once(
519                make_scalar_entity_chunk(entity_path, timelines, &sliced)?,
520            )))
521        }
522        DataType::Float32 | DataType::Float64 => {
523            let feature_data = data.column_by_name(feature_key).ok_or_else(|| {
524                DataLoaderError::Other(anyhow!(
525                    "Failed to get LeRobot dataset column data for: {:?}",
526                    field.name()
527                ))
528            })?;
529
530            let sliced = extract_scalar_slices_as_f64(feature_data).with_context(|| {
531                format!("Failed to cast scalar feature {entity_path} to Float64")
532            })?;
533
534            Ok(ScalarChunkIterator::Single(std::iter::once(
535                make_scalar_entity_chunk(entity_path, timelines, &sliced)?,
536            )))
537        }
538        _ => {
539            re_log::warn_once!(
540                "Tried logging scalar {} with unsupported dtype: {}",
541                field.name(),
542                field.data_type()
543            );
544            Ok(ScalarChunkIterator::Empty(std::iter::empty()))
545        }
546    }
547}
548
549fn make_scalar_batch_entity_chunks(
550    entity_path: EntityPath,
551    feature: &Feature,
552    timelines: &IntMap<TimelineName, TimeColumn>,
553    data: &FixedSizeListArray,
554) -> Result<impl ExactSizeIterator<Item = Chunk>, DataLoaderError> {
555    let num_elements = data.value_length() as usize;
556
557    let mut chunks = Vec::with_capacity(num_elements);
558
559    let sliced = extract_fixed_size_list_array_elements_as_f64(data)
560        .with_context(|| format!("Failed to cast scalar feature {entity_path} to Float64"))?;
561
562    chunks.push(make_scalar_entity_chunk(
563        entity_path.clone(),
564        timelines,
565        &sliced,
566    )?);
567
568    // If we have names for this feature, we insert a single static chunk containing the names.
569    if let Some(names) = feature.names.clone() {
570        let names: Vec<_> = (0..data.value_length() as usize)
571            .map(|idx| names.name_for_index(idx))
572            .collect();
573
574        chunks.push(
575            Chunk::builder(entity_path)
576                .with_row(
577                    RowId::new(),
578                    TimePoint::default(),
579                    std::iter::once((
580                        <Name as Component>::descriptor().clone(),
581                        Arc::new(StringArray::from_iter(names)) as Arc<dyn ArrowArray>,
582                    )),
583                )
584                .build()?,
585        );
586    }
587
588    Ok(chunks.into_iter())
589}
590
591fn make_scalar_entity_chunk(
592    entity_path: EntityPath,
593    timelines: &IntMap<TimelineName, TimeColumn>,
594    sliced_data: &[ArrayRef],
595) -> Result<Chunk, DataLoaderError> {
596    let data_arrays = sliced_data
597        .iter()
598        .map(|e| Some(e.as_ref()))
599        .collect::<Vec<_>>();
600
601    let data_field_inner = Field::new("item", DataType::Float64, true /* nullable */);
602    #[allow(clippy::unwrap_used)] // we know we've given the right field type
603    let data_field_array: arrow::array::ListArray =
604        re_arrow_util::arrays_to_list_array(data_field_inner.data_type().clone(), &data_arrays)
605            .unwrap();
606
607    Ok(Chunk::from_auto_row_ids(
608        ChunkId::new(),
609        entity_path,
610        timelines.clone(),
611        std::iter::once((
612            <Scalar as Component>::descriptor().clone(),
613            data_field_array,
614        ))
615        .collect(),
616    )?)
617}
618
619fn extract_scalar_slices_as_f64(data: &ArrayRef) -> anyhow::Result<Vec<ArrayRef>> {
620    // cast the slice to f64 first, as scalars need an f64
621    let scalar_values = cast(&data, &DataType::Float64)
622        .with_context(|| format!("Failed to cast {:?} to Float64", data.data_type()))?;
623
624    Ok((0..data.len())
625        .map(|idx| scalar_values.slice(idx, 1))
626        .collect::<Vec<_>>())
627}
628
629fn extract_fixed_size_list_array_elements_as_f64(
630    data: &FixedSizeListArray,
631) -> anyhow::Result<Vec<ArrayRef>> {
632    (0..data.len())
633        .map(|idx| {
634            cast(&data.value(idx), &DataType::Float64)
635                .with_context(|| format!("Failed to cast {:?} to Float64", data.data_type()))
636        })
637        .collect::<Result<Vec<_>, _>>()
638}
639
640fn extract_list_array_elements_as_f64(
641    data: &arrow::array::ListArray,
642) -> anyhow::Result<Vec<ArrayRef>> {
643    (0..data.len())
644        .map(|idx| {
645            cast(&data.value(idx), &DataType::Float64)
646                .with_context(|| format!("Failed to cast {:?} to Float64", data.data_type()))
647        })
648        .collect::<Result<Vec<_>, _>>()
649}