use std::fs::File;
use std::io::{self, Read};
use std::path::{Path, PathBuf};
use arrow_array::{Array, RecordBatch};
use arrow_schema::DataType;
use hf_hub::{
Repo, RepoType,
api::sync::{ApiBuilder, ApiRepo},
};
use parquet::arrow::arrow_reader::{ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder};
use serde_json::{Map, Value};
use crate::error::{Error, Result};
pub fn parse_hf_url(s: &str) -> Option<(String, Option<String>)> {
let rest = s.strip_prefix("hf://")?;
if rest.is_empty() {
return None;
}
let (repo, rev) = match rest.split_once('@') {
Some((r, v)) if !v.is_empty() => (r.to_string(), Some(v.to_string())),
_ => (rest.to_string(), None),
};
if repo.contains(' ') || repo.split('/').count() < 2 {
return None;
}
Some((repo, rev))
}
pub struct HfShardSource {
repo_handle: ApiRepo,
repo_id: String,
shards: Vec<String>,
}
impl HfShardSource {
pub fn open(url: &str) -> Result<Self> {
let (repo_id, revision) = parse_hf_url(url).ok_or_else(|| {
Error::Input(format!(
"hf-hub: malformed URL {url:?}; expected hf://<namespace>/<repo>[@<revision>]"
))
})?;
let api = ApiBuilder::new()
.with_progress(false)
.build()
.map_err(|e| Error::Input(format!("hf-hub: build api failed: {e}")))?;
let repo_handle = match revision {
Some(rev) => api.repo(Repo::with_revision(repo_id.clone(), RepoType::Dataset, rev)),
None => api.dataset(repo_id.clone()),
};
let info = repo_handle
.info()
.map_err(|e| Error::Input(format!("hf-hub: fetch info for {repo_id:?} failed: {e}")))?;
let mut shards: Vec<String> = info
.siblings
.into_iter()
.map(|s| s.rfilename)
.filter(|n: &String| n.starts_with("data/") && n.ends_with(".parquet"))
.collect();
if shards.is_empty() {
return Err(Error::Input(format!(
"hf-hub: no data/*.parquet shards found in {repo_id:?}"
)));
}
shards.sort();
tracing::info!(
repo = %repo_id,
shards = shards.len(),
"HF dataset manifest loaded (shards will download lazily)",
);
Ok(Self {
repo_handle,
repo_id,
shards,
})
}
pub fn num_shards(&self) -> usize {
self.shards.len()
}
pub fn fetch_shard(&self, idx: usize) -> Result<PathBuf> {
let name = self.shards.get(idx).ok_or_else(|| {
Error::Input(format!(
"hf-hub: shard index {idx} out of range (have {} shards)",
self.shards.len()
))
})?;
tracing::info!(
repo = %self.repo_id,
shard_idx = idx,
shard = %name,
"downloading HF shard (or cache-hit)",
);
self.repo_handle
.get(name)
.map_err(|e| Error::Input(format!("hf-hub: get {name:?} failed: {e}")))
}
}
enum PathSource {
Local(Vec<PathBuf>),
Hf(HfShardSource),
}
pub struct ParquetJsonlReader {
source: PathSource,
next_file: usize,
projection: Option<Vec<String>>,
reader: Option<ParquetRecordBatchReader>,
current_batch: Option<RecordBatch>,
next_row_in_batch: usize,
out_buf: Vec<u8>,
out_pos: usize,
}
impl ParquetJsonlReader {
pub fn new(files: Vec<PathBuf>, projection: Option<Vec<String>>) -> Self {
Self {
source: PathSource::Local(files),
next_file: 0,
projection,
reader: None,
current_batch: None,
next_row_in_batch: 0,
out_buf: Vec::with_capacity(16 * 1024),
out_pos: 0,
}
}
pub fn from_hf(source: HfShardSource, projection: Option<Vec<String>>) -> Self {
Self {
source: PathSource::Hf(source),
next_file: 0,
projection,
reader: None,
current_batch: None,
next_row_in_batch: 0,
out_buf: Vec::with_capacity(16 * 1024),
out_pos: 0,
}
}
fn total_files(&self) -> usize {
match &self.source {
PathSource::Local(v) => v.len(),
PathSource::Hf(h) => h.num_shards(),
}
}
fn next_file_path(&mut self) -> Result<Option<PathBuf>> {
if self.next_file >= self.total_files() {
return Ok(None);
}
let idx = self.next_file;
self.next_file += 1;
let path = match &self.source {
PathSource::Local(v) => v[idx].clone(),
PathSource::Hf(h) => h.fetch_shard(idx)?,
};
Ok(Some(path))
}
fn open_next_file(&mut self) -> Result<bool> {
let Some(path) = self.next_file_path()? else {
return Ok(false);
};
let path_ref = &path;
let f = File::open(path_ref).map_err(Error::Io)?;
let builder = ParquetRecordBatchReaderBuilder::try_new(f)
.map_err(|e| Error::Input(format!("parquet open {path_ref:?}: {e}")))?;
let builder = if let Some(cols) = &self.projection {
let schema = builder.parquet_schema();
let mut indices = Vec::new();
for name in cols {
let idx = (0..schema.num_columns()).find(|i| schema.column(*i).name() == name);
match idx {
Some(i) => indices.push(i),
None => {
return Err(Error::Input(format!(
"parquet {path_ref:?}: projected column {name:?} not found. Available: {:?}",
(0..schema.num_columns())
.map(|i| schema.column(i).name().to_string())
.collect::<Vec<_>>()
)));
}
}
}
let mask = parquet::arrow::ProjectionMask::leaves(schema, indices);
builder.with_projection(mask)
} else {
builder
};
let reader = builder
.build()
.map_err(|e| Error::Input(format!("parquet build reader {path_ref:?}: {e}")))?;
tracing::debug!(path = %path_ref.display(), "opened parquet file");
self.reader = Some(reader);
self.current_batch = None;
self.next_row_in_batch = 0;
Ok(true)
}
fn pull_next_batch(&mut self) -> Result<bool> {
let Some(reader) = self.reader.as_mut() else {
return Ok(false);
};
loop {
match reader.next() {
Some(Ok(batch)) if batch.num_rows() == 0 => continue,
Some(Ok(batch)) => {
self.current_batch = Some(batch);
self.next_row_in_batch = 0;
return Ok(true);
}
Some(Err(e)) => {
return Err(Error::Input(format!("parquet decode: {e}")));
}
None => return Ok(false),
}
}
}
fn prepare_next_row(&mut self) -> Result<bool> {
loop {
if let Some(batch) = self.current_batch.as_ref() {
if self.next_row_in_batch < batch.num_rows() {
let row = serialize_row(batch, self.next_row_in_batch)?;
self.next_row_in_batch += 1;
self.out_buf.clear();
self.out_buf.extend_from_slice(row.as_bytes());
self.out_buf.push(b'\n');
self.out_pos = 0;
return Ok(true);
}
self.current_batch = None;
}
if self.reader.is_none() && !self.open_next_file()? {
return Ok(false);
}
if !self.pull_next_batch()? {
self.reader = None;
}
}
}
}
impl Read for ParquetJsonlReader {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
let mut written = 0;
while written < dst.len() {
if self.out_pos >= self.out_buf.len() {
match self.prepare_next_row() {
Ok(true) => {}
Ok(false) => return Ok(written),
Err(e) => {
return Err(io::Error::other(e.to_string()));
}
}
}
let remaining = &self.out_buf[self.out_pos..];
let n = remaining.len().min(dst.len() - written);
dst[written..written + n].copy_from_slice(&remaining[..n]);
self.out_pos += n;
written += n;
}
Ok(written)
}
}
fn serialize_row(batch: &RecordBatch, row: usize) -> Result<String> {
use arrow_array::cast::AsArray;
let schema = batch.schema();
let mut map: Map<String, Value> = Map::with_capacity(batch.num_columns());
for (col_idx, col) in batch.columns().iter().enumerate() {
let field = schema.field(col_idx);
let name = field.name().to_string();
let value: Value = if col.is_null(row) {
Value::Null
} else {
match field.data_type() {
DataType::Utf8 => Value::String(col.as_string::<i32>().value(row).to_string()),
DataType::LargeUtf8 => Value::String(col.as_string::<i64>().value(row).to_string()),
DataType::Int8 => Value::from(
col.as_primitive::<arrow_array::types::Int8Type>()
.value(row),
),
DataType::Int16 => Value::from(
col.as_primitive::<arrow_array::types::Int16Type>()
.value(row),
),
DataType::Int32 => Value::from(
col.as_primitive::<arrow_array::types::Int32Type>()
.value(row),
),
DataType::Int64 => Value::from(
col.as_primitive::<arrow_array::types::Int64Type>()
.value(row),
),
DataType::UInt8 => Value::from(
col.as_primitive::<arrow_array::types::UInt8Type>()
.value(row),
),
DataType::UInt16 => Value::from(
col.as_primitive::<arrow_array::types::UInt16Type>()
.value(row),
),
DataType::UInt32 => Value::from(
col.as_primitive::<arrow_array::types::UInt32Type>()
.value(row),
),
DataType::UInt64 => Value::from(
col.as_primitive::<arrow_array::types::UInt64Type>()
.value(row),
),
DataType::Float32 => serde_json::Number::from_f64(
col.as_primitive::<arrow_array::types::Float32Type>()
.value(row) as f64,
)
.map(Value::Number)
.unwrap_or(Value::Null),
DataType::Float64 => serde_json::Number::from_f64(
col.as_primitive::<arrow_array::types::Float64Type>()
.value(row),
)
.map(Value::Number)
.unwrap_or(Value::Null),
DataType::Boolean => Value::Bool(col.as_boolean().value(row)),
DataType::Binary => {
let bytes = col.as_binary::<i32>().value(row);
Value::String(hex_lowercase(bytes))
}
DataType::List(_) | DataType::LargeList(_) => {
let list_arr = col.as_any().downcast_ref::<arrow_array::ListArray>();
if let Some(la) = list_arr {
let inner = la.value(row);
Value::Array(arrow_array_to_json_values(&inner))
} else if let Some(la) =
col.as_any().downcast_ref::<arrow_array::LargeListArray>()
{
let inner = la.value(row);
Value::Array(arrow_array_to_json_values(&inner))
} else {
tracing::warn!(column = %name, "list column downcast failed");
Value::Null
}
}
other => {
tracing::warn!(
column = %name,
dtype = ?other,
"unsupported arrow type for JSON serialization; emitting null"
);
Value::Null
}
}
};
map.insert(name, value);
}
serde_json::to_string(&Value::Object(map))
.map_err(|e| Error::Input(format!("json serialize row: {e}")))
}
fn arrow_array_to_json_values(arr: &dyn arrow_array::Array) -> Vec<Value> {
use arrow_array::cast::AsArray;
let n = arr.len();
let mut out = Vec::with_capacity(n);
for i in 0..n {
if arr.is_null(i) {
out.push(Value::Null);
continue;
}
let v = match arr.data_type() {
DataType::Int8 => {
Value::from(arr.as_primitive::<arrow_array::types::Int8Type>().value(i))
}
DataType::Int16 => {
Value::from(arr.as_primitive::<arrow_array::types::Int16Type>().value(i))
}
DataType::Int32 => {
Value::from(arr.as_primitive::<arrow_array::types::Int32Type>().value(i))
}
DataType::Int64 => {
Value::from(arr.as_primitive::<arrow_array::types::Int64Type>().value(i))
}
DataType::UInt8 => {
Value::from(arr.as_primitive::<arrow_array::types::UInt8Type>().value(i))
}
DataType::UInt16 => Value::from(
arr.as_primitive::<arrow_array::types::UInt16Type>()
.value(i),
),
DataType::UInt32 => Value::from(
arr.as_primitive::<arrow_array::types::UInt32Type>()
.value(i),
),
DataType::UInt64 => Value::from(
arr.as_primitive::<arrow_array::types::UInt64Type>()
.value(i),
),
DataType::Float32 => serde_json::Number::from_f64(
arr.as_primitive::<arrow_array::types::Float32Type>()
.value(i) as f64,
)
.map(Value::Number)
.unwrap_or(Value::Null),
DataType::Float64 => serde_json::Number::from_f64(
arr.as_primitive::<arrow_array::types::Float64Type>()
.value(i),
)
.map(Value::Number)
.unwrap_or(Value::Null),
DataType::Boolean => Value::Bool(arr.as_boolean().value(i)),
DataType::Utf8 => Value::String(arr.as_string::<i32>().value(i).to_string()),
_ => Value::Null,
};
out.push(v);
}
out
}
fn hex_lowercase(bytes: &[u8]) -> String {
const H: &[u8; 16] = b"0123456789abcdef";
let mut s = String::with_capacity(bytes.len() * 2);
for b in bytes {
s.push(H[(b >> 4) as usize] as char);
s.push(H[(b & 0xF) as usize] as char);
}
s
}
pub fn looks_like_parquet_input(path: &Path) -> bool {
let as_str = path.to_string_lossy();
if as_str.starts_with("hf://") {
return true;
}
if path.extension().is_some_and(|e| e == "parquet") {
return true;
}
if path.is_dir() {
if let Ok(mut entries) = std::fs::read_dir(path) {
return entries.any(|e| {
e.ok()
.map(|e| e.path().extension().is_some_and(|ext| ext == "parquet"))
.unwrap_or(false)
});
}
}
false
}
pub fn list_parquet_shards(dir: &Path) -> Result<Vec<PathBuf>> {
let mut shards: Vec<PathBuf> = Vec::new();
let iter = std::fs::read_dir(dir).map_err(Error::Io)?;
for entry in iter {
let p = entry.map_err(Error::Io)?.path();
if p.extension().is_some_and(|e| e == "parquet") {
shards.push(p);
}
}
shards.sort();
Ok(shards)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_valid_hf_urls() {
assert_eq!(
parse_hf_url("hf://alea-institute/kl3m-data-sample-006-medium"),
Some((
"alea-institute/kl3m-data-sample-006-medium".to_string(),
None
))
);
assert_eq!(
parse_hf_url("hf://ns/repo@v1"),
Some(("ns/repo".to_string(), Some("v1".to_string())))
);
}
#[test]
fn rejects_malformed_hf_urls() {
assert_eq!(parse_hf_url("hf://"), None);
assert_eq!(parse_hf_url("hf://no-slash"), None);
assert_eq!(parse_hf_url("https://huggingface.co/x"), None);
assert_eq!(parse_hf_url("hf://with space/repo"), None);
}
#[test]
fn looks_like_parquet() {
assert!(looks_like_parquet_input(Path::new("hf://a/b")));
assert!(looks_like_parquet_input(Path::new("data.parquet")));
assert!(looks_like_parquet_input(Path::new("/x/y.parquet")));
assert!(!looks_like_parquet_input(Path::new("data.jsonl")));
assert!(!looks_like_parquet_input(Path::new("data.jsonl.zst")));
}
#[test]
fn list_parquet_shards_sorted() {
use std::io::Write;
let dir = tempfile::tempdir().unwrap();
for name in [
"train-00002.parquet",
"train-00000.parquet",
"other.txt",
"train-00001.parquet",
] {
let mut f = std::fs::File::create(dir.path().join(name)).unwrap();
f.write_all(b"x").unwrap();
}
let shards = list_parquet_shards(dir.path()).unwrap();
let names: Vec<String> = shards
.iter()
.map(|p| p.file_name().unwrap().to_string_lossy().into_owned())
.collect();
assert_eq!(
names,
vec![
"train-00000.parquet".to_string(),
"train-00001.parquet".to_string(),
"train-00002.parquet".to_string(),
]
);
}
#[test]
fn looks_like_parquet_recognizes_dir_with_shards() {
use std::io::Write;
let dir = tempfile::tempdir().unwrap();
std::fs::File::create(dir.path().join("train-00000.parquet"))
.unwrap()
.write_all(b"x")
.unwrap();
assert!(looks_like_parquet_input(dir.path()));
}
#[test]
fn looks_like_parquet_rejects_empty_dir() {
let dir = tempfile::tempdir().unwrap();
assert!(!looks_like_parquet_input(dir.path()));
}
}