1use crate::byte_dict::{ByteDictionary, CategoryOrdering};
40use crate::detcoll::SortedVecMap;
41use crate::{Column, DataFrame, TidyError, TidyView};
42use cjc_repro::Rng;
43use cjc_runtime::tensor::Tensor;
44
45#[derive(Debug, Clone, PartialEq)]
50pub enum DatasetError {
51 UnknownColumn(String),
52 UnsupportedColumnType {
53 column: String,
54 type_name: &'static str,
55 },
56 EncodingMismatch {
57 column: String,
58 encoding: &'static str,
59 column_type: &'static str,
60 },
61 NullCategorical {
64 column: String,
65 row: u32,
66 },
67 EmptySplit(Split),
68 InvalidFractions {
70 train: f64,
71 val: f64,
72 test: f64,
73 },
74 BadBatchSize(usize),
75 NoFeatures,
76 OrphanEncoding(String),
79 Tidy(String),
80 Shape(String),
81}
82
83impl std::fmt::Display for DatasetError {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 DatasetError::UnknownColumn(c) => write!(f, "unknown column `{c}`"),
87 DatasetError::UnsupportedColumnType { column, type_name } => write!(
88 f,
89 "column `{column}` has type `{type_name}` which is not supported"
90 ),
91 DatasetError::EncodingMismatch {
92 column,
93 encoding,
94 column_type,
95 } => write!(
96 f,
97 "column `{column}` (type `{column_type}`) cannot be encoded as `{encoding}`"
98 ),
99 DatasetError::NullCategorical { column, row } => {
100 write!(f, "null value in categorical column `{column}` at row {row}")
101 }
102 DatasetError::EmptySplit(s) => write!(f, "split `{s:?}` is empty"),
103 DatasetError::InvalidFractions { train, val, test } => write!(
104 f,
105 "invalid split fractions train={train}, val={val}, test={test} \
106 (each must be in [0,1] and sum ≤ 1)"
107 ),
108 DatasetError::BadBatchSize(n) => write!(f, "batch_size must be ≥ 1 (got {n})"),
109 DatasetError::NoFeatures => write!(f, "no feature columns specified"),
110 DatasetError::OrphanEncoding(c) => {
111 write!(f, "encoding registered for column `{c}` but it is neither a feature nor the label")
112 }
113 DatasetError::Tidy(m) => write!(f, "tidy error: {m}"),
114 DatasetError::Shape(m) => write!(f, "shape error: {m}"),
115 }
116 }
117}
118
119impl std::error::Error for DatasetError {}
120
121impl From<TidyError> for DatasetError {
122 fn from(e: TidyError) -> Self {
123 DatasetError::Tidy(format!("{e:?}"))
124 }
125}
126
127#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
132pub enum Split {
133 Train,
134 Val,
135 Test,
136 Full,
137}
138
139#[derive(Debug, Clone, PartialEq)]
140pub enum SplitSpec {
141 Full,
145 Sequential { train: f64, val: f64, test: f64 },
149 Hashed {
153 seed: u64,
154 train: f64,
155 val: f64,
156 test: f64,
157 },
158}
159
160impl SplitSpec {
161 fn validate(&self) -> Result<(), DatasetError> {
162 let (t, v, te) = match self {
163 SplitSpec::Full => return Ok(()),
164 SplitSpec::Sequential { train, val, test } => (*train, *val, *test),
165 SplitSpec::Hashed {
166 train, val, test, ..
167 } => (*train, *val, *test),
168 };
169 let valid_each = (0.0..=1.0).contains(&t)
170 && (0.0..=1.0).contains(&v)
171 && (0.0..=1.0).contains(&te);
172 let sum = t + v + te;
173 if !valid_each || sum > 1.0 + 1e-9 {
174 return Err(DatasetError::InvalidFractions {
175 train: t,
176 val: v,
177 test: te,
178 });
179 }
180 Ok(())
181 }
182}
183
184#[derive(Debug, Clone, Copy, PartialEq)]
185pub struct BatchSpec {
186 pub batch_size: usize,
187 pub drop_last: bool,
188 pub shuffle: Option<u64>,
191}
192
193impl Default for BatchSpec {
194 fn default() -> Self {
195 Self {
196 batch_size: 1,
197 drop_last: false,
198 shuffle: None,
199 }
200 }
201}
202
203impl BatchSpec {
204 pub fn new(batch_size: usize) -> Self {
205 Self {
206 batch_size,
207 drop_last: false,
208 shuffle: None,
209 }
210 }
211 pub fn with_drop_last(mut self, drop_last: bool) -> Self {
212 self.drop_last = drop_last;
213 self
214 }
215 pub fn with_shuffle(mut self, seed: u64) -> Self {
216 self.shuffle = Some(seed);
217 self
218 }
219}
220
221#[derive(Debug, Clone, PartialEq)]
225pub enum EncodingSpec {
226 Float,
228 IntAsFloat,
231 BoolAsFloat,
233 Categorical { ordering: CategoryOrdering },
237}
238
239impl EncodingSpec {
240 fn name(&self) -> &'static str {
241 match self {
242 EncodingSpec::Float => "Float",
243 EncodingSpec::IntAsFloat => "IntAsFloat",
244 EncodingSpec::BoolAsFloat => "BoolAsFloat",
245 EncodingSpec::Categorical { .. } => "Categorical",
246 }
247 }
248}
249
250#[derive(Clone)]
257pub struct DatasetPlan {
258 source: TidyView,
259 feature_cols: Vec<String>,
260 label_col: Option<String>,
261 encodings: SortedVecMap<String, EncodingSpec>,
262 split: SplitSpec,
263 batch: BatchSpec,
264 plan_hash: Option<[u8; 32]>,
266}
267
268impl DatasetPlan {
269 pub fn from_view(source: TidyView) -> Self {
270 Self {
271 source,
272 feature_cols: Vec::new(),
273 label_col: None,
274 encodings: SortedVecMap::new(),
275 split: SplitSpec::Full,
276 batch: BatchSpec::default(),
277 plan_hash: None,
278 }
279 }
280
281 pub fn from_dataframe(df: DataFrame) -> Self {
282 Self::from_view(df.tidy())
283 }
284
285 pub fn with_features(mut self, cols: Vec<String>) -> Self {
286 self.feature_cols = cols;
287 self
288 }
289
290 pub fn with_label(mut self, col: String) -> Self {
291 self.label_col = Some(col);
292 self
293 }
294
295 pub fn with_encoding(mut self, col: String, enc: EncodingSpec) -> Self {
296 self.encodings.insert(col, enc);
297 self
298 }
299
300 pub fn with_split(mut self, split: SplitSpec) -> Self {
301 self.split = split;
302 self
303 }
304
305 pub fn with_batch(mut self, batch: BatchSpec) -> Self {
306 self.batch = batch;
307 self
308 }
309
310 pub fn nrows(&self) -> usize {
311 self.source.nrows()
312 }
313 pub fn n_features(&self) -> usize {
314 self.feature_cols.len()
315 }
316 pub fn feature_cols(&self) -> &[String] {
317 &self.feature_cols
318 }
319 pub fn label_col(&self) -> Option<&str> {
320 self.label_col.as_deref()
321 }
322 pub fn split_spec(&self) -> &SplitSpec {
323 &self.split
324 }
325 pub fn batch_spec(&self) -> &BatchSpec {
326 &self.batch
327 }
328 pub fn plan_hash(&self) -> Option<&[u8; 32]> {
329 self.plan_hash.as_ref()
330 }
331
332 pub fn validate(&self) -> Result<(), DatasetError> {
336 if self.feature_cols.is_empty() {
337 return Err(DatasetError::NoFeatures);
338 }
339 if self.batch.batch_size == 0 {
340 return Err(DatasetError::BadBatchSize(self.batch.batch_size));
341 }
342 self.split.validate()?;
343
344 let known: std::collections::BTreeSet<&str> =
345 self.source.column_names().into_iter().collect();
346 for c in &self.feature_cols {
347 if !known.contains(c.as_str()) {
348 return Err(DatasetError::UnknownColumn(c.clone()));
349 }
350 }
351 if let Some(l) = &self.label_col {
352 if !known.contains(l.as_str()) {
353 return Err(DatasetError::UnknownColumn(l.clone()));
354 }
355 }
356 for (col, _) in self.encodings.iter() {
357 let in_features = self.feature_cols.iter().any(|c| c == col);
358 let in_label = self.label_col.as_ref().is_some_and(|l| l == col);
359 if !in_features && !in_label {
360 return Err(DatasetError::OrphanEncoding(col.clone()));
361 }
362 }
363 Ok(())
364 }
365
366 pub fn split_rows(&self, which: Split) -> Result<Vec<u32>, DatasetError> {
370 self.validate()?;
371 let n = self.nrows();
372 Ok(assign_split(n, &self.split, which))
373 }
374
375 pub fn iter_batches(&self, which: Split) -> Result<BatchIterator, DatasetError> {
380 self.validate()?;
381 let df = self.source.materialize()?;
382
383 let mut dictionaries: SortedVecMap<String, ByteDictionary> = SortedVecMap::new();
386 for (col, enc) in self.encodings.iter() {
387 if let EncodingSpec::Categorical { ordering } = enc {
388 let column = df
389 .get_column(col)
390 .ok_or_else(|| DatasetError::UnknownColumn(col.clone()))?;
391 let dict = build_dict(col, column, ordering.clone())?;
392 dictionaries.insert(col.clone(), dict);
393 }
394 }
395
396 let mut row_ids = assign_split(df.nrows(), &self.split, which);
398 if row_ids.is_empty() && !matches!(which, Split::Full) && self.nrows() == 0 {
399 return Err(DatasetError::EmptySplit(which));
400 }
401 if let Some(seed) = self.batch.shuffle {
402 shuffle_in_place(&mut row_ids, seed);
403 }
404
405 Ok(BatchIterator {
406 df,
407 feature_cols: self.feature_cols.clone(),
408 label_col: self.label_col.clone(),
409 encodings: self.encodings.clone(),
410 dictionaries,
411 row_ids,
412 batch_size: self.batch.batch_size,
413 drop_last: self.batch.drop_last,
414 cursor: 0,
415 })
416 }
417}
418
419#[inline]
424fn splitmix64_mix(mut x: u64) -> u64 {
425 x = x.wrapping_add(0x9E3779B97F4A7C15);
426 x = (x ^ (x >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
427 x = (x ^ (x >> 27)).wrapping_mul(0x94D049BB133111EB);
428 x ^ (x >> 31)
429}
430
431fn assign_split(nrows: usize, spec: &SplitSpec, which: Split) -> Vec<u32> {
432 match spec {
433 SplitSpec::Full => match which {
434 Split::Full => (0..nrows as u32).collect(),
435 _ => Vec::new(),
436 },
437 SplitSpec::Sequential { train, val, test } => {
438 let n = nrows as f64;
439 let train_n = (n * train).floor() as usize;
440 let val_n = (n * val).floor() as usize;
441 let test_n = (n * test).floor() as usize;
442 match which {
443 Split::Train => (0..train_n as u32).collect(),
444 Split::Val => (train_n as u32..(train_n + val_n) as u32).collect(),
445 Split::Test => {
446 let start = (train_n + val_n) as u32;
447 let end = (train_n + val_n + test_n) as u32;
448 (start..end).collect()
449 }
450 Split::Full => (0..nrows as u32).collect(),
451 }
452 }
453 SplitSpec::Hashed {
454 seed,
455 train,
456 val,
457 test,
458 } => {
459 if matches!(which, Split::Full) {
460 return (0..nrows as u32).collect();
461 }
462 let train_t = *train;
463 let val_t = train_t + *val;
464 let test_t = val_t + *test;
465 let mut out = Vec::new();
466 for r in 0..nrows as u32 {
467 let h = splitmix64_mix((r as u64) ^ *seed);
468 let bucket = (h >> 32) as f64 / (u32::MAX as f64 + 1.0);
470 let pick = if bucket < train_t {
471 Split::Train
472 } else if bucket < val_t {
473 Split::Val
474 } else if bucket < test_t {
475 Split::Test
476 } else {
477 continue; };
479 if pick == which {
480 out.push(r);
481 }
482 }
483 out
484 }
485 }
486}
487
488fn shuffle_in_place(rows: &mut [u32], seed: u64) {
489 if rows.len() <= 1 {
490 return;
491 }
492 let mut rng = Rng::seeded(seed);
493 for i in (1..rows.len()).rev() {
496 let j = (rng.next_u64() % (i as u64 + 1)) as usize;
497 rows.swap(i, j);
498 }
499}
500
501fn build_dict(
506 col_name: &str,
507 column: &Column,
508 ordering: CategoryOrdering,
509) -> Result<ByteDictionary, DatasetError> {
510 let mut dict = ByteDictionary::with_ordering(ordering);
511 match column {
512 Column::Str(values) => {
513 for v in values {
514 dict.intern(v.as_bytes())
515 .map_err(|e| DatasetError::Tidy(format!("intern: {e:?}")))?;
516 }
517 }
518 Column::Categorical { levels, codes } => {
519 for &c in codes {
520 let v = &levels[c as usize];
521 dict.intern(v.as_bytes())
522 .map_err(|e| DatasetError::Tidy(format!("intern: {e:?}")))?;
523 }
524 }
525 Column::CategoricalAdaptive(cc) => {
526 for i in 0..cc.len() {
527 match cc.get(i) {
528 Some(b) => {
529 dict.intern(b)
530 .map_err(|e| DatasetError::Tidy(format!("intern: {e:?}")))?;
531 }
532 None => {
533 return Err(DatasetError::NullCategorical {
534 column: col_name.to_string(),
535 row: i as u32,
536 });
537 }
538 }
539 }
540 }
541 other => {
542 return Err(DatasetError::EncodingMismatch {
543 column: col_name.to_string(),
544 encoding: "Categorical",
545 column_type: other.type_name(),
546 });
547 }
548 }
549 dict.freeze();
550 Ok(dict)
551}
552
553#[derive(Debug, Clone)]
558pub struct MaterializedBatch {
559 pub row_ids: Vec<u32>,
560 pub features: Tensor,
562 pub labels: Option<Tensor>,
564}
565
566pub struct BatchIterator {
567 df: DataFrame,
568 feature_cols: Vec<String>,
569 label_col: Option<String>,
570 encodings: SortedVecMap<String, EncodingSpec>,
571 dictionaries: SortedVecMap<String, ByteDictionary>,
572 row_ids: Vec<u32>,
573 batch_size: usize,
574 drop_last: bool,
575 cursor: usize,
576}
577
578impl BatchIterator {
579 pub fn split_len(&self) -> usize {
582 self.row_ids.len()
583 }
584
585 pub fn row_ids(&self) -> &[u32] {
587 &self.row_ids
588 }
589
590 fn encode_cell(
591 &self,
592 col_name: &str,
593 col: &Column,
594 row: u32,
595 ) -> Result<f64, DatasetError> {
596 let enc = self.encodings.get(&col_name.to_string()).cloned();
597 match (col, enc) {
598 (Column::Float(v), Some(EncodingSpec::Float)) => Ok(v[row as usize]),
599 (Column::Float(v), None) => Ok(v[row as usize]),
600 (Column::Int(v), Some(EncodingSpec::IntAsFloat)) => Ok(v[row as usize] as f64),
601 (Column::Int(v), None) => Ok(v[row as usize] as f64),
602 (Column::Bool(v), Some(EncodingSpec::BoolAsFloat)) => {
603 Ok(if v[row as usize] { 1.0 } else { 0.0 })
604 }
605 (Column::Bool(v), None) => Ok(if v[row as usize] { 1.0 } else { 0.0 }),
606 (Column::Str(_), Some(EncodingSpec::Categorical { .. }))
607 | (Column::Categorical { .. }, Some(EncodingSpec::Categorical { .. }))
608 | (Column::CategoricalAdaptive(_), Some(EncodingSpec::Categorical { .. })) => {
609 let dict = self
610 .dictionaries
611 .get(&col_name.to_string())
612 .ok_or_else(|| DatasetError::Tidy(format!(
613 "missing dictionary for column `{col_name}`"
614 )))?;
615 let bytes: Vec<u8> = match col {
616 Column::Str(v) => v[row as usize].as_bytes().to_vec(),
617 Column::Categorical { levels, codes } => {
618 levels[codes[row as usize] as usize].as_bytes().to_vec()
619 }
620 Column::CategoricalAdaptive(cc) => match cc.get(row as usize) {
621 Some(b) => b.to_vec(),
622 None => {
623 return Err(DatasetError::NullCategorical {
624 column: col_name.to_string(),
625 row,
626 });
627 }
628 },
629 _ => unreachable!(),
630 };
631 let code = dict.lookup(&bytes).ok_or_else(|| {
632 DatasetError::Tidy(format!(
633 "value at row {row} of `{col_name}` not in frozen dictionary"
634 ))
635 })?;
636 Ok(code as f64)
637 }
638 (other, Some(enc)) => Err(DatasetError::EncodingMismatch {
639 column: col_name.to_string(),
640 encoding: enc.name(),
641 column_type: other.type_name(),
642 }),
643 (other, None) => Err(DatasetError::UnsupportedColumnType {
644 column: col_name.to_string(),
645 type_name: other.type_name(),
646 }),
647 }
648 }
649
650 fn materialize_chunk(
651 &self,
652 chunk_rows: &[u32],
653 ) -> Result<MaterializedBatch, DatasetError> {
654 let n_features = self.feature_cols.len();
655 let bsz = chunk_rows.len();
656
657 let mut feat_columns: Vec<&Column> = Vec::with_capacity(n_features);
659 for c in &self.feature_cols {
660 let col = self
661 .df
662 .get_column(c)
663 .ok_or_else(|| DatasetError::UnknownColumn(c.clone()))?;
664 feat_columns.push(col);
665 }
666
667 let mut feat_data: Vec<f64> = Vec::with_capacity(bsz * n_features);
668 for &row in chunk_rows {
669 for (ci, c) in self.feature_cols.iter().enumerate() {
670 feat_data.push(self.encode_cell(c, feat_columns[ci], row)?);
671 }
672 }
673 let features = Tensor::from_vec(feat_data, &[bsz, n_features])
674 .map_err(|e| DatasetError::Shape(format!("features: {e:?}")))?;
675
676 let labels = if let Some(lcol) = &self.label_col {
677 let col = self
678 .df
679 .get_column(lcol)
680 .ok_or_else(|| DatasetError::UnknownColumn(lcol.clone()))?;
681 let mut data: Vec<f64> = Vec::with_capacity(bsz);
682 for &row in chunk_rows {
683 data.push(self.encode_cell(lcol, col, row)?);
684 }
685 Some(
686 Tensor::from_vec(data, &[bsz])
687 .map_err(|e| DatasetError::Shape(format!("labels: {e:?}")))?,
688 )
689 } else {
690 None
691 };
692
693 Ok(MaterializedBatch {
694 row_ids: chunk_rows.to_vec(),
695 features,
696 labels,
697 })
698 }
699}
700
701impl Iterator for BatchIterator {
702 type Item = Result<MaterializedBatch, DatasetError>;
703
704 fn next(&mut self) -> Option<Self::Item> {
705 let total = self.row_ids.len();
706 if self.cursor >= total {
707 return None;
708 }
709 let end = (self.cursor + self.batch_size).min(total);
710 let len = end - self.cursor;
711 if len < self.batch_size && self.drop_last {
712 self.cursor = total;
713 return None;
714 }
715 let chunk = self.row_ids[self.cursor..end].to_vec();
716 self.cursor = end;
717 Some(self.materialize_chunk(&chunk))
718 }
719}