1use std::{
2 sync::{Arc, mpsc::Sender},
3 thread,
4};
5
6use anyhow::{Context as _, anyhow};
7use arrow::{
8 array::{
9 ArrayRef, BinaryArray, FixedSizeListArray, Int64Array, RecordBatch, StringArray,
10 StructArray,
11 },
12 compute::cast,
13 datatypes::{DataType, Field},
14};
15use itertools::Either;
16
17use re_arrow_util::ArrowArrayDowncastRef as _;
18use re_chunk::{
19 ArrowArray, Chunk, ChunkId, EntityPath, RowId, TimeColumn, TimeInt, TimePoint, Timeline,
20 TimelineName, external::nohash_hasher::IntMap,
21};
22use re_log_types::{ApplicationId, StoreId};
23use re_types::{
24 archetypes::{
25 self, AssetVideo, DepthImage, EncodedImage, Scalars, TextDocument, VideoFrameReference,
26 },
27 components::VideoTimestamp,
28};
29
30use crate::lerobot::{
31 DType, EpisodeIndex, Feature, LeRobotDataset, TaskIndex, is_lerobot_dataset,
32 is_v1_lerobot_dataset,
33};
34use crate::load_file::prepare_store_info;
35use crate::{DataLoader, DataLoaderError, LoadedData};
36
37const LEROBOT_DATASET_IGNORED_COLUMNS: &[&str] =
39 &["episode_index", "index", "frame_index", "timestamp"];
40
41const LEROBOT_DATASET_SUPPORTED_VERSIONS: &[&str] = &["v2.0", "v2.1"];
44
45pub struct LeRobotDatasetLoader;
49
50impl DataLoader for LeRobotDatasetLoader {
51 fn name(&self) -> String {
52 "LeRobotDatasetLoader".into()
53 }
54
55 fn load_from_path(
56 &self,
57 settings: &crate::DataLoaderSettings,
58 filepath: std::path::PathBuf,
59 tx: Sender<LoadedData>,
60 ) -> Result<(), DataLoaderError> {
61 if !is_lerobot_dataset(&filepath) {
62 return Err(DataLoaderError::Incompatible(filepath));
63 }
64
65 if is_v1_lerobot_dataset(&filepath) {
66 re_log::error!("LeRobot 'v1.x' dataset format is unsupported.");
67 return Ok(());
68 }
69
70 let dataset = LeRobotDataset::load_from_directory(&filepath)
71 .map_err(|err| anyhow!("Loading LeRobot dataset failed: {err}"))?;
72
73 if !LEROBOT_DATASET_SUPPORTED_VERSIONS
74 .contains(&dataset.metadata.info.codebase_version.as_str())
75 {
76 re_log::error!(
77 "LeRobot '{}' dataset format is unsupported.",
78 dataset.metadata.info.codebase_version
79 );
80 return Ok(());
81 }
82
83 let application_id = settings
84 .application_id
85 .clone()
86 .unwrap_or(ApplicationId(filepath.display().to_string()));
87
88 thread::Builder::new()
94 .name(format!("load_and_stream({filepath:?}"))
95 .spawn({
96 move || {
97 re_log::info!(
98 "Loading LeRobot dataset from {:?}, with {} episode(s)",
99 dataset.path,
100 dataset.metadata.episodes.len(),
101 );
102 load_and_stream(&dataset, &application_id, &tx);
103 }
104 })
105 .with_context(|| {
106 format!("Failed to spawn IO thread to load LeRobot dataset {filepath:?} ")
107 })?;
108
109 Ok(())
110 }
111
112 fn load_from_file_contents(
113 &self,
114 _settings: &crate::DataLoaderSettings,
115 filepath: std::path::PathBuf,
116 _contents: std::borrow::Cow<'_, [u8]>,
117 _tx: Sender<LoadedData>,
118 ) -> Result<(), DataLoaderError> {
119 Err(DataLoaderError::Incompatible(filepath))
120 }
121}
122
123fn load_and_stream(
124 dataset: &LeRobotDataset,
125 application_id: &ApplicationId,
126 tx: &Sender<crate::LoadedData>,
127) {
128 let episodes = prepare_episode_chunks(dataset, application_id, tx);
130
131 for (episode, store_id) in &episodes {
132 match load_episode(dataset, *episode) {
134 Ok(chunks) => {
135 let recording_info = re_types::archetypes::RecordingInfo::new()
136 .with_name(format!("Episode {}", episode.0));
137
138 let Ok(initial) = Chunk::builder(EntityPath::properties())
139 .with_archetype(RowId::new(), TimePoint::STATIC, &recording_info)
140 .build()
141 else {
142 re_log::error!(
143 "Failed to build recording properties chunk for episode {}",
144 episode.0
145 );
146 return;
147 };
148
149 for chunk in std::iter::once(initial).chain(chunks.into_iter()) {
150 let data = LoadedData::Chunk(
151 LeRobotDatasetLoader::name(&LeRobotDatasetLoader),
152 store_id.clone(),
153 chunk,
154 );
155
156 if tx.send(data).is_err() {
157 break; }
159 }
160 }
161 Err(err) => {
162 re_log::warn!(
163 "Failed to load episode {} from LeRobot dataset: {err}",
164 episode.0
165 );
166 }
167 }
168 }
169}
170
171fn prepare_episode_chunks(
174 dataset: &LeRobotDataset,
175 application_id: &ApplicationId,
176 tx: &Sender<crate::LoadedData>,
177) -> Vec<(EpisodeIndex, StoreId)> {
178 let mut store_ids = vec![];
179
180 for episode in &dataset.metadata.episodes {
181 let episode = episode.index;
182
183 let store_id = StoreId::from_string(
184 re_log_types::StoreKind::Recording,
185 format!("episode_{}", episode.0),
186 );
187 let set_store_info = LoadedData::LogMsg(
188 LeRobotDatasetLoader::name(&LeRobotDatasetLoader),
189 prepare_store_info(
190 application_id.clone(),
191 &store_id,
192 re_log_types::FileSource::Sdk,
193 ),
194 );
195
196 if tx.send(set_store_info).is_err() {
197 break;
198 }
199
200 store_ids.push((episode, store_id.clone()));
201 }
202
203 store_ids
204}
205
206pub fn load_episode(
212 dataset: &LeRobotDataset,
213 episode: EpisodeIndex,
214) -> Result<Vec<Chunk>, DataLoaderError> {
215 let data = dataset
216 .read_episode_data(episode)
217 .map_err(|err| anyhow!("Reading data for episode {} failed: {err}", episode.0))?;
218
219 let frame_indices = data
220 .column_by_name("frame_index")
221 .ok_or_else(|| anyhow!("Failed to get frame index column in LeRobot dataset"))?
222 .clone();
223
224 let timeline = re_log_types::Timeline::new_sequence("frame_index");
225 let times: &arrow::buffer::ScalarBuffer<i64> = frame_indices
226 .downcast_array_ref::<Int64Array>()
227 .ok_or_else(|| anyhow!("LeRobot dataset frame indices are of an unexpected type"))?
228 .values();
229
230 let time_column = re_chunk::TimeColumn::new(None, timeline, times.clone());
231 let timelines = std::iter::once((*timeline.name(), time_column.clone())).collect();
232
233 let mut chunks = Vec::new();
234
235 for (feature_key, feature) in dataset
236 .metadata
237 .info
238 .features
239 .iter()
240 .filter(|(key, _)| !LEROBOT_DATASET_IGNORED_COLUMNS.contains(&key.as_str()))
241 {
242 match feature.dtype {
243 DType::Video => {
244 chunks.extend(load_episode_video(
245 dataset,
246 feature_key,
247 episode,
248 &timeline,
249 time_column.clone(),
250 )?);
251 }
252
253 DType::Image => {
254 let num_channels = feature.channel_dim();
255
256 match num_channels {
257 1 => chunks.extend(load_episode_depth_images(feature_key, &timeline, &data)?),
258 3 => chunks.extend(load_episode_images(feature_key, &timeline, &data)?),
259 _ => re_log::warn_once!(
260 "Unsupported channel count {num_channels} (shape: {:?}) for LeRobot dataset; Only 1- and 3-channel images are supported",
261 feature.shape
262 ),
263 };
264 }
265 DType::Int64 if feature_key == "task_index" => {
266 chunks.extend(log_episode_task(dataset, &timeline, &data)?);
269 }
270 DType::Int16 | DType::Int64 | DType::Bool | DType::String => {
271 re_log::warn_once!(
272 "Loading LeRobot feature ({feature_key}) of dtype `{:?}` into Rerun is not yet implemented",
273 feature.dtype
274 );
275 }
276 DType::Float32 | DType::Float64 => {
277 chunks.extend(load_scalar(feature_key, feature, &timelines, &data)?);
278 }
279 }
280 }
281
282 Ok(chunks)
283}
284
285fn log_episode_task(
286 dataset: &LeRobotDataset,
287 timeline: &Timeline,
288 data: &RecordBatch,
289) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
290 let task_indices = data
291 .column_by_name("task_index")
292 .and_then(|c| c.downcast_array_ref::<Int64Array>())
293 .with_context(|| "Failed to get task_index field from dataset!")?;
294
295 let mut chunk = Chunk::builder("task");
296 let mut row_id = RowId::new();
297 let mut time_int = TimeInt::ZERO;
298
299 for task_index in task_indices {
300 let Some(task) = task_index
301 .and_then(|i| usize::try_from(i).ok())
302 .and_then(|i| dataset.task_by_index(TaskIndex(i)))
303 else {
304 time_int = time_int.inc();
306 continue;
307 };
308
309 let timepoint = TimePoint::default().with(*timeline, time_int);
310 let text = TextDocument::new(task.task.clone());
311 chunk = chunk.with_archetype(row_id, timepoint, &text);
312
313 row_id = row_id.next();
314 time_int = time_int.inc();
315 }
316
317 Ok(std::iter::once(chunk.build()?))
318}
319
320fn load_episode_images(
321 observation: &str,
322 timeline: &Timeline,
323 data: &RecordBatch,
324) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
325 let image_bytes = data
326 .column_by_name(observation)
327 .and_then(|c| c.downcast_array_ref::<StructArray>())
328 .and_then(|a| a.column_by_name("bytes"))
329 .and_then(|a| a.downcast_array_ref::<BinaryArray>())
330 .with_context(|| format!("Failed to get binary data from image feature: {observation}"))?;
331
332 let mut chunk = Chunk::builder(observation);
333 let mut row_id = RowId::new();
334
335 for frame_idx in 0..image_bytes.len() {
336 let img_buffer = image_bytes.value(frame_idx);
337 let encoded_image = EncodedImage::from_file_contents(img_buffer.to_owned());
338 let timepoint = TimePoint::default().with(*timeline, frame_idx as i64);
339 chunk = chunk.with_archetype(row_id, timepoint, &encoded_image);
340
341 row_id = row_id.next();
342 }
343
344 Ok(std::iter::once(chunk.build().with_context(|| {
345 format!("Failed to build image chunk for image: {observation}")
346 })?))
347}
348
349fn load_episode_depth_images(
350 observation: &str,
351 timeline: &Timeline,
352 data: &RecordBatch,
353) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
354 let image_bytes = data
355 .column_by_name(observation)
356 .and_then(|c| c.downcast_array_ref::<StructArray>())
357 .and_then(|a| a.column_by_name("bytes"))
358 .and_then(|a| a.downcast_array_ref::<BinaryArray>())
359 .with_context(|| format!("Failed to get binary data from image feature: {observation}"))?;
360
361 let mut chunk = Chunk::builder(observation);
362 let mut row_id = RowId::new();
363
364 for frame_idx in 0..image_bytes.len() {
365 let img_buffer = image_bytes.value(frame_idx);
366 let depth_image = DepthImage::from_file_contents(img_buffer.to_owned())
367 .map_err(|err| anyhow!("Failed to decode image: {err}"))?;
368
369 let timepoint = TimePoint::default().with(*timeline, frame_idx as i64);
370 chunk = chunk.with_archetype(row_id, timepoint, &depth_image);
371
372 row_id = row_id.next();
373 }
374
375 Ok(std::iter::once(chunk.build().with_context(|| {
376 format!("Failed to build image chunk for image: {observation}")
377 })?))
378}
379
380fn load_episode_video(
381 dataset: &LeRobotDataset,
382 observation: &str,
383 episode: EpisodeIndex,
384 timeline: &Timeline,
385 time_column: TimeColumn,
386) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
387 let contents = dataset
388 .read_episode_video_contents(observation, episode)
389 .with_context(|| format!("Reading video contents for episode {episode:?} failed!"))?;
390
391 let video_asset = AssetVideo::new(contents.into_owned());
392 let entity_path = observation;
393
394 let video_frame_reference_chunk = match video_asset.read_frame_timestamps_nanos() {
395 Ok(frame_timestamps_nanos) => {
396 let frame_timestamps_nanos: arrow::buffer::ScalarBuffer<i64> =
397 frame_timestamps_nanos.into();
398
399 let video_timestamps = frame_timestamps_nanos
400 .iter()
401 .take(time_column.num_rows())
402 .copied()
403 .map(VideoTimestamp::from_nanos)
404 .collect::<Vec<_>>();
405
406 let video_frame_reference_column = VideoFrameReference::update_fields()
407 .with_many_timestamp(video_timestamps)
408 .columns_of_unit_batches()
409 .with_context(|| {
410 format!(
411 "Failed to create `VideoFrameReference` column for episode {episode:?}."
412 )
413 })?;
414
415 Some(Chunk::from_auto_row_ids(
416 re_chunk::ChunkId::new(),
417 entity_path.into(),
418 std::iter::once((*timeline.name(), time_column)).collect(),
419 video_frame_reference_column.collect(),
420 )?)
421 }
422 Err(err) => {
423 re_log::warn_once!(
424 "Failed to read frame timestamps from episode {episode:?} video: {err}"
425 );
426 None
427 }
428 };
429
430 let video_asset_chunk = Chunk::builder(entity_path)
432 .with_archetype(RowId::new(), TimePoint::default(), &video_asset)
433 .build()?;
434
435 if let Some(video_frame_reference_chunk) = video_frame_reference_chunk {
436 Ok(Either::Left(
437 [video_asset_chunk, video_frame_reference_chunk].into_iter(),
438 ))
439 } else {
440 Ok(Either::Right(std::iter::once(video_asset_chunk)))
442 }
443}
444
445enum ScalarChunkIterator {
447 Empty(std::iter::Empty<Chunk>),
448 Batch(Box<dyn ExactSizeIterator<Item = Chunk>>),
449 Single(std::iter::Once<Chunk>),
450}
451
452impl Iterator for ScalarChunkIterator {
453 type Item = Chunk;
454
455 fn next(&mut self) -> Option<Self::Item> {
456 match self {
457 Self::Empty(iter) => iter.next(),
458 Self::Batch(iter) => iter.next(),
459 Self::Single(iter) => iter.next(),
460 }
461 }
462}
463
464impl ExactSizeIterator for ScalarChunkIterator {}
465
466fn load_scalar(
467 feature_key: &str,
468 feature: &Feature,
469 timelines: &IntMap<TimelineName, TimeColumn>,
470 data: &RecordBatch,
471) -> Result<ScalarChunkIterator, DataLoaderError> {
472 let field = data
473 .schema_ref()
474 .field_with_name(feature_key)
475 .with_context(|| {
476 format!("Failed to get field for feature {feature_key} from parquet file")
477 })?;
478
479 let entity_path = EntityPath::parse_forgiving(field.name());
480
481 match field.data_type() {
482 DataType::FixedSizeList(_, _) => {
483 let fixed_size_array = data
484 .column_by_name(feature_key)
485 .and_then(|col| col.downcast_array_ref::<FixedSizeListArray>())
486 .ok_or_else(|| {
487 DataLoaderError::Other(anyhow!(
488 "Failed to downcast feature to FixedSizeListArray"
489 ))
490 })?;
491
492 let batch_chunks =
493 make_scalar_batch_entity_chunks(entity_path, feature, timelines, fixed_size_array)?;
494 Ok(ScalarChunkIterator::Batch(Box::new(batch_chunks)))
495 }
496 DataType::List(_field) => {
497 let list_array = data
498 .column_by_name(feature_key)
499 .and_then(|col| col.downcast_array_ref::<arrow::array::ListArray>())
500 .ok_or_else(|| {
501 DataLoaderError::Other(anyhow!("Failed to downcast feature to ListArray"))
502 })?;
503
504 let sliced = extract_list_array_elements_as_f64(list_array).with_context(|| {
505 format!("Failed to cast scalar feature {entity_path} to Float64")
506 })?;
507
508 Ok(ScalarChunkIterator::Single(std::iter::once(
509 make_scalar_entity_chunk(entity_path, timelines, &sliced)?,
510 )))
511 }
512 DataType::Float32 | DataType::Float64 => {
513 let feature_data = data.column_by_name(feature_key).ok_or_else(|| {
514 DataLoaderError::Other(anyhow!(
515 "Failed to get LeRobot dataset column data for: {:?}",
516 field.name()
517 ))
518 })?;
519
520 let sliced = extract_scalar_slices_as_f64(feature_data).with_context(|| {
521 format!("Failed to cast scalar feature {entity_path} to Float64")
522 })?;
523
524 Ok(ScalarChunkIterator::Single(std::iter::once(
525 make_scalar_entity_chunk(entity_path, timelines, &sliced)?,
526 )))
527 }
528 _ => {
529 re_log::warn_once!(
530 "Tried logging scalar {} with unsupported dtype: {}",
531 field.name(),
532 field.data_type()
533 );
534 Ok(ScalarChunkIterator::Empty(std::iter::empty()))
535 }
536 }
537}
538
539fn make_scalar_batch_entity_chunks(
540 entity_path: EntityPath,
541 feature: &Feature,
542 timelines: &IntMap<TimelineName, TimeColumn>,
543 data: &FixedSizeListArray,
544) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
545 let num_elements = data.value_length() as usize;
546
547 let mut chunks = Vec::with_capacity(num_elements);
548
549 let sliced = extract_fixed_size_list_array_elements_as_f64(data)
550 .with_context(|| format!("Failed to cast scalar feature {entity_path} to Float64"))?;
551
552 chunks.push(make_scalar_entity_chunk(
553 entity_path.clone(),
554 timelines,
555 &sliced,
556 )?);
557
558 if let Some(names) = feature.names.clone() {
560 let names: Vec<_> = (0..data.value_length() as usize)
561 .map(|idx| names.name_for_index(idx))
562 .collect();
563
564 chunks.push(
565 Chunk::builder(entity_path)
566 .with_row(
567 RowId::new(),
568 TimePoint::default(),
569 std::iter::once((
570 archetypes::SeriesLines::descriptor_names(),
571 Arc::new(StringArray::from_iter(names)) as Arc<dyn ArrowArray>,
572 )),
573 )
574 .build()?,
575 );
576 }
577
578 Ok(chunks.into_iter())
579}
580
581fn make_scalar_entity_chunk(
582 entity_path: EntityPath,
583 timelines: &IntMap<TimelineName, TimeColumn>,
584 sliced_data: &[ArrayRef],
585) -> Result<Chunk, DataLoaderError> {
586 let data_arrays = sliced_data
587 .iter()
588 .map(|e| Some(e.as_ref()))
589 .collect::<Vec<_>>();
590
591 let data_field_inner = Field::new("item", DataType::Float64, true );
592 #[allow(clippy::unwrap_used)] let data_field_array: arrow::array::ListArray =
594 re_arrow_util::arrays_to_list_array(data_field_inner.data_type().clone(), &data_arrays)
595 .unwrap();
596
597 Ok(Chunk::from_auto_row_ids(
598 ChunkId::new(),
599 entity_path,
600 timelines.clone(),
601 std::iter::once((Scalars::descriptor_scalars().clone(), data_field_array)).collect(),
602 )?)
603}
604
605fn extract_scalar_slices_as_f64(data: &ArrayRef) -> anyhow::Result<Vec<ArrayRef>> {
606 let scalar_values = cast(&data, &DataType::Float64)
608 .with_context(|| format!("Failed to cast {:?} to Float64", data.data_type()))?;
609
610 Ok((0..data.len())
611 .map(|idx| scalar_values.slice(idx, 1))
612 .collect::<Vec<_>>())
613}
614
615fn extract_fixed_size_list_array_elements_as_f64(
616 data: &FixedSizeListArray,
617) -> anyhow::Result<Vec<ArrayRef>> {
618 (0..data.len())
619 .map(|idx| {
620 cast(&data.value(idx), &DataType::Float64)
621 .with_context(|| format!("Failed to cast {:?} to Float64", data.data_type()))
622 })
623 .collect::<Result<Vec<_>, _>>()
624}
625
626fn extract_list_array_elements_as_f64(
627 data: &arrow::array::ListArray,
628) -> anyhow::Result<Vec<ArrayRef>> {
629 (0..data.len())
630 .map(|idx| {
631 cast(&data.value(idx), &DataType::Float64)
632 .with_context(|| format!("Failed to cast {:?} to Float64", data.data_type()))
633 })
634 .collect::<Result<Vec<_>, _>>()
635}