1use std::{
2 sync::{Arc, mpsc::Sender},
3 thread,
4};
5
6use anyhow::{Context as _, anyhow};
7use arrow::{
8 array::{
9 ArrayRef, BinaryArray, FixedSizeListArray, Int64Array, RecordBatch, StringArray,
10 StructArray,
11 },
12 compute::cast,
13 datatypes::{DataType, Field},
14};
15use itertools::Either;
16
17use re_arrow_util::ArrowArrayDowncastRef as _;
18use re_chunk::{
19 ArrowArray, Chunk, ChunkId, EntityPath, RowId, TimeColumn, TimeInt, TimePoint, Timeline,
20 TimelineName, external::nohash_hasher::IntMap,
21};
22use re_log_types::{ApplicationId, StoreId};
23use re_types::{
24 archetypes::{
25 self, AssetVideo, DepthImage, EncodedImage, Scalars, TextDocument, VideoFrameReference,
26 },
27 components::VideoTimestamp,
28};
29
30use crate::lerobot::{
31 DType, EpisodeIndex, Feature, LeRobotDataset, TaskIndex, is_lerobot_dataset,
32 is_v1_lerobot_dataset,
33};
34use crate::load_file::prepare_store_info;
35use crate::{DataLoader, DataLoaderError, LoadedData};
36
37const LEROBOT_DATASET_IGNORED_COLUMNS: &[&str] =
39 &["episode_index", "index", "frame_index", "timestamp"];
40
41const LEROBOT_DATASET_SUPPORTED_VERSIONS: &[&str] = &["v2.0", "v2.1"];
44
45pub struct LeRobotDatasetLoader;
49
50impl DataLoader for LeRobotDatasetLoader {
51 fn name(&self) -> String {
52 "LeRobotDatasetLoader".into()
53 }
54
55 fn load_from_path(
56 &self,
57 settings: &crate::DataLoaderSettings,
58 filepath: std::path::PathBuf,
59 tx: Sender<LoadedData>,
60 ) -> Result<(), DataLoaderError> {
61 if !is_lerobot_dataset(&filepath) {
62 return Err(DataLoaderError::Incompatible(filepath));
63 }
64
65 if is_v1_lerobot_dataset(&filepath) {
66 re_log::error!("LeRobot 'v1.x' dataset format is unsupported.");
67 return Ok(());
68 }
69
70 let dataset = LeRobotDataset::load_from_directory(&filepath)
71 .map_err(|err| anyhow!("Loading LeRobot dataset failed: {err}"))?;
72
73 if !LEROBOT_DATASET_SUPPORTED_VERSIONS
74 .contains(&dataset.metadata.info.codebase_version.as_str())
75 {
76 re_log::error!(
77 "LeRobot '{}' dataset format is unsupported.",
78 dataset.metadata.info.codebase_version
79 );
80 return Ok(());
81 }
82
83 let application_id = settings
84 .application_id
85 .clone()
86 .unwrap_or(filepath.display().to_string().into());
87
88 thread::Builder::new()
94 .name(format!("load_and_stream({filepath:?}"))
95 .spawn({
96 move || {
97 re_log::info!(
98 "Loading LeRobot dataset from {:?}, with {} episode(s)",
99 dataset.path,
100 dataset.metadata.episode_count(),
101 );
102 load_and_stream(&dataset, &application_id, &tx);
103 }
104 })
105 .with_context(|| {
106 format!("Failed to spawn IO thread to load LeRobot dataset {filepath:?} ")
107 })?;
108
109 Ok(())
110 }
111
112 fn load_from_file_contents(
113 &self,
114 _settings: &crate::DataLoaderSettings,
115 filepath: std::path::PathBuf,
116 _contents: std::borrow::Cow<'_, [u8]>,
117 _tx: Sender<LoadedData>,
118 ) -> Result<(), DataLoaderError> {
119 Err(DataLoaderError::Incompatible(filepath))
120 }
121}
122
123fn load_and_stream(
124 dataset: &LeRobotDataset,
125 application_id: &ApplicationId,
126 tx: &Sender<crate::LoadedData>,
127) {
128 let episodes = prepare_episode_chunks(dataset, application_id, tx);
130
131 for (episode, store_id) in &episodes {
132 match load_episode(dataset, *episode) {
134 Ok(chunks) => {
135 let recording_info = re_types::archetypes::RecordingInfo::new()
136 .with_name(format!("Episode {}", episode.0));
137
138 let Ok(initial) = Chunk::builder(EntityPath::properties())
139 .with_archetype(RowId::new(), TimePoint::STATIC, &recording_info)
140 .build()
141 else {
142 re_log::error!(
143 "Failed to build recording properties chunk for episode {}",
144 episode.0
145 );
146 return;
147 };
148
149 for chunk in std::iter::once(initial).chain(chunks.into_iter()) {
150 let data = LoadedData::Chunk(
151 LeRobotDatasetLoader::name(&LeRobotDatasetLoader),
152 store_id.clone(),
153 chunk,
154 );
155
156 if tx.send(data).is_err() {
157 break; }
159 }
160 }
161 Err(err) => {
162 re_log::warn!(
163 "Failed to load episode {} from LeRobot dataset: {err}",
164 episode.0
165 );
166 }
167 }
168 }
169}
170
171fn prepare_episode_chunks(
174 dataset: &LeRobotDataset,
175 application_id: &ApplicationId,
176 tx: &Sender<crate::LoadedData>,
177) -> Vec<(EpisodeIndex, StoreId)> {
178 let mut store_ids = vec![];
179
180 for episode_index in dataset.metadata.episodes.keys() {
181 let episode = *episode_index;
182
183 let store_id = StoreId::recording(application_id.clone(), format!("episode_{}", episode.0));
184 let set_store_info = LoadedData::LogMsg(
185 LeRobotDatasetLoader::name(&LeRobotDatasetLoader),
186 prepare_store_info(&store_id, re_log_types::FileSource::Sdk),
187 );
188
189 if tx.send(set_store_info).is_err() {
190 break;
191 }
192
193 store_ids.push((episode, store_id));
194 }
195
196 store_ids
197}
198
199pub fn load_episode(
205 dataset: &LeRobotDataset,
206 episode: EpisodeIndex,
207) -> Result<Vec<Chunk>, DataLoaderError> {
208 let data = dataset
209 .read_episode_data(episode)
210 .map_err(|err| anyhow!("Reading data for episode {} failed: {err}", episode.0))?;
211
212 let frame_indices = data
213 .column_by_name("frame_index")
214 .ok_or_else(|| anyhow!("Failed to get frame index column in LeRobot dataset"))?
215 .clone();
216
217 let timeline = re_log_types::Timeline::new_sequence("frame_index");
218 let times: &arrow::buffer::ScalarBuffer<i64> = frame_indices
219 .downcast_array_ref::<Int64Array>()
220 .ok_or_else(|| anyhow!("LeRobot dataset frame indices are of an unexpected type"))?
221 .values();
222
223 let time_column = re_chunk::TimeColumn::new(None, timeline, times.clone());
224 let timelines = std::iter::once((*timeline.name(), time_column.clone())).collect();
225
226 let mut chunks = Vec::new();
227
228 for (feature_key, feature) in dataset
229 .metadata
230 .info
231 .features
232 .iter()
233 .filter(|(key, _)| !LEROBOT_DATASET_IGNORED_COLUMNS.contains(&key.as_str()))
234 {
235 match feature.dtype {
236 DType::Video => {
237 chunks.extend(load_episode_video(
238 dataset,
239 feature_key,
240 episode,
241 &timeline,
242 time_column.clone(),
243 )?);
244 }
245
246 DType::Image => {
247 let num_channels = feature.channel_dim();
248
249 match num_channels {
250 1 => chunks.extend(load_episode_depth_images(feature_key, &timeline, &data)?),
251 3 => chunks.extend(load_episode_images(feature_key, &timeline, &data)?),
252 _ => re_log::warn_once!(
253 "Unsupported channel count {num_channels} (shape: {:?}) for LeRobot dataset; Only 1- and 3-channel images are supported",
254 feature.shape
255 ),
256 }
257 }
258 DType::Int64 if feature_key == "task_index" => {
259 chunks.extend(log_episode_task(dataset, &timeline, &data)?);
262 }
263 DType::Int16 | DType::Int64 | DType::Bool | DType::String => {
264 re_log::warn_once!(
265 "Loading LeRobot feature ({feature_key}) of dtype `{:?}` into Rerun is not yet implemented",
266 feature.dtype
267 );
268 }
269 DType::Float32 | DType::Float64 => {
270 chunks.extend(load_scalar(feature_key, feature, &timelines, &data)?);
271 }
272 }
273 }
274
275 Ok(chunks)
276}
277
278fn log_episode_task(
279 dataset: &LeRobotDataset,
280 timeline: &Timeline,
281 data: &RecordBatch,
282) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
283 let task_indices = data
284 .column_by_name("task_index")
285 .and_then(|c| c.downcast_array_ref::<Int64Array>())
286 .with_context(|| "Failed to get task_index field from dataset!")?;
287
288 let mut chunk = Chunk::builder("task");
289 let mut row_id = RowId::new();
290 let mut time_int = TimeInt::ZERO;
291
292 for task_index in task_indices {
293 let Some(task) = task_index
294 .and_then(|i| usize::try_from(i).ok())
295 .and_then(|i| dataset.task_by_index(TaskIndex(i)))
296 else {
297 time_int = time_int.inc();
299 continue;
300 };
301
302 let timepoint = TimePoint::default().with(*timeline, time_int);
303 let text = TextDocument::new(task.task.clone());
304 chunk = chunk.with_archetype(row_id, timepoint, &text);
305
306 row_id = row_id.next();
307 time_int = time_int.inc();
308 }
309
310 Ok(std::iter::once(chunk.build()?))
311}
312
313fn load_episode_images(
314 observation: &str,
315 timeline: &Timeline,
316 data: &RecordBatch,
317) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
318 let image_bytes = data
319 .column_by_name(observation)
320 .and_then(|c| c.downcast_array_ref::<StructArray>())
321 .and_then(|a| a.column_by_name("bytes"))
322 .and_then(|a| a.downcast_array_ref::<BinaryArray>())
323 .with_context(|| format!("Failed to get binary data from image feature: {observation}"))?;
324
325 let mut chunk = Chunk::builder(observation);
326 let mut row_id = RowId::new();
327
328 for frame_idx in 0..image_bytes.len() {
329 let img_buffer = image_bytes.value(frame_idx);
330 let encoded_image = EncodedImage::from_file_contents(img_buffer.to_owned());
331 let timepoint = TimePoint::default().with(*timeline, frame_idx as i64);
332 chunk = chunk.with_archetype(row_id, timepoint, &encoded_image);
333
334 row_id = row_id.next();
335 }
336
337 Ok(std::iter::once(chunk.build().with_context(|| {
338 format!("Failed to build image chunk for image: {observation}")
339 })?))
340}
341
342fn load_episode_depth_images(
343 observation: &str,
344 timeline: &Timeline,
345 data: &RecordBatch,
346) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
347 let image_bytes = data
348 .column_by_name(observation)
349 .and_then(|c| c.downcast_array_ref::<StructArray>())
350 .and_then(|a| a.column_by_name("bytes"))
351 .and_then(|a| a.downcast_array_ref::<BinaryArray>())
352 .with_context(|| format!("Failed to get binary data from image feature: {observation}"))?;
353
354 let mut chunk = Chunk::builder(observation);
355 let mut row_id = RowId::new();
356
357 for frame_idx in 0..image_bytes.len() {
358 let img_buffer = image_bytes.value(frame_idx);
359 let depth_image = DepthImage::from_file_contents(img_buffer.to_owned())
360 .map_err(|err| anyhow!("Failed to decode image: {err}"))?;
361
362 let timepoint = TimePoint::default().with(*timeline, frame_idx as i64);
363 chunk = chunk.with_archetype(row_id, timepoint, &depth_image);
364
365 row_id = row_id.next();
366 }
367
368 Ok(std::iter::once(chunk.build().with_context(|| {
369 format!("Failed to build image chunk for image: {observation}")
370 })?))
371}
372
373fn load_episode_video(
374 dataset: &LeRobotDataset,
375 observation: &str,
376 episode: EpisodeIndex,
377 timeline: &Timeline,
378 time_column: TimeColumn,
379) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
380 let contents = dataset
381 .read_episode_video_contents(observation, episode)
382 .with_context(|| format!("Reading video contents for episode {episode:?} failed!"))?;
383
384 let video_asset = AssetVideo::new(contents.into_owned());
385 let entity_path = observation;
386
387 let video_frame_reference_chunk = match video_asset.read_frame_timestamps_nanos() {
388 Ok(frame_timestamps_nanos) => {
389 let frame_timestamps_nanos: arrow::buffer::ScalarBuffer<i64> =
390 frame_timestamps_nanos.into();
391
392 let video_timestamps = frame_timestamps_nanos
393 .iter()
394 .take(time_column.num_rows())
395 .copied()
396 .map(VideoTimestamp::from_nanos)
397 .collect::<Vec<_>>();
398
399 let video_frame_reference_column = VideoFrameReference::update_fields()
400 .with_many_timestamp(video_timestamps)
401 .columns_of_unit_batches()
402 .with_context(|| {
403 format!(
404 "Failed to create `VideoFrameReference` column for episode {episode:?}."
405 )
406 })?;
407
408 Some(Chunk::from_auto_row_ids(
409 re_chunk::ChunkId::new(),
410 entity_path.into(),
411 std::iter::once((*timeline.name(), time_column)).collect(),
412 video_frame_reference_column.collect(),
413 )?)
414 }
415 Err(err) => {
416 re_log::warn_once!(
417 "Failed to read frame timestamps from episode {episode:?} video: {err}"
418 );
419 None
420 }
421 };
422
423 let video_asset_chunk = Chunk::builder(entity_path)
425 .with_archetype(RowId::new(), TimePoint::default(), &video_asset)
426 .build()?;
427
428 if let Some(video_frame_reference_chunk) = video_frame_reference_chunk {
429 Ok(Either::Left(
430 [video_asset_chunk, video_frame_reference_chunk].into_iter(),
431 ))
432 } else {
433 Ok(Either::Right(std::iter::once(video_asset_chunk)))
435 }
436}
437
438enum ScalarChunkIterator {
440 Empty(std::iter::Empty<Chunk>),
441 Batch(Box<dyn ExactSizeIterator<Item = Chunk>>),
442
443 Single(Box<std::iter::Once<Chunk>>),
445}
446
447impl Iterator for ScalarChunkIterator {
448 type Item = Chunk;
449
450 fn next(&mut self) -> Option<Self::Item> {
451 match self {
452 Self::Empty(iter) => iter.next(),
453 Self::Batch(iter) => iter.next(),
454 Self::Single(iter) => iter.next(),
455 }
456 }
457}
458
459impl ExactSizeIterator for ScalarChunkIterator {}
460
461fn load_scalar(
462 feature_key: &str,
463 feature: &Feature,
464 timelines: &IntMap<TimelineName, TimeColumn>,
465 data: &RecordBatch,
466) -> Result<ScalarChunkIterator, DataLoaderError> {
467 let field = data
468 .schema_ref()
469 .field_with_name(feature_key)
470 .with_context(|| {
471 format!("Failed to get field for feature {feature_key} from parquet file")
472 })?;
473
474 let entity_path = EntityPath::parse_forgiving(field.name());
475
476 match field.data_type() {
477 DataType::FixedSizeList(_, _) => {
478 let fixed_size_array = data
479 .column_by_name(feature_key)
480 .and_then(|col| col.downcast_array_ref::<FixedSizeListArray>())
481 .ok_or_else(|| {
482 DataLoaderError::Other(anyhow!(
483 "Failed to downcast feature to FixedSizeListArray"
484 ))
485 })?;
486
487 let batch_chunks =
488 make_scalar_batch_entity_chunks(entity_path, feature, timelines, fixed_size_array)?;
489 Ok(ScalarChunkIterator::Batch(Box::new(batch_chunks)))
490 }
491 DataType::List(_field) => {
492 let list_array = data
493 .column_by_name(feature_key)
494 .and_then(|col| col.downcast_array_ref::<arrow::array::ListArray>())
495 .ok_or_else(|| {
496 DataLoaderError::Other(anyhow!("Failed to downcast feature to ListArray"))
497 })?;
498
499 let sliced = extract_list_array_elements_as_f64(list_array).with_context(|| {
500 format!("Failed to cast scalar feature {entity_path} to Float64")
501 })?;
502
503 Ok(ScalarChunkIterator::Single(Box::new(std::iter::once(
504 make_scalar_entity_chunk(entity_path, timelines, &sliced)?,
505 ))))
506 }
507 DataType::Float32 | DataType::Float64 => {
508 let feature_data = data.column_by_name(feature_key).ok_or_else(|| {
509 DataLoaderError::Other(anyhow!(
510 "Failed to get LeRobot dataset column data for: {:?}",
511 field.name()
512 ))
513 })?;
514
515 let sliced = extract_scalar_slices_as_f64(feature_data).with_context(|| {
516 format!("Failed to cast scalar feature {entity_path} to Float64")
517 })?;
518
519 Ok(ScalarChunkIterator::Single(Box::new(std::iter::once(
520 make_scalar_entity_chunk(entity_path, timelines, &sliced)?,
521 ))))
522 }
523 _ => {
524 re_log::warn_once!(
525 "Tried logging scalar {} with unsupported dtype: {}",
526 field.name(),
527 field.data_type()
528 );
529 Ok(ScalarChunkIterator::Empty(std::iter::empty()))
530 }
531 }
532}
533
534fn make_scalar_batch_entity_chunks(
535 entity_path: EntityPath,
536 feature: &Feature,
537 timelines: &IntMap<TimelineName, TimeColumn>,
538 data: &FixedSizeListArray,
539) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
540 let num_elements = data.value_length() as usize;
541
542 let mut chunks = Vec::with_capacity(num_elements);
543
544 let sliced = extract_fixed_size_list_array_elements_as_f64(data)
545 .with_context(|| format!("Failed to cast scalar feature {entity_path} to Float64"))?;
546
547 chunks.push(make_scalar_entity_chunk(
548 entity_path.clone(),
549 timelines,
550 &sliced,
551 )?);
552
553 if let Some(names) = feature.names.clone() {
555 let names: Vec<_> = (0..data.value_length() as usize)
556 .map(|idx| names.name_for_index(idx))
557 .collect();
558
559 chunks.push(
560 Chunk::builder(entity_path)
561 .with_row(
562 RowId::new(),
563 TimePoint::default(),
564 std::iter::once((
565 archetypes::SeriesLines::descriptor_names(),
566 Arc::new(StringArray::from_iter(names)) as Arc<dyn ArrowArray>,
567 )),
568 )
569 .build()?,
570 );
571 }
572
573 Ok(chunks.into_iter())
574}
575
576fn make_scalar_entity_chunk(
577 entity_path: EntityPath,
578 timelines: &IntMap<TimelineName, TimeColumn>,
579 sliced_data: &[ArrayRef],
580) -> Result<Chunk, DataLoaderError> {
581 let data_arrays = sliced_data
582 .iter()
583 .map(|e| Some(e.as_ref()))
584 .collect::<Vec<_>>();
585
586 let data_field_inner = Field::new("item", DataType::Float64, true );
587 #[allow(clippy::unwrap_used)] let data_field_array: arrow::array::ListArray =
589 re_arrow_util::arrays_to_list_array(data_field_inner.data_type().clone(), &data_arrays)
590 .unwrap();
591
592 Ok(Chunk::from_auto_row_ids(
593 ChunkId::new(),
594 entity_path,
595 timelines.clone(),
596 std::iter::once((Scalars::descriptor_scalars().clone(), data_field_array)).collect(),
597 )?)
598}
599
600fn extract_scalar_slices_as_f64(data: &ArrayRef) -> anyhow::Result<Vec<ArrayRef>> {
601 let scalar_values = cast(&data, &DataType::Float64)
603 .with_context(|| format!("Failed to cast {} to Float64", data.data_type()))?;
604
605 Ok((0..data.len())
606 .map(|idx| scalar_values.slice(idx, 1))
607 .collect::<Vec<_>>())
608}
609
610fn extract_fixed_size_list_array_elements_as_f64(
611 data: &FixedSizeListArray,
612) -> anyhow::Result<Vec<ArrayRef>> {
613 (0..data.len())
614 .map(|idx| {
615 cast(&data.value(idx), &DataType::Float64)
616 .with_context(|| format!("Failed to cast {} to Float64", data.data_type()))
617 })
618 .collect::<Result<Vec<_>, _>>()
619}
620
621fn extract_list_array_elements_as_f64(
622 data: &arrow::array::ListArray,
623) -> anyhow::Result<Vec<ArrayRef>> {
624 (0..data.len())
625 .map(|idx| {
626 cast(&data.value(idx), &DataType::Float64)
627 .with_context(|| format!("Failed to cast {} to Float64", data.data_type()))
628 })
629 .collect::<Result<Vec<_>, _>>()
630}