1use std::sync::mpsc::Sender;
2use std::sync::Arc;
3use std::thread;
4
5use anyhow::{anyhow, Context as _};
6use arrow::array::{
7 ArrayRef, BinaryArray, FixedSizeListArray, Int64Array, RecordBatch, StringArray, StructArray,
8};
9use arrow::compute::cast;
10use arrow::datatypes::{DataType, Field};
11use itertools::Either;
12use re_arrow_util::ArrowArrayDowncastRef as _;
13use re_chunk::{external::nohash_hasher::IntMap, TimelineName};
14use re_chunk::{
15 ArrowArray, Chunk, ChunkId, EntityPath, RowId, TimeColumn, TimeInt, TimePoint, Timeline,
16};
17
18use re_log_types::{ApplicationId, StoreId};
19use re_types::archetypes::{
20 AssetVideo, DepthImage, EncodedImage, TextDocument, VideoFrameReference,
21};
22use re_types::components::{Name, Scalar, VideoTimestamp};
23use re_types::{Archetype, Component, ComponentBatch};
24
25use crate::lerobot::{
26 is_lerobot_dataset, is_v1_lerobot_dataset, DType, EpisodeIndex, Feature, LeRobotDataset,
27 TaskIndex,
28};
29use crate::load_file::prepare_store_info;
30use crate::{DataLoader, DataLoaderError, LoadedData};
31
32const LEROBOT_DATASET_IGNORED_COLUMNS: &[&str] =
34 &["episode_index", "index", "frame_index", "timestamp"];
35
36const LEROBOT_DATASET_SUPPORTED_VERSIONS: &[&str] = &["v2.0", "v2.1"];
39
40pub struct LeRobotDatasetLoader;
44
45impl DataLoader for LeRobotDatasetLoader {
46 fn name(&self) -> String {
47 "LeRobotDatasetLoader".into()
48 }
49
50 fn load_from_path(
51 &self,
52 settings: &crate::DataLoaderSettings,
53 filepath: std::path::PathBuf,
54 tx: Sender<LoadedData>,
55 ) -> Result<(), DataLoaderError> {
56 if !is_lerobot_dataset(&filepath) {
57 return Err(DataLoaderError::Incompatible(filepath));
58 }
59
60 if is_v1_lerobot_dataset(&filepath) {
61 re_log::error!("LeRobot 'v1.x' dataset format is unsupported.");
62 return Ok(());
63 }
64
65 let dataset = LeRobotDataset::load_from_directory(&filepath)
66 .map_err(|err| anyhow!("Loading LeRobot dataset failed: {err}"))?;
67
68 if !LEROBOT_DATASET_SUPPORTED_VERSIONS
69 .contains(&dataset.metadata.info.codebase_version.as_str())
70 {
71 re_log::error!(
72 "LeRobot '{}' dataset format is unsupported.",
73 dataset.metadata.info.codebase_version
74 );
75 return Ok(());
76 }
77
78 let application_id = settings
79 .application_id
80 .clone()
81 .unwrap_or(ApplicationId(filepath.display().to_string()));
82
83 thread::Builder::new()
89 .name(format!("load_and_stream({filepath:?}"))
90 .spawn({
91 move || {
92 re_log::info!(
93 "Loading LeRobot dataset from {:?}, with {} episode(s)",
94 dataset.path,
95 dataset.metadata.episodes.len(),
96 );
97 load_and_stream(&dataset, &application_id, &tx);
98 }
99 })
100 .with_context(|| {
101 format!("Failed to spawn IO thread to load LeRobot dataset {filepath:?} ")
102 })?;
103
104 Ok(())
105 }
106
107 fn load_from_file_contents(
108 &self,
109 _settings: &crate::DataLoaderSettings,
110 filepath: std::path::PathBuf,
111 _contents: std::borrow::Cow<'_, [u8]>,
112 _tx: Sender<LoadedData>,
113 ) -> Result<(), DataLoaderError> {
114 Err(DataLoaderError::Incompatible(filepath))
115 }
116}
117
118fn load_and_stream(
119 dataset: &LeRobotDataset,
120 application_id: &ApplicationId,
121 tx: &Sender<crate::LoadedData>,
122) {
123 let episodes = prepare_episode_chunks(dataset, application_id, tx);
125
126 for (episode, store_id) in &episodes {
127 match load_episode(dataset, *episode) {
129 Ok(chunks) => {
130 let properties = re_types::archetypes::RecordingProperties::new()
131 .with_name(format!("Episode {}", episode.0));
132
133 debug_assert!(TimePoint::default().is_static());
134 let Ok(initial) = Chunk::builder(EntityPath::recording_properties())
135 .with_archetype(RowId::new(), TimePoint::default(), &properties)
136 .build()
137 else {
138 re_log::error!(
139 "Failed to build recording properties chunk for episode {}",
140 episode.0
141 );
142 return;
143 };
144
145 for chunk in std::iter::once(initial).chain(chunks.into_iter()) {
146 let data = LoadedData::Chunk(
147 LeRobotDatasetLoader::name(&LeRobotDatasetLoader),
148 store_id.clone(),
149 chunk,
150 );
151
152 if tx.send(data).is_err() {
153 break; }
155 }
156 }
157 Err(err) => {
158 re_log::warn!(
159 "Failed to load episode {} from LeRobot dataset: {err}",
160 episode.0
161 );
162 }
163 }
164 }
165}
166
167fn prepare_episode_chunks(
170 dataset: &LeRobotDataset,
171 application_id: &ApplicationId,
172 tx: &Sender<crate::LoadedData>,
173) -> Vec<(EpisodeIndex, StoreId)> {
174 let mut store_ids = vec![];
175
176 for episode in &dataset.metadata.episodes {
177 let episode = episode.index;
178
179 let store_id = StoreId::from_string(
180 re_log_types::StoreKind::Recording,
181 format!("episode_{}", episode.0),
182 );
183 let set_store_info = LoadedData::LogMsg(
184 LeRobotDatasetLoader::name(&LeRobotDatasetLoader),
185 prepare_store_info(
186 application_id.clone(),
187 &store_id,
188 re_log_types::FileSource::Sdk,
189 ),
190 );
191
192 if tx.send(set_store_info).is_err() {
193 break;
194 }
195
196 store_ids.push((episode, store_id.clone()));
197 }
198
199 store_ids
200}
201
202pub fn load_episode(
208 dataset: &LeRobotDataset,
209 episode: EpisodeIndex,
210) -> Result<Vec<Chunk>, DataLoaderError> {
211 let data = dataset
212 .read_episode_data(episode)
213 .map_err(|err| anyhow!("Reading data for episode {} failed: {err}", episode.0))?;
214
215 let frame_indices = data
216 .column_by_name("frame_index")
217 .ok_or_else(|| anyhow!("Failed to get frame index column in LeRobot dataset"))?
218 .clone();
219
220 let timeline = re_log_types::Timeline::new_sequence("frame_index");
221 let times: &arrow::buffer::ScalarBuffer<i64> = frame_indices
222 .downcast_array_ref::<Int64Array>()
223 .ok_or_else(|| anyhow!("LeRobot dataset frame indices are of an unexpected type"))?
224 .values();
225
226 let time_column = re_chunk::TimeColumn::new(None, timeline, times.clone());
227 let timelines = std::iter::once((*timeline.name(), time_column.clone())).collect();
228
229 let mut chunks = Vec::new();
230
231 for (feature_key, feature) in dataset
232 .metadata
233 .info
234 .features
235 .iter()
236 .filter(|(key, _)| !LEROBOT_DATASET_IGNORED_COLUMNS.contains(&key.as_str()))
237 {
238 match feature.dtype {
239 DType::Video => {
240 chunks.extend(load_episode_video(
241 dataset,
242 feature_key,
243 episode,
244 &timeline,
245 time_column.clone(),
246 )?);
247 }
248
249 DType::Image => {
250 let num_channels = feature.channel_dim();
251
252 match num_channels {
253 1 => chunks.extend(load_episode_depth_images(feature_key, &timeline, &data)?),
254 3 => chunks.extend(load_episode_images(feature_key, &timeline, &data)?),
255 _ => re_log::warn_once!(
256 "Unsupported channel count {num_channels} (shape: {:?}) for LeRobot dataset; Only 1- and 3-channel images are supported",
257 feature.shape
258 ),
259 };
260 }
261 DType::Int64 if feature_key == "task_index" => {
262 chunks.extend(log_episode_task(dataset, &timeline, &data)?);
265 }
266 DType::Int16 | DType::Int64 | DType::Bool | DType::String => {
267 re_log::warn_once!(
268 "Loading LeRobot feature ({feature_key}) of dtype `{:?}` into Rerun is not yet implemented",
269 feature.dtype
270 );
271 }
272 DType::Float32 | DType::Float64 => {
273 chunks.extend(load_scalar(feature_key, feature, &timelines, &data)?);
274 }
275 }
276 }
277
278 Ok(chunks)
279}
280
281fn log_episode_task(
282 dataset: &LeRobotDataset,
283 timeline: &Timeline,
284 data: &RecordBatch,
285) -> Result<impl ExactSizeIterator<Item = Chunk>, DataLoaderError> {
286 let task_indices = data
287 .column_by_name("task_index")
288 .and_then(|c| c.downcast_array_ref::<Int64Array>())
289 .with_context(|| "Failed to get task_index field from dataset!")?;
290
291 let mut chunk = Chunk::builder("task".into());
292 let mut row_id = RowId::new();
293 let mut time_int = TimeInt::ZERO;
294
295 for task_index in task_indices {
296 let Some(task) = task_index
297 .and_then(|i| usize::try_from(i).ok())
298 .and_then(|i| dataset.task_by_index(TaskIndex(i)))
299 else {
300 time_int = time_int.inc();
302 continue;
303 };
304
305 let timepoint = TimePoint::default().with(*timeline, time_int);
306 let text = TextDocument::new(task.task.clone());
307 chunk = chunk.with_archetype(row_id, timepoint, &text);
308
309 row_id = row_id.next();
310 time_int = time_int.inc();
311 }
312
313 Ok(std::iter::once(chunk.build()?))
314}
315
316fn load_episode_images(
317 observation: &str,
318 timeline: &Timeline,
319 data: &RecordBatch,
320) -> Result<impl ExactSizeIterator<Item = Chunk>, DataLoaderError> {
321 let image_bytes = data
322 .column_by_name(observation)
323 .and_then(|c| c.downcast_array_ref::<StructArray>())
324 .and_then(|a| a.column_by_name("bytes"))
325 .and_then(|a| a.downcast_array_ref::<BinaryArray>())
326 .with_context(|| format!("Failed to get binary data from image feature: {observation}"))?;
327
328 let mut chunk = Chunk::builder(observation.into());
329 let mut row_id = RowId::new();
330
331 for frame_idx in 0..image_bytes.len() {
332 let img_buffer = image_bytes.value(frame_idx);
333 let encoded_image = EncodedImage::from_file_contents(img_buffer.to_owned());
334 let timepoint = TimePoint::default().with(*timeline, frame_idx as i64);
335 chunk = chunk.with_archetype(row_id, timepoint, &encoded_image);
336
337 row_id = row_id.next();
338 }
339
340 Ok(std::iter::once(chunk.build().with_context(|| {
341 format!("Failed to build image chunk for image: {observation}")
342 })?))
343}
344
345fn load_episode_depth_images(
346 observation: &str,
347 timeline: &Timeline,
348 data: &RecordBatch,
349) -> Result<impl ExactSizeIterator<Item = Chunk>, DataLoaderError> {
350 let image_bytes = data
351 .column_by_name(observation)
352 .and_then(|c| c.downcast_array_ref::<StructArray>())
353 .and_then(|a| a.column_by_name("bytes"))
354 .and_then(|a| a.downcast_array_ref::<BinaryArray>())
355 .with_context(|| format!("Failed to get binary data from image feature: {observation}"))?;
356
357 let mut chunk = Chunk::builder(observation.into());
358 let mut row_id = RowId::new();
359
360 for frame_idx in 0..image_bytes.len() {
361 let img_buffer = image_bytes.value(frame_idx);
362 let depth_image = DepthImage::from_file_contents(img_buffer.to_owned())
363 .map_err(|err| anyhow!("Failed to decode image: {err}"))?;
364
365 let timepoint = TimePoint::default().with(*timeline, frame_idx as i64);
366 chunk = chunk.with_archetype(row_id, timepoint, &depth_image);
367
368 row_id = row_id.next();
369 }
370
371 Ok(std::iter::once(chunk.build().with_context(|| {
372 format!("Failed to build image chunk for image: {observation}")
373 })?))
374}
375
376fn load_episode_video(
377 dataset: &LeRobotDataset,
378 observation: &str,
379 episode: EpisodeIndex,
380 timeline: &Timeline,
381 time_column: TimeColumn,
382) -> Result<impl ExactSizeIterator<Item = Chunk>, DataLoaderError> {
383 let contents = dataset
384 .read_episode_video_contents(observation, episode)
385 .with_context(|| format!("Reading video contents for episode {episode:?} failed!"))?;
386
387 let video_asset = AssetVideo::new(contents.into_owned());
388 let entity_path = observation;
389
390 let video_frame_reference_chunk = match video_asset.read_frame_timestamps_nanos() {
391 Ok(frame_timestamps_nanos) => {
392 let frame_timestamps_nanos: arrow::buffer::ScalarBuffer<i64> =
393 frame_timestamps_nanos.into();
394
395 let video_timestamps = frame_timestamps_nanos
396 .iter()
397 .take(time_column.num_rows())
398 .copied()
399 .map(VideoTimestamp::from_nanos)
400 .collect::<Vec<_>>();
401
402 let video_timestamp_batch = &video_timestamps as &dyn ComponentBatch;
403 let video_timestamp_list_array = video_timestamp_batch
404 .to_arrow_list_array()
405 .map_err(re_chunk::ChunkError::from)?;
406
407 let video_frame_reference_indicators =
409 <VideoFrameReference as Archetype>::Indicator::new_array(video_timestamps.len());
410 let video_frame_reference_indicators_list_array = video_frame_reference_indicators
411 .to_arrow_list_array()
412 .map_err(re_chunk::ChunkError::from)?;
413
414 Some(Chunk::from_auto_row_ids(
415 re_chunk::ChunkId::new(),
416 entity_path.into(),
417 std::iter::once((*timeline.name(), time_column)).collect(),
418 [
419 (
420 VideoFrameReference::indicator().descriptor.clone(),
421 video_frame_reference_indicators_list_array,
422 ),
423 (
424 video_timestamp_batch.descriptor().into_owned(),
425 video_timestamp_list_array,
426 ),
427 ]
428 .into_iter()
429 .collect(),
430 )?)
431 }
432 Err(err) => {
433 re_log::warn_once!(
434 "Failed to read frame timestamps from episode {episode:?} video: {err}"
435 );
436 None
437 }
438 };
439
440 let video_asset_chunk = Chunk::builder(entity_path.into())
442 .with_archetype(RowId::new(), TimePoint::default(), &video_asset)
443 .build()?;
444
445 if let Some(video_frame_reference_chunk) = video_frame_reference_chunk {
446 Ok(Either::Left(
447 [video_asset_chunk, video_frame_reference_chunk].into_iter(),
448 ))
449 } else {
450 Ok(Either::Right(std::iter::once(video_asset_chunk)))
452 }
453}
454
455enum ScalarChunkIterator {
457 Empty(std::iter::Empty<Chunk>),
458 Batch(Box<dyn ExactSizeIterator<Item = Chunk>>),
459 Single(std::iter::Once<Chunk>),
460}
461
462impl Iterator for ScalarChunkIterator {
463 type Item = Chunk;
464
465 fn next(&mut self) -> Option<Self::Item> {
466 match self {
467 Self::Empty(iter) => iter.next(),
468 Self::Batch(iter) => iter.next(),
469 Self::Single(iter) => iter.next(),
470 }
471 }
472}
473
474impl ExactSizeIterator for ScalarChunkIterator {}
475
476fn load_scalar(
477 feature_key: &str,
478 feature: &Feature,
479 timelines: &IntMap<TimelineName, TimeColumn>,
480 data: &RecordBatch,
481) -> Result<ScalarChunkIterator, DataLoaderError> {
482 let field = data
483 .schema_ref()
484 .field_with_name(feature_key)
485 .with_context(|| {
486 format!("Failed to get field for feature {feature_key} from parquet file")
487 })?;
488
489 let entity_path = EntityPath::parse_forgiving(field.name());
490
491 match field.data_type() {
492 DataType::FixedSizeList(_, _) => {
493 let fixed_size_array = data
494 .column_by_name(feature_key)
495 .and_then(|col| col.downcast_array_ref::<FixedSizeListArray>())
496 .ok_or_else(|| {
497 DataLoaderError::Other(anyhow!(
498 "Failed to downcast feature to FixedSizeListArray"
499 ))
500 })?;
501
502 let batch_chunks =
503 make_scalar_batch_entity_chunks(entity_path, feature, timelines, fixed_size_array)?;
504 Ok(ScalarChunkIterator::Batch(Box::new(batch_chunks)))
505 }
506 DataType::List(_field) => {
507 let list_array = data
508 .column_by_name(feature_key)
509 .and_then(|col| col.downcast_array_ref::<arrow::array::ListArray>())
510 .ok_or_else(|| {
511 DataLoaderError::Other(anyhow!("Failed to downcast feature to ListArray"))
512 })?;
513
514 let sliced = extract_list_array_elements_as_f64(list_array).with_context(|| {
515 format!("Failed to cast scalar feature {entity_path} to Float64")
516 })?;
517
518 Ok(ScalarChunkIterator::Single(std::iter::once(
519 make_scalar_entity_chunk(entity_path, timelines, &sliced)?,
520 )))
521 }
522 DataType::Float32 | DataType::Float64 => {
523 let feature_data = data.column_by_name(feature_key).ok_or_else(|| {
524 DataLoaderError::Other(anyhow!(
525 "Failed to get LeRobot dataset column data for: {:?}",
526 field.name()
527 ))
528 })?;
529
530 let sliced = extract_scalar_slices_as_f64(feature_data).with_context(|| {
531 format!("Failed to cast scalar feature {entity_path} to Float64")
532 })?;
533
534 Ok(ScalarChunkIterator::Single(std::iter::once(
535 make_scalar_entity_chunk(entity_path, timelines, &sliced)?,
536 )))
537 }
538 _ => {
539 re_log::warn_once!(
540 "Tried logging scalar {} with unsupported dtype: {}",
541 field.name(),
542 field.data_type()
543 );
544 Ok(ScalarChunkIterator::Empty(std::iter::empty()))
545 }
546 }
547}
548
549fn make_scalar_batch_entity_chunks(
550 entity_path: EntityPath,
551 feature: &Feature,
552 timelines: &IntMap<TimelineName, TimeColumn>,
553 data: &FixedSizeListArray,
554) -> Result<impl ExactSizeIterator<Item = Chunk>, DataLoaderError> {
555 let num_elements = data.value_length() as usize;
556
557 let mut chunks = Vec::with_capacity(num_elements);
558
559 let sliced = extract_fixed_size_list_array_elements_as_f64(data)
560 .with_context(|| format!("Failed to cast scalar feature {entity_path} to Float64"))?;
561
562 chunks.push(make_scalar_entity_chunk(
563 entity_path.clone(),
564 timelines,
565 &sliced,
566 )?);
567
568 if let Some(names) = feature.names.clone() {
570 let names: Vec<_> = (0..data.value_length() as usize)
571 .map(|idx| names.name_for_index(idx))
572 .collect();
573
574 chunks.push(
575 Chunk::builder(entity_path)
576 .with_row(
577 RowId::new(),
578 TimePoint::default(),
579 std::iter::once((
580 <Name as Component>::descriptor().clone(),
581 Arc::new(StringArray::from_iter(names)) as Arc<dyn ArrowArray>,
582 )),
583 )
584 .build()?,
585 );
586 }
587
588 Ok(chunks.into_iter())
589}
590
591fn make_scalar_entity_chunk(
592 entity_path: EntityPath,
593 timelines: &IntMap<TimelineName, TimeColumn>,
594 sliced_data: &[ArrayRef],
595) -> Result<Chunk, DataLoaderError> {
596 let data_arrays = sliced_data
597 .iter()
598 .map(|e| Some(e.as_ref()))
599 .collect::<Vec<_>>();
600
601 let data_field_inner = Field::new("item", DataType::Float64, true );
602 #[allow(clippy::unwrap_used)] let data_field_array: arrow::array::ListArray =
604 re_arrow_util::arrays_to_list_array(data_field_inner.data_type().clone(), &data_arrays)
605 .unwrap();
606
607 Ok(Chunk::from_auto_row_ids(
608 ChunkId::new(),
609 entity_path,
610 timelines.clone(),
611 std::iter::once((
612 <Scalar as Component>::descriptor().clone(),
613 data_field_array,
614 ))
615 .collect(),
616 )?)
617}
618
619fn extract_scalar_slices_as_f64(data: &ArrayRef) -> anyhow::Result<Vec<ArrayRef>> {
620 let scalar_values = cast(&data, &DataType::Float64)
622 .with_context(|| format!("Failed to cast {:?} to Float64", data.data_type()))?;
623
624 Ok((0..data.len())
625 .map(|idx| scalar_values.slice(idx, 1))
626 .collect::<Vec<_>>())
627}
628
629fn extract_fixed_size_list_array_elements_as_f64(
630 data: &FixedSizeListArray,
631) -> anyhow::Result<Vec<ArrayRef>> {
632 (0..data.len())
633 .map(|idx| {
634 cast(&data.value(idx), &DataType::Float64)
635 .with_context(|| format!("Failed to cast {:?} to Float64", data.data_type()))
636 })
637 .collect::<Result<Vec<_>, _>>()
638}
639
640fn extract_list_array_elements_as_f64(
641 data: &arrow::array::ListArray,
642) -> anyhow::Result<Vec<ArrayRef>> {
643 (0..data.len())
644 .map(|idx| {
645 cast(&data.value(idx), &DataType::Float64)
646 .with_context(|| format!("Failed to cast {:?} to Float64", data.data_type()))
647 })
648 .collect::<Result<Vec<_>, _>>()
649}