1use crate::circuit_integration::{QuantumLayer, QuantumMLExecutor};
8use crate::error::{MLError, Result};
9use crate::simulator_backends::{DynamicCircuit, Observable, SimulationResult, SimulatorBackend};
10use scirs2_core::ndarray::{s, Array1, Array2, Array3, Array4, ArrayD, Axis};
11use quantrs2_circuit::prelude::*;
12use quantrs2_core::prelude::*;
13use std::collections::HashMap;
14use std::sync::Arc;
15
16pub struct QuantumCircuitLayer {
18 circuit: Circuit<8>, symbols: Vec<String>,
22 observable: Observable,
24 backend: Arc<dyn SimulatorBackend>,
26 differentiable: bool,
28 repetitions: Option<usize>,
30}
31
32impl QuantumCircuitLayer {
33 pub fn new(
35 circuit: Circuit<8>, symbols: Vec<String>,
37 observable: Observable,
38 backend: Arc<dyn SimulatorBackend>,
39 ) -> Self {
40 Self {
41 circuit,
42 symbols,
43 observable,
44 backend,
45 differentiable: true,
46 repetitions: None,
47 }
48 }
49
50 pub fn set_differentiable(mut self, differentiable: bool) -> Self {
52 self.differentiable = differentiable;
53 self
54 }
55
56 pub fn set_repetitions(mut self, repetitions: usize) -> Self {
58 self.repetitions = Some(repetitions);
59 self
60 }
61
62 pub fn forward(&self, inputs: &Array2<f64>, parameters: &Array2<f64>) -> Result<Array1<f64>> {
64 let batch_size = inputs.nrows();
65 let mut outputs = Array1::zeros(batch_size);
66
67 for batch_idx in 0..batch_size {
68 let input_data = inputs.row(batch_idx);
70 let param_data = parameters.row(batch_idx % parameters.nrows());
71 let combined_params: Vec<f64> = input_data
72 .iter()
73 .chain(param_data.iter())
74 .copied()
75 .collect();
76
77 let dynamic_circuit =
79 crate::simulator_backends::DynamicCircuit::from_circuit(self.circuit.clone())?;
80 let expectation = self.backend.expectation_value(
81 &dynamic_circuit,
82 &combined_params,
83 &self.observable,
84 )?;
85
86 outputs[batch_idx] = expectation;
87 }
88
89 Ok(outputs)
90 }
91
92 pub fn compute_gradients(
94 &self,
95 inputs: &Array2<f64>,
96 parameters: &Array2<f64>,
97 upstream_gradients: &Array1<f64>,
98 ) -> Result<(Array2<f64>, Array2<f64>)> {
99 if !self.differentiable {
100 return Err(MLError::InvalidConfiguration(
101 "Layer is not differentiable".to_string(),
102 ));
103 }
104
105 let batch_size = inputs.nrows();
106 let num_input_params = inputs.ncols();
107 let num_trainable_params = parameters.ncols();
108
109 let mut input_gradients = Array2::zeros((batch_size, num_input_params));
110 let mut param_gradients = Array2::zeros((batch_size, num_trainable_params));
111
112 for batch_idx in 0..batch_size {
113 let input_data = inputs.row(batch_idx);
114 let param_data = parameters.row(batch_idx % parameters.nrows());
115 let combined_params: Vec<f64> = input_data
116 .iter()
117 .chain(param_data.iter())
118 .copied()
119 .collect();
120
121 let dynamic_circuit =
123 crate::simulator_backends::DynamicCircuit::from_circuit(self.circuit.clone())?;
124 let gradients = self.backend.compute_gradients(
125 &dynamic_circuit,
126 &combined_params,
127 &self.observable,
128 crate::simulator_backends::GradientMethod::ParameterShift,
129 )?;
130
131 let upstream_grad = upstream_gradients[batch_idx];
133 for (i, grad) in gradients.iter().enumerate() {
134 if i < num_input_params {
135 input_gradients[[batch_idx, i]] = grad * upstream_grad;
136 } else {
137 param_gradients[[batch_idx, i - num_input_params]] = grad * upstream_grad;
138 }
139 }
140 }
141
142 Ok((input_gradients, param_gradients))
143 }
144}
145
146pub struct PQCLayer {
148 layer: QuantumCircuitLayer,
150 input_scaling: f64,
152 init_strategy: ParameterInitStrategy,
154 regularization: Option<RegularizationType>,
156}
157
158#[derive(Debug, Clone)]
160pub enum ParameterInitStrategy {
161 RandomNormal { mean: f64, std: f64 },
163 RandomUniform { low: f64, high: f64 },
165 Zeros,
167 Ones,
169 Custom(Vec<f64>),
171}
172
173#[derive(Debug, Clone)]
175pub enum RegularizationType {
176 L1(f64),
178 L2(f64),
180 Dropout(f64),
182}
183
184impl PQCLayer {
185 pub fn new(
187 circuit: Circuit<8>, symbols: Vec<String>,
189 observable: Observable,
190 backend: Arc<dyn SimulatorBackend>,
191 ) -> Self {
192 let layer = QuantumCircuitLayer::new(circuit, symbols, observable, backend);
193
194 Self {
195 layer,
196 input_scaling: 1.0,
197 init_strategy: ParameterInitStrategy::RandomNormal {
198 mean: 0.0,
199 std: 0.1,
200 },
201 regularization: None,
202 }
203 }
204
205 pub fn with_input_scaling(mut self, scaling: f64) -> Self {
207 self.input_scaling = scaling;
208 self
209 }
210
211 pub fn with_initialization(mut self, strategy: ParameterInitStrategy) -> Self {
213 self.init_strategy = strategy;
214 self
215 }
216
217 pub fn with_regularization(mut self, regularization: RegularizationType) -> Self {
219 self.regularization = Some(regularization);
220 self
221 }
222
223 pub fn initialize_parameters(&self, batch_size: usize, num_params: usize) -> Array2<f64> {
225 match &self.init_strategy {
226 ParameterInitStrategy::RandomNormal { mean, std } => {
227 Array2::from_shape_fn((batch_size, num_params), |_| {
229 let u1 = fastrand::f64();
230 let u2 = fastrand::f64();
231 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
232 mean + std * z0
233 })
234 }
235 ParameterInitStrategy::RandomUniform { low, high } => {
236 Array2::from_shape_fn((batch_size, num_params), |_| {
237 fastrand::f64() * (high - low) + low
238 })
239 }
240 ParameterInitStrategy::Zeros => Array2::zeros((batch_size, num_params)),
241 ParameterInitStrategy::Ones => Array2::ones((batch_size, num_params)),
242 ParameterInitStrategy::Custom(values) => {
243 let mut params = Array2::zeros((batch_size, num_params));
244 for i in 0..batch_size {
245 for j in 0..num_params.min(values.len()) {
246 params[[i, j]] = values[j];
247 }
248 }
249 params
250 }
251 }
252 }
253
254 pub fn forward(&self, inputs: &Array2<f64>, parameters: &Array2<f64>) -> Result<Array1<f64>> {
256 let scaled_inputs = inputs * self.input_scaling;
258
259 let outputs = self.layer.forward(&scaled_inputs, parameters)?;
261
262 Ok(outputs)
266 }
267
268 pub fn compute_gradients(
270 &self,
271 inputs: &Array2<f64>,
272 parameters: &Array2<f64>,
273 upstream_gradients: &Array1<f64>,
274 ) -> Result<(Array2<f64>, Array2<f64>)> {
275 let scaled_inputs = inputs * self.input_scaling;
276 let (mut input_grads, mut param_grads) =
277 self.layer
278 .compute_gradients(&scaled_inputs, parameters, upstream_gradients)?;
279
280 input_grads *= self.input_scaling;
282
283 if let Some(ref reg) = self.regularization {
285 match reg {
286 RegularizationType::L1(lambda) => {
287 param_grads += &(parameters.mapv(|x| lambda * x.signum()));
288 }
289 RegularizationType::L2(lambda) => {
290 param_grads += &(parameters * (2.0 * lambda));
291 }
292 RegularizationType::Dropout(_) => {
293 }
295 }
296 }
297
298 Ok((input_grads, param_grads))
299 }
300}
301
302pub struct QuantumConvolutionalLayer {
304 pqc: PQCLayer,
306 filter_size: (usize, usize),
308 stride: (usize, usize),
310 padding: PaddingType,
312}
313
314#[derive(Debug, Clone)]
316pub enum PaddingType {
317 Valid,
319 Same,
321 Custom(usize),
323}
324
325impl QuantumConvolutionalLayer {
326 pub fn new(
328 circuit: Circuit<8>, symbols: Vec<String>,
330 observable: Observable,
331 backend: Arc<dyn SimulatorBackend>,
332 filter_size: (usize, usize),
333 ) -> Self {
334 let pqc = PQCLayer::new(circuit, symbols, observable, backend);
335
336 Self {
337 pqc,
338 filter_size,
339 stride: (1, 1),
340 padding: PaddingType::Valid,
341 }
342 }
343
344 pub fn with_stride(mut self, stride: (usize, usize)) -> Self {
346 self.stride = stride;
347 self
348 }
349
350 pub fn with_padding(mut self, padding: PaddingType) -> Self {
352 self.padding = padding;
353 self
354 }
355
356 pub fn forward(&self, inputs: &Array4<f64>, parameters: &Array2<f64>) -> Result<Array4<f64>> {
358 let (batch_size, height, width, channels) = inputs.dim();
359 let (filter_h, filter_w) = self.filter_size;
360 let (stride_h, stride_w) = self.stride;
361
362 let output_h = (height - filter_h) / stride_h + 1;
364 let output_w = (width - filter_w) / stride_w + 1;
365
366 let mut outputs = Array4::zeros((batch_size, output_h, output_w, 1));
367
368 for batch in 0..batch_size {
369 for out_y in 0..output_h {
370 for out_x in 0..output_w {
371 let start_y = out_y * stride_h;
373 let start_x = out_x * stride_w;
374
375 let mut patch_data = Array2::zeros((1, filter_h * filter_w * channels));
376 let mut patch_idx = 0;
377
378 for dy in 0..filter_h {
379 for dx in 0..filter_w {
380 for c in 0..channels {
381 if start_y + dy < height && start_x + dx < width {
382 patch_data[[0, patch_idx]] =
383 inputs[[batch, start_y + dy, start_x + dx, c]];
384 }
385 patch_idx += 1;
386 }
387 }
388 }
389
390 let result = self.pqc.forward(&patch_data, parameters)?;
392 outputs[[batch, out_y, out_x, 0]] = result[0];
393 }
394 }
395 }
396
397 Ok(outputs)
398 }
399}
400
401pub struct TFQModel {
403 layers: Vec<Box<dyn TFQLayer>>,
405 input_shape: Vec<usize>,
407 loss_function: TFQLossFunction,
409 optimizer: TFQOptimizer,
411}
412
413pub trait TFQLayer: Send + Sync {
415 fn forward(&self, inputs: &ArrayD<f64>) -> Result<ArrayD<f64>>;
417
418 fn backward(&self, upstream_gradients: &ArrayD<f64>) -> Result<ArrayD<f64>>;
420
421 fn get_parameters(&self) -> Vec<Array1<f64>>;
423
424 fn set_parameters(&mut self, params: Vec<Array1<f64>>) -> Result<()>;
426
427 fn name(&self) -> &str;
429}
430
431#[derive(Debug, Clone)]
433pub enum TFQLossFunction {
434 MeanSquaredError,
436 BinaryCrossentropy,
438 CategoricalCrossentropy,
440 Hinge,
442 Custom(String),
444}
445
446#[derive(Debug, Clone)]
448pub enum TFQOptimizer {
449 Adam {
451 learning_rate: f64,
452 beta1: f64,
453 beta2: f64,
454 epsilon: f64,
455 },
456 SGD { learning_rate: f64, momentum: f64 },
458 RMSprop {
460 learning_rate: f64,
461 rho: f64,
462 epsilon: f64,
463 },
464}
465
466impl TFQModel {
467 pub fn new(input_shape: Vec<usize>) -> Self {
469 Self {
470 layers: Vec::new(),
471 input_shape,
472 loss_function: TFQLossFunction::MeanSquaredError,
473 optimizer: TFQOptimizer::Adam {
474 learning_rate: 0.001,
475 beta1: 0.9,
476 beta2: 0.999,
477 epsilon: 1e-8,
478 },
479 }
480 }
481
482 pub fn add_layer(&mut self, layer: Box<dyn TFQLayer>) {
484 self.layers.push(layer);
485 }
486
487 pub fn set_loss(mut self, loss: TFQLossFunction) -> Self {
489 self.loss_function = loss;
490 self
491 }
492
493 pub fn set_optimizer(mut self, optimizer: TFQOptimizer) -> Self {
495 self.optimizer = optimizer;
496 self
497 }
498
499 pub fn compile(&mut self) -> Result<()> {
501 if self.layers.is_empty() {
503 return Err(MLError::InvalidConfiguration(
504 "Model must have at least one layer".to_string(),
505 ));
506 }
507
508 Ok(())
509 }
510
511 pub fn predict(&self, inputs: &ArrayD<f64>) -> Result<ArrayD<f64>> {
513 let mut current = inputs.clone();
514
515 for layer in &self.layers {
516 current = layer.forward(¤t)?;
517 }
518
519 Ok(current)
520 }
521
522 pub fn train_step(&mut self, inputs: &ArrayD<f64>, targets: &ArrayD<f64>) -> Result<f64> {
524 let predictions = self.predict(inputs)?;
526
527 let loss = self.compute_loss(&predictions, targets)?;
529
530 let mut gradients = self.compute_loss_gradients(&predictions, targets)?;
532
533 for layer in self.layers.iter().rev() {
534 gradients = layer.backward(&gradients)?;
535 }
536
537 self.update_parameters()?;
539
540 Ok(loss)
541 }
542
543 fn compute_loss(&self, predictions: &ArrayD<f64>, targets: &ArrayD<f64>) -> Result<f64> {
545 match &self.loss_function {
546 TFQLossFunction::MeanSquaredError => {
547 let diff = predictions - targets;
548 Ok(diff.mapv(|x| x * x).mean().unwrap())
549 }
550 TFQLossFunction::BinaryCrossentropy => {
551 let epsilon = 1e-15;
552 let clipped_preds = predictions.mapv(|x| x.max(epsilon).min(1.0 - epsilon));
553 let loss = targets * clipped_preds.mapv(|x| x.ln())
554 + (1.0 - targets) * clipped_preds.mapv(|x| (1.0 - x).ln());
555 Ok(-loss.mean().unwrap())
556 }
557 _ => Err(MLError::InvalidConfiguration(
558 "Loss function not implemented".to_string(),
559 )),
560 }
561 }
562
563 fn compute_loss_gradients(
565 &self,
566 predictions: &ArrayD<f64>,
567 targets: &ArrayD<f64>,
568 ) -> Result<ArrayD<f64>> {
569 match &self.loss_function {
570 TFQLossFunction::MeanSquaredError => {
571 Ok(2.0 * (predictions - targets) / predictions.len() as f64)
572 }
573 TFQLossFunction::BinaryCrossentropy => {
574 let epsilon = 1e-15;
575 let clipped_preds = predictions.mapv(|x| x.max(epsilon).min(1.0 - epsilon));
576 Ok((clipped_preds.clone() - targets)
577 / (clipped_preds.clone() * (1.0 - &clipped_preds)))
578 }
579 _ => Err(MLError::InvalidConfiguration(
580 "Loss gradient not implemented".to_string(),
581 )),
582 }
583 }
584
585 fn update_parameters(&mut self) -> Result<()> {
587 Ok(())
589 }
590}
591
592pub struct QuantumDataset {
594 circuits: Vec<DynamicCircuit>,
596 parameters: Array2<f64>,
598 labels: Array1<f64>,
600 batch_size: usize,
602}
603
604impl QuantumDataset {
605 pub fn new(
607 circuits: Vec<Circuit<8>>, parameters: Array2<f64>,
609 labels: Array1<f64>,
610 batch_size: usize,
611 ) -> Result<Self> {
612 let dynamic_circuits: std::result::Result<Vec<DynamicCircuit>, crate::error::MLError> =
613 circuits
614 .into_iter()
615 .map(|c| DynamicCircuit::from_circuit(c))
616 .collect();
617
618 Ok(Self {
619 circuits: dynamic_circuits?,
620 parameters,
621 labels,
622 batch_size,
623 })
624 }
625
626 pub fn batches(&self) -> QuantumDatasetIterator {
628 QuantumDatasetIterator::new(self)
629 }
630
631 pub fn shuffle(&mut self) {
633 let n = self.circuits.len();
634 let mut indices: Vec<usize> = (0..n).collect();
635
636 for i in (1..n).rev() {
638 let j = fastrand::usize(0..=i);
639 indices.swap(i, j);
640 }
641
642 let mut new_circuits = Vec::with_capacity(n);
644 let mut new_parameters = Array2::zeros(self.parameters.dim());
645 let mut new_labels = Array1::zeros(self.labels.dim());
646
647 for (new_idx, &old_idx) in indices.iter().enumerate() {
648 new_circuits.push(self.circuits[old_idx].clone());
649 new_parameters
650 .row_mut(new_idx)
651 .assign(&self.parameters.row(old_idx));
652 new_labels[new_idx] = self.labels[old_idx];
653 }
654
655 self.circuits = new_circuits;
656 self.parameters = new_parameters;
657 self.labels = new_labels;
658 }
659}
660
661pub struct QuantumDatasetIterator<'a> {
663 dataset: &'a QuantumDataset,
664 current_batch: usize,
665 total_batches: usize,
666}
667
668impl<'a> QuantumDatasetIterator<'a> {
669 fn new(dataset: &'a QuantumDataset) -> Self {
670 let total_batches = (dataset.circuits.len() + dataset.batch_size - 1) / dataset.batch_size;
671 Self {
672 dataset,
673 current_batch: 0,
674 total_batches,
675 }
676 }
677}
678
679impl<'a> Iterator for QuantumDatasetIterator<'a> {
680 type Item = (Vec<DynamicCircuit>, Array2<f64>, Array1<f64>);
681
682 fn next(&mut self) -> Option<Self::Item> {
683 if self.current_batch >= self.total_batches {
684 return None;
685 }
686
687 let start_idx = self.current_batch * self.dataset.batch_size;
688 let end_idx =
689 ((self.current_batch + 1) * self.dataset.batch_size).min(self.dataset.circuits.len());
690
691 let batch_circuits = self.dataset.circuits[start_idx..end_idx].to_vec();
692 let batch_parameters = self
693 .dataset
694 .parameters
695 .slice(s![start_idx..end_idx, ..])
696 .to_owned();
697 let batch_labels = self.dataset.labels.slice(s![start_idx..end_idx]).to_owned();
698
699 self.current_batch += 1;
700 Some((batch_circuits, batch_parameters, batch_labels))
701 }
702}
703
704pub mod tfq_utils {
706 use super::*;
707
708 pub fn circuit_to_tfq_format(circuit: &DynamicCircuit) -> Result<TFQCircuitFormat> {
710 let tfq_gates: Vec<TFQGate> = Vec::new();
714
715 Ok(TFQCircuitFormat {
716 gates: tfq_gates,
717 num_qubits: circuit.num_qubits(),
718 })
719 }
720
721 pub fn create_data_encoding_circuit(
723 num_qubits: usize,
724 encoding_type: DataEncodingType,
725 ) -> Result<DynamicCircuit> {
726 let mut builder: Circuit<8> = CircuitBuilder::new(); match encoding_type {
729 DataEncodingType::Amplitude => {
730 for qubit in 0..num_qubits {
732 builder.ry(qubit, 0.0)?; }
734 }
735 DataEncodingType::Angle => {
736 for qubit in 0..num_qubits {
738 builder.rz(qubit, 0.0)?; }
740 }
741 DataEncodingType::Basis => {
742 for qubit in 0..num_qubits {
744 builder.x(qubit)?; }
746 }
747 }
748
749 let circuit = builder.build();
750 DynamicCircuit::from_circuit(circuit)
751 }
752
753 pub fn create_hardware_efficient_ansatz(
755 num_qubits: usize,
756 layers: usize,
757 ) -> Result<DynamicCircuit> {
758 let mut builder: Circuit<8> = CircuitBuilder::new(); for layer in 0..layers {
761 for qubit in 0..num_qubits {
763 builder.ry(qubit, 0.0)?;
764 builder.rz(qubit, 0.0)?;
765 }
766
767 for qubit in 0..num_qubits - 1 {
769 builder.cnot(qubit, qubit + 1)?;
770 }
771
772 if layer < layers - 1 && num_qubits > 2 {
774 builder.cnot(num_qubits - 1, 0)?;
775 }
776 }
777
778 let circuit = builder.build();
779 DynamicCircuit::from_circuit(circuit)
780 }
781
782 pub fn batch_execute_circuits(
784 circuits: &[DynamicCircuit],
785 parameters: &Array2<f64>,
786 observables: &[Observable],
787 backend: &dyn SimulatorBackend,
788 ) -> Result<Array2<f64>> {
789 let batch_size = circuits.len();
790 let num_observables = observables.len();
791 let mut results = Array2::zeros((batch_size, num_observables));
792
793 for (circuit_idx, circuit) in circuits.iter().enumerate() {
794 let params = parameters.row(circuit_idx % parameters.nrows());
795
796 for (obs_idx, observable) in observables.iter().enumerate() {
797 let expectation =
798 backend.expectation_value(circuit, params.as_slice().unwrap(), observable)?;
799 results[[circuit_idx, obs_idx]] = expectation;
800 }
801 }
802
803 Ok(results)
804 }
805}
806
807#[derive(Debug, Clone)]
809pub struct TFQCircuitFormat {
810 gates: Vec<TFQGate>,
812 num_qubits: usize,
814}
815
816#[derive(Debug, Clone)]
818pub struct TFQGate {
819 gate_type: String,
821 qubits: Vec<usize>,
823 parameters: Vec<f64>,
825}
826
827#[derive(Debug, Clone)]
829pub enum DataEncodingType {
830 Amplitude,
832 Angle,
834 Basis,
836}
837
838#[cfg(test)]
839mod tests {
840 use super::*;
841 use crate::simulator_backends::{BackendCapabilities, StatevectorBackend};
842
843 #[test]
844 #[ignore]
845 fn test_quantum_circuit_layer() {
846 let mut builder = CircuitBuilder::new();
847 builder.ry(0, 0.0).unwrap();
848 builder.ry(1, 0.0).unwrap();
849 builder.cnot(0, 1).unwrap();
850 let circuit = builder.build();
851
852 let symbols = vec!["theta1".to_string(), "theta2".to_string()];
853 let observable = Observable::PauliZ(vec![0, 1]);
854 let backend = Arc::new(StatevectorBackend::new(8));
855
856 let layer = QuantumCircuitLayer::new(circuit, symbols, observable, backend);
857
858 let inputs = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
859 let parameters = Array2::from_shape_vec((2, 2), vec![0.5, 0.6, 0.7, 0.8]).unwrap();
860
861 let result = layer.forward(&inputs, ¶meters);
862 assert!(result.is_ok());
863 }
864
865 #[test]
866 fn test_pqc_layer_initialization() -> Result<()> {
867 let mut builder = CircuitBuilder::new();
868 builder.h(0)?;
869 let circuit = builder.build();
870
871 let symbols = vec!["param1".to_string()];
872 let observable = Observable::PauliZ(vec![0]);
873 let backend = Arc::new(StatevectorBackend::new(8));
874
875 let pqc = PQCLayer::new(circuit, symbols, observable, backend).with_initialization(
876 ParameterInitStrategy::RandomNormal {
877 mean: 0.0,
878 std: 0.1,
879 },
880 );
881
882 let params = pqc.initialize_parameters(5, 3);
883 assert_eq!(params.shape(), &[5, 3]);
884 Ok(())
885 }
886
887 #[test]
888 #[ignore]
889 fn test_tfq_utils() {
890 let circuit = tfq_utils::create_data_encoding_circuit(3, DataEncodingType::Angle).unwrap();
891 assert_eq!(circuit.num_qubits(), 3);
892
893 let ansatz = tfq_utils::create_hardware_efficient_ansatz(4, 2).unwrap();
894 assert_eq!(ansatz.num_qubits(), 4);
895 }
896
897 #[test]
898 fn test_quantum_dataset() -> Result<()> {
899 let circuits = vec![CircuitBuilder::new().build(), CircuitBuilder::new().build()];
900 let parameters =
901 Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
902 let labels = Array1::from_vec(vec![0.0, 1.0]);
903
904 let dataset = QuantumDataset::new(circuits, parameters, labels, 1);
905 let dataset = dataset?;
906 let batches: Vec<_> = dataset.batches().collect();
907
908 assert_eq!(batches.len(), 2);
909 assert_eq!(batches[0].0.len(), 1); Ok(())
911 }
912
913 #[test]
914 #[ignore]
915 fn test_tfq_model() {
916 let mut model = TFQModel::new(vec![2, 2])
917 .set_loss(TFQLossFunction::MeanSquaredError)
918 .set_optimizer(TFQOptimizer::Adam {
919 learning_rate: 0.01,
920 beta1: 0.9,
921 beta2: 0.999,
922 epsilon: 1e-8,
923 });
924
925 assert!(model.compile().is_ok());
926 }
927}