re_data_loader/
loader_lerobot.rs

1use std::{
2    sync::{Arc, mpsc::Sender},
3    thread,
4};
5
6use anyhow::{Context as _, anyhow};
7use arrow::{
8    array::{
9        ArrayRef, BinaryArray, FixedSizeListArray, Int64Array, RecordBatch, StringArray,
10        StructArray,
11    },
12    compute::cast,
13    datatypes::{DataType, Field},
14};
15use itertools::Either;
16
17use re_arrow_util::ArrowArrayDowncastRef as _;
18use re_chunk::{
19    ArrowArray, Chunk, ChunkId, EntityPath, RowId, TimeColumn, TimeInt, TimePoint, Timeline,
20    TimelineName, external::nohash_hasher::IntMap,
21};
22use re_log_types::{ApplicationId, StoreId};
23use re_types::{
24    archetypes::{
25        self, AssetVideo, DepthImage, EncodedImage, Scalars, TextDocument, VideoFrameReference,
26    },
27    components::VideoTimestamp,
28};
29
30use crate::lerobot::{
31    DType, EpisodeIndex, Feature, LeRobotDataset, TaskIndex, is_lerobot_dataset,
32    is_v1_lerobot_dataset,
33};
34use crate::load_file::prepare_store_info;
35use crate::{DataLoader, DataLoaderError, LoadedData};
36
37/// Columns in the `LeRobot` dataset schema that we do not visualize in the viewer, and thus ignore.
38const LEROBOT_DATASET_IGNORED_COLUMNS: &[&str] =
39    &["episode_index", "index", "frame_index", "timestamp"];
40
41/// Only supports `LeRobot` datasets that are in a supported version format.
42/// Datasets from unsupported versions won't load.
43const LEROBOT_DATASET_SUPPORTED_VERSIONS: &[&str] = &["v2.0", "v2.1"];
44
45/// A [`DataLoader`] for `LeRobot` datasets.
46///
47/// An example dataset which can be loaded can be found on Hugging Face: [lerobot/pusht_image](https://huggingface.co/datasets/lerobot/pusht_image)
48pub struct LeRobotDatasetLoader;
49
50impl DataLoader for LeRobotDatasetLoader {
51    fn name(&self) -> String {
52        "LeRobotDatasetLoader".into()
53    }
54
55    fn load_from_path(
56        &self,
57        settings: &crate::DataLoaderSettings,
58        filepath: std::path::PathBuf,
59        tx: Sender<LoadedData>,
60    ) -> Result<(), DataLoaderError> {
61        if !is_lerobot_dataset(&filepath) {
62            return Err(DataLoaderError::Incompatible(filepath));
63        }
64
65        if is_v1_lerobot_dataset(&filepath) {
66            re_log::error!("LeRobot 'v1.x' dataset format is unsupported.");
67            return Ok(());
68        }
69
70        let dataset = LeRobotDataset::load_from_directory(&filepath)
71            .map_err(|err| anyhow!("Loading LeRobot dataset failed: {err}"))?;
72
73        if !LEROBOT_DATASET_SUPPORTED_VERSIONS
74            .contains(&dataset.metadata.info.codebase_version.as_str())
75        {
76            re_log::error!(
77                "LeRobot '{}' dataset format is unsupported.",
78                dataset.metadata.info.codebase_version
79            );
80            return Ok(());
81        }
82
83        let application_id = settings
84            .application_id
85            .clone()
86            .unwrap_or(ApplicationId(filepath.display().to_string()));
87
88        // NOTE(1): `spawn` is fine, this whole function is native-only.
89        // NOTE(2): this must spawned on a dedicated thread to avoid a deadlock!
90        // `load` will spawn a bunch of loaders on the common rayon thread pool and wait for
91        // their response via channels: we cannot be waiting for these responses on the
92        // common rayon thread pool.
93        thread::Builder::new()
94            .name(format!("load_and_stream({filepath:?}"))
95            .spawn({
96                move || {
97                    re_log::info!(
98                        "Loading LeRobot dataset from {:?}, with {} episode(s)",
99                        dataset.path,
100                        dataset.metadata.episodes.len(),
101                    );
102                    load_and_stream(&dataset, &application_id, &tx);
103                }
104            })
105            .with_context(|| {
106                format!("Failed to spawn IO thread to load LeRobot dataset {filepath:?} ")
107            })?;
108
109        Ok(())
110    }
111
112    fn load_from_file_contents(
113        &self,
114        _settings: &crate::DataLoaderSettings,
115        filepath: std::path::PathBuf,
116        _contents: std::borrow::Cow<'_, [u8]>,
117        _tx: Sender<LoadedData>,
118    ) -> Result<(), DataLoaderError> {
119        Err(DataLoaderError::Incompatible(filepath))
120    }
121}
122
123fn load_and_stream(
124    dataset: &LeRobotDataset,
125    application_id: &ApplicationId,
126    tx: &Sender<crate::LoadedData>,
127) {
128    // set up all recordings
129    let episodes = prepare_episode_chunks(dataset, application_id, tx);
130
131    for (episode, store_id) in &episodes {
132        // log episode data to its respective recording
133        match load_episode(dataset, *episode) {
134            Ok(chunks) => {
135                let recording_info = re_types::archetypes::RecordingInfo::new()
136                    .with_name(format!("Episode {}", episode.0));
137
138                let Ok(initial) = Chunk::builder(EntityPath::properties())
139                    .with_archetype(RowId::new(), TimePoint::STATIC, &recording_info)
140                    .build()
141                else {
142                    re_log::error!(
143                        "Failed to build recording properties chunk for episode {}",
144                        episode.0
145                    );
146                    return;
147                };
148
149                for chunk in std::iter::once(initial).chain(chunks.into_iter()) {
150                    let data = LoadedData::Chunk(
151                        LeRobotDatasetLoader::name(&LeRobotDatasetLoader),
152                        store_id.clone(),
153                        chunk,
154                    );
155
156                    if tx.send(data).is_err() {
157                        break; // The other end has decided to hang up, not our problem.
158                    }
159                }
160            }
161            Err(err) => {
162                re_log::warn!(
163                    "Failed to load episode {} from LeRobot dataset: {err}",
164                    episode.0
165                );
166            }
167        }
168    }
169}
170
171/// Prepare the viewer for all episodes, by sending out a [`SetStoreInfo`](`re_log_types::SetStoreInfo`)
172/// [`LogMsg`](`re_log_types::LogMsg`) for each episode.
173fn prepare_episode_chunks(
174    dataset: &LeRobotDataset,
175    application_id: &ApplicationId,
176    tx: &Sender<crate::LoadedData>,
177) -> Vec<(EpisodeIndex, StoreId)> {
178    let mut store_ids = vec![];
179
180    for episode in &dataset.metadata.episodes {
181        let episode = episode.index;
182
183        let store_id = StoreId::from_string(
184            re_log_types::StoreKind::Recording,
185            format!("episode_{}", episode.0),
186        );
187        let set_store_info = LoadedData::LogMsg(
188            LeRobotDatasetLoader::name(&LeRobotDatasetLoader),
189            prepare_store_info(
190                application_id.clone(),
191                &store_id,
192                re_log_types::FileSource::Sdk,
193            ),
194        );
195
196        if tx.send(set_store_info).is_err() {
197            break;
198        }
199
200        store_ids.push((episode, store_id.clone()));
201    }
202
203    store_ids
204}
205
206/// Loads a single episode from a `LeRobot` dataset and converts it into a collection of Rerun chunks.
207///
208/// This function processes an episode from the dataset by extracting the relevant data columns and
209/// converting them into appropriate Rerun data structures. It handles different types of data
210/// (videos, images, scalar values, etc.) based on their data type specifications in the dataset metadata.
211pub fn load_episode(
212    dataset: &LeRobotDataset,
213    episode: EpisodeIndex,
214) -> Result<Vec<Chunk>, DataLoaderError> {
215    let data = dataset
216        .read_episode_data(episode)
217        .map_err(|err| anyhow!("Reading data for episode {} failed: {err}", episode.0))?;
218
219    let frame_indices = data
220        .column_by_name("frame_index")
221        .ok_or_else(|| anyhow!("Failed to get frame index column in LeRobot dataset"))?
222        .clone();
223
224    let timeline = re_log_types::Timeline::new_sequence("frame_index");
225    let times: &arrow::buffer::ScalarBuffer<i64> = frame_indices
226        .downcast_array_ref::<Int64Array>()
227        .ok_or_else(|| anyhow!("LeRobot dataset frame indices are of an unexpected type"))?
228        .values();
229
230    let time_column = re_chunk::TimeColumn::new(None, timeline, times.clone());
231    let timelines = std::iter::once((*timeline.name(), time_column.clone())).collect();
232
233    let mut chunks = Vec::new();
234
235    for (feature_key, feature) in dataset
236        .metadata
237        .info
238        .features
239        .iter()
240        .filter(|(key, _)| !LEROBOT_DATASET_IGNORED_COLUMNS.contains(&key.as_str()))
241    {
242        match feature.dtype {
243            DType::Video => {
244                chunks.extend(load_episode_video(
245                    dataset,
246                    feature_key,
247                    episode,
248                    &timeline,
249                    time_column.clone(),
250                )?);
251            }
252
253            DType::Image => {
254                let num_channels = feature.channel_dim();
255
256                match num_channels {
257                    1 => chunks.extend(load_episode_depth_images(feature_key, &timeline, &data)?),
258                    3 => chunks.extend(load_episode_images(feature_key, &timeline, &data)?),
259                    _ => re_log::warn_once!(
260                        "Unsupported channel count {num_channels} (shape: {:?}) for LeRobot dataset; Only 1- and 3-channel images are supported",
261                        feature.shape
262                    ),
263                };
264            }
265            DType::Int64 if feature_key == "task_index" => {
266                // special case int64 task_index columns
267                // this always refers to the task description in the dataset metadata.
268                chunks.extend(log_episode_task(dataset, &timeline, &data)?);
269            }
270            DType::Int16 | DType::Int64 | DType::Bool | DType::String => {
271                re_log::warn_once!(
272                    "Loading LeRobot feature ({feature_key}) of dtype `{:?}` into Rerun is not yet implemented",
273                    feature.dtype
274                );
275            }
276            DType::Float32 | DType::Float64 => {
277                chunks.extend(load_scalar(feature_key, feature, &timelines, &data)?);
278            }
279        }
280    }
281
282    Ok(chunks)
283}
284
285fn log_episode_task(
286    dataset: &LeRobotDataset,
287    timeline: &Timeline,
288    data: &RecordBatch,
289) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
290    let task_indices = data
291        .column_by_name("task_index")
292        .and_then(|c| c.downcast_array_ref::<Int64Array>())
293        .with_context(|| "Failed to get task_index field from dataset!")?;
294
295    let mut chunk = Chunk::builder("task");
296    let mut row_id = RowId::new();
297    let mut time_int = TimeInt::ZERO;
298
299    for task_index in task_indices {
300        let Some(task) = task_index
301            .and_then(|i| usize::try_from(i).ok())
302            .and_then(|i| dataset.task_by_index(TaskIndex(i)))
303        else {
304            // if there is no valid task for the current frame index, we skip it.
305            time_int = time_int.inc();
306            continue;
307        };
308
309        let timepoint = TimePoint::default().with(*timeline, time_int);
310        let text = TextDocument::new(task.task.clone());
311        chunk = chunk.with_archetype(row_id, timepoint, &text);
312
313        row_id = row_id.next();
314        time_int = time_int.inc();
315    }
316
317    Ok(std::iter::once(chunk.build()?))
318}
319
320fn load_episode_images(
321    observation: &str,
322    timeline: &Timeline,
323    data: &RecordBatch,
324) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
325    let image_bytes = data
326        .column_by_name(observation)
327        .and_then(|c| c.downcast_array_ref::<StructArray>())
328        .and_then(|a| a.column_by_name("bytes"))
329        .and_then(|a| a.downcast_array_ref::<BinaryArray>())
330        .with_context(|| format!("Failed to get binary data from image feature: {observation}"))?;
331
332    let mut chunk = Chunk::builder(observation);
333    let mut row_id = RowId::new();
334
335    for frame_idx in 0..image_bytes.len() {
336        let img_buffer = image_bytes.value(frame_idx);
337        let encoded_image = EncodedImage::from_file_contents(img_buffer.to_owned());
338        let timepoint = TimePoint::default().with(*timeline, frame_idx as i64);
339        chunk = chunk.with_archetype(row_id, timepoint, &encoded_image);
340
341        row_id = row_id.next();
342    }
343
344    Ok(std::iter::once(chunk.build().with_context(|| {
345        format!("Failed to build image chunk for image: {observation}")
346    })?))
347}
348
349fn load_episode_depth_images(
350    observation: &str,
351    timeline: &Timeline,
352    data: &RecordBatch,
353) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
354    let image_bytes = data
355        .column_by_name(observation)
356        .and_then(|c| c.downcast_array_ref::<StructArray>())
357        .and_then(|a| a.column_by_name("bytes"))
358        .and_then(|a| a.downcast_array_ref::<BinaryArray>())
359        .with_context(|| format!("Failed to get binary data from image feature: {observation}"))?;
360
361    let mut chunk = Chunk::builder(observation);
362    let mut row_id = RowId::new();
363
364    for frame_idx in 0..image_bytes.len() {
365        let img_buffer = image_bytes.value(frame_idx);
366        let depth_image = DepthImage::from_file_contents(img_buffer.to_owned())
367            .map_err(|err| anyhow!("Failed to decode image: {err}"))?;
368
369        let timepoint = TimePoint::default().with(*timeline, frame_idx as i64);
370        chunk = chunk.with_archetype(row_id, timepoint, &depth_image);
371
372        row_id = row_id.next();
373    }
374
375    Ok(std::iter::once(chunk.build().with_context(|| {
376        format!("Failed to build image chunk for image: {observation}")
377    })?))
378}
379
380fn load_episode_video(
381    dataset: &LeRobotDataset,
382    observation: &str,
383    episode: EpisodeIndex,
384    timeline: &Timeline,
385    time_column: TimeColumn,
386) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
387    let contents = dataset
388        .read_episode_video_contents(observation, episode)
389        .with_context(|| format!("Reading video contents for episode {episode:?} failed!"))?;
390
391    let video_asset = AssetVideo::new(contents.into_owned());
392    let entity_path = observation;
393
394    let video_frame_reference_chunk = match video_asset.read_frame_timestamps_nanos() {
395        Ok(frame_timestamps_nanos) => {
396            let frame_timestamps_nanos: arrow::buffer::ScalarBuffer<i64> =
397                frame_timestamps_nanos.into();
398
399            let video_timestamps = frame_timestamps_nanos
400                .iter()
401                .take(time_column.num_rows())
402                .copied()
403                .map(VideoTimestamp::from_nanos)
404                .collect::<Vec<_>>();
405
406            let video_frame_reference_column = VideoFrameReference::update_fields()
407                .with_many_timestamp(video_timestamps)
408                .columns_of_unit_batches()
409                .with_context(|| {
410                    format!(
411                        "Failed to create `VideoFrameReference` column for episode {episode:?}."
412                    )
413                })?;
414
415            Some(Chunk::from_auto_row_ids(
416                re_chunk::ChunkId::new(),
417                entity_path.into(),
418                std::iter::once((*timeline.name(), time_column)).collect(),
419                video_frame_reference_column.collect(),
420            )?)
421        }
422        Err(err) => {
423            re_log::warn_once!(
424                "Failed to read frame timestamps from episode {episode:?} video: {err}"
425            );
426            None
427        }
428    };
429
430    // Put video asset into its own (static) chunk since it can be fairly large.
431    let video_asset_chunk = Chunk::builder(entity_path)
432        .with_archetype(RowId::new(), TimePoint::default(), &video_asset)
433        .build()?;
434
435    if let Some(video_frame_reference_chunk) = video_frame_reference_chunk {
436        Ok(Either::Left(
437            [video_asset_chunk, video_frame_reference_chunk].into_iter(),
438        ))
439    } else {
440        // Still log the video asset, but don't include video frames.
441        Ok(Either::Right(std::iter::once(video_asset_chunk)))
442    }
443}
444
445/// Helper type similar to [`Either`], but with 3 variants.
446enum ScalarChunkIterator {
447    Empty(std::iter::Empty<Chunk>),
448    Batch(Box<dyn ExactSizeIterator<Item = Chunk>>),
449    Single(std::iter::Once<Chunk>),
450}
451
452impl Iterator for ScalarChunkIterator {
453    type Item = Chunk;
454
455    fn next(&mut self) -> Option<Self::Item> {
456        match self {
457            Self::Empty(iter) => iter.next(),
458            Self::Batch(iter) => iter.next(),
459            Self::Single(iter) => iter.next(),
460        }
461    }
462}
463
464impl ExactSizeIterator for ScalarChunkIterator {}
465
466fn load_scalar(
467    feature_key: &str,
468    feature: &Feature,
469    timelines: &IntMap<TimelineName, TimeColumn>,
470    data: &RecordBatch,
471) -> Result<ScalarChunkIterator, DataLoaderError> {
472    let field = data
473        .schema_ref()
474        .field_with_name(feature_key)
475        .with_context(|| {
476            format!("Failed to get field for feature {feature_key} from parquet file")
477        })?;
478
479    let entity_path = EntityPath::parse_forgiving(field.name());
480
481    match field.data_type() {
482        DataType::FixedSizeList(_, _) => {
483            let fixed_size_array = data
484                .column_by_name(feature_key)
485                .and_then(|col| col.downcast_array_ref::<FixedSizeListArray>())
486                .ok_or_else(|| {
487                    DataLoaderError::Other(anyhow!(
488                        "Failed to downcast feature to FixedSizeListArray"
489                    ))
490                })?;
491
492            let batch_chunks =
493                make_scalar_batch_entity_chunks(entity_path, feature, timelines, fixed_size_array)?;
494            Ok(ScalarChunkIterator::Batch(Box::new(batch_chunks)))
495        }
496        DataType::List(_field) => {
497            let list_array = data
498                .column_by_name(feature_key)
499                .and_then(|col| col.downcast_array_ref::<arrow::array::ListArray>())
500                .ok_or_else(|| {
501                    DataLoaderError::Other(anyhow!("Failed to downcast feature to ListArray"))
502                })?;
503
504            let sliced = extract_list_array_elements_as_f64(list_array).with_context(|| {
505                format!("Failed to cast scalar feature {entity_path} to Float64")
506            })?;
507
508            Ok(ScalarChunkIterator::Single(std::iter::once(
509                make_scalar_entity_chunk(entity_path, timelines, &sliced)?,
510            )))
511        }
512        DataType::Float32 | DataType::Float64 => {
513            let feature_data = data.column_by_name(feature_key).ok_or_else(|| {
514                DataLoaderError::Other(anyhow!(
515                    "Failed to get LeRobot dataset column data for: {:?}",
516                    field.name()
517                ))
518            })?;
519
520            let sliced = extract_scalar_slices_as_f64(feature_data).with_context(|| {
521                format!("Failed to cast scalar feature {entity_path} to Float64")
522            })?;
523
524            Ok(ScalarChunkIterator::Single(std::iter::once(
525                make_scalar_entity_chunk(entity_path, timelines, &sliced)?,
526            )))
527        }
528        _ => {
529            re_log::warn_once!(
530                "Tried logging scalar {} with unsupported dtype: {}",
531                field.name(),
532                field.data_type()
533            );
534            Ok(ScalarChunkIterator::Empty(std::iter::empty()))
535        }
536    }
537}
538
539fn make_scalar_batch_entity_chunks(
540    entity_path: EntityPath,
541    feature: &Feature,
542    timelines: &IntMap<TimelineName, TimeColumn>,
543    data: &FixedSizeListArray,
544) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
545    let num_elements = data.value_length() as usize;
546
547    let mut chunks = Vec::with_capacity(num_elements);
548
549    let sliced = extract_fixed_size_list_array_elements_as_f64(data)
550        .with_context(|| format!("Failed to cast scalar feature {entity_path} to Float64"))?;
551
552    chunks.push(make_scalar_entity_chunk(
553        entity_path.clone(),
554        timelines,
555        &sliced,
556    )?);
557
558    // If we have names for this feature, we insert a single static chunk containing the names.
559    if let Some(names) = feature.names.clone() {
560        let names: Vec<_> = (0..data.value_length() as usize)
561            .map(|idx| names.name_for_index(idx))
562            .collect();
563
564        chunks.push(
565            Chunk::builder(entity_path)
566                .with_row(
567                    RowId::new(),
568                    TimePoint::default(),
569                    std::iter::once((
570                        archetypes::SeriesLines::descriptor_names(),
571                        Arc::new(StringArray::from_iter(names)) as Arc<dyn ArrowArray>,
572                    )),
573                )
574                .build()?,
575        );
576    }
577
578    Ok(chunks.into_iter())
579}
580
581fn make_scalar_entity_chunk(
582    entity_path: EntityPath,
583    timelines: &IntMap<TimelineName, TimeColumn>,
584    sliced_data: &[ArrayRef],
585) -> Result<Chunk, DataLoaderError> {
586    let data_arrays = sliced_data
587        .iter()
588        .map(|e| Some(e.as_ref()))
589        .collect::<Vec<_>>();
590
591    let data_field_inner = Field::new("item", DataType::Float64, true /* nullable */);
592    #[allow(clippy::unwrap_used)] // we know we've given the right field type
593    let data_field_array: arrow::array::ListArray =
594        re_arrow_util::arrays_to_list_array(data_field_inner.data_type().clone(), &data_arrays)
595            .unwrap();
596
597    Ok(Chunk::from_auto_row_ids(
598        ChunkId::new(),
599        entity_path,
600        timelines.clone(),
601        std::iter::once((Scalars::descriptor_scalars().clone(), data_field_array)).collect(),
602    )?)
603}
604
605fn extract_scalar_slices_as_f64(data: &ArrayRef) -> anyhow::Result<Vec<ArrayRef>> {
606    // cast the slice to f64 first, as scalars need an f64
607    let scalar_values = cast(&data, &DataType::Float64)
608        .with_context(|| format!("Failed to cast {:?} to Float64", data.data_type()))?;
609
610    Ok((0..data.len())
611        .map(|idx| scalar_values.slice(idx, 1))
612        .collect::<Vec<_>>())
613}
614
615fn extract_fixed_size_list_array_elements_as_f64(
616    data: &FixedSizeListArray,
617) -> anyhow::Result<Vec<ArrayRef>> {
618    (0..data.len())
619        .map(|idx| {
620            cast(&data.value(idx), &DataType::Float64)
621                .with_context(|| format!("Failed to cast {:?} to Float64", data.data_type()))
622        })
623        .collect::<Result<Vec<_>, _>>()
624}
625
626fn extract_list_array_elements_as_f64(
627    data: &arrow::array::ListArray,
628) -> anyhow::Result<Vec<ArrayRef>> {
629    (0..data.len())
630        .map(|idx| {
631            cast(&data.value(idx), &DataType::Float64)
632                .with_context(|| format!("Failed to cast {:?} to Float64", data.data_type()))
633        })
634        .collect::<Result<Vec<_>, _>>()
635}