1#![expect(clippy::cast_possible_wrap)] use std::{
4 sync::{Arc, mpsc::Sender},
5 thread,
6};
7
8use anyhow::{Context as _, anyhow};
9use arrow::{
10 array::{
11 ArrayRef, BinaryArray, FixedSizeListArray, Int64Array, RecordBatch, StringArray,
12 StructArray,
13 },
14 compute::cast,
15 datatypes::{DataType, Field},
16};
17use itertools::Either;
18
19use re_arrow_util::ArrowArrayDowncastRef as _;
20use re_chunk::{
21 ArrowArray, Chunk, ChunkId, EntityPath, RowId, TimeColumn, TimeInt, TimePoint, Timeline,
22 TimelineName, external::nohash_hasher::IntMap,
23};
24use re_log_types::{ApplicationId, StoreId};
25use re_types::{
26 archetypes::{
27 self, AssetVideo, DepthImage, EncodedImage, Scalars, TextDocument, VideoFrameReference,
28 },
29 components::VideoTimestamp,
30};
31
32use crate::lerobot::{
33 DType, EpisodeIndex, Feature, LeRobotDataset, TaskIndex, is_lerobot_dataset,
34 is_v1_lerobot_dataset,
35};
36use crate::load_file::prepare_store_info;
37use crate::{DataLoader, DataLoaderError, LoadedData};
38
39const LEROBOT_DATASET_IGNORED_COLUMNS: &[&str] =
41 &["episode_index", "index", "frame_index", "timestamp"];
42
43const LEROBOT_DATASET_SUPPORTED_VERSIONS: &[&str] = &["v2.0", "v2.1"];
46
47pub struct LeRobotDatasetLoader;
51
52impl DataLoader for LeRobotDatasetLoader {
53 fn name(&self) -> String {
54 "LeRobotDatasetLoader".into()
55 }
56
57 fn load_from_path(
58 &self,
59 settings: &crate::DataLoaderSettings,
60 filepath: std::path::PathBuf,
61 tx: Sender<LoadedData>,
62 ) -> Result<(), DataLoaderError> {
63 if !is_lerobot_dataset(&filepath) {
64 return Err(DataLoaderError::Incompatible(filepath));
65 }
66
67 if is_v1_lerobot_dataset(&filepath) {
68 re_log::error!("LeRobot 'v1.x' dataset format is unsupported.");
69 return Ok(());
70 }
71
72 let dataset = LeRobotDataset::load_from_directory(&filepath)
73 .map_err(|err| anyhow!("Loading LeRobot dataset failed: {err}"))?;
74
75 if !LEROBOT_DATASET_SUPPORTED_VERSIONS
76 .contains(&dataset.metadata.info.codebase_version.as_str())
77 {
78 re_log::error!(
79 "LeRobot '{}' dataset format is unsupported.",
80 dataset.metadata.info.codebase_version
81 );
82 return Ok(());
83 }
84
85 let application_id = settings
86 .application_id
87 .clone()
88 .unwrap_or(filepath.display().to_string().into());
89
90 thread::Builder::new()
96 .name(format!("load_and_stream({filepath:?}"))
97 .spawn({
98 move || {
99 re_log::info!(
100 "Loading LeRobot dataset from {:?}, with {} episode(s)",
101 dataset.path,
102 dataset.metadata.episode_count(),
103 );
104 load_and_stream(&dataset, &application_id, &tx);
105 }
106 })
107 .with_context(|| {
108 format!("Failed to spawn IO thread to load LeRobot dataset {filepath:?} ")
109 })?;
110
111 Ok(())
112 }
113
114 fn load_from_file_contents(
115 &self,
116 _settings: &crate::DataLoaderSettings,
117 filepath: std::path::PathBuf,
118 _contents: std::borrow::Cow<'_, [u8]>,
119 _tx: Sender<LoadedData>,
120 ) -> Result<(), DataLoaderError> {
121 Err(DataLoaderError::Incompatible(filepath))
122 }
123}
124
125fn load_and_stream(
126 dataset: &LeRobotDataset,
127 application_id: &ApplicationId,
128 tx: &Sender<crate::LoadedData>,
129) {
130 let episodes = prepare_episode_chunks(dataset, application_id, tx);
132
133 for (episode, store_id) in &episodes {
134 match load_episode(dataset, *episode) {
136 Ok(chunks) => {
137 let recording_info = re_types::archetypes::RecordingInfo::new()
138 .with_name(format!("Episode {}", episode.0));
139
140 let Ok(initial) = Chunk::builder(EntityPath::properties())
141 .with_archetype(RowId::new(), TimePoint::STATIC, &recording_info)
142 .build()
143 else {
144 re_log::error!(
145 "Failed to build recording properties chunk for episode {}",
146 episode.0
147 );
148 return;
149 };
150
151 for chunk in std::iter::once(initial).chain(chunks.into_iter()) {
152 let data = LoadedData::Chunk(
153 LeRobotDatasetLoader::name(&LeRobotDatasetLoader),
154 store_id.clone(),
155 chunk,
156 );
157
158 if tx.send(data).is_err() {
159 break; }
161 }
162 }
163 Err(err) => {
164 re_log::warn!(
165 "Failed to load episode {} from LeRobot dataset: {err}",
166 episode.0
167 );
168 }
169 }
170 }
171}
172
173fn prepare_episode_chunks(
176 dataset: &LeRobotDataset,
177 application_id: &ApplicationId,
178 tx: &Sender<crate::LoadedData>,
179) -> Vec<(EpisodeIndex, StoreId)> {
180 let mut store_ids = vec![];
181
182 for episode_index in dataset.metadata.episodes.keys() {
183 let episode = *episode_index;
184
185 let store_id = StoreId::recording(application_id.clone(), format!("episode_{}", episode.0));
186 let set_store_info = LoadedData::LogMsg(
187 LeRobotDatasetLoader::name(&LeRobotDatasetLoader),
188 prepare_store_info(&store_id, re_log_types::FileSource::Sdk),
189 );
190
191 if tx.send(set_store_info).is_err() {
192 break;
193 }
194
195 store_ids.push((episode, store_id));
196 }
197
198 store_ids
199}
200
201pub fn load_episode(
207 dataset: &LeRobotDataset,
208 episode: EpisodeIndex,
209) -> Result<Vec<Chunk>, DataLoaderError> {
210 let data = dataset
211 .read_episode_data(episode)
212 .map_err(|err| anyhow!("Reading data for episode {} failed: {err}", episode.0))?;
213
214 let frame_indices = data
215 .column_by_name("frame_index")
216 .ok_or_else(|| anyhow!("Failed to get frame index column in LeRobot dataset"))?
217 .clone();
218
219 let timeline = re_log_types::Timeline::new_sequence("frame_index");
220 let times: &arrow::buffer::ScalarBuffer<i64> = frame_indices
221 .downcast_array_ref::<Int64Array>()
222 .ok_or_else(|| anyhow!("LeRobot dataset frame indices are of an unexpected type"))?
223 .values();
224
225 let time_column = re_chunk::TimeColumn::new(None, timeline, times.clone());
226 let timelines = std::iter::once((*timeline.name(), time_column.clone())).collect();
227
228 let mut chunks = Vec::new();
229
230 for (feature_key, feature) in dataset
231 .metadata
232 .info
233 .features
234 .iter()
235 .filter(|(key, _)| !LEROBOT_DATASET_IGNORED_COLUMNS.contains(&key.as_str()))
236 {
237 match feature.dtype {
238 DType::Video => {
239 chunks.extend(load_episode_video(
240 dataset,
241 feature_key,
242 episode,
243 &timeline,
244 time_column.clone(),
245 )?);
246 }
247
248 DType::Image => {
249 let num_channels = feature.channel_dim();
250
251 match num_channels {
252 1 => chunks.extend(load_episode_depth_images(feature_key, &timeline, &data)?),
253 3 => chunks.extend(load_episode_images(feature_key, &timeline, &data)?),
254 _ => re_log::warn_once!(
255 "Unsupported channel count {num_channels} (shape: {:?}) for LeRobot dataset; Only 1- and 3-channel images are supported",
256 feature.shape
257 ),
258 }
259 }
260 DType::Int64 if feature_key == "task_index" => {
261 chunks.extend(log_episode_task(dataset, &timeline, &data)?);
264 }
265 DType::Int16 | DType::Int64 | DType::Bool | DType::String => {
266 re_log::warn_once!(
267 "Loading LeRobot feature ({feature_key}) of dtype `{:?}` into Rerun is not yet implemented",
268 feature.dtype
269 );
270 }
271 DType::Float32 | DType::Float64 => {
272 chunks.extend(load_scalar(feature_key, feature, &timelines, &data)?);
273 }
274 }
275 }
276
277 Ok(chunks)
278}
279
280fn log_episode_task(
281 dataset: &LeRobotDataset,
282 timeline: &Timeline,
283 data: &RecordBatch,
284) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
285 let task_indices = data
286 .column_by_name("task_index")
287 .and_then(|c| c.downcast_array_ref::<Int64Array>())
288 .with_context(|| "Failed to get task_index field from dataset!")?;
289
290 let mut chunk = Chunk::builder("task");
291 let mut row_id = RowId::new();
292 let mut time_int = TimeInt::ZERO;
293
294 for task_index in task_indices {
295 let Some(task) = task_index
296 .and_then(|i| usize::try_from(i).ok())
297 .and_then(|i| dataset.task_by_index(TaskIndex(i)))
298 else {
299 time_int = time_int.inc();
301 continue;
302 };
303
304 let timepoint = TimePoint::default().with(*timeline, time_int);
305 let text = TextDocument::new(task.task.clone());
306 chunk = chunk.with_archetype(row_id, timepoint, &text);
307
308 row_id = row_id.next();
309 time_int = time_int.inc();
310 }
311
312 Ok(std::iter::once(chunk.build()?))
313}
314
315fn load_episode_images(
316 observation: &str,
317 timeline: &Timeline,
318 data: &RecordBatch,
319) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
320 let image_bytes = data
321 .column_by_name(observation)
322 .and_then(|c| c.downcast_array_ref::<StructArray>())
323 .and_then(|a| a.column_by_name("bytes"))
324 .and_then(|a| a.downcast_array_ref::<BinaryArray>())
325 .with_context(|| format!("Failed to get binary data from image feature: {observation}"))?;
326
327 let mut chunk = Chunk::builder(observation);
328 let mut row_id = RowId::new();
329
330 for frame_idx in 0..image_bytes.len() {
331 let img_buffer = image_bytes.value(frame_idx);
332 let encoded_image = EncodedImage::from_file_contents(img_buffer.to_owned());
333 let timepoint = TimePoint::default().with(*timeline, frame_idx as i64);
334 chunk = chunk.with_archetype(row_id, timepoint, &encoded_image);
335
336 row_id = row_id.next();
337 }
338
339 Ok(std::iter::once(chunk.build().with_context(|| {
340 format!("Failed to build image chunk for image: {observation}")
341 })?))
342}
343
344fn load_episode_depth_images(
345 observation: &str,
346 timeline: &Timeline,
347 data: &RecordBatch,
348) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
349 let image_bytes = data
350 .column_by_name(observation)
351 .and_then(|c| c.downcast_array_ref::<StructArray>())
352 .and_then(|a| a.column_by_name("bytes"))
353 .and_then(|a| a.downcast_array_ref::<BinaryArray>())
354 .with_context(|| format!("Failed to get binary data from image feature: {observation}"))?;
355
356 let mut chunk = Chunk::builder(observation);
357 let mut row_id = RowId::new();
358
359 for frame_idx in 0..image_bytes.len() {
360 let img_buffer = image_bytes.value(frame_idx);
361 let depth_image = DepthImage::from_file_contents(img_buffer.to_owned())
362 .map_err(|err| anyhow!("Failed to decode image: {err}"))?;
363
364 let timepoint = TimePoint::default().with(*timeline, frame_idx as i64);
365 chunk = chunk.with_archetype(row_id, timepoint, &depth_image);
366
367 row_id = row_id.next();
368 }
369
370 Ok(std::iter::once(chunk.build().with_context(|| {
371 format!("Failed to build image chunk for image: {observation}")
372 })?))
373}
374
375fn load_episode_video(
376 dataset: &LeRobotDataset,
377 observation: &str,
378 episode: EpisodeIndex,
379 timeline: &Timeline,
380 time_column: TimeColumn,
381) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
382 let contents = dataset
383 .read_episode_video_contents(observation, episode)
384 .with_context(|| format!("Reading video contents for episode {episode:?} failed!"))?;
385
386 let video_asset = AssetVideo::new(contents.into_owned());
387 let entity_path = observation;
388
389 let video_frame_reference_chunk = match video_asset.read_frame_timestamps_nanos() {
390 Ok(frame_timestamps_nanos) => {
391 let frame_timestamps_nanos: arrow::buffer::ScalarBuffer<i64> =
392 frame_timestamps_nanos.into();
393
394 let video_timestamps = frame_timestamps_nanos
395 .iter()
396 .take(time_column.num_rows())
397 .copied()
398 .map(VideoTimestamp::from_nanos)
399 .collect::<Vec<_>>();
400
401 let video_frame_reference_column = VideoFrameReference::update_fields()
402 .with_many_timestamp(video_timestamps)
403 .columns_of_unit_batches()
404 .with_context(|| {
405 format!(
406 "Failed to create `VideoFrameReference` column for episode {episode:?}."
407 )
408 })?;
409
410 Some(Chunk::from_auto_row_ids(
411 re_chunk::ChunkId::new(),
412 entity_path.into(),
413 std::iter::once((*timeline.name(), time_column)).collect(),
414 video_frame_reference_column.collect(),
415 )?)
416 }
417 Err(err) => {
418 re_log::warn_once!(
419 "Failed to read frame timestamps from episode {episode:?} video: {err}"
420 );
421 None
422 }
423 };
424
425 let video_asset_chunk = Chunk::builder(entity_path)
427 .with_archetype(RowId::new(), TimePoint::default(), &video_asset)
428 .build()?;
429
430 if let Some(video_frame_reference_chunk) = video_frame_reference_chunk {
431 Ok(Either::Left(
432 [video_asset_chunk, video_frame_reference_chunk].into_iter(),
433 ))
434 } else {
435 Ok(Either::Right(std::iter::once(video_asset_chunk)))
437 }
438}
439
440enum ScalarChunkIterator {
442 Empty(std::iter::Empty<Chunk>),
443 Batch(Box<dyn ExactSizeIterator<Item = Chunk>>),
444
445 Single(Box<std::iter::Once<Chunk>>),
447}
448
449impl Iterator for ScalarChunkIterator {
450 type Item = Chunk;
451
452 fn next(&mut self) -> Option<Self::Item> {
453 match self {
454 Self::Empty(iter) => iter.next(),
455 Self::Batch(iter) => iter.next(),
456 Self::Single(iter) => iter.next(),
457 }
458 }
459}
460
461impl ExactSizeIterator for ScalarChunkIterator {}
462
463fn load_scalar(
464 feature_key: &str,
465 feature: &Feature,
466 timelines: &IntMap<TimelineName, TimeColumn>,
467 data: &RecordBatch,
468) -> Result<ScalarChunkIterator, DataLoaderError> {
469 let field = data
470 .schema_ref()
471 .field_with_name(feature_key)
472 .with_context(|| {
473 format!("Failed to get field for feature {feature_key} from parquet file")
474 })?;
475
476 let entity_path = EntityPath::parse_forgiving(field.name());
477
478 match field.data_type() {
479 DataType::FixedSizeList(_, _) => {
480 let fixed_size_array = data
481 .column_by_name(feature_key)
482 .and_then(|col| col.downcast_array_ref::<FixedSizeListArray>())
483 .ok_or_else(|| {
484 DataLoaderError::Other(anyhow!(
485 "Failed to downcast feature to FixedSizeListArray"
486 ))
487 })?;
488
489 let batch_chunks =
490 make_scalar_batch_entity_chunks(entity_path, feature, timelines, fixed_size_array)?;
491 Ok(ScalarChunkIterator::Batch(Box::new(batch_chunks)))
492 }
493 DataType::List(_field) => {
494 let list_array = data
495 .column_by_name(feature_key)
496 .and_then(|col| col.downcast_array_ref::<arrow::array::ListArray>())
497 .ok_or_else(|| {
498 DataLoaderError::Other(anyhow!("Failed to downcast feature to ListArray"))
499 })?;
500
501 let sliced = extract_list_array_elements_as_f64(list_array).with_context(|| {
502 format!("Failed to cast scalar feature {entity_path} to Float64")
503 })?;
504
505 Ok(ScalarChunkIterator::Single(Box::new(std::iter::once(
506 make_scalar_entity_chunk(entity_path, timelines, &sliced)?,
507 ))))
508 }
509 DataType::Float32 | DataType::Float64 => {
510 let feature_data = data.column_by_name(feature_key).ok_or_else(|| {
511 DataLoaderError::Other(anyhow!(
512 "Failed to get LeRobot dataset column data for: {:?}",
513 field.name()
514 ))
515 })?;
516
517 let sliced = extract_scalar_slices_as_f64(feature_data).with_context(|| {
518 format!("Failed to cast scalar feature {entity_path} to Float64")
519 })?;
520
521 Ok(ScalarChunkIterator::Single(Box::new(std::iter::once(
522 make_scalar_entity_chunk(entity_path, timelines, &sliced)?,
523 ))))
524 }
525 _ => {
526 re_log::warn_once!(
527 "Tried logging scalar {} with unsupported dtype: {}",
528 field.name(),
529 field.data_type()
530 );
531 Ok(ScalarChunkIterator::Empty(std::iter::empty()))
532 }
533 }
534}
535
536fn make_scalar_batch_entity_chunks(
537 entity_path: EntityPath,
538 feature: &Feature,
539 timelines: &IntMap<TimelineName, TimeColumn>,
540 data: &FixedSizeListArray,
541) -> Result<impl ExactSizeIterator<Item = Chunk> + use<>, DataLoaderError> {
542 let num_elements = data.value_length() as usize;
543
544 let mut chunks = Vec::with_capacity(num_elements);
545
546 let sliced = extract_fixed_size_list_array_elements_as_f64(data)
547 .with_context(|| format!("Failed to cast scalar feature {entity_path} to Float64"))?;
548
549 chunks.push(make_scalar_entity_chunk(
550 entity_path.clone(),
551 timelines,
552 &sliced,
553 )?);
554
555 if let Some(names) = feature.names.clone() {
557 let names: Vec<_> = (0..data.value_length() as usize)
558 .map(|idx| names.name_for_index(idx))
559 .collect();
560
561 chunks.push(
562 Chunk::builder(entity_path)
563 .with_row(
564 RowId::new(),
565 TimePoint::default(),
566 std::iter::once((
567 archetypes::SeriesLines::descriptor_names(),
568 Arc::new(StringArray::from_iter(names)) as Arc<dyn ArrowArray>,
569 )),
570 )
571 .build()?,
572 );
573 }
574
575 Ok(chunks.into_iter())
576}
577
578fn make_scalar_entity_chunk(
579 entity_path: EntityPath,
580 timelines: &IntMap<TimelineName, TimeColumn>,
581 sliced_data: &[ArrayRef],
582) -> Result<Chunk, DataLoaderError> {
583 let data_arrays = sliced_data
584 .iter()
585 .map(|e| Some(e.as_ref()))
586 .collect::<Vec<_>>();
587
588 let data_field_inner = Field::new("item", DataType::Float64, true );
589 #[expect(clippy::unwrap_used)] let data_field_array: arrow::array::ListArray =
591 re_arrow_util::arrays_to_list_array(data_field_inner.data_type().clone(), &data_arrays)
592 .unwrap();
593
594 Ok(Chunk::from_auto_row_ids(
595 ChunkId::new(),
596 entity_path,
597 timelines.clone(),
598 std::iter::once((Scalars::descriptor_scalars().clone(), data_field_array)).collect(),
599 )?)
600}
601
602fn extract_scalar_slices_as_f64(data: &ArrayRef) -> anyhow::Result<Vec<ArrayRef>> {
603 let scalar_values = cast(&data, &DataType::Float64)
605 .with_context(|| format!("Failed to cast {} to Float64", data.data_type()))?;
606
607 Ok((0..data.len())
608 .map(|idx| scalar_values.slice(idx, 1))
609 .collect::<Vec<_>>())
610}
611
612fn extract_fixed_size_list_array_elements_as_f64(
613 data: &FixedSizeListArray,
614) -> anyhow::Result<Vec<ArrayRef>> {
615 (0..data.len())
616 .map(|idx| {
617 cast(&data.value(idx), &DataType::Float64)
618 .with_context(|| format!("Failed to cast {} to Float64", data.data_type()))
619 })
620 .collect::<Result<Vec<_>, _>>()
621}
622
623fn extract_list_array_elements_as_f64(
624 data: &arrow::array::ListArray,
625) -> anyhow::Result<Vec<ArrayRef>> {
626 (0..data.len())
627 .map(|idx| {
628 cast(&data.value(idx), &DataType::Float64)
629 .with_context(|| format!("Failed to cast {} to Float64", data.data_type()))
630 })
631 .collect::<Result<Vec<_>, _>>()
632}