quantrs2_ml/tensorflow_compatibility/
quantumdatasetiterator_traits.rs

1//! # QuantumDatasetIterator - Trait Implementations
2//!
3//! This module contains trait implementations for `QuantumDatasetIterator`.
4//!
5//! ## Implemented Traits
6//!
7//! - `Iterator`
8//!
9//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
10
11use crate::simulator_backends::{DynamicCircuit, Observable, SimulationResult, SimulatorBackend};
12use quantrs2_circuit::prelude::*;
13use quantrs2_core::prelude::*;
14use scirs2_core::ndarray::{s, Array1, Array2, Array3, Array4, ArrayD, Axis};
15
16use super::types::QuantumDatasetIterator;
17
18impl<'a> Iterator for QuantumDatasetIterator<'a> {
19    type Item = (Vec<DynamicCircuit>, Array2<f64>, Array1<f64>);
20    fn next(&mut self) -> Option<Self::Item> {
21        if self.current_batch >= self.total_batches {
22            return None;
23        }
24        let start_idx = self.current_batch * self.dataset.batch_size;
25        let end_idx =
26            ((self.current_batch + 1) * self.dataset.batch_size).min(self.dataset.circuits.len());
27        let batch_circuits = self.dataset.circuits[start_idx..end_idx].to_vec();
28        let batch_parameters = self
29            .dataset
30            .parameters
31            .slice(s![start_idx..end_idx, ..])
32            .to_owned();
33        let batch_labels = self.dataset.labels.slice(s![start_idx..end_idx]).to_owned();
34        self.current_batch += 1;
35        Some((batch_circuits, batch_parameters, batch_labels))
36    }
37}