Skip to main content

cjc_data/
dataset_plan.rs

1//! Phase 1 — Deterministic ML training-data plan.
2//!
3//! `DatasetPlan` is a small, immutable composition over a [`TidyView`]
4//! that adds the four ML-specific concerns missing from the data engine:
5//!
6//! 1. **Feature/label column selection** with deterministic ordering.
7//! 2. **Encoding**: float / int / bool / categorical → `f64` features.
8//! 3. **Train/val/test splits**, sequential or hashed-by-row.
9//! 4. **Batching** with optional seeded SplitMix64 shuffle, materializing
10//!    each batch into a row-major `Tensor`.
11//!
12//! Phase 1 is Rust-only; not yet exposed to `.cjcl`. That's Phase 3. Phase
13//! 6 will wire `plan_hash` into a training manifest — for now the field is
14//! reserved (`Option<[u8; 32]>`, always `None`).
15//!
16//! ## Determinism contract
17//!
18//! - Row IDs are always ascending `u32` by default; `TidyView` already
19//!   guarantees this for the underlying selection.
20//! - Shuffles use `cjc_repro::Rng::seeded(seed)` (SplitMix64) with
21//!   Fisher-Yates over the split's row vector.
22//! - Hashed splits use the fixed `splitmix64` mixer keyed by `row ^ seed`.
23//! - Categorical dictionaries are built over **all source rows** (not just
24//!   train) so val/test see codes consistent with train, then frozen
25//!   before any batch is materialized.
26//! - Tensor materialization is row-major; no reductions, no FMA — bit
27//!   copies only.
28//!
29//! ## Reuse map
30//!
31//! | Need                              | Existing primitive                       |
32//! |-----------------------------------|------------------------------------------|
33//! | Filter / project upstream         | `TidyView::filter`, `TidyView::select`   |
34//! | Row mask                          | `AdaptiveSelection` inside the TidyView  |
35//! | Categorical encoding              | `ByteDictionary::intern` + `freeze`      |
36//! | Column-name → encoding map        | `detcoll::SortedVecMap`                  |
37//! | Seeded RNG                        | `cjc_repro::Rng::seeded` (SplitMix64)    |
38
39use crate::byte_dict::{ByteDictionary, CategoryOrdering};
40use crate::detcoll::SortedVecMap;
41use crate::{Column, DataFrame, TidyError, TidyView};
42use cjc_repro::Rng;
43use cjc_runtime::tensor::Tensor;
44
45// ════════════════════════════════════════════════════════════════════════
46//  Errors
47// ════════════════════════════════════════════════════════════════════════
48
49#[derive(Debug, Clone, PartialEq)]
50pub enum DatasetError {
51    UnknownColumn(String),
52    UnsupportedColumnType {
53        column: String,
54        type_name: &'static str,
55    },
56    EncodingMismatch {
57        column: String,
58        encoding: &'static str,
59        column_type: &'static str,
60    },
61    /// Categorical encoding requested but the column row at this index is
62    /// null. Phase 1 does not have a null policy — this is an error.
63    NullCategorical {
64        column: String,
65        row: u32,
66    },
67    EmptySplit(Split),
68    /// Fractions must each be in `[0, 1]` and sum to ≤ 1.
69    InvalidFractions {
70        train: f64,
71        val: f64,
72        test: f64,
73    },
74    BadBatchSize(usize),
75    NoFeatures,
76    /// Encoding was registered for a column not in `feature_cols` or
77    /// `label_col`.
78    OrphanEncoding(String),
79    Tidy(String),
80    Shape(String),
81}
82
83impl std::fmt::Display for DatasetError {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        match self {
86            DatasetError::UnknownColumn(c) => write!(f, "unknown column `{c}`"),
87            DatasetError::UnsupportedColumnType { column, type_name } => write!(
88                f,
89                "column `{column}` has type `{type_name}` which is not supported"
90            ),
91            DatasetError::EncodingMismatch {
92                column,
93                encoding,
94                column_type,
95            } => write!(
96                f,
97                "column `{column}` (type `{column_type}`) cannot be encoded as `{encoding}`"
98            ),
99            DatasetError::NullCategorical { column, row } => {
100                write!(f, "null value in categorical column `{column}` at row {row}")
101            }
102            DatasetError::EmptySplit(s) => write!(f, "split `{s:?}` is empty"),
103            DatasetError::InvalidFractions { train, val, test } => write!(
104                f,
105                "invalid split fractions train={train}, val={val}, test={test} \
106                 (each must be in [0,1] and sum ≤ 1)"
107            ),
108            DatasetError::BadBatchSize(n) => write!(f, "batch_size must be ≥ 1 (got {n})"),
109            DatasetError::NoFeatures => write!(f, "no feature columns specified"),
110            DatasetError::OrphanEncoding(c) => {
111                write!(f, "encoding registered for column `{c}` but it is neither a feature nor the label")
112            }
113            DatasetError::Tidy(m) => write!(f, "tidy error: {m}"),
114            DatasetError::Shape(m) => write!(f, "shape error: {m}"),
115        }
116    }
117}
118
119impl std::error::Error for DatasetError {}
120
121impl From<TidyError> for DatasetError {
122    fn from(e: TidyError) -> Self {
123        DatasetError::Tidy(format!("{e:?}"))
124    }
125}
126
127// ════════════════════════════════════════════════════════════════════════
128//  Specs
129// ════════════════════════════════════════════════════════════════════════
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
132pub enum Split {
133    Train,
134    Val,
135    Test,
136    Full,
137}
138
139#[derive(Debug, Clone, PartialEq)]
140pub enum SplitSpec {
141    /// Single full-dataset partition. `Split::Train`, `Split::Val`, and
142    /// `Split::Test` all yield empty row vectors; only `Split::Full`
143    /// returns rows.
144    Full,
145    /// Sequential ranges by ascending row index. Train gets the first
146    /// `floor(nrows * train)`, val the next `floor(nrows * val)`, test
147    /// the next `floor(nrows * test)`. Trailing rows are excluded.
148    Sequential { train: f64, val: f64, test: f64 },
149    /// Per-row deterministic hash assignment. Bucket =
150    /// `splitmix64(row as u64 ^ seed) >> 32` divided by `2^32`. Same
151    /// `seed` ⇒ identical assignment, regardless of `nrows`.
152    Hashed {
153        seed: u64,
154        train: f64,
155        val: f64,
156        test: f64,
157    },
158}
159
160impl SplitSpec {
161    fn validate(&self) -> Result<(), DatasetError> {
162        let (t, v, te) = match self {
163            SplitSpec::Full => return Ok(()),
164            SplitSpec::Sequential { train, val, test } => (*train, *val, *test),
165            SplitSpec::Hashed {
166                train, val, test, ..
167            } => (*train, *val, *test),
168        };
169        let valid_each = (0.0..=1.0).contains(&t)
170            && (0.0..=1.0).contains(&v)
171            && (0.0..=1.0).contains(&te);
172        let sum = t + v + te;
173        if !valid_each || sum > 1.0 + 1e-9 {
174            return Err(DatasetError::InvalidFractions {
175                train: t,
176                val: v,
177                test: te,
178            });
179        }
180        Ok(())
181    }
182}
183
184#[derive(Debug, Clone, Copy, PartialEq)]
185pub struct BatchSpec {
186    pub batch_size: usize,
187    pub drop_last: bool,
188    /// `None` ⇒ ascending row order. `Some(seed)` ⇒ SplitMix64
189    /// Fisher-Yates permutation of the split's row IDs.
190    pub shuffle: Option<u64>,
191}
192
193impl Default for BatchSpec {
194    fn default() -> Self {
195        Self {
196            batch_size: 1,
197            drop_last: false,
198            shuffle: None,
199        }
200    }
201}
202
203impl BatchSpec {
204    pub fn new(batch_size: usize) -> Self {
205        Self {
206            batch_size,
207            drop_last: false,
208            shuffle: None,
209        }
210    }
211    pub fn with_drop_last(mut self, drop_last: bool) -> Self {
212        self.drop_last = drop_last;
213        self
214    }
215    pub fn with_shuffle(mut self, seed: u64) -> Self {
216        self.shuffle = Some(seed);
217        self
218    }
219}
220
221/// Per-column encoding directive. Each feature/label column must have
222/// one and only one of these. Phase 1 supports four encodings; richer
223/// schemes (one-hot, embedding lookup) are deferred to Phase 3.
224#[derive(Debug, Clone, PartialEq)]
225pub enum EncodingSpec {
226    /// `Column::Float` → `f64` pass-through.
227    Float,
228    /// `Column::Int` → `f64` via `as f64` cast (lossy for |x| ≥ 2^53;
229    /// caller's responsibility to know).
230    IntAsFloat,
231    /// `Column::Bool` → `0.0` / `1.0`.
232    BoolAsFloat,
233    /// `Column::Str | Categorical | CategoricalAdaptive` → integer code
234    /// from a fresh `ByteDictionary` (cast to `f64`). Dictionary built
235    /// over **all source rows** before any split is materialized.
236    Categorical { ordering: CategoryOrdering },
237}
238
239impl EncodingSpec {
240    fn name(&self) -> &'static str {
241        match self {
242            EncodingSpec::Float => "Float",
243            EncodingSpec::IntAsFloat => "IntAsFloat",
244            EncodingSpec::BoolAsFloat => "BoolAsFloat",
245            EncodingSpec::Categorical { .. } => "Categorical",
246        }
247    }
248}
249
250// ════════════════════════════════════════════════════════════════════════
251//  DatasetPlan
252// ════════════════════════════════════════════════════════════════════════
253
254/// Immutable training-data plan. Cheap to clone (`TidyView` holds
255/// `Rc<DataFrame>`).
256#[derive(Clone)]
257pub struct DatasetPlan {
258    source: TidyView,
259    feature_cols: Vec<String>,
260    label_col: Option<String>,
261    encodings: SortedVecMap<String, EncodingSpec>,
262    split: SplitSpec,
263    batch: BatchSpec,
264    /// Phase 6 reserved field. Always `None` today.
265    plan_hash: Option<[u8; 32]>,
266}
267
268impl DatasetPlan {
269    pub fn from_view(source: TidyView) -> Self {
270        Self {
271            source,
272            feature_cols: Vec::new(),
273            label_col: None,
274            encodings: SortedVecMap::new(),
275            split: SplitSpec::Full,
276            batch: BatchSpec::default(),
277            plan_hash: None,
278        }
279    }
280
281    pub fn from_dataframe(df: DataFrame) -> Self {
282        Self::from_view(df.tidy())
283    }
284
285    pub fn with_features(mut self, cols: Vec<String>) -> Self {
286        self.feature_cols = cols;
287        self
288    }
289
290    pub fn with_label(mut self, col: String) -> Self {
291        self.label_col = Some(col);
292        self
293    }
294
295    pub fn with_encoding(mut self, col: String, enc: EncodingSpec) -> Self {
296        self.encodings.insert(col, enc);
297        self
298    }
299
300    pub fn with_split(mut self, split: SplitSpec) -> Self {
301        self.split = split;
302        self
303    }
304
305    pub fn with_batch(mut self, batch: BatchSpec) -> Self {
306        self.batch = batch;
307        self
308    }
309
310    pub fn nrows(&self) -> usize {
311        self.source.nrows()
312    }
313    pub fn n_features(&self) -> usize {
314        self.feature_cols.len()
315    }
316    pub fn feature_cols(&self) -> &[String] {
317        &self.feature_cols
318    }
319    pub fn label_col(&self) -> Option<&str> {
320        self.label_col.as_deref()
321    }
322    pub fn split_spec(&self) -> &SplitSpec {
323        &self.split
324    }
325    pub fn batch_spec(&self) -> &BatchSpec {
326        &self.batch
327    }
328    pub fn plan_hash(&self) -> Option<&[u8; 32]> {
329        self.plan_hash.as_ref()
330    }
331
332    /// Validate the plan against the source schema. Cheap; no
333    /// materialization. Called automatically by `iter_batches` and
334    /// `split_rows`; useful in tests / dry-runs.
335    pub fn validate(&self) -> Result<(), DatasetError> {
336        if self.feature_cols.is_empty() {
337            return Err(DatasetError::NoFeatures);
338        }
339        if self.batch.batch_size == 0 {
340            return Err(DatasetError::BadBatchSize(self.batch.batch_size));
341        }
342        self.split.validate()?;
343
344        let known: std::collections::BTreeSet<&str> =
345            self.source.column_names().into_iter().collect();
346        for c in &self.feature_cols {
347            if !known.contains(c.as_str()) {
348                return Err(DatasetError::UnknownColumn(c.clone()));
349            }
350        }
351        if let Some(l) = &self.label_col {
352            if !known.contains(l.as_str()) {
353                return Err(DatasetError::UnknownColumn(l.clone()));
354            }
355        }
356        for (col, _) in self.encodings.iter() {
357            let in_features = self.feature_cols.iter().any(|c| c == col);
358            let in_label = self.label_col.as_ref().is_some_and(|l| l == col);
359            if !in_features && !in_label {
360                return Err(DatasetError::OrphanEncoding(col.clone()));
361            }
362        }
363        Ok(())
364    }
365
366    /// Ascending row IDs assigned to `which`. Row IDs are indices into the
367    /// **materialized** source (post-filter, post-select), not the raw
368    /// underlying DataFrame.
369    pub fn split_rows(&self, which: Split) -> Result<Vec<u32>, DatasetError> {
370        self.validate()?;
371        let n = self.nrows();
372        Ok(assign_split(n, &self.split, which))
373    }
374
375    /// Iterate batches over `which` split. Each batch is fully resolved
376    /// into row-major `Tensor`s. Categorical dictionaries are built over
377    /// the entire materialized source (so val/test see codes consistent
378    /// with train) and frozen before iteration begins.
379    pub fn iter_batches(&self, which: Split) -> Result<BatchIterator, DatasetError> {
380        self.validate()?;
381        let df = self.source.materialize()?;
382
383        // Build dictionaries for any categorically-encoded column over
384        // ALL source rows. Frozen after build.
385        let mut dictionaries: SortedVecMap<String, ByteDictionary> = SortedVecMap::new();
386        for (col, enc) in self.encodings.iter() {
387            if let EncodingSpec::Categorical { ordering } = enc {
388                let column = df
389                    .get_column(col)
390                    .ok_or_else(|| DatasetError::UnknownColumn(col.clone()))?;
391                let dict = build_dict(col, column, ordering.clone())?;
392                dictionaries.insert(col.clone(), dict);
393            }
394        }
395
396        // Compute split rows + apply shuffle.
397        let mut row_ids = assign_split(df.nrows(), &self.split, which);
398        if row_ids.is_empty() && !matches!(which, Split::Full) && self.nrows() == 0 {
399            return Err(DatasetError::EmptySplit(which));
400        }
401        if let Some(seed) = self.batch.shuffle {
402            shuffle_in_place(&mut row_ids, seed);
403        }
404
405        Ok(BatchIterator {
406            df,
407            feature_cols: self.feature_cols.clone(),
408            label_col: self.label_col.clone(),
409            encodings: self.encodings.clone(),
410            dictionaries,
411            row_ids,
412            batch_size: self.batch.batch_size,
413            drop_last: self.batch.drop_last,
414            cursor: 0,
415        })
416    }
417}
418
419// ════════════════════════════════════════════════════════════════════════
420//  Split assignment
421// ════════════════════════════════════════════════════════════════════════
422
423#[inline]
424fn splitmix64_mix(mut x: u64) -> u64 {
425    x = x.wrapping_add(0x9E3779B97F4A7C15);
426    x = (x ^ (x >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
427    x = (x ^ (x >> 27)).wrapping_mul(0x94D049BB133111EB);
428    x ^ (x >> 31)
429}
430
431fn assign_split(nrows: usize, spec: &SplitSpec, which: Split) -> Vec<u32> {
432    match spec {
433        SplitSpec::Full => match which {
434            Split::Full => (0..nrows as u32).collect(),
435            _ => Vec::new(),
436        },
437        SplitSpec::Sequential { train, val, test } => {
438            let n = nrows as f64;
439            let train_n = (n * train).floor() as usize;
440            let val_n = (n * val).floor() as usize;
441            let test_n = (n * test).floor() as usize;
442            match which {
443                Split::Train => (0..train_n as u32).collect(),
444                Split::Val => (train_n as u32..(train_n + val_n) as u32).collect(),
445                Split::Test => {
446                    let start = (train_n + val_n) as u32;
447                    let end = (train_n + val_n + test_n) as u32;
448                    (start..end).collect()
449                }
450                Split::Full => (0..nrows as u32).collect(),
451            }
452        }
453        SplitSpec::Hashed {
454            seed,
455            train,
456            val,
457            test,
458        } => {
459            if matches!(which, Split::Full) {
460                return (0..nrows as u32).collect();
461            }
462            let train_t = *train;
463            let val_t = train_t + *val;
464            let test_t = val_t + *test;
465            let mut out = Vec::new();
466            for r in 0..nrows as u32 {
467                let h = splitmix64_mix((r as u64) ^ *seed);
468                // Bucket in [0, 1): top 32 bits of mixed value.
469                let bucket = (h >> 32) as f64 / (u32::MAX as f64 + 1.0);
470                let pick = if bucket < train_t {
471                    Split::Train
472                } else if bucket < val_t {
473                    Split::Val
474                } else if bucket < test_t {
475                    Split::Test
476                } else {
477                    continue; // excluded
478                };
479                if pick == which {
480                    out.push(r);
481                }
482            }
483            out
484        }
485    }
486}
487
488fn shuffle_in_place(rows: &mut [u32], seed: u64) {
489    if rows.len() <= 1 {
490        return;
491    }
492    let mut rng = Rng::seeded(seed);
493    // Fisher-Yates: for i from n-1 down to 1, swap rows[i] with rows[j],
494    // j = rng.next_u64() % (i+1).
495    for i in (1..rows.len()).rev() {
496        let j = (rng.next_u64() % (i as u64 + 1)) as usize;
497        rows.swap(i, j);
498    }
499}
500
501// ════════════════════════════════════════════════════════════════════════
502//  Categorical dictionary builder
503// ════════════════════════════════════════════════════════════════════════
504
505fn build_dict(
506    col_name: &str,
507    column: &Column,
508    ordering: CategoryOrdering,
509) -> Result<ByteDictionary, DatasetError> {
510    let mut dict = ByteDictionary::with_ordering(ordering);
511    match column {
512        Column::Str(values) => {
513            for v in values {
514                dict.intern(v.as_bytes())
515                    .map_err(|e| DatasetError::Tidy(format!("intern: {e:?}")))?;
516            }
517        }
518        Column::Categorical { levels, codes } => {
519            for &c in codes {
520                let v = &levels[c as usize];
521                dict.intern(v.as_bytes())
522                    .map_err(|e| DatasetError::Tidy(format!("intern: {e:?}")))?;
523            }
524        }
525        Column::CategoricalAdaptive(cc) => {
526            for i in 0..cc.len() {
527                match cc.get(i) {
528                    Some(b) => {
529                        dict.intern(b)
530                            .map_err(|e| DatasetError::Tidy(format!("intern: {e:?}")))?;
531                    }
532                    None => {
533                        return Err(DatasetError::NullCategorical {
534                            column: col_name.to_string(),
535                            row: i as u32,
536                        });
537                    }
538                }
539            }
540        }
541        other => {
542            return Err(DatasetError::EncodingMismatch {
543                column: col_name.to_string(),
544                encoding: "Categorical",
545                column_type: other.type_name(),
546            });
547        }
548    }
549    dict.freeze();
550    Ok(dict)
551}
552
553// ════════════════════════════════════════════════════════════════════════
554//  BatchIterator + MaterializedBatch
555// ════════════════════════════════════════════════════════════════════════
556
557#[derive(Debug, Clone)]
558pub struct MaterializedBatch {
559    pub row_ids: Vec<u32>,
560    /// Shape `[batch_size, n_features]`, row-major.
561    pub features: Tensor,
562    /// Shape `[batch_size]` (1-D). `None` if no `label_col` was set.
563    pub labels: Option<Tensor>,
564}
565
566pub struct BatchIterator {
567    df: DataFrame,
568    feature_cols: Vec<String>,
569    label_col: Option<String>,
570    encodings: SortedVecMap<String, EncodingSpec>,
571    dictionaries: SortedVecMap<String, ByteDictionary>,
572    row_ids: Vec<u32>,
573    batch_size: usize,
574    drop_last: bool,
575    cursor: usize,
576}
577
578impl BatchIterator {
579    /// Number of rows in this iterator's split (after shuffle, before
580    /// any drop_last accounting).
581    pub fn split_len(&self) -> usize {
582        self.row_ids.len()
583    }
584
585    /// Return the (already-shuffled) row IDs this iterator will visit.
586    pub fn row_ids(&self) -> &[u32] {
587        &self.row_ids
588    }
589
590    fn encode_cell(
591        &self,
592        col_name: &str,
593        col: &Column,
594        row: u32,
595    ) -> Result<f64, DatasetError> {
596        let enc = self.encodings.get(&col_name.to_string()).cloned();
597        match (col, enc) {
598            (Column::Float(v), Some(EncodingSpec::Float)) => Ok(v[row as usize]),
599            (Column::Float(v), None) => Ok(v[row as usize]),
600            (Column::Int(v), Some(EncodingSpec::IntAsFloat)) => Ok(v[row as usize] as f64),
601            (Column::Int(v), None) => Ok(v[row as usize] as f64),
602            (Column::Bool(v), Some(EncodingSpec::BoolAsFloat)) => {
603                Ok(if v[row as usize] { 1.0 } else { 0.0 })
604            }
605            (Column::Bool(v), None) => Ok(if v[row as usize] { 1.0 } else { 0.0 }),
606            (Column::Str(_), Some(EncodingSpec::Categorical { .. }))
607            | (Column::Categorical { .. }, Some(EncodingSpec::Categorical { .. }))
608            | (Column::CategoricalAdaptive(_), Some(EncodingSpec::Categorical { .. })) => {
609                let dict = self
610                    .dictionaries
611                    .get(&col_name.to_string())
612                    .ok_or_else(|| DatasetError::Tidy(format!(
613                        "missing dictionary for column `{col_name}`"
614                    )))?;
615                let bytes: Vec<u8> = match col {
616                    Column::Str(v) => v[row as usize].as_bytes().to_vec(),
617                    Column::Categorical { levels, codes } => {
618                        levels[codes[row as usize] as usize].as_bytes().to_vec()
619                    }
620                    Column::CategoricalAdaptive(cc) => match cc.get(row as usize) {
621                        Some(b) => b.to_vec(),
622                        None => {
623                            return Err(DatasetError::NullCategorical {
624                                column: col_name.to_string(),
625                                row,
626                            });
627                        }
628                    },
629                    _ => unreachable!(),
630                };
631                let code = dict.lookup(&bytes).ok_or_else(|| {
632                    DatasetError::Tidy(format!(
633                        "value at row {row} of `{col_name}` not in frozen dictionary"
634                    ))
635                })?;
636                Ok(code as f64)
637            }
638            (other, Some(enc)) => Err(DatasetError::EncodingMismatch {
639                column: col_name.to_string(),
640                encoding: enc.name(),
641                column_type: other.type_name(),
642            }),
643            (other, None) => Err(DatasetError::UnsupportedColumnType {
644                column: col_name.to_string(),
645                type_name: other.type_name(),
646            }),
647        }
648    }
649
650    fn materialize_chunk(
651        &self,
652        chunk_rows: &[u32],
653    ) -> Result<MaterializedBatch, DatasetError> {
654        let n_features = self.feature_cols.len();
655        let bsz = chunk_rows.len();
656
657        // Resolve each feature column once (avoids n_rows × n_features lookups).
658        let mut feat_columns: Vec<&Column> = Vec::with_capacity(n_features);
659        for c in &self.feature_cols {
660            let col = self
661                .df
662                .get_column(c)
663                .ok_or_else(|| DatasetError::UnknownColumn(c.clone()))?;
664            feat_columns.push(col);
665        }
666
667        let mut feat_data: Vec<f64> = Vec::with_capacity(bsz * n_features);
668        for &row in chunk_rows {
669            for (ci, c) in self.feature_cols.iter().enumerate() {
670                feat_data.push(self.encode_cell(c, feat_columns[ci], row)?);
671            }
672        }
673        let features = Tensor::from_vec(feat_data, &[bsz, n_features])
674            .map_err(|e| DatasetError::Shape(format!("features: {e:?}")))?;
675
676        let labels = if let Some(lcol) = &self.label_col {
677            let col = self
678                .df
679                .get_column(lcol)
680                .ok_or_else(|| DatasetError::UnknownColumn(lcol.clone()))?;
681            let mut data: Vec<f64> = Vec::with_capacity(bsz);
682            for &row in chunk_rows {
683                data.push(self.encode_cell(lcol, col, row)?);
684            }
685            Some(
686                Tensor::from_vec(data, &[bsz])
687                    .map_err(|e| DatasetError::Shape(format!("labels: {e:?}")))?,
688            )
689        } else {
690            None
691        };
692
693        Ok(MaterializedBatch {
694            row_ids: chunk_rows.to_vec(),
695            features,
696            labels,
697        })
698    }
699}
700
701impl Iterator for BatchIterator {
702    type Item = Result<MaterializedBatch, DatasetError>;
703
704    fn next(&mut self) -> Option<Self::Item> {
705        let total = self.row_ids.len();
706        if self.cursor >= total {
707            return None;
708        }
709        let end = (self.cursor + self.batch_size).min(total);
710        let len = end - self.cursor;
711        if len < self.batch_size && self.drop_last {
712            self.cursor = total;
713            return None;
714        }
715        let chunk = self.row_ids[self.cursor..end].to_vec();
716        self.cursor = end;
717        Some(self.materialize_chunk(&chunk))
718    }
719}