re_data_loader/
loader_lerobot.rs

1#![expect(clippy::cast_possible_wrap)] // usize -> i64
2
3use std::{
4    sync::{Arc, mpsc::Sender},
5    thread,
6};
7
8use anyhow::{Context as _, anyhow};
9use arrow::{
10    array::{
11        ArrayRef, BinaryArray, FixedSizeListArray, Int64Array, RecordBatch, StringArray,
12        StructArray,
13    },
14    compute::cast,
15    datatypes::{DataType, Field},
16};
17use itertools::Either;
18
19use re_arrow_util::ArrowArrayDowncastRef as _;
20use re_chunk::{
21    ArrowArray, Chunk, ChunkId, EntityPath, RowId, TimeColumn, TimeInt, TimePoint, Timeline,
22    TimelineName, external::nohash_hasher::IntMap,
23};
24use re_log_types::{ApplicationId, StoreId};
25use re_types::{
26    archetypes::{
27        self, AssetVideo, DepthImage, EncodedImage, Scalars, TextDocument, VideoFrameReference,
28    },
29    components::VideoTimestamp,
30};
31
32use crate::lerobot::{
33    DType, EpisodeIndex, Feature, LeRobotDataset, TaskIndex, is_lerobot_dataset,
34    is_v1_lerobot_dataset,
35};
36use crate::load_file::prepare_store_info;
37use crate::{DataLoader, DataLoaderError, LoadedData};
38
39/// Columns in the `LeRobot` dataset schema that we do not visualize in the viewer, and thus ignore.
40const LEROBOT_DATASET_IGNORED_COLUMNS: &[&str] =
41    &["episode_index", "index", "frame_index", "timestamp"];
42
43/// Only supports `LeRobot` datasets that are in a supported version format.
44/// Datasets from unsupported versions won't load.
45const LEROBOT_DATASET_SUPPORTED_VERSIONS: &[&str] = &["v2.0", "v2.1"];
46
47/// A [`DataLoader`] for `LeRobot` datasets.
48///
49/// An example dataset which can be loaded can be found on Hugging Face: [lerobot/pusht_image](https://huggingface.co/datasets/lerobot/pusht_image)
50pub struct LeRobotDatasetLoader;
51
52impl DataLoader for LeRobotDatasetLoader {
53    fn name(&self) -> String {
54        "LeRobotDatasetLoader".into()
55    }
56
57    fn load_from_path(
58        &self,
59        settings: &crate::DataLoaderSettings,
60        filepath: std::path::PathBuf,
61        tx: Sender<LoadedData>,
62    ) -> Result<(), DataLoaderError> {
63        if !is_lerobot_dataset(&filepath) {
64            return Err(DataLoaderError::Incompatible(filepath));
65        }
66
67        if is_v1_lerobot_dataset(&filepath) {
68            re_log::error!("LeRobot 'v1.x' dataset format is unsupported.");
69            return Ok(());
70        }
71
72        let dataset = LeRobotDataset::load_from_directory(&filepath)
73            .map_err(|err| anyhow!("Loading LeRobot dataset failed: {err}"))?;
74
75        if !LEROBOT_DATASET_SUPPORTED_VERSIONS
76            .contains(&dataset.metadata.info.codebase_version.as_str())
77        {
78            re_log::error!(
79                "LeRobot '{}' dataset format is unsupported.",
80                dataset.metadata.info.codebase_version
81            );
82            return Ok(());
83        }
84
85        let application_id = settings
86            .application_id
87            .clone()
88            .unwrap_or(filepath.display().to_string().into());
89
90        // NOTE(1): `spawn` is fine, this whole function is native-only.
91        // NOTE(2): this must spawned on a dedicated thread to avoid a deadlock!
92        // `load` will spawn a bunch of loaders on the common rayon thread pool and wait for
93        // their response via channels: we cannot be waiting for these responses on the
94        // common rayon thread pool.
95        thread::Builder::new()
96            .name(format!("load_and_stream({filepath:?}"))
97            .spawn({
98                move || {
99                    re_log::info!(
100                        "Loading LeRobot dataset from {:?}, with {} episode(s)",
101                        dataset.path,
102                        dataset.metadata.episode_count(),
103                    );
104                    load_and_stream(&dataset, &application_id, &tx);
105                }
106            })
107            .with_context(|| {
108                format!("Failed to spawn IO thread to load LeRobot dataset {filepath:?} ")
109            })?;
110
111        Ok(())
112    }
113
114    fn load_from_file_contents(
115        &self,
116        _settings: &crate::DataLoaderSettings,
117        filepath: std::path::PathBuf,
118        _contents: std::borrow::Cow<'_, [u8]>,
119        _tx: Sender<LoadedData>,
120    ) -> Result<(), DataLoaderError> {
121        Err(DataLoaderError::Incompatible(filepath))
122    }
123}
124
125fn load_and_stream(
126    dataset: &LeRobotDataset,
127    application_id: &ApplicationId,
128    tx: &Sender<crate::LoadedData>,
129) {
130    // set up all recordings
131    let episodes = prepare_episode_chunks(dataset, application_id, tx);
132
133    for (episode, store_id) in &episodes {
134        // log episode data to its respective recording
135        match load_episode(dataset, *episode) {
136            Ok(chunks) => {
137                let recording_info = re_types::archetypes::RecordingInfo::new()
138                    .with_name(format!("Episode {}", episode.0));
139
140                let Ok(initial) = Chunk::builder(EntityPath::properties())
141                    .with_archetype(RowId::new(), TimePoint::STATIC, &recording_info)
142                    .build()
143                else {
144                    re_log::error!(
145                        "Failed to build recording properties chunk for episode {}",
146                        episode.0
147                    );
148                    return;
149                };
150
151                for chunk in std::iter::once(initial).chain(chunks.into_iter()) {
152                    let data = LoadedData::Chunk(
153                        LeRobotDatasetLoader::name(&LeRobotDatasetLoader),
154                        store_id.clone(),
155                        chunk,
156                    );
157
158                    if tx.send(data).is_err() {
159                        break; // The other end has decided to hang up, not our problem.
160                    }
161                }
162            }
163            Err(err) => {
164                re_log::warn!(
165                    "Failed to load episode {} from LeRobot dataset: {err}",
166                    episode.0
167                );
168            }
169        }
170    }
171}
172
173/// Prepare the viewer for all episodes, by sending out a [`SetStoreInfo`](`re_log_types::SetStoreInfo`)
174/// [`LogMsg`](`re_log_types::LogMsg`) for each episode.
175fn prepare_episode_chunks(
176    dataset: &LeRobotDataset,
177    application_id: &ApplicationId,
178    tx: &Sender<crate::LoadedData>,
179) -> Vec<(EpisodeIndex, StoreId)> {
180    let mut store_ids = vec![];
181
182    for episode_index in dataset.metadata.episodes.keys() {
183        let episode = *episode_index;
184
185        let store_id = StoreId::recording(application_id.clone(), format!("episode_{}", episode.0));
186        let set_store_info = LoadedData::LogMsg(
187            LeRobotDatasetLoader::name(&LeRobotDatasetLoader),
188            prepare_store_info(&store_id, re_log_types::FileSource::Sdk),
189        );
190
191        if tx.send(set_store_info).is_err() {
192            break;
193        }
194
195        store_ids.push((episode, store_id));
196    }
197
198    store_ids
199}
200
201/// Loads a single episode from a `LeRobot` dataset and converts it into a collection of Rerun chunks.
202///
203/// This function processes an episode from the dataset by extracting the relevant data columns and
204/// converting them into appropriate Rerun data structures. It handles different types of data
205/// (videos, images, scalar values, etc.) based on their data type specifications in the dataset metadata.
206pub fn load_episode(
207    dataset: &LeRobotDataset,
208    episode: EpisodeIndex,
209) -> Result<Vec<Chunk>, DataLoaderError> {
210    let data = dataset
211        .read_episode_data(episode)
212        .map_err(|err| anyhow!("Reading data for episode {} failed: {err}", episode.0))?;
213
214    let frame_indices = data
215        .column_by_name("frame_index")
216        .ok_or_else(|| anyhow!("Failed to get frame index column in LeRobot dataset"))?
217        .clone();
218
219    let timeline = re_log_types::Timeline::new_sequence("frame_index");
220    let times: &arrow::buffer::ScalarBuffer<i64> = frame_indices
221        .downcast_array_ref::<Int64Array>()
222        .ok_or_else(|| anyhow!("LeRobot dataset frame indices are of an unexpected type"))?
223        .values();
224
225    let time_column = re_chunk::TimeColumn::new(None, timeline, times.clone());
226    let timelines = std::iter::once((*timeline.name(), time_column.clone())).collect();
227
228    let mut chunks = Vec::new();
229
230    for (feature_key, feature) in dataset
231        .metadata
232        .info
233        .features
234        .iter()
235        .filter(|(key, _)| !LEROBOT_DATASET_IGNORED_COLUMNS.contains(&key.as_str()))
236    {
237        match feature.dtype {
238            DType::Video => {
239                chunks.extend(load_episode_video(
240                    dataset,
241                    feature_key,
242                    episode,
243                    &timeline,
244                    time_column.clone(),
245                )?);
246            }
247
248            DType::Image => {
249                let num_channels = feature.channel_dim();
250
251                match num_channels {
252                    1 => chunks.extend(load_episode_depth_images(feature_key, &timeline, &data)?),
253                    3 => chunks.extend(load_episode_images(feature_key, &timeline, &data)?),
254                    _ => re_log::warn_once!(
255                        "Unsupported channel count {num_channels} (shape: {:?}) for LeRobot dataset; Only 1- and 3-channel images are supported",
256                        feature.shape
257                    ),
258                }
259            }
260            DType::Int64 if feature_key == "task_index" => {
261                // special case int64 task_index columns
262                // this always refers to the task description in the dataset metadata.
263                chunks.extend(log_episode_task(dataset, &timeline, &data)?);
264            }
265            DType::Int16 | DType::Int64 | DType::Bool | DType::String => {
266                re_log::warn_once!(
267                    "Loading LeRobot feature ({feature_key}) of dtype `{:?}` into Rerun is not yet implemented",
268                    feature.dtype
269                );
270            }
271            DType::Float32 | DType::Float64 => {
272                chunks.extend(load_scalar(feature_key, feature, &timelines, &data)?);
273            }
274        }
275    }
276
277    Ok(chunks)
278}
279
280fn log_episode_task(
281    dataset: &LeRobotDataset,
282    timeline: &Timeline,
283    data: &RecordBatch,
284) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
285    let task_indices = data
286        .column_by_name("task_index")
287        .and_then(|c| c.downcast_array_ref::<Int64Array>())
288        .with_context(|| "Failed to get task_index field from dataset!")?;
289
290    let mut chunk = Chunk::builder("task");
291    let mut row_id = RowId::new();
292    let mut time_int = TimeInt::ZERO;
293
294    for task_index in task_indices {
295        let Some(task) = task_index
296            .and_then(|i| usize::try_from(i).ok())
297            .and_then(|i| dataset.task_by_index(TaskIndex(i)))
298        else {
299            // if there is no valid task for the current frame index, we skip it.
300            time_int = time_int.inc();
301            continue;
302        };
303
304        let timepoint = TimePoint::default().with(*timeline, time_int);
305        let text = TextDocument::new(task.task.clone());
306        chunk = chunk.with_archetype(row_id, timepoint, &text);
307
308        row_id = row_id.next();
309        time_int = time_int.inc();
310    }
311
312    Ok(std::iter::once(chunk.build()?))
313}
314
315fn load_episode_images(
316    observation: &str,
317    timeline: &Timeline,
318    data: &RecordBatch,
319) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
320    let image_bytes = data
321        .column_by_name(observation)
322        .and_then(|c| c.downcast_array_ref::<StructArray>())
323        .and_then(|a| a.column_by_name("bytes"))
324        .and_then(|a| a.downcast_array_ref::<BinaryArray>())
325        .with_context(|| format!("Failed to get binary data from image feature: {observation}"))?;
326
327    let mut chunk = Chunk::builder(observation);
328    let mut row_id = RowId::new();
329
330    for frame_idx in 0..image_bytes.len() {
331        let img_buffer = image_bytes.value(frame_idx);
332        let encoded_image = EncodedImage::from_file_contents(img_buffer.to_owned());
333        let timepoint = TimePoint::default().with(*timeline, frame_idx as i64);
334        chunk = chunk.with_archetype(row_id, timepoint, &encoded_image);
335
336        row_id = row_id.next();
337    }
338
339    Ok(std::iter::once(chunk.build().with_context(|| {
340        format!("Failed to build image chunk for image: {observation}")
341    })?))
342}
343
344fn load_episode_depth_images(
345    observation: &str,
346    timeline: &Timeline,
347    data: &RecordBatch,
348) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
349    let image_bytes = data
350        .column_by_name(observation)
351        .and_then(|c| c.downcast_array_ref::<StructArray>())
352        .and_then(|a| a.column_by_name("bytes"))
353        .and_then(|a| a.downcast_array_ref::<BinaryArray>())
354        .with_context(|| format!("Failed to get binary data from image feature: {observation}"))?;
355
356    let mut chunk = Chunk::builder(observation);
357    let mut row_id = RowId::new();
358
359    for frame_idx in 0..image_bytes.len() {
360        let img_buffer = image_bytes.value(frame_idx);
361        let depth_image = DepthImage::from_file_contents(img_buffer.to_owned())
362            .map_err(|err| anyhow!("Failed to decode image: {err}"))?;
363
364        let timepoint = TimePoint::default().with(*timeline, frame_idx as i64);
365        chunk = chunk.with_archetype(row_id, timepoint, &depth_image);
366
367        row_id = row_id.next();
368    }
369
370    Ok(std::iter::once(chunk.build().with_context(|| {
371        format!("Failed to build image chunk for image: {observation}")
372    })?))
373}
374
375fn load_episode_video(
376    dataset: &LeRobotDataset,
377    observation: &str,
378    episode: EpisodeIndex,
379    timeline: &Timeline,
380    time_column: TimeColumn,
381) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
382    let contents = dataset
383        .read_episode_video_contents(observation, episode)
384        .with_context(|| format!("Reading video contents for episode {episode:?} failed!"))?;
385
386    let video_asset = AssetVideo::new(contents.into_owned());
387    let entity_path = observation;
388
389    let video_frame_reference_chunk = match video_asset.read_frame_timestamps_nanos() {
390        Ok(frame_timestamps_nanos) => {
391            let frame_timestamps_nanos: arrow::buffer::ScalarBuffer<i64> =
392                frame_timestamps_nanos.into();
393
394            let video_timestamps = frame_timestamps_nanos
395                .iter()
396                .take(time_column.num_rows())
397                .copied()
398                .map(VideoTimestamp::from_nanos)
399                .collect::<Vec<_>>();
400
401            let video_frame_reference_column = VideoFrameReference::update_fields()
402                .with_many_timestamp(video_timestamps)
403                .columns_of_unit_batches()
404                .with_context(|| {
405                    format!(
406                        "Failed to create `VideoFrameReference` column for episode {episode:?}."
407                    )
408                })?;
409
410            Some(Chunk::from_auto_row_ids(
411                re_chunk::ChunkId::new(),
412                entity_path.into(),
413                std::iter::once((*timeline.name(), time_column)).collect(),
414                video_frame_reference_column.collect(),
415            )?)
416        }
417        Err(err) => {
418            re_log::warn_once!(
419                "Failed to read frame timestamps from episode {episode:?} video: {err}"
420            );
421            None
422        }
423    };
424
425    // Put video asset into its own (static) chunk since it can be fairly large.
426    let video_asset_chunk = Chunk::builder(entity_path)
427        .with_archetype(RowId::new(), TimePoint::default(), &video_asset)
428        .build()?;
429
430    if let Some(video_frame_reference_chunk) = video_frame_reference_chunk {
431        Ok(Either::Left(
432            [video_asset_chunk, video_frame_reference_chunk].into_iter(),
433        ))
434    } else {
435        // Still log the video asset, but don't include video frames.
436        Ok(Either::Right(std::iter::once(video_asset_chunk)))
437    }
438}
439
440/// Helper type similar to [`Either`], but with 3 variants.
441enum ScalarChunkIterator {
442    Empty(std::iter::Empty<Chunk>),
443    Batch(Box<dyn ExactSizeIterator<Item = Chunk>>),
444
445    // Boxed, because `Chunk` is huge, and by extension so is `std::iter::Once<Chunk>`.
446    Single(Box<std::iter::Once<Chunk>>),
447}
448
449impl Iterator for ScalarChunkIterator {
450    type Item = Chunk;
451
452    fn next(&mut self) -> Option<Self::Item> {
453        match self {
454            Self::Empty(iter) => iter.next(),
455            Self::Batch(iter) => iter.next(),
456            Self::Single(iter) => iter.next(),
457        }
458    }
459}
460
461impl ExactSizeIterator for ScalarChunkIterator {}
462
463fn load_scalar(
464    feature_key: &str,
465    feature: &Feature,
466    timelines: &IntMap<TimelineName, TimeColumn>,
467    data: &RecordBatch,
468) -> Result<ScalarChunkIterator, DataLoaderError> {
469    let field = data
470        .schema_ref()
471        .field_with_name(feature_key)
472        .with_context(|| {
473            format!("Failed to get field for feature {feature_key} from parquet file")
474        })?;
475
476    let entity_path = EntityPath::parse_forgiving(field.name());
477
478    match field.data_type() {
479        DataType::FixedSizeList(_, _) => {
480            let fixed_size_array = data
481                .column_by_name(feature_key)
482                .and_then(|col| col.downcast_array_ref::<FixedSizeListArray>())
483                .ok_or_else(|| {
484                    DataLoaderError::Other(anyhow!(
485                        "Failed to downcast feature to FixedSizeListArray"
486                    ))
487                })?;
488
489            let batch_chunks =
490                make_scalar_batch_entity_chunks(entity_path, feature, timelines, fixed_size_array)?;
491            Ok(ScalarChunkIterator::Batch(Box::new(batch_chunks)))
492        }
493        DataType::List(_field) => {
494            let list_array = data
495                .column_by_name(feature_key)
496                .and_then(|col| col.downcast_array_ref::<arrow::array::ListArray>())
497                .ok_or_else(|| {
498                    DataLoaderError::Other(anyhow!("Failed to downcast feature to ListArray"))
499                })?;
500
501            let sliced = extract_list_array_elements_as_f64(list_array).with_context(|| {
502                format!("Failed to cast scalar feature {entity_path} to Float64")
503            })?;
504
505            Ok(ScalarChunkIterator::Single(Box::new(std::iter::once(
506                make_scalar_entity_chunk(entity_path, timelines, &sliced)?,
507            ))))
508        }
509        DataType::Float32 | DataType::Float64 => {
510            let feature_data = data.column_by_name(feature_key).ok_or_else(|| {
511                DataLoaderError::Other(anyhow!(
512                    "Failed to get LeRobot dataset column data for: {:?}",
513                    field.name()
514                ))
515            })?;
516
517            let sliced = extract_scalar_slices_as_f64(feature_data).with_context(|| {
518                format!("Failed to cast scalar feature {entity_path} to Float64")
519            })?;
520
521            Ok(ScalarChunkIterator::Single(Box::new(std::iter::once(
522                make_scalar_entity_chunk(entity_path, timelines, &sliced)?,
523            ))))
524        }
525        _ => {
526            re_log::warn_once!(
527                "Tried logging scalar {} with unsupported dtype: {}",
528                field.name(),
529                field.data_type()
530            );
531            Ok(ScalarChunkIterator::Empty(std::iter::empty()))
532        }
533    }
534}
535
536fn make_scalar_batch_entity_chunks(
537    entity_path: EntityPath,
538    feature: &Feature,
539    timelines: &IntMap<TimelineName, TimeColumn>,
540    data: &FixedSizeListArray,
541) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
542    let num_elements = data.value_length() as usize;
543
544    let mut chunks = Vec::with_capacity(num_elements);
545
546    let sliced = extract_fixed_size_list_array_elements_as_f64(data)
547        .with_context(|| format!("Failed to cast scalar feature {entity_path} to Float64"))?;
548
549    chunks.push(make_scalar_entity_chunk(
550        entity_path.clone(),
551        timelines,
552        &sliced,
553    )?);
554
555    // If we have names for this feature, we insert a single static chunk containing the names.
556    if let Some(names) = feature.names.clone() {
557        let names: Vec<_> = (0..data.value_length() as usize)
558            .map(|idx| names.name_for_index(idx))
559            .collect();
560
561        chunks.push(
562            Chunk::builder(entity_path)
563                .with_row(
564                    RowId::new(),
565                    TimePoint::default(),
566                    std::iter::once((
567                        archetypes::SeriesLines::descriptor_names(),
568                        Arc::new(StringArray::from_iter(names)) as Arc<dyn ArrowArray>,
569                    )),
570                )
571                .build()?,
572        );
573    }
574
575    Ok(chunks.into_iter())
576}
577
578fn make_scalar_entity_chunk(
579    entity_path: EntityPath,
580    timelines: &IntMap<TimelineName, TimeColumn>,
581    sliced_data: &[ArrayRef],
582) -> Result<Chunk, DataLoaderError> {
583    let data_arrays = sliced_data
584        .iter()
585        .map(|e| Some(e.as_ref()))
586        .collect::<Vec<_>>();
587
588    let data_field_inner = Field::new("item", DataType::Float64, true /* nullable */);
589    #[expect(clippy::unwrap_used)] // we know we've given the right field type
590    let data_field_array: arrow::array::ListArray =
591        re_arrow_util::arrays_to_list_array(data_field_inner.data_type().clone(), &data_arrays)
592            .unwrap();
593
594    Ok(Chunk::from_auto_row_ids(
595        ChunkId::new(),
596        entity_path,
597        timelines.clone(),
598        std::iter::once((Scalars::descriptor_scalars().clone(), data_field_array)).collect(),
599    )?)
600}
601
602fn extract_scalar_slices_as_f64(data: &ArrayRef) -> anyhow::Result<Vec<ArrayRef>> {
603    // cast the slice to f64 first, as scalars need an f64
604    let scalar_values = cast(&data, &DataType::Float64)
605        .with_context(|| format!("Failed to cast {} to Float64", data.data_type()))?;
606
607    Ok((0..data.len())
608        .map(|idx| scalar_values.slice(idx, 1))
609        .collect::<Vec<_>>())
610}
611
612fn extract_fixed_size_list_array_elements_as_f64(
613    data: &FixedSizeListArray,
614) -> anyhow::Result<Vec<ArrayRef>> {
615    (0..data.len())
616        .map(|idx| {
617            cast(&data.value(idx), &DataType::Float64)
618                .with_context(|| format!("Failed to cast {} to Float64", data.data_type()))
619        })
620        .collect::<Result<Vec<_>, _>>()
621}
622
623fn extract_list_array_elements_as_f64(
624    data: &arrow::array::ListArray,
625) -> anyhow::Result<Vec<ArrayRef>> {
626    (0..data.len())
627        .map(|idx| {
628            cast(&data.value(idx), &DataType::Float64)
629                .with_context(|| format!("Failed to cast {} to Float64", data.data_type()))
630        })
631        .collect::<Result<Vec<_>, _>>()
632}