use std::{
path::PathBuf,
sync::atomic::{AtomicU64, Ordering},
};
use serde_json::json;
use super::*;
struct CountingReader<R: std::io::BufRead> {
inner: R,
consumed: usize,
}
impl<R: std::io::BufRead> CountingReader<R> {
fn new(inner: R) -> Self {
Self { inner, consumed: 0 }
}
fn consumed(&self) -> usize {
self.consumed
}
}
impl<R: std::io::BufRead> std::io::Read for CountingReader<R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let n = self.inner.read(buf)?;
self.consumed += n;
Ok(n)
}
}
impl<R: std::io::BufRead> std::io::BufRead for CountingReader<R> {
fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
self.inner.fill_buf()
}
fn consume(&mut self, amt: usize) {
self.consumed += amt;
self.inner.consume(amt);
}
}
fn fresh_dir(tag: &str) -> PathBuf {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let dir = std::env::temp_dir().join(format!(
"mlxrs-lm-tuner-datasets-{tag}-{}-{n}",
std::process::id()
));
let _ = std::fs::remove_dir_all(&dir);
std::fs::create_dir_all(&dir).unwrap();
dir
}
fn write_tokenizer(dir: &Path) -> Tokenizer {
let tokenizer_json = json!({
"version": "1.0",
"added_tokens": [
{"id": 0, "content": "<unk>", "single_word": false, "lstrip": false,
"rstrip": false, "normalized": false, "special": true},
{"id": 1, "content": "<s>", "single_word": false, "lstrip": false,
"rstrip": false, "normalized": false, "special": true},
{"id": 2, "content": "</s>", "single_word": false, "lstrip": false,
"rstrip": false, "normalized": false, "special": true},
{"id": 7, "content": ":", "single_word": false, "lstrip": false,
"rstrip": false, "normalized": false, "special": true},
{"id": 8, "content": "tools", "single_word": false, "lstrip": false,
"rstrip": false, "normalized": false, "special": true},
],
"normalizer": null,
"pre_tokenizer": { "type": "Whitespace" },
"post_processor": null,
"decoder": null,
"model": {
"type": "WordLevel",
"vocab": {
"<unk>": 0, "<s>": 1, "</s>": 2,
"hello": 3, "world": 4,
"user": 5, "assistant": 6,
":": 7, "tools": 8
},
"unk_token": "<unk>"
}
});
let cfg = json!({
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"chat_template":
"{% for m in messages %}{{ m['role'] }} : {{ m['content'] }} \
{% endfor %}{% if add_generation_prompt %}assistant : {% endif %}"
});
std::fs::write(dir.join("tokenizer.json"), tokenizer_json.to_string()).unwrap();
std::fs::write(dir.join("tokenizer_config.json"), cfg.to_string()).unwrap();
Tokenizer::from_path(dir, None).unwrap()
}
fn tokenizer_fixture(tag: &str) -> Tokenizer {
let dir = fresh_dir(tag);
write_tokenizer(&dir)
}
fn write_jsonl(path: &Path, lines: &[Value]) {
let mut s = String::new();
for v in lines {
s.push_str(&v.to_string());
s.push('\n');
}
std::fs::write(path, s).unwrap();
}
#[test]
fn text_dataset_happy_path_appends_eos_when_missing() {
let tok = tokenizer_fixture("text_happy");
let data = vec![
json!({ "text": "hello world" }),
json!({ "text": "world hello" }),
json!({ "text": "hello" }),
];
let ds = TextDataset::new(data, &tok, DEFAULT_TEXT_KEY);
assert_eq!(ds.len(), 3);
let (toks0, off0) = ds.process(0).unwrap();
let (toks1, off1) = ds.process(1).unwrap();
let (toks2, off2) = ds.process(2).unwrap();
assert_eq!(toks0, vec![3, 4, 2]);
assert_eq!(toks1, vec![4, 3, 2]);
assert_eq!(toks2, vec![3, 2]);
assert_eq!(off0, 0);
assert_eq!(off1, 0);
assert_eq!(off2, 0);
}
#[test]
fn text_dataset_does_not_double_append_eos() {
let tok = tokenizer_fixture("text_no_dup_eos");
let data = vec![json!({ "text": "hello </s>" })];
let ds = TextDataset::new(data, &tok, DEFAULT_TEXT_KEY);
let (toks, _) = ds.process(0).unwrap();
assert_eq!(toks, vec![3, 2]);
}
#[test]
fn text_dataset_missing_field_errors() {
let tok = tokenizer_fixture("text_missing_field");
let data = vec![json!({ "not_text": "hello" })];
let ds = TextDataset::new(data, &tok, DEFAULT_TEXT_KEY);
let err = ds.process(0).unwrap_err();
match err {
Error::MissingKey(p) => assert_eq!(p.key(), "jsonl record missing 'text'"),
other => panic!("expected MissingKey, got: {other:?}"),
}
}
#[test]
fn text_dataset_wrong_type_errors() {
let tok = tokenizer_fixture("text_wrong_type");
let data = vec![json!({ "text": 42 })];
let ds = TextDataset::new(data, &tok, DEFAULT_TEXT_KEY);
let err = ds.process(0).unwrap_err();
assert!(
matches!(err, Error::OutOfRange(_)),
"expected OutOfRange, got: {err:?}"
);
}
#[test]
fn chat_dataset_happy_path_no_mask() {
let tok = tokenizer_fixture("chat_happy_no_mask");
let data = vec![json!({
"messages": [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "world"},
]
})];
let ds = ChatDataset::new(data, &tok, DEFAULT_CHAT_KEY, false);
let (toks, off) = ds.process(0).unwrap();
assert_eq!(toks, vec![5, 7, 3, 6, 7, 4]);
assert_eq!(off, 0);
}
#[test]
fn chat_dataset_mask_prompt_returns_prefix_offset() {
let tok = tokenizer_fixture("chat_mask");
let data = vec![json!({
"messages": [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "world"},
]
})];
let ds = ChatDataset::new(data, &tok, DEFAULT_CHAT_KEY, true);
let (toks, off) = ds.process(0).unwrap();
assert_eq!(toks, vec![5, 7, 3, 6, 7, 4]);
assert_eq!(off, 5);
}
#[test]
fn chat_dataset_missing_messages_errors() {
let tok = tokenizer_fixture("chat_missing");
let data = vec![json!({ "no_messages_field": [] })];
let ds = ChatDataset::new(data, &tok, DEFAULT_CHAT_KEY, false);
let err = ds.process(0).unwrap_err();
match err {
Error::MissingKey(p) => {
assert_eq!(p.key(), "messages");
}
other => panic!("expected MissingKey, got: {other:?}"),
}
}
#[test]
fn chat_dataset_messages_not_array_errors() {
let tok = tokenizer_fixture("chat_not_array");
let data = vec![json!({ "messages": "not an array" })];
let ds = ChatDataset::new(data, &tok, DEFAULT_CHAT_KEY, false);
let err = ds.process(0).unwrap_err();
assert!(
matches!(err, Error::OutOfRange(_)),
"expected OutOfRange, got: {err:?}"
);
}
#[test]
fn completions_dataset_happy_path_no_mask() {
let tok = tokenizer_fixture("comp_happy_no_mask");
let data = vec![json!({ "prompt": "hello", "completion": "world" })];
let ds = CompletionsDataset::new(
data,
&tok,
DEFAULT_PROMPT_KEY,
DEFAULT_COMPLETION_KEY,
false,
);
let (toks, off) = ds.process(0).unwrap();
assert_eq!(toks, vec![5, 7, 3, 6, 7, 4]);
assert_eq!(off, 0);
}
#[test]
fn completions_dataset_mask_prompt_returns_prefix_offset() {
let tok = tokenizer_fixture("comp_mask");
let data = vec![json!({ "prompt": "hello", "completion": "world" })];
let ds = CompletionsDataset::new(data, &tok, DEFAULT_PROMPT_KEY, DEFAULT_COMPLETION_KEY, true);
let (toks, off) = ds.process(0).unwrap();
assert_eq!(toks, vec![5, 7, 3, 6, 7, 4]);
assert_eq!(off, 5);
}
#[test]
fn completions_dataset_missing_prompt_errors() {
let tok = tokenizer_fixture("comp_missing_prompt");
let data = vec![json!({ "completion": "world" })];
let ds = CompletionsDataset::new(
data,
&tok,
DEFAULT_PROMPT_KEY,
DEFAULT_COMPLETION_KEY,
false,
);
let err = ds.process(0).unwrap_err();
match err {
Error::MissingKey(p) => assert_eq!(p.key(), "jsonl record missing 'prompt'"),
other => panic!("expected MissingKey, got: {other:?}"),
}
}
#[test]
fn concatenated_dataset_indexes_across_inner_in_order() {
let tok = tokenizer_fixture("concat_indexes");
let a = TextDataset::new(vec![json!({ "text": "hello" })], &tok, DEFAULT_TEXT_KEY);
let b = TextDataset::new(
vec![json!({ "text": "world" }), json!({ "text": "hello world" })],
&tok,
DEFAULT_TEXT_KEY,
);
let cat = ConcatenatedDataset::new(vec![Box::new(a), Box::new(b)]);
assert_eq!(cat.len(), 3);
assert_eq!(cat.process(0).unwrap().0, vec![3, 2]);
assert_eq!(cat.process(1).unwrap().0, vec![4, 2]);
assert_eq!(cat.process(2).unwrap().0, vec![3, 4, 2]);
}
#[test]
fn concatenated_dataset_out_of_range_errors() {
let tok = tokenizer_fixture("concat_oor");
let a = TextDataset::new(vec![json!({ "text": "hello" })], &tok, DEFAULT_TEXT_KEY);
let cat = ConcatenatedDataset::new(vec![Box::new(a)]);
assert!(cat.process(7).is_err());
}
#[test]
fn concatenated_dataset_empty_inputs_yield_empty_dataset() {
let tok = tokenizer_fixture("concat_empty");
let a = TextDataset::new(vec![], &tok, DEFAULT_TEXT_KEY);
let b = TextDataset::new(vec![], &tok, DEFAULT_TEXT_KEY);
let cat = ConcatenatedDataset::new(vec![Box::new(a), Box::new(b)]);
assert_eq!(cat.len(), 0);
assert!(cat.is_empty());
}
#[test]
fn cache_dataset_returns_consistent_result_on_repeat() {
let tok = tokenizer_fixture("cache_repeat");
let inner = TextDataset::new(
vec![json!({ "text": "hello" }), json!({ "text": "world" })],
&tok,
DEFAULT_TEXT_KEY,
);
let cache = CacheDataset::new(Box::new(inner));
assert_eq!(cache.len(), 2);
let first = cache.process(0).unwrap();
let second = cache.process(0).unwrap();
assert_eq!(first, second);
assert_eq!(cache.item_len(1).unwrap(), 2); }
#[test]
fn cache_dataset_out_of_range_errors() {
let tok = tokenizer_fixture("cache_oor");
let inner = TextDataset::new(vec![json!({ "text": "hello" })], &tok, DEFAULT_TEXT_KEY);
let cache = CacheDataset::new(Box::new(inner));
assert!(cache.process(99).is_err());
}
#[test]
fn load_dataset_text() {
let tok = tokenizer_fixture("load_text");
let dir = fresh_dir("load_text_data");
let p = dir.join("train.jsonl");
write_jsonl(
&p,
&[json!({ "text": "hello" }), json!({ "text": "world" })],
);
let ds = load_dataset(&p, &tok, DatasetType::Text, &DatasetConfig::default()).unwrap();
assert_eq!(ds.len(), 2);
assert_eq!(ds.process(0).unwrap().0, vec![3, 2]);
assert_eq!(ds.process(1).unwrap().0, vec![4, 2]);
}
#[test]
fn load_dataset_chat() {
let tok = tokenizer_fixture("load_chat");
let dir = fresh_dir("load_chat_data");
let p = dir.join("train.jsonl");
write_jsonl(
&p,
&[json!({
"messages": [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "world"},
]
})],
);
let ds = load_dataset(&p, &tok, DatasetType::Chat, &DatasetConfig::default()).unwrap();
assert_eq!(ds.len(), 1);
assert_eq!(ds.process(0).unwrap().0, vec![5, 7, 3, 6, 7, 4]);
}
#[test]
fn load_dataset_completions() {
let tok = tokenizer_fixture("load_comp");
let dir = fresh_dir("load_comp_data");
let p = dir.join("train.jsonl");
write_jsonl(&p, &[json!({ "prompt": "hello", "completion": "world" })]);
let ds = load_dataset(
&p,
&tok,
DatasetType::Completions,
&DatasetConfig::default(),
)
.unwrap();
assert_eq!(ds.len(), 1);
assert_eq!(ds.process(0).unwrap().0, vec![5, 7, 3, 6, 7, 4]);
}
#[test]
fn load_dataset_concatenated() {
let tok = tokenizer_fixture("load_concat");
let dir = fresh_dir("load_concat_data");
let p1 = dir.join("train.jsonl");
let p2 = dir.join("valid.jsonl");
write_jsonl(&p1, &[json!({ "text": "hello" })]);
write_jsonl(&p2, &[json!({ "text": "world" })]);
let a = load_dataset(&p1, &tok, DatasetType::Text, &DatasetConfig::default()).unwrap();
let b = load_dataset(&p2, &tok, DatasetType::Text, &DatasetConfig::default()).unwrap();
let cat = ConcatenatedDataset::new(vec![Box::new(a), Box::new(b)]);
assert_eq!(cat.len(), 2);
assert_eq!(cat.process(0).unwrap().0, vec![3, 2]);
assert_eq!(cat.process(1).unwrap().0, vec![4, 2]);
}
#[test]
fn load_dataset_cache() {
let tok = tokenizer_fixture("load_cache");
let dir = fresh_dir("load_cache_data");
let p = dir.join("train.jsonl");
write_jsonl(&p, &[json!({ "text": "hello" })]);
let ds = load_dataset(&p, &tok, DatasetType::Auto, &DatasetConfig::default()).unwrap();
let first = ds.process(0).unwrap();
let second = ds.process(0).unwrap();
assert_eq!(first, second);
}
#[test]
fn load_dataset_auto_detects_completions_first() {
let tok = tokenizer_fixture("load_auto_comp");
let dir = fresh_dir("load_auto_comp_data");
let p = dir.join("train.jsonl");
write_jsonl(
&p,
&[json!({ "prompt": "hello", "completion": "world", "text": "ignored" })],
);
let ds = load_dataset(&p, &tok, DatasetType::Auto, &DatasetConfig::default()).unwrap();
assert_eq!(ds.len(), 1);
assert_eq!(ds.process(0).unwrap().0, vec![5, 7, 3, 6, 7, 4]);
}
#[test]
fn load_dataset_auto_unsupported_format_errors() {
let tok = tokenizer_fixture("load_auto_bad");
let dir = fresh_dir("load_auto_bad_data");
let p = dir.join("train.jsonl");
write_jsonl(&p, &[json!({ "irrelevant": "junk" })]);
let err = load_dataset(&p, &tok, DatasetType::Auto, &DatasetConfig::default()).unwrap_err();
assert!(
err.to_string().contains("Unsupported data format"),
"got: {err}"
);
}
#[test]
fn load_dataset_rejects_hf_hub_path() {
let tok = tokenizer_fixture("load_hf");
let p = PathBuf::from("hf://datasets/mlx-community/some-dataset");
let err = load_dataset(&p, &tok, DatasetType::Auto, &DatasetConfig::default()).unwrap_err();
match err {
Error::OutOfRange(p) => assert!(p.context().contains("HF Hub URI rejected"), "got: {p:?}"),
other => panic!("expected OutOfRange, got: {other:?}"),
}
}
#[test]
fn load_dataset_text_with_mask_prompt_errors() {
let tok = tokenizer_fixture("load_text_mask_err");
let dir = fresh_dir("load_text_mask_err_data");
let p = dir.join("train.jsonl");
write_jsonl(&p, &[json!({ "text": "hello" })]);
let cfg = DatasetConfig::new().with_mask_prompt(true);
let err = load_dataset(&p, &tok, DatasetType::Text, &cfg).unwrap_err();
assert!(
err.to_string().contains("not supported for text dataset"),
"got: {err}"
);
}
#[test]
fn load_dataset_empty_file_errors_with_path() {
let tok = tokenizer_fixture("load_empty");
let dir = fresh_dir("load_empty_data");
let p = dir.join("train.jsonl");
std::fs::write(&p, "").unwrap();
let err = load_dataset(&p, &tok, DatasetType::Auto, &DatasetConfig::default()).unwrap_err();
assert!(
matches!(err, Error::EmptyInput(_)),
"expected EmptyInput rejection, got: {err:?}",
);
let s = err.to_string();
let _ = &s;
assert!(
s.contains("empty"),
"expected 'empty' in error message, got: {s}",
);
}
#[test]
fn load_dataset_blank_line_errors_with_line_number() {
let tok = tokenizer_fixture("load_blank");
let dir = fresh_dir("load_blank_data");
let p = dir.join("train.jsonl");
std::fs::write(&p, "{\"text\": \"hello\"}\n\n{\"text\": \"world\"}\n").unwrap();
let err = load_dataset(&p, &tok, DatasetType::Text, &DatasetConfig::default()).unwrap_err();
assert!(
matches!(err, Error::EmptyInput(_)),
"expected EmptyInput on blank line, got: {err:?}",
);
let s = err.to_string();
assert!(
s.contains("blank"),
"expected 'blank' in blank-line error, got: {s}",
);
}
#[test]
fn load_dataset_rejects_non_regular_file() {
let tok = tokenizer_fixture("load_dir");
let dir = fresh_dir("load_dir_data");
let err = load_dataset(&dir, &tok, DatasetType::Auto, &DatasetConfig::default()).unwrap_err();
let s = err.to_string();
assert!(
s.contains("not a regular file"),
"expected non-regular-file rejection, got: {s}",
);
}
#[test]
fn load_dataset_cap_enforced_during_read_loop() {
use std::io::Cursor;
let cap: u64 = 40;
let body = "{\"text\": \"aaa\"}\n{\"text\": \"bbb\"}\n{\"text\": \"ccc\"}\n";
let path = std::path::PathBuf::from("/synthetic/grows.jsonl");
let err = read_jsonl_with_cap(Cursor::new(body), &path, cap).unwrap_err();
match err {
Error::CapExceeded(p) => {
assert_eq!(p.cap(), 40);
assert_eq!(p.cap_name(), "MAX_DATASET_FILE_BYTES");
assert!(p.observed() > 40, "observed must exceed cap, got: {p:?}");
assert!(
p.context().contains("read jsonl"),
"expected read-jsonl context, got: {p:?}"
);
}
other => panic!("expected CapExceeded, got: {other:?}"),
}
}
#[test]
fn load_dataset_malformed_line_errors_with_line_number() {
let tok = tokenizer_fixture("load_malformed");
let dir = fresh_dir("load_malformed_data");
let p = dir.join("train.jsonl");
std::fs::write(
&p,
"{\"text\": \"hello\"}\n{this is not json}\n{\"text\": \"world\"}\n",
)
.unwrap();
let err = load_dataset(&p, &tok, DatasetType::Text, &DatasetConfig::default()).unwrap_err();
assert!(
matches!(err, Error::Parse(_)),
"expected Parse error on malformed line, got: {err:?}",
);
}
#[test]
fn load_dataset_nonexistent_path_errors() {
let tok = tokenizer_fixture("load_nopath");
let p = std::env::temp_dir().join(format!(
"mlxrs-a6-does-not-exist-{}.jsonl",
std::process::id()
));
let err = load_dataset(&p, &tok, DatasetType::Auto, &DatasetConfig::default()).unwrap_err();
match err {
Error::FileIo(p) => {
assert_eq!(p.inner().kind(), std::io::ErrorKind::NotFound);
}
other => panic!("expected FileIo NotFound, got: {other:?}"),
}
}
#[test]
fn cache_dataset_invalidates_on_source_mtime_change() {
let tok = tokenizer_fixture("cache_mtime");
let dir = fresh_dir("cache_mtime_data");
let p = dir.join("train.jsonl");
write_jsonl(&p, &[json!({ "text": "hello" })]);
let first = {
let ds = load_dataset(&p, &tok, DatasetType::Text, &DatasetConfig::default()).unwrap();
ds.process(0).unwrap()
};
assert_eq!(first.0, vec![3, 2]);
std::thread::sleep(std::time::Duration::from_millis(10));
write_jsonl(&p, &[json!({ "text": "world" })]);
let second = {
let ds = load_dataset(&p, &tok, DatasetType::Text, &DatasetConfig::default()).unwrap();
ds.process(0).unwrap()
};
assert_eq!(second.0, vec![4, 2]); assert_ne!(first.0, second.0);
}
#[cfg(unix)]
#[test]
fn load_dataset_rejects_fifo_without_blocking() {
use std::{os::unix::ffi::OsStrExt, sync::mpsc};
let dir = fresh_dir("load_fifo");
let tok = tokenizer_fixture("load_fifo_tok");
let path = dir.join("train.jsonl");
let c_path = std::ffi::CString::new(path.as_os_str().as_bytes()).unwrap();
let rc = unsafe { libc::mkfifo(c_path.as_ptr(), 0o600) };
assert_eq!(rc, 0, "mkfifo failed (rc {rc})");
let (tx, rx) = mpsc::channel();
let handle = std::thread::spawn(move || {
let r = load_dataset(&path, &tok, DatasetType::Auto, &DatasetConfig::default());
let msg = match &r {
Err(Error::FileIo(p)) => Some(p.to_string()),
_ => None,
};
let _ = tx.send(msg);
});
match rx.recv_timeout(std::time::Duration::from_secs(2)) {
Ok(Some(msg)) => {
handle.join().unwrap();
assert!(
msg.contains("not a regular file"),
"FIFO at dataset path must yield 'not a regular file' \
rejection, got: {msg}",
);
}
Ok(None) => {
handle.join().unwrap();
panic!(
"FIFO at dataset path must yield Err(FileIo), got a \
different result"
);
}
Err(_) => {
std::fs::remove_dir_all(&dir).ok();
panic!(
"load_dataset HUNG on a writer-less FIFO — the O_NONBLOCK \
open regressed"
);
}
}
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn load_dataset_cap_enforced_on_single_giant_line() {
use std::io::{BufReader, Cursor};
let body: Vec<u8> = vec![b'a'; 100];
let cap: u64 = 40;
let path = std::path::PathBuf::from("/synthetic/giant.jsonl");
let mut counting = CountingReader::new(BufReader::new(Cursor::new(body)));
let err = read_jsonl_with_cap(&mut counting, &path, cap).unwrap_err();
match &err {
Error::CapExceeded(p) => {
assert_eq!(p.cap(), 40);
assert_eq!(p.cap_name(), "MAX_DATASET_FILE_BYTES");
assert!(
p.observed() >= cap,
"observed must be at or above cap, got: {p:?}"
);
assert!(
p.context().contains("read jsonl"),
"expected read-jsonl context, got: {p:?}"
);
}
other => panic!("expected CapExceeded, got: {other:?}"),
}
let consumed = counting.consumed();
assert!(
consumed <= (cap as usize) + 1,
"take(remaining + 1) allocation cap violated: consumed {consumed} bytes from a 100-byte \
fixture with cap={cap} (expected <= {}); a lines() impl would consume 100",
cap as usize + 1,
);
let small_body = b"{\"text\":\"abc\"}".to_vec();
let v = read_jsonl_with_cap(Cursor::new(small_body), &path, cap).unwrap();
assert_eq!(v.len(), 1);
assert_eq!(v[0]["text"].as_str(), Some("abc"));
}
#[test]
fn text_dataset_debug_reports_len_and_key() {
let tok = tokenizer_fixture("text_debug");
let ds = TextDataset::new(
vec![json!({ "text": "hello" }), json!({ "text": "world" })],
&tok,
"body",
);
let s = format!("{ds:?}");
assert!(s.contains("TextDataset"), "got: {s}");
assert!(s.contains("len: 2"), "got: {s}");
assert!(s.contains("text_key: \"body\""), "got: {s}");
}
#[test]
fn chat_dataset_debug_reports_len_key_and_mask() {
let tok = tokenizer_fixture("chat_debug");
let ds = ChatDataset::new(vec![json!({ "messages": [] })], &tok, "msgs", true);
let s = format!("{ds:?}");
assert!(s.contains("ChatDataset"), "got: {s}");
assert!(s.contains("len: 1"), "got: {s}");
assert!(s.contains("chat_key: \"msgs\""), "got: {s}");
assert!(s.contains("mask_prompt: true"), "got: {s}");
}
#[test]
fn completions_dataset_debug_reports_all_fields() {
let tok = tokenizer_fixture("comp_debug");
let ds = CompletionsDataset::new(
vec![json!({ "prompt": "p", "completion": "c" })],
&tok,
"pk",
"ck",
false,
);
let s = format!("{ds:?}");
assert!(s.contains("CompletionsDataset"), "got: {s}");
assert!(s.contains("len: 1"), "got: {s}");
assert!(s.contains("prompt_key: \"pk\""), "got: {s}");
assert!(s.contains("completion_key: \"ck\""), "got: {s}");
assert!(s.contains("mask_prompt: false"), "got: {s}");
}
#[test]
fn concatenated_dataset_debug_reports_inner_count_and_len() {
let tok = tokenizer_fixture("concat_debug");
let a = TextDataset::new(vec![json!({ "text": "hello" })], &tok, DEFAULT_TEXT_KEY);
let b = TextDataset::new(
vec![json!({ "text": "world" }), json!({ "text": "hello" })],
&tok,
DEFAULT_TEXT_KEY,
);
let cat = ConcatenatedDataset::new(vec![Box::new(a), Box::new(b)]);
let s = format!("{cat:?}");
assert!(s.contains("ConcatenatedDataset"), "got: {s}");
assert!(s.contains("inner_count: 2"), "got: {s}");
assert!(s.contains("len: 3"), "got: {s}");
}
#[test]
fn cache_dataset_debug_reports_len_and_cached_count() {
let tok = tokenizer_fixture("cache_debug");
let inner = TextDataset::new(
vec![json!({ "text": "hello" }), json!({ "text": "world" })],
&tok,
DEFAULT_TEXT_KEY,
);
let cache = CacheDataset::new(Box::new(inner));
let s0 = format!("{cache:?}");
assert!(s0.contains("CacheDataset"), "got: {s0}");
assert!(s0.contains("len: 2"), "got: {s0}");
assert!(s0.contains("cached_count: Some(0)"), "got: {s0}");
let _ = cache.process(0).unwrap();
let s1 = format!("{cache:?}");
assert!(s1.contains("cached_count: Some(1)"), "got: {s1}");
}
#[test]
fn chat_dataset_get_out_of_range_errors() {
let tok = tokenizer_fixture("chat_get_oor");
let ds = ChatDataset::new(
vec![json!({ "messages": [] })],
&tok,
DEFAULT_CHAT_KEY,
false,
);
assert!(ds.get(0).is_ok());
let err = ds.get(1).unwrap_err();
match err {
Error::OutOfRange(p) => {
assert_eq!(p.context(), "ChatDataset: index");
assert!(p.value().contains("1 (len=1)"), "got: {p:?}");
}
other => panic!("expected OutOfRange, got: {other:?}"),
}
}
#[test]
fn completions_dataset_get_out_of_range_errors() {
let tok = tokenizer_fixture("comp_get_oor");
let ds = CompletionsDataset::new(
vec![json!({ "prompt": "p", "completion": "c" })],
&tok,
DEFAULT_PROMPT_KEY,
DEFAULT_COMPLETION_KEY,
false,
);
assert!(ds.get(0).is_ok());
let err = ds.get(5).unwrap_err();
match err {
Error::OutOfRange(p) => {
assert_eq!(p.context(), "CompletionsDataset: index");
assert!(p.value().contains("5 (len=1)"), "got: {p:?}");
}
other => panic!("expected OutOfRange, got: {other:?}"),
}
}
#[test]
fn concatenated_dataset_get_routes_to_inner() {
let tok = tokenizer_fixture("concat_get");
let a = TextDataset::new(vec![json!({ "text": "hello" })], &tok, DEFAULT_TEXT_KEY);
let b = TextDataset::new(
vec![json!({ "text": "world" }), json!({ "text": "extra" })],
&tok,
DEFAULT_TEXT_KEY,
);
let cat = ConcatenatedDataset::new(vec![Box::new(a), Box::new(b)]);
assert_eq!(cat.get(0).unwrap()["text"].as_str(), Some("hello"));
assert_eq!(cat.get(1).unwrap()["text"].as_str(), Some("world"));
assert_eq!(cat.get(2).unwrap()["text"].as_str(), Some("extra"));
let err = cat.get(3).unwrap_err();
match err {
Error::OutOfRange(p) => {
assert_eq!(p.context(), "ConcatenatedDataset: index");
assert!(p.value().contains("3 (len=3)"), "got: {p:?}");
}
other => panic!("expected OutOfRange, got: {other:?}"),
}
}
#[test]
fn cache_dataset_get_routes_to_inner() {
let tok = tokenizer_fixture("cache_get");
let inner = TextDataset::new(vec![json!({ "text": "hello" })], &tok, DEFAULT_TEXT_KEY);
let cache = CacheDataset::new(Box::new(inner));
assert_eq!(cache.get(0).unwrap()["text"].as_str(), Some("hello"));
assert!(cache.get(1).is_err());
}
struct OverreachingDataset;
impl Dataset for OverreachingDataset {
fn len(&self) -> usize {
0
}
fn get(&self, _idx: usize) -> Result<&Value> {
Err(Error::OutOfRange(OutOfRangePayload::new(
"OverreachingDataset: get",
"unused",
"0",
)))
}
fn process(&self, _idx: usize) -> Result<Example> {
Ok((vec![1, 2, 3], 0))
}
}
#[test]
fn cache_dataset_process_rejects_index_past_presized_cache() {
let cache = CacheDataset::new(Box::new(OverreachingDataset));
assert_eq!(cache.len(), 0);
let err = cache.process(0).unwrap_err();
match err {
Error::OutOfRange(p) => {
assert_eq!(p.context(), "CacheDataset: index");
assert!(p.value().contains("0 (len=0)"), "got: {p:?}");
}
other => panic!("expected OutOfRange from post-compute cache guard, got: {other:?}"),
}
}
#[test]
fn dataset_type_as_str_and_display_cover_all_variants() {
assert_eq!(DatasetType::Text.as_str(), "text");
assert_eq!(DatasetType::Chat.as_str(), "chat");
assert_eq!(DatasetType::Completions.as_str(), "completions");
assert_eq!(DatasetType::Auto.as_str(), "auto");
assert_eq!(DatasetType::Text.to_string(), "text");
assert_eq!(DatasetType::Chat.to_string(), "chat");
assert_eq!(DatasetType::Completions.to_string(), "completions");
assert_eq!(DatasetType::Auto.to_string(), "auto");
}
#[test]
fn dataset_config_builders_override_each_feature() {
let cfg = DatasetConfig::new()
.with_text_feature("body")
.with_chat_feature("turns")
.with_prompt_feature("question")
.with_completion_feature("answer")
.with_mask_prompt(true);
assert_eq!(cfg.text_feature(), "body");
assert_eq!(cfg.chat_feature(), "turns");
assert_eq!(cfg.prompt_feature(), "question");
assert_eq!(cfg.completion_feature(), "answer");
assert!(cfg.mask_prompt());
let d = DatasetConfig::default();
assert_eq!(d.text_feature(), "text");
assert_eq!(d.chat_feature(), "messages");
assert_eq!(d.prompt_feature(), "prompt");
assert_eq!(d.completion_feature(), "completion");
assert!(!d.mask_prompt());
}
#[test]
fn create_dataset_honors_custom_features_per_type() {
let tok = tokenizer_fixture("create_custom");
let cfg_text = DatasetConfig::new().with_text_feature("body");
let text_ds = create_dataset(
vec![json!({ "body": "hello" })],
&tok,
&cfg_text,
DatasetType::Text,
)
.unwrap();
assert_eq!(text_ds.len(), 1);
assert_eq!(text_ds.process(0).unwrap(), (vec![3, 2], 0));
let chat_ds = create_dataset(
vec![json!({ "messages": [{"role": "user", "content": "hello"}] })],
&tok,
&DatasetConfig::default(),
DatasetType::Chat,
)
.unwrap();
assert_eq!(chat_ds.len(), 1);
assert_eq!(chat_ds.process(0).unwrap(), (vec![5, 7, 3], 0));
let cfg_comp = DatasetConfig::new()
.with_prompt_feature("q")
.with_completion_feature("a");
let comp_ds = create_dataset(
vec![json!({ "q": "hello", "a": "world" })],
&tok,
&cfg_comp,
DatasetType::Completions,
)
.unwrap();
assert_eq!(comp_ds.len(), 1);
assert_eq!(comp_ds.process(0).unwrap(), (vec![5, 7, 3, 6, 7, 4], 0));
}
#[test]
fn create_dataset_auto_empty_records_errors() {
let tok = tokenizer_fixture("auto_empty");
let Err(err) = create_dataset(vec![], &tok, &DatasetConfig::default(), DatasetType::Auto) else {
panic!("expected EmptyInput error from empty auto-detect, got Ok(dataset)");
};
match err {
Error::EmptyInput(p) => {
assert!(
p.context().contains("auto-detection"),
"expected auto-detection context, got: {p:?}"
);
}
other => panic!("expected EmptyInput, got: {other:?}"),
}
}
#[test]
fn create_dataset_auto_detects_chat_when_only_messages_present() {
let tok = tokenizer_fixture("auto_chat");
let ds = create_dataset(
vec![json!({ "messages": [{"role": "user", "content": "hello"}] })],
&tok,
&DatasetConfig::default(),
DatasetType::Auto,
)
.unwrap();
assert_eq!(ds.len(), 1);
assert_eq!(ds.process(0).unwrap(), (vec![5, 7, 3], 0));
}
#[test]
fn create_dataset_auto_detects_text_when_only_text_present() {
let tok = tokenizer_fixture("auto_text");
let ds = create_dataset(
vec![json!({ "text": "hello world" })],
&tok,
&DatasetConfig::default(),
DatasetType::Auto,
)
.unwrap();
assert_eq!(ds.len(), 1);
assert_eq!(ds.process(0).unwrap(), (vec![3, 4, 2], 0));
}
#[test]
fn text_dataset_array_valued_field_reports_array_kind() {
let tok = tokenizer_fixture("text_array_kind");
let data = vec![json!({ "text": [1, 2, 3] })];
let ds = TextDataset::new(data, &tok, DEFAULT_TEXT_KEY);
let err = ds.process(0).unwrap_err();
match err {
Error::OutOfRange(p) => {
assert_eq!(p.requirement(), "field must be a JSON string");
assert!(p.value().contains("=array"), "got: {p:?}");
}
other => panic!("expected OutOfRange, got: {other:?}"),
}
}
#[test]
fn chat_dataset_non_array_messages_report_json_kind() {
let tok = tokenizer_fixture("chat_kinds");
for (field, kind) in [
(json!(null), "null"),
(json!(true), "bool"),
(json!(7), "number"),
(json!({ "a": 1 }), "object"),
] {
let data = vec![json!({ "messages": field })];
let ds = ChatDataset::new(data, &tok, DEFAULT_CHAT_KEY, false);
let err = ds.process(0).unwrap_err();
match err {
Error::OutOfRange(p) => {
let needle = format!("messages={kind}");
assert!(
p.value().contains(needle.as_str()),
"expected json_kind '{kind}', got: {p:?}"
);
}
other => panic!("expected OutOfRange for kind {kind}, got: {other:?}"),
}
}
}
#[test]
fn read_jsonl_with_cap_exact_cap_at_eof_succeeds() {
use std::io::Cursor;
let body = "{\"text\":\"a\"}\n";
let cap = body.len() as u64; let path = std::path::PathBuf::from("/synthetic/exact.jsonl");
let v = read_jsonl_with_cap(Cursor::new(body), &path, cap).unwrap();
assert_eq!(v.len(), 1);
assert_eq!(v[0]["text"].as_str(), Some("a"));
}
#[test]
fn read_jsonl_with_cap_zero_remaining_with_pending_bytes_errors() {
use std::io::Cursor;
let line1 = "{\"text\":\"a\"}\n";
let cap = line1.len() as u64; let body = format!("{line1}{{\"text\":\"b\"}}\n");
let path = std::path::PathBuf::from("/synthetic/pending.jsonl");
let err = read_jsonl_with_cap(Cursor::new(body), &path, cap).unwrap_err();
match err {
Error::CapExceeded(p) => {
assert_eq!(p.cap(), cap);
assert_eq!(p.cap_name(), "MAX_DATASET_FILE_BYTES");
assert_eq!(p.observed(), cap + 1);
assert!(
p.context().contains("more bytes remained"),
"expected post-cap probe context, got: {p:?}"
);
}
other => panic!("expected CapExceeded (post-cap probe), got: {other:?}"),
}
}
#[test]
fn read_jsonl_with_cap_strips_crlf_line_ending() {
use std::io::Cursor;
let body = "{\"text\":\"a\"}\r\n";
let path = std::path::PathBuf::from("/synthetic/crlf.jsonl");
let v = read_jsonl_with_cap(Cursor::new(body), &path, 1000).unwrap();
assert_eq!(v.len(), 1);
assert_eq!(v[0]["text"].as_str(), Some("a"));
}
#[test]
fn read_jsonl_with_cap_invalid_utf8_line_errors_parse() {
use std::io::Cursor;
let body: Vec<u8> = vec![0xFF, 0xFE, b'\n'];
let path = std::path::PathBuf::from("/synthetic/badutf8.jsonl");
let err = read_jsonl_with_cap(Cursor::new(body), &path, 1000).unwrap_err();
match err {
Error::Parse(p) => {
assert!(
p.context().contains("not valid UTF-8"),
"expected UTF-8 parse context, got: {p:?}"
);
}
other => panic!("expected Parse (invalid UTF-8), got: {other:?}"),
}
}
struct ErroringReader;
impl std::io::Read for ErroringReader {
fn read(&mut self, _buf: &mut [u8]) -> std::io::Result<usize> {
Err(std::io::Error::other("synthetic read failure"))
}
}
impl std::io::BufRead for ErroringReader {
fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
Err(std::io::Error::other("synthetic fill_buf failure"))
}
fn consume(&mut self, _amt: usize) {}
}
#[test]
fn read_jsonl_with_cap_read_until_error_surfaces_fileio() {
let path = std::path::PathBuf::from("/synthetic/ioerror.jsonl");
let err = read_jsonl_with_cap(ErroringReader, &path, 100).unwrap_err();
match err {
Error::FileIo(p) => {
assert_eq!(p.op(), FileOp::Read);
assert!(
p.context().contains("read_until"),
"expected read_until context, got: {p:?}"
);
}
other => panic!("expected FileIo(Read) from read_until error, got: {other:?}"),
}
}
#[cfg(unix)]
#[test]
fn load_dataset_open_permission_denied_errors() {
use std::os::unix::fs::PermissionsExt;
let tok = tokenizer_fixture("load_perm_tok");
let dir = fresh_dir("load_perm");
let p = dir.join("train.jsonl");
std::fs::write(&p, "{\"text\": \"hello\"}\n").unwrap();
std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o000)).unwrap();
let readable = std::fs::File::open(&p).is_ok();
let res = load_dataset(&p, &tok, DatasetType::Auto, &DatasetConfig::default());
let _ = std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o600));
std::fs::remove_dir_all(&dir).ok();
if readable {
return;
}
match res {
Err(Error::FileIo(payload)) => {
assert_eq!(payload.op(), FileOp::Open);
assert_eq!(payload.inner().kind(), std::io::ErrorKind::PermissionDenied);
}
other => panic!("expected FileIo(Open, PermissionDenied), got: {other:?}"),
}
}