use std::{
cell::RefCell,
io::{BufRead, BufReader},
path::Path,
};
use serde_json::Value;
use smol_str::format_smolstr;
use crate::{
error::{
CapExceededPayload, EmptyInputPayload, Error, FileIoPayload, FileOp, InvariantViolationPayload,
MalformedDataPayload, MissingKeyPayload, OutOfRangePayload, ParsePayload, Result,
},
tokenizer::Tokenizer,
};
pub const DEFAULT_TEXT_KEY: &str = "text";
pub const DEFAULT_CHAT_KEY: &str = "messages";
pub const DEFAULT_PROMPT_KEY: &str = "prompt";
pub const DEFAULT_COMPLETION_KEY: &str = "completion";
pub const MAX_DATASET_FILE_BYTES: u64 = 2 << 30;
pub type Example = (Vec<u32>, usize);
pub trait Dataset {
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn get(&self, idx: usize) -> Result<&Value>;
fn process(&self, idx: usize) -> Result<Example>;
}
pub struct TextDataset<'a> {
data: Vec<Value>,
tokenizer: &'a Tokenizer,
text_key: String,
}
impl std::fmt::Debug for TextDataset<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TextDataset")
.field("len", &self.data.len())
.field("text_key", &self.text_key)
.finish()
}
}
impl<'a> TextDataset<'a> {
pub fn new(data: Vec<Value>, tokenizer: &'a Tokenizer, text_key: impl Into<String>) -> Self {
Self {
data,
tokenizer,
text_key: text_key.into(),
}
}
}
impl Dataset for TextDataset<'_> {
fn len(&self) -> usize {
self.data.len()
}
fn get(&self, idx: usize) -> Result<&Value> {
self.data.get(idx).ok_or_else(|| {
Error::OutOfRange(OutOfRangePayload::new(
"TextDataset: index",
"must be < len",
format_smolstr!("{idx} (len={})", self.data.len()),
))
})
}
fn process(&self, idx: usize) -> Result<Example> {
let record = self.get(idx)?;
let text = field_as_str(record, &self.text_key, "TextDataset")?;
let mut tokens = self.tokenizer.encode(text, true)?;
if let Some(eos) = self.tokenizer.eos_token_id()
&& tokens.last() != Some(&eos)
{
tokens.push(eos);
}
Ok((tokens, 0))
}
}
pub struct ChatDataset<'a> {
data: Vec<Value>,
tokenizer: &'a Tokenizer,
chat_key: String,
mask_prompt: bool,
}
impl std::fmt::Debug for ChatDataset<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChatDataset")
.field("len", &self.data.len())
.field("chat_key", &self.chat_key)
.field("mask_prompt", &self.mask_prompt)
.finish()
}
}
impl<'a> ChatDataset<'a> {
pub fn new(
data: Vec<Value>,
tokenizer: &'a Tokenizer,
chat_key: impl Into<String>,
mask_prompt: bool,
) -> Self {
Self {
data,
tokenizer,
chat_key: chat_key.into(),
mask_prompt,
}
}
}
impl Dataset for ChatDataset<'_> {
fn len(&self) -> usize {
self.data.len()
}
fn get(&self, idx: usize) -> Result<&Value> {
self.data.get(idx).ok_or_else(|| {
Error::OutOfRange(OutOfRangePayload::new(
"ChatDataset: index",
"must be < len",
format_smolstr!("{idx} (len={})", self.data.len()),
))
})
}
fn process(&self, idx: usize) -> Result<Example> {
let record = self.get(idx)?;
let messages = record.get(&self.chat_key).ok_or_else(|| {
Error::MissingKey(MissingKeyPayload::new(
"ChatDataset: jsonl record missing field",
self.chat_key.as_str(),
))
})?;
if !messages.is_array() {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"ChatDataset: chat field JSON kind (must be array)",
"must be a JSON array",
format_smolstr!("{}={}", self.chat_key, json_kind(messages)),
)));
}
let tools = record.get("tools");
let tokens = self
.tokenizer
.apply_chat_template_ids(messages, tools, false, false, None)?;
if !self.mask_prompt {
return Ok((tokens, 0));
}
let arr = messages
.as_array()
.expect("messages.is_array() was checked above");
let last_role = arr
.last()
.and_then(|m| m.get("role"))
.and_then(Value::as_str);
let add_generation_prompt = last_role == Some("assistant");
let prefix = Value::Array(arr[..arr.len().saturating_sub(1)].to_vec());
let prefix_tokens =
self
.tokenizer
.apply_chat_template_ids(&prefix, tools, add_generation_prompt, false, None)?;
Ok((tokens, prefix_tokens.len()))
}
}
pub struct CompletionsDataset<'a> {
data: Vec<Value>,
tokenizer: &'a Tokenizer,
prompt_key: String,
completion_key: String,
mask_prompt: bool,
}
impl std::fmt::Debug for CompletionsDataset<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompletionsDataset")
.field("len", &self.data.len())
.field("prompt_key", &self.prompt_key)
.field("completion_key", &self.completion_key)
.field("mask_prompt", &self.mask_prompt)
.finish()
}
}
impl<'a> CompletionsDataset<'a> {
pub fn new(
data: Vec<Value>,
tokenizer: &'a Tokenizer,
prompt_key: impl Into<String>,
completion_key: impl Into<String>,
mask_prompt: bool,
) -> Self {
Self {
data,
tokenizer,
prompt_key: prompt_key.into(),
completion_key: completion_key.into(),
mask_prompt,
}
}
}
impl Dataset for CompletionsDataset<'_> {
fn len(&self) -> usize {
self.data.len()
}
fn get(&self, idx: usize) -> Result<&Value> {
self.data.get(idx).ok_or_else(|| {
Error::OutOfRange(OutOfRangePayload::new(
"CompletionsDataset: index",
"must be < len",
format_smolstr!("{idx} (len={})", self.data.len()),
))
})
}
fn process(&self, idx: usize) -> Result<Example> {
let record = self.get(idx)?;
let prompt = field_as_str(record, &self.prompt_key, "CompletionsDataset")?;
let completion = field_as_str(record, &self.completion_key, "CompletionsDataset")?;
let tools = record.get("tools");
let messages = serde_json::json!([
{ "role": "user", "content": prompt },
{ "role": "assistant", "content": completion },
]);
let tokens = self
.tokenizer
.apply_chat_template_ids(&messages, tools, false, false, None)?;
if !self.mask_prompt {
return Ok((tokens, 0));
}
let prefix = serde_json::json!([
{ "role": "user", "content": prompt },
]);
let prefix_tokens = self
.tokenizer
.apply_chat_template_ids(&prefix, tools, true, false, None)?;
Ok((tokens, prefix_tokens.len()))
}
}
pub struct ConcatenatedDataset<'a> {
data: Vec<Box<dyn Dataset + 'a>>,
len: usize,
}
impl std::fmt::Debug for ConcatenatedDataset<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConcatenatedDataset")
.field("inner_count", &self.data.len())
.field("len", &self.len)
.finish()
}
}
impl<'a> ConcatenatedDataset<'a> {
pub fn new(data: Vec<Box<dyn Dataset + 'a>>) -> Self {
let len = data.iter().map(|d| d.len()).sum();
Self { data, len }
}
fn resolve(&self, idx: usize) -> Result<(usize, usize)> {
let mut remaining = idx;
for (data_idx, inner) in self.data.iter().enumerate() {
if remaining < inner.len() {
return Ok((data_idx, remaining));
}
remaining -= inner.len();
}
Err(Error::OutOfRange(OutOfRangePayload::new(
"ConcatenatedDataset: index",
"must be < len",
format_smolstr!("{idx} (len={})", self.len),
)))
}
}
impl Dataset for ConcatenatedDataset<'_> {
fn len(&self) -> usize {
self.len
}
fn get(&self, idx: usize) -> Result<&Value> {
let (di, li) = self.resolve(idx)?;
self.data[di].get(li)
}
fn process(&self, idx: usize) -> Result<Example> {
let (di, li) = self.resolve(idx)?;
self.data[di].process(li)
}
}
pub struct CacheDataset<'a> {
data: Box<dyn Dataset + 'a>,
proc_data: RefCell<Vec<Option<Example>>>,
}
impl std::fmt::Debug for CacheDataset<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let cached = self
.proc_data
.try_borrow()
.map(|c| c.iter().filter(|e| e.is_some()).count())
.ok();
f.debug_struct("CacheDataset")
.field("len", &self.data.len())
.field("cached_count", &cached)
.finish()
}
}
impl<'a> CacheDataset<'a> {
pub fn new(data: Box<dyn Dataset + 'a>) -> Self {
let n = data.len();
Self {
data,
proc_data: RefCell::new(vec![None; n]),
}
}
pub fn item_len(&self, idx: usize) -> Result<usize> {
let cached = self.process(idx)?;
Ok(cached.0.len())
}
}
impl Dataset for CacheDataset<'_> {
fn len(&self) -> usize {
self.data.len()
}
fn get(&self, idx: usize) -> Result<&Value> {
self.data.get(idx)
}
fn process(&self, idx: usize) -> Result<Example> {
{
let cache = self.proc_data.borrow();
if let Some(Some(pair)) = cache.get(idx) {
return Ok(pair.clone());
}
}
let computed = self.data.process(idx)?;
let mut cache = self.proc_data.borrow_mut();
if idx >= cache.len() {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"CacheDataset: index",
"must be < len",
format_smolstr!("{idx} (len={})", cache.len()),
)));
}
cache[idx] = Some(computed.clone());
Ok(computed)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, derive_more::Display, derive_more::IsVariant)]
#[display("{}", self.as_str())]
pub enum DatasetType {
Text,
Chat,
Completions,
Auto,
}
impl DatasetType {
pub const fn as_str(&self) -> &'static str {
match self {
DatasetType::Text => "text",
DatasetType::Chat => "chat",
DatasetType::Completions => "completions",
DatasetType::Auto => "auto",
}
}
}
#[derive(Debug, Clone)]
pub struct DatasetConfig {
mask_prompt: bool,
text_feature: String,
chat_feature: String,
prompt_feature: String,
completion_feature: String,
}
impl DatasetConfig {
pub fn new() -> Self {
Self {
mask_prompt: false,
text_feature: DEFAULT_TEXT_KEY.to_owned(),
chat_feature: DEFAULT_CHAT_KEY.to_owned(),
prompt_feature: DEFAULT_PROMPT_KEY.to_owned(),
completion_feature: DEFAULT_COMPLETION_KEY.to_owned(),
}
}
#[inline(always)]
pub fn mask_prompt(&self) -> bool {
self.mask_prompt
}
#[inline(always)]
pub fn text_feature(&self) -> &str {
&self.text_feature
}
#[inline(always)]
pub fn chat_feature(&self) -> &str {
&self.chat_feature
}
#[inline(always)]
pub fn prompt_feature(&self) -> &str {
&self.prompt_feature
}
#[inline(always)]
pub fn completion_feature(&self) -> &str {
&self.completion_feature
}
#[must_use]
pub fn with_mask_prompt(mut self, mask_prompt: bool) -> Self {
self.mask_prompt = mask_prompt;
self
}
#[must_use]
pub fn with_text_feature(mut self, text_feature: impl Into<String>) -> Self {
self.text_feature = text_feature.into();
self
}
#[must_use]
pub fn with_chat_feature(mut self, chat_feature: impl Into<String>) -> Self {
self.chat_feature = chat_feature.into();
self
}
#[must_use]
pub fn with_prompt_feature(mut self, prompt_feature: impl Into<String>) -> Self {
self.prompt_feature = prompt_feature.into();
self
}
#[must_use]
pub fn with_completion_feature(mut self, completion_feature: impl Into<String>) -> Self {
self.completion_feature = completion_feature.into();
self
}
}
impl Default for DatasetConfig {
fn default() -> Self {
Self::new()
}
}
pub fn create_dataset<'a>(
data: Vec<Value>,
tokenizer: &'a Tokenizer,
config: &DatasetConfig,
dataset_type: DatasetType,
) -> Result<Box<dyn Dataset + 'a>> {
let resolved = match dataset_type {
DatasetType::Auto => auto_detect(&data, config)?,
other => other,
};
match resolved {
DatasetType::Text => {
if config.mask_prompt() {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"create_dataset: prompt masking",
"is not supported for text dataset",
)));
}
Ok(Box::new(TextDataset::new(
data,
tokenizer,
config.text_feature().to_owned(),
)))
}
DatasetType::Chat => Ok(Box::new(ChatDataset::new(
data,
tokenizer,
config.chat_feature().to_owned(),
config.mask_prompt(),
))),
DatasetType::Completions => Ok(Box::new(CompletionsDataset::new(
data,
tokenizer,
config.prompt_feature().to_owned(),
config.completion_feature().to_owned(),
config.mask_prompt(),
))),
DatasetType::Auto => unreachable!("auto_detect returned Auto"),
}
}
fn auto_detect(data: &[Value], config: &DatasetConfig) -> Result<DatasetType> {
let sample = data.first().ok_or_else(|| {
Error::EmptyInput(EmptyInputPayload::new(
"create_dataset: jsonl records for auto-detection (pass an explicit DatasetType instead)",
))
})?;
let has = |k: &str| sample.get(k).is_some();
if has(config.prompt_feature()) && has(config.completion_feature()) {
Ok(DatasetType::Completions)
} else if has(config.chat_feature()) {
Ok(DatasetType::Chat)
} else if has(config.text_feature()) {
Ok(DatasetType::Text)
} else {
Err(Error::MalformedData(MalformedDataPayload::new(
"create_dataset: auto-detect",
"Unsupported data format, check the supported formats here:\n\
https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/LORA.md#Data.",
)))
}
}
pub fn load_dataset<'a>(
path: &Path,
tokenizer: &'a Tokenizer,
dataset_type: DatasetType,
config: &DatasetConfig,
) -> Result<CacheDataset<'a>> {
if let Some(s) = path.to_str()
&& (s.starts_with("hf://") || s.starts_with("hf:"))
{
return Err(Error::OutOfRange(OutOfRangePayload::new(
"create_dataset: HF Hub URI rejected (local-only mlxrs build)",
"pass a local jsonl file path instead",
format_smolstr!("{s}"),
)));
}
if !path.exists() {
return Err(Error::FileIo(FileIoPayload::new(
"open jsonl",
FileOp::Open,
::std::path::PathBuf::from(path),
std::io::Error::from(std::io::ErrorKind::NotFound),
)));
}
#[cfg(unix)]
let file = {
use std::os::unix::fs::OpenOptionsExt;
std::fs::OpenOptions::new()
.read(true)
.custom_flags(libc::O_NONBLOCK | libc::O_CLOEXEC)
.open(path)
.map_err(|e| {
Error::FileIo(FileIoPayload::new(
"open jsonl",
FileOp::Open,
::std::path::PathBuf::from(path),
e,
))
})?
};
#[cfg(not(unix))]
let file = std::fs::File::open(path).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"open jsonl",
FileOp::Open,
::std::path::PathBuf::from(path),
e,
))
})?;
let meta = file.metadata().map_err(|e| {
Error::FileIo(FileIoPayload::new(
"stat jsonl",
FileOp::Stat,
::std::path::PathBuf::from(path),
e,
))
})?;
if !meta.is_file() {
return Err(Error::FileIo(FileIoPayload::new(
"open jsonl: not a regular file (directories, sockets, FIFOs etc. are not accepted)",
FileOp::Stat,
::std::path::PathBuf::from(path),
std::io::Error::from(std::io::ErrorKind::InvalidInput),
)));
}
if meta.len() > MAX_DATASET_FILE_BYTES {
return Err(Error::CapExceeded(CapExceededPayload::new(
"open jsonl: file size",
"MAX_DATASET_FILE_BYTES",
MAX_DATASET_FILE_BYTES,
meta.len(),
)));
}
let data = read_jsonl_with_cap(BufReader::new(file), path, MAX_DATASET_FILE_BYTES)?;
if data.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"open jsonl: parsed records (file is empty; skip absent valid.jsonl/test.jsonl at caller level)",
)));
}
let inner = create_dataset(data, tokenizer, config, dataset_type)?;
Ok(CacheDataset::new(inner))
}
fn read_jsonl_with_cap<R: BufRead>(
mut reader: R,
path_for_errors: &Path,
max_bytes: u64,
) -> Result<Vec<Value>> {
let mut data: Vec<Value> = Vec::new();
let mut total_bytes: u64 = 0;
let mut line_buf: Vec<u8> = Vec::with_capacity(4096);
loop {
line_buf.clear();
let remaining = max_bytes.saturating_sub(total_bytes);
if remaining == 0 {
let mut peek = [0u8; 1];
let n = std::io::Read::read(&mut reader, &mut peek).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"read jsonl: probing post-cap bytes",
FileOp::Read,
::std::path::PathBuf::from(path_for_errors),
e,
))
})?;
if n == 0 {
break;
}
return Err(Error::CapExceeded(CapExceededPayload::new(
"read jsonl: cumulative bytes (more bytes remained in reader past the cap)",
"MAX_DATASET_FILE_BYTES",
max_bytes,
total_bytes.saturating_add(1),
)));
}
let cap_this_line = remaining.saturating_add(1);
let mut take = <&mut R as std::io::Read>::take(&mut reader, cap_this_line);
let n = match std::io::BufRead::read_until(&mut take, b'\n', &mut line_buf) {
Ok(n) => n,
Err(e) => {
return Err(Error::FileIo(FileIoPayload::new(
"read jsonl: read_until",
FileOp::Read,
::std::path::PathBuf::from(path_for_errors),
e,
)));
}
};
if n == 0 {
break;
}
total_bytes = total_bytes.saturating_add(n as u64);
if total_bytes > max_bytes {
return Err(Error::CapExceeded(CapExceededPayload::new(
"read jsonl: cumulative bytes (file may have grown mid-read or per-line size is unexpectedly large)",
"MAX_DATASET_FILE_BYTES",
max_bytes,
total_bytes,
)));
}
if line_buf.last() == Some(&b'\n') {
line_buf.pop();
if line_buf.last() == Some(&b'\r') {
line_buf.pop();
}
}
let line = std::str::from_utf8(&line_buf).map_err(|e| {
Error::Parse(ParsePayload::new(
"read jsonl: line is not valid UTF-8",
"jsonl line UTF-8",
Box::new(e) as Box<dyn std::error::Error + Send + Sync>,
))
})?;
let trimmed = line.trim();
if trimmed.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"read jsonl: blank line (every line must be a valid JSON record; matches Python `json.loads(l)` failing on \"\")",
)));
}
let v: Value = serde_json::from_str(trimmed).map_err(|e| {
Error::Parse(ParsePayload::new(
"read jsonl: serde_json::from_str",
"jsonl record",
Box::new(e) as Box<dyn std::error::Error + Send + Sync>,
))
})?;
data.push(v);
}
Ok(data)
}
fn field_as_str<'a>(record: &'a Value, key: &str, type_name: &'static str) -> Result<&'a str> {
let v = record.get(key).ok_or_else(|| {
Error::MissingKey(MissingKeyPayload::new(
type_name,
format_smolstr!("jsonl record missing '{key}'"),
))
})?;
v.as_str().ok_or_else(|| {
Error::OutOfRange(OutOfRangePayload::new(
type_name,
"field must be a JSON string",
format_smolstr!("'{key}'={}", json_kind(v)),
))
})
}
fn json_kind(v: &Value) -> &'static str {
match v {
Value::Null => "null",
Value::Bool(_) => "bool",
Value::Number(_) => "number",
Value::String(_) => "string",
Value::Array(_) => "array",
Value::Object(_) => "object",
}
}
#[cfg(test)]
mod tests;