pub mod common;
pub mod datasetv2;
pub mod datasetv3;
use std::{fmt, path::Path};
use serde::{
Deserialize, Deserializer, Serialize,
de::{MapAccess, SeqAccess, Visitor},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum LeRobotDatasetVersion {
V1,
V2,
V3,
}
impl LeRobotDatasetVersion {
pub fn find_version(path: impl AsRef<Path>) -> Option<Self> {
let path = path.as_ref();
if is_v3_lerobot_dataset(path) {
Some(Self::V3)
} else if is_v2_lerobot_dataset(path) {
Some(Self::V2)
} else if is_v1_lerobot_dataset(path) {
Some(Self::V1)
} else {
None
}
}
}
pub fn is_lerobot_dataset(path: impl AsRef<Path>) -> bool {
is_v1_lerobot_dataset(path.as_ref())
|| is_v2_lerobot_dataset(path.as_ref())
|| is_v3_lerobot_dataset(path.as_ref())
}
fn is_v3_lerobot_dataset(_path: impl AsRef<Path>) -> bool {
let path = _path.as_ref();
if !path.is_dir() {
return false;
}
has_sub_directories(&["meta", "data"], path) && path.join("meta").join("episodes").is_dir()
}
fn is_v2_lerobot_dataset(path: impl AsRef<Path>) -> bool {
let path = path.as_ref();
if !path.is_dir() {
return false;
}
has_sub_directories(&["meta", "data"], path)
}
fn is_v1_lerobot_dataset(path: impl AsRef<Path>) -> bool {
let path = path.as_ref();
if !path.is_dir() {
return false;
}
has_sub_directories(&["meta_data", "data"], path)
}
fn has_sub_directories(directories: &[&str], path: impl AsRef<Path>) -> bool {
directories.iter().all(|subdir| {
let subpath = path.as_ref().join(subdir);
subpath.is_dir()
&& subpath
.read_dir()
.is_ok_and(|mut contents| contents.next().is_some())
})
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Feature {
pub dtype: DType,
pub shape: Vec<usize>,
pub names: Option<Names>,
}
impl Feature {
pub fn channel_dim(&self) -> usize {
if let Some(names) = &self.names
&& let Some(channel_idx) = names.0.iter().position(|name| name == "channels")
{
if channel_idx < self.shape.len() {
return self.shape[channel_idx];
}
}
self.shape.last().copied().unwrap_or(0)
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum DType {
Video,
Image,
Bool,
Float32,
Float64,
Int16,
Int64,
String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct Names(pub(super) Vec<String>);
impl Names {
pub fn name_for_index(&self, index: usize) -> Option<&String> {
self.0.get(index)
}
}
struct NamesVisitor;
impl<'de> Visitor<'de> for NamesVisitor {
type Value = Names;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str(
"a flat string array, a nested string array, or a single-entry object with a string array or null value",
)
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum ListItem {
Str(String),
List(Vec<String>),
}
#[derive(PartialEq)]
enum ListType {
Undetermined,
Flat,
Nested,
}
let mut names = Vec::new();
let mut determined_type = ListType::Undetermined;
while let Some(item) = seq.next_element::<ListItem>()? {
match item {
ListItem::Str(s) => {
if determined_type == ListType::Nested {
return Err(serde::de::Error::custom(
"Cannot mix nested lists with flat strings within names array",
));
}
determined_type = ListType::Flat;
names.push(s);
}
ListItem::List(list) => {
if determined_type == ListType::Flat {
return Err(serde::de::Error::custom(
"Cannot mix flat strings and nested lists within names array",
));
}
determined_type = ListType::Nested;
names.extend(list);
}
}
}
Ok(Names(names))
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut names_vec: Option<Vec<String>> = None;
let mut entry_count = 0;
while let Some((_key, value)) = map.next_entry::<String, Option<Vec<String>>>()? {
entry_count += 1;
if entry_count > 1 {
while map
.next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
.is_some()
{}
return Err(serde::de::Error::invalid_length(
entry_count,
&"a Names object with exactly one entry.",
));
}
names_vec = Some(value.unwrap_or_default());
}
Ok(Names(names_vec.unwrap_or_default()))
}
}
impl<'de> Deserialize<'de> for Names {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(NamesVisitor)
}
}
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[serde(transparent)]
pub struct EpisodeIndex(pub usize);
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[serde(transparent)]
pub struct TaskIndex(pub usize);
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[serde(transparent)]
pub struct SubtaskIndex(pub usize);
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LeRobotDatasetTask {
#[serde(rename = "task_index")]
pub index: TaskIndex,
pub task: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LeRobotDatasetSubtask {
#[serde(rename = "subtask_index")]
pub index: SubtaskIndex,
pub subtask: String,
}
#[derive(thiserror::Error, Debug)]
pub enum LeRobotError {
#[error("IO error occurred on path: {path}")]
IO {
#[source]
source: std::io::Error,
path: std::path::PathBuf,
},
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error(transparent)]
Parquet(#[from] parquet::errors::ParquetError),
#[error(transparent)]
Arrow(#[from] arrow::error::ArrowError),
#[error("Invalid feature key: {0}")]
InvalidFeatureKey(String),
#[error("Missing dataset info: {0}")]
MissingDatasetInfo(String),
#[error("Invalid feature dtype, expected {key} to be of type {expected:?}, but got {actual:?}")]
InvalidFeatureDtype {
key: String,
expected: DType,
actual: DType,
},
#[error("Invalid chunk index: {0}")]
InvalidChunkIndex(usize),
#[error("Invalid episode index: {0:?}")]
InvalidEpisodeIndex(EpisodeIndex),
#[error("Episode {0:?} data file does not contain any records")]
EmptyEpisode(EpisodeIndex),
}
impl LeRobotError {
pub fn io(source: std::io::Error, path: impl Into<std::path::PathBuf>) -> Self {
Self::IO {
source,
path: path.into(),
}
}
}
#[cfg(test)]
mod tests {
use serde_json;
use super::*;
#[test]
fn test_deserialize_flat_list() {
let json = r#"["a", "b", "c"]"#;
let expected = Names(vec!["a".to_owned(), "b".to_owned(), "c".to_owned()]);
let names: Names = serde_json::from_str(json).unwrap();
assert_eq!(names, expected);
}
#[test]
fn test_deserialize_nested_list() {
let json = r#"[["a", "b"], ["c"]]"#;
let expected = Names(vec!["a".to_owned(), "b".to_owned(), "c".to_owned()]);
let names: Names = serde_json::from_str(json).unwrap();
assert_eq!(names, expected);
}
#[test]
fn test_deserialize_empty_nested_list() {
let json = r#"[[], []]"#;
let expected = Names(vec![]);
let names: Names = serde_json::from_str(json).unwrap();
assert_eq!(names, expected);
}
#[test]
fn test_deserialize_empty_list() {
let json = r#"[]"#;
let expected = Names(vec![]);
let names: Names = serde_json::from_str(json).unwrap();
assert_eq!(names, expected);
}
#[test]
fn test_deserialize_object_with_list() {
let json = r#"{ "axes": ["x", "y", "z"] }"#;
let expected = Names(vec!["x".to_owned(), "y".to_owned(), "z".to_owned()]);
let names: Names = serde_json::from_str(json).unwrap();
assert_eq!(names, expected);
}
#[test]
fn test_deserialize_object_with_empty_list() {
let json = r#"{ "motors": [] }"#;
let expected = Names(vec![]);
let names: Names = serde_json::from_str(json).unwrap();
assert_eq!(names, expected);
}
#[test]
fn test_deserialize_object_with_null() {
let json = r#"{ "axes": null }"#;
let expected = Names(vec![]); let names: Names = serde_json::from_str(json).unwrap();
assert_eq!(names, expected);
}
#[test]
fn test_deserialize_empty_object() {
let json = r#"{}"#;
let expected = Names(vec![]);
let names: Names = serde_json::from_str(json).unwrap();
assert_eq!(names, expected);
}
#[test]
fn test_deserialize_error_mixed_list() {
let json = r#"["a", ["b"]]"#; let result: Result<Names, _> = serde_json::from_str(json);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Cannot mix flat strings and nested lists")
);
}
#[test]
fn test_deserialize_error_object_multiple_entries() {
let json = r#"{ "axes": ["x"], "motors": ["m"] }"#;
let result: Result<Names, _> = serde_json::from_str(json);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("a Names object with exactly one entry")
);
}
}