quantrs2_ml/
scirs2_integration.rs

1//! SciRS2 integration layer for quantum machine learning
2//!
3//! This module provides integration with the SciRS2 scientific computing framework,
4//! enabling quantum ML models to leverage SciRS2's optimized tensor operations,
5//! distributed training capabilities, and serialization formats.
6
7use crate::error::{MLError, Result};
8use scirs2_core::ndarray::{Array, Array1, Array2, Array3, ArrayD, ArrayViewD, Dimension, IxDyn};
9use std::collections::HashMap;
10
11/// Trait for tensor operations compatible with SciRS2
12pub trait SciRS2Tensor {
13    /// Get tensor shape
14    fn shape(&self) -> &[usize];
15
16    /// Get tensor data as ArrayViewD
17    fn view(&self) -> ArrayViewD<f64>;
18
19    /// Convert to SciRS2 format (placeholder)
20    fn to_scirs2(&self) -> Result<SciRS2Array>;
21
22    /// Perform tensor operations using SciRS2 backend
23    fn matmul(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array>;
24
25    /// Element-wise operations
26    fn add(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array>;
27    fn mul(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array>;
28    fn sub(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array>;
29
30    /// Reduction operations
31    fn sum(&self, axis: Option<usize>) -> Result<SciRS2Array>;
32    fn mean(&self, axis: Option<usize>) -> Result<SciRS2Array>;
33    fn max(&self, axis: Option<usize>) -> Result<SciRS2Array>;
34    fn min(&self, axis: Option<usize>) -> Result<SciRS2Array>;
35}
36
37/// SciRS2 array wrapper for quantum ML operations
38pub struct SciRS2Array {
39    /// Array data
40    pub data: ArrayD<f64>,
41    /// Whether gradients are required
42    pub requires_grad: bool,
43    /// Gradient accumulator
44    pub grad: Option<ArrayD<f64>>,
45    /// Operation history for backpropagation
46    pub grad_fn: Option<Box<dyn GradFunction>>,
47}
48
49impl std::fmt::Debug for SciRS2Array {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        f.debug_struct("SciRS2Array")
52            .field("data", &self.data)
53            .field("requires_grad", &self.requires_grad)
54            .field("grad", &self.grad)
55            .field("grad_fn", &"<gradient_function>")
56            .finish()
57    }
58}
59
60impl Clone for SciRS2Array {
61    fn clone(&self) -> Self {
62        Self {
63            data: self.data.clone(),
64            requires_grad: self.requires_grad,
65            grad: self.grad.clone(),
66            grad_fn: None, // Cannot clone trait objects
67        }
68    }
69}
70
71impl SciRS2Array {
72    /// Create a new SciRS2Array
73    pub fn new(data: ArrayD<f64>, requires_grad: bool) -> Self {
74        let grad = if requires_grad {
75            Some(ArrayD::zeros(data.raw_dim()))
76        } else {
77            None
78        };
79        Self {
80            data,
81            requires_grad,
82            grad,
83            grad_fn: None,
84        }
85    }
86
87    /// Create from ndarray
88    pub fn from_array<D: Dimension>(arr: Array<f64, D>) -> Self {
89        let data = arr.into_dyn();
90        Self::new(data, false)
91    }
92
93    /// Create with gradient tracking
94    pub fn with_grad<D: Dimension>(arr: Array<f64, D>) -> Self {
95        let data = arr.into_dyn();
96        Self::new(data, true)
97    }
98
99    /// Zero gradients
100    pub fn zero_grad(&mut self) {
101        if let Some(ref mut grad) = self.grad {
102            grad.fill(0.0);
103        }
104    }
105
106    /// Backward pass
107    pub fn backward(&mut self) -> Result<()> {
108        // Extract grad_fn to avoid borrow conflicts
109        if let Some(grad_fn) = self.grad_fn.take() {
110            grad_fn.backward(self)?;
111            self.grad_fn = Some(grad_fn);
112        }
113        Ok(())
114    }
115
116    /// Matrix multiplication using SciRS2 backend
117    pub fn matmul(&self, other: &SciRS2Array) -> Result<SciRS2Array> {
118        // Placeholder - would use SciRS2 linalg operations
119        let result_data = if self.data.ndim() == 2 && other.data.ndim() == 2 {
120            let self_2d = self
121                .data
122                .view()
123                .into_dimensionality::<scirs2_core::ndarray::Ix2>()
124                .map_err(|e| MLError::ComputationError(format!("Shape error: {}", e)))?;
125            let other_2d = other
126                .data
127                .view()
128                .into_dimensionality::<scirs2_core::ndarray::Ix2>()
129                .map_err(|e| MLError::ComputationError(format!("Shape error: {}", e)))?;
130            self_2d.dot(&other_2d).into_dyn()
131        } else {
132            return Err(MLError::InvalidConfiguration(
133                "Matrix multiplication requires 2D arrays".to_string(),
134            ));
135        };
136
137        let requires_grad = self.requires_grad || other.requires_grad;
138        let mut result = SciRS2Array::new(result_data, requires_grad);
139
140        if requires_grad {
141            result.grad_fn = Some(Box::new(MatmulGradFn {
142                left_shape: self.data.raw_dim(),
143                right_shape: other.data.raw_dim(),
144            }));
145        }
146
147        Ok(result)
148    }
149
150    /// Element-wise addition
151    pub fn add(&self, other: &SciRS2Array) -> Result<SciRS2Array> {
152        let result_data = &self.data + &other.data;
153        let requires_grad = self.requires_grad || other.requires_grad;
154        let mut result = SciRS2Array::new(result_data, requires_grad);
155
156        if requires_grad {
157            result.grad_fn = Some(Box::new(AddGradFn));
158        }
159
160        Ok(result)
161    }
162
163    /// Element-wise multiplication
164    pub fn mul(&self, other: &SciRS2Array) -> Result<SciRS2Array> {
165        let result_data = &self.data * &other.data;
166        let requires_grad = self.requires_grad || other.requires_grad;
167        let mut result = SciRS2Array::new(result_data, requires_grad);
168
169        if requires_grad {
170            result.grad_fn = Some(Box::new(MulGradFn {
171                left_data: self.data.clone(),
172                right_data: other.data.clone(),
173            }));
174        }
175
176        Ok(result)
177    }
178
179    /// Reduction sum
180    pub fn sum(&self, axis: Option<usize>) -> Result<SciRS2Array> {
181        let result_data = match axis {
182            Some(ax) => self
183                .data
184                .sum_axis(scirs2_core::ndarray::Axis(ax))
185                .into_dyn(),
186            None => {
187                let sum_val = self.data.sum();
188                ArrayD::from_elem(IxDyn(&[]), sum_val)
189            }
190        };
191
192        let mut result = SciRS2Array::new(result_data, self.requires_grad);
193
194        if self.requires_grad {
195            result.grad_fn = Some(Box::new(SumGradFn { axis }));
196        }
197
198        Ok(result)
199    }
200}
201
202impl SciRS2Tensor for SciRS2Array {
203    fn shape(&self) -> &[usize] {
204        self.data.shape()
205    }
206
207    fn view(&self) -> ArrayViewD<f64> {
208        self.data.view()
209    }
210
211    fn to_scirs2(&self) -> Result<SciRS2Array> {
212        Ok(self.clone())
213    }
214
215    fn matmul(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array> {
216        // Convert other to SciRS2Array for computation
217        let other_array = other.to_scirs2()?;
218        self.matmul(&other_array)
219    }
220
221    fn add(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array> {
222        let other_array = other.to_scirs2()?;
223        self.add(&other_array)
224    }
225
226    fn mul(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array> {
227        let other_array = other.to_scirs2()?;
228        self.mul(&other_array)
229    }
230
231    fn sub(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array> {
232        let result_data = &self.data - &other.to_scirs2()?.data;
233        let requires_grad = self.requires_grad || other.to_scirs2()?.requires_grad;
234        Ok(SciRS2Array::new(result_data, requires_grad))
235    }
236
237    fn sum(&self, axis: Option<usize>) -> Result<SciRS2Array> {
238        self.sum(axis)
239    }
240
241    fn mean(&self, axis: Option<usize>) -> Result<SciRS2Array> {
242        let result_data = match axis {
243            Some(ax) => self
244                .data
245                .mean_axis(scirs2_core::ndarray::Axis(ax))
246                .ok_or_else(|| {
247                    MLError::ComputationError("Empty axis for mean computation".to_string())
248                })?
249                .into_dyn(),
250            None => {
251                let mean_val = self.data.mean().ok_or_else(|| {
252                    MLError::ComputationError("Empty array for mean computation".to_string())
253                })?;
254                ArrayD::from_elem(IxDyn(&[]), mean_val)
255            }
256        };
257        Ok(SciRS2Array::new(result_data, self.requires_grad))
258    }
259
260    fn max(&self, axis: Option<usize>) -> Result<SciRS2Array> {
261        let result_data = match axis {
262            Some(ax) => self
263                .data
264                .map_axis(scirs2_core::ndarray::Axis(ax), |view| {
265                    *view
266                        .iter()
267                        .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
268                        .expect("map_axis guarantees non-empty view for valid axis")
269                })
270                .into_dyn(),
271            None => {
272                let max_val = *self
273                    .data
274                    .iter()
275                    .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
276                    .ok_or_else(|| {
277                        MLError::ComputationError("Empty array for max computation".to_string())
278                    })?;
279                ArrayD::from_elem(IxDyn(&[]), max_val)
280            }
281        };
282        Ok(SciRS2Array::new(result_data, self.requires_grad))
283    }
284
285    fn min(&self, axis: Option<usize>) -> Result<SciRS2Array> {
286        let result_data = match axis {
287            Some(ax) => self
288                .data
289                .map_axis(scirs2_core::ndarray::Axis(ax), |view| {
290                    *view
291                        .iter()
292                        .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
293                        .expect("map_axis guarantees non-empty view for valid axis")
294                })
295                .into_dyn(),
296            None => {
297                let min_val = *self
298                    .data
299                    .iter()
300                    .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
301                    .ok_or_else(|| {
302                        MLError::ComputationError("Empty array for min computation".to_string())
303                    })?;
304                ArrayD::from_elem(IxDyn(&[]), min_val)
305            }
306        };
307        Ok(SciRS2Array::new(result_data, self.requires_grad))
308    }
309}
310
311/// Trait for gradient functions
312pub trait GradFunction: Send + Sync {
313    fn backward(&self, output: &mut SciRS2Array) -> Result<()>;
314}
315
316/// Gradient function for matrix multiplication
317#[derive(Debug)]
318struct MatmulGradFn {
319    left_shape: IxDyn,
320    right_shape: IxDyn,
321}
322
323impl GradFunction for MatmulGradFn {
324    fn backward(&self, _output: &mut SciRS2Array) -> Result<()> {
325        // Placeholder - would compute gradients for matmul inputs
326        Ok(())
327    }
328}
329
330/// Gradient function for addition
331#[derive(Debug)]
332struct AddGradFn;
333
334impl GradFunction for AddGradFn {
335    fn backward(&self, _output: &mut SciRS2Array) -> Result<()> {
336        // Gradient flows through unchanged for addition
337        Ok(())
338    }
339}
340
341/// Gradient function for multiplication
342#[derive(Debug)]
343struct MulGradFn {
344    left_data: ArrayD<f64>,
345    right_data: ArrayD<f64>,
346}
347
348impl GradFunction for MulGradFn {
349    fn backward(&self, _output: &mut SciRS2Array) -> Result<()> {
350        // Placeholder - would compute gradients for element-wise multiplication
351        Ok(())
352    }
353}
354
355/// Gradient function for sum reduction
356#[derive(Debug)]
357struct SumGradFn {
358    axis: Option<usize>,
359}
360
361impl GradFunction for SumGradFn {
362    fn backward(&self, _output: &mut SciRS2Array) -> Result<()> {
363        // Placeholder - would broadcast gradients for sum reduction
364        Ok(())
365    }
366}
367
368/// SciRS2 optimization interface
369pub struct SciRS2Optimizer {
370    /// Optimizer type
371    pub optimizer_type: String,
372    /// Configuration parameters
373    pub config: HashMap<String, f64>,
374    /// Parameter state (for stateful optimizers like Adam)
375    pub state: HashMap<String, ArrayD<f64>>,
376}
377
378impl SciRS2Optimizer {
379    /// Create a new SciRS2 optimizer
380    pub fn new(optimizer_type: impl Into<String>) -> Self {
381        Self {
382            optimizer_type: optimizer_type.into(),
383            config: HashMap::new(),
384            state: HashMap::new(),
385        }
386    }
387
388    /// Set optimizer configuration
389    pub fn with_config(mut self, key: impl Into<String>, value: f64) -> Self {
390        self.config.insert(key.into(), value);
391        self
392    }
393
394    /// Update parameters using computed gradients
395    pub fn step(&mut self, params: &mut HashMap<String, SciRS2Array>) -> Result<()> {
396        match self.optimizer_type.as_str() {
397            "adam" => self.adam_step(params),
398            "sgd" => self.sgd_step(params),
399            "lbfgs" => self.lbfgs_step(params),
400            _ => Err(MLError::InvalidConfiguration(format!(
401                "Unknown optimizer type: {}",
402                self.optimizer_type
403            ))),
404        }
405    }
406
407    /// Adam optimizer step
408    fn adam_step(&mut self, params: &mut HashMap<String, SciRS2Array>) -> Result<()> {
409        let learning_rate = self.config.get("learning_rate").unwrap_or(&0.001);
410        let beta1 = self.config.get("beta1").unwrap_or(&0.9);
411        let beta2 = self.config.get("beta2").unwrap_or(&0.999);
412        let epsilon = self.config.get("epsilon").unwrap_or(&1e-8);
413
414        for (name, param) in params.iter_mut() {
415            if let Some(ref grad) = param.grad {
416                // Initialize momentum and velocity if not present
417                let m_key = format!("{}_m", name);
418                let v_key = format!("{}_v", name);
419
420                if !self.state.contains_key(&m_key) {
421                    self.state
422                        .insert(m_key.clone(), ArrayD::zeros(grad.raw_dim()));
423                    self.state
424                        .insert(v_key.clone(), ArrayD::zeros(grad.raw_dim()));
425                }
426
427                // Update first moment estimate
428                {
429                    let m = self
430                        .state
431                        .get_mut(&m_key)
432                        .expect("m_key was just inserted if not present");
433                    *m = *beta1 * &*m + (1.0 - *beta1) * grad;
434                }
435
436                // Update second moment estimate
437                {
438                    let v = self
439                        .state
440                        .get_mut(&v_key)
441                        .expect("v_key was just inserted if not present");
442                    *v = *beta2 * &*v + (1.0 - *beta2) * grad * grad;
443                }
444
445                // Get references for bias correction
446                let m_hat = self
447                    .state
448                    .get(&m_key)
449                    .expect("m_key exists after update")
450                    .clone();
451                let v_hat = self
452                    .state
453                    .get(&v_key)
454                    .expect("v_key exists after update")
455                    .clone();
456
457                // Update parameters
458                param.data =
459                    &param.data - *learning_rate * &m_hat / (v_hat.mapv(|x| x.sqrt()) + *epsilon);
460            }
461        }
462
463        Ok(())
464    }
465
466    /// SGD optimizer step
467    fn sgd_step(&mut self, params: &mut HashMap<String, SciRS2Array>) -> Result<()> {
468        let learning_rate = self.config.get("learning_rate").unwrap_or(&0.01);
469        let momentum = self.config.get("momentum").unwrap_or(&0.0);
470
471        for (name, param) in params.iter_mut() {
472            if let Some(ref grad) = param.grad {
473                if *momentum > 0.0 {
474                    let v_key = format!("{}_v", name);
475                    if !self.state.contains_key(&v_key) {
476                        self.state
477                            .insert(v_key.clone(), ArrayD::zeros(grad.raw_dim()));
478                    }
479
480                    let v = self
481                        .state
482                        .get_mut(&v_key)
483                        .expect("v_key was just inserted if not present");
484                    *v = *momentum * &*v + *learning_rate * grad;
485                    param.data = &param.data - &*v;
486                } else {
487                    param.data = &param.data - *learning_rate * grad;
488                }
489            }
490        }
491
492        Ok(())
493    }
494
495    /// L-BFGS optimizer step (placeholder)
496    fn lbfgs_step(&mut self, _params: &mut HashMap<String, SciRS2Array>) -> Result<()> {
497        // Placeholder - would implement L-BFGS using SciRS2
498        Ok(())
499    }
500}
501
502/// SciRS2 distributed training support
503pub struct SciRS2DistributedTrainer {
504    /// World size (number of processes)
505    pub world_size: usize,
506    /// Local rank
507    pub rank: usize,
508    /// Backend for communication
509    pub backend: String,
510}
511
512impl SciRS2DistributedTrainer {
513    /// Create a new distributed trainer
514    pub fn new(world_size: usize, rank: usize) -> Self {
515        Self {
516            world_size,
517            rank,
518            backend: "nccl".to_string(),
519        }
520    }
521
522    /// All-reduce operation for gradient synchronization
523    pub fn all_reduce(&self, tensor: &mut SciRS2Array) -> Result<()> {
524        // Placeholder - would use SciRS2 distributed operations
525        Ok(())
526    }
527
528    /// All-reduce scalar operation for metrics synchronization
529    pub fn all_reduce_scalar(&self, value: f64) -> Result<f64> {
530        // Placeholder - would use SciRS2 distributed operations
531        // For now, just return the value unchanged (single process behavior)
532        Ok(value)
533    }
534
535    /// Broadcast operation
536    pub fn broadcast(&self, tensor: &mut SciRS2Array, root: usize) -> Result<()> {
537        // Placeholder - would use SciRS2 distributed operations
538        Ok(())
539    }
540
541    /// All-gather operation
542    pub fn all_gather(&self, tensor: &SciRS2Array) -> Result<Vec<SciRS2Array>> {
543        // Placeholder - would use SciRS2 distributed operations
544        Ok(vec![tensor.clone(); self.world_size])
545    }
546
547    /// Wrap a model for distributed training
548    pub fn wrap_model<T>(&self, model: T) -> Result<T> {
549        // Placeholder - would wrap the model with distributed training capabilities
550        // For now, just return the model unchanged
551        Ok(model)
552    }
553}
554
555/// SciRS2 model serialization interface
556pub struct SciRS2Serializer;
557
558impl SciRS2Serializer {
559    /// Serialize model parameters to SciRS2 format
560    pub fn save_model(params: &HashMap<String, SciRS2Array>, path: &str) -> Result<()> {
561        // Placeholder - would use SciRS2 serialization
562        Ok(())
563    }
564
565    /// Load model parameters from SciRS2 format
566    pub fn load_model(path: &str) -> Result<HashMap<String, SciRS2Array>> {
567        // Placeholder - would use SciRS2 deserialization
568        Ok(HashMap::new())
569    }
570
571    /// Save checkpoint with optimizer state
572    pub fn save_checkpoint(
573        params: &HashMap<String, SciRS2Array>,
574        optimizer: &SciRS2Optimizer,
575        epoch: usize,
576        path: &str,
577    ) -> Result<()> {
578        // Placeholder - would use SciRS2 checkpoint format
579        Ok(())
580    }
581
582    /// Load checkpoint with optimizer state
583    pub fn load_checkpoint(
584        path: &str,
585    ) -> Result<(HashMap<String, SciRS2Array>, SciRS2Optimizer, usize)> {
586        // Placeholder - would use SciRS2 checkpoint format
587        Ok((HashMap::new(), SciRS2Optimizer::new("adam"), 0))
588    }
589}
590
591/// SciRS2 Dataset wrapper for quantum ML
592pub struct SciRS2Dataset {
593    /// Training data
594    pub data: ArrayD<f64>,
595    /// Labels
596    pub labels: ArrayD<f64>,
597    /// Dataset size
598    pub size: usize,
599}
600
601impl SciRS2Dataset {
602    /// Create a new dataset
603    pub fn new(data: ArrayD<f64>, labels: ArrayD<f64>) -> Result<Self> {
604        let size = data.shape()[0];
605        if labels.shape()[0] != size {
606            return Err(MLError::InvalidConfiguration(
607                "Data and labels must have same number of samples".to_string(),
608            ));
609        }
610
611        Ok(Self { data, labels, size })
612    }
613}
614
615/// SciRS2 DataLoader for batch processing
616pub struct SciRS2DataLoader {
617    /// Dataset reference
618    pub dataset: SciRS2Dataset,
619    /// Batch size
620    pub batch_size: usize,
621    /// Current index
622    pub current_index: usize,
623}
624
625impl SciRS2DataLoader {
626    /// Create a new data loader
627    pub fn new(dataset: SciRS2Dataset, batch_size: usize) -> Self {
628        Self {
629            dataset,
630            batch_size,
631            current_index: 0,
632        }
633    }
634
635    /// Iterator-like enumeration support
636    pub fn enumerate(&mut self) -> DataLoaderIterator {
637        DataLoaderIterator {
638            loader: self,
639            batch_idx: 0,
640        }
641    }
642}
643
644/// Iterator for DataLoader
645pub struct DataLoaderIterator<'a> {
646    loader: &'a mut SciRS2DataLoader,
647    batch_idx: usize,
648}
649
650impl<'a> Iterator for DataLoaderIterator<'a> {
651    type Item = (usize, (SciRS2Array, SciRS2Array));
652
653    fn next(&mut self) -> Option<Self::Item> {
654        if self.loader.current_index >= self.loader.dataset.size {
655            return None;
656        }
657
658        let start = self.loader.current_index;
659        let end = (start + self.loader.batch_size).min(self.loader.dataset.size);
660
661        // Extract batch data and labels
662        let batch_data = self
663            .loader
664            .dataset
665            .data
666            .slice(scirs2_core::ndarray::s![start..end, ..])
667            .to_owned();
668        let batch_labels = self
669            .loader
670            .dataset
671            .labels
672            .slice(scirs2_core::ndarray::s![start..end, ..])
673            .to_owned();
674
675        let data_array = SciRS2Array::from_array(batch_data);
676        let label_array = SciRS2Array::from_array(batch_labels);
677
678        self.loader.current_index = end;
679        let batch_idx = self.batch_idx;
680        self.batch_idx += 1;
681
682        Some((batch_idx, (data_array, label_array)))
683    }
684}
685
686/// SciRS2 Device enumeration
687#[derive(Debug, Clone, Copy)]
688pub enum SciRS2Device {
689    CPU,
690    GPU,
691    Quantum,
692}
693
694/// Additional SciRS2Array methods for compatibility
695impl SciRS2Array {
696    /// Create array with specified device
697    pub fn randn(shape: Vec<usize>, device: SciRS2Device) -> Result<Self> {
698        use scirs2_core::random::prelude::*;
699        let total_size = shape.iter().product();
700        let mut rng = thread_rng();
701        let data: Vec<f64> = (0..total_size).map(|_| rng.gen_range(-1.0..1.0)).collect();
702        let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
703            .map_err(|e| MLError::ComputationError(format!("Shape error: {}", e)))?;
704        Ok(Self::new(array, false))
705    }
706
707    /// Create ones_like array
708    pub fn ones_like(&self) -> Result<Self> {
709        let ones = ArrayD::ones(self.data.raw_dim());
710        Ok(Self::new(ones, false))
711    }
712
713    /// Create random integers
714    pub fn randint(low: i32, high: i32, shape: Vec<usize>, device: SciRS2Device) -> Result<Self> {
715        use scirs2_core::random::prelude::*;
716        let total_size = shape.iter().product();
717        let mut rng = thread_rng();
718        let data: Vec<f64> = (0..total_size)
719            .map(|_| rng.gen_range(low..high) as f64)
720            .collect();
721        let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
722            .map_err(|e| MLError::ComputationError(format!("Shape error: {}", e)))?;
723        Ok(Self::new(array, false))
724    }
725
726    /// Create quantum observable
727    pub fn quantum_observable(name: &str, num_qubits: usize) -> Result<Self> {
728        match name {
729            "pauli_z_all" => {
730                let size = 1 << num_qubits;
731                let mut data = ArrayD::zeros(IxDyn(&[size, size]));
732                for i in 0..size {
733                    let parity = i.count_ones() % 2;
734                    data[[i, i]] = if parity == 0 { 1.0 } else { -1.0 };
735                }
736                Ok(Self::new(data, false))
737            }
738            _ => Err(MLError::InvalidConfiguration(format!(
739                "Unknown observable: {}",
740                name
741            ))),
742        }
743    }
744}
745
746/// Integration helper functions
747pub mod integration {
748    use super::*;
749
750    /// Convert ndarray to SciRS2Array
751    pub fn from_ndarray<D: Dimension>(arr: Array<f64, D>) -> SciRS2Array {
752        SciRS2Array::from_array(arr)
753    }
754
755    /// Convert SciRS2Array to ndarray
756    pub fn to_ndarray<D: Dimension>(arr: &SciRS2Array) -> Result<Array<f64, D>> {
757        arr.data
758            .view()
759            .into_dimensionality::<D>()
760            .map(|v| v.to_owned())
761            .map_err(|e| MLError::ComputationError(format!("Dimension error: {}", e)))
762    }
763
764    /// Create SciRS2 optimizer from configuration
765    pub fn create_optimizer(optimizer_type: &str, config: HashMap<String, f64>) -> SciRS2Optimizer {
766        let mut optimizer = SciRS2Optimizer::new(optimizer_type);
767        for (key, value) in config {
768            optimizer = optimizer.with_config(key, value);
769        }
770        optimizer
771    }
772
773    /// Setup distributed training
774    pub fn setup_distributed(world_size: usize, rank: usize) -> SciRS2DistributedTrainer {
775        SciRS2DistributedTrainer::new(world_size, rank)
776    }
777}
778
779#[cfg(test)]
780mod tests {
781    use super::*;
782    use scirs2_core::ndarray::Array2;
783
784    #[test]
785    fn test_scirs2_array_creation() {
786        let arr = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0])
787            .expect("valid shape for 2x2 array");
788        let scirs2_arr = SciRS2Array::from_array(arr);
789
790        assert_eq!(scirs2_arr.data.shape(), &[2, 2]);
791        assert!(!scirs2_arr.requires_grad);
792    }
793
794    #[test]
795    fn test_scirs2_array_with_grad() {
796        let arr = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0])
797            .expect("valid shape for 2x2 array");
798        let scirs2_arr = SciRS2Array::with_grad(arr);
799
800        assert!(scirs2_arr.requires_grad);
801        assert!(scirs2_arr.grad.is_some());
802    }
803
804    #[test]
805    fn test_scirs2_matmul() {
806        let arr1 = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
807            .expect("valid shape for 2x3 array");
808        let arr2 = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
809            .expect("valid shape for 3x2 array");
810
811        let scirs2_arr1 = SciRS2Array::from_array(arr1);
812        let scirs2_arr2 = SciRS2Array::from_array(arr2);
813
814        let result = scirs2_arr1
815            .matmul(&scirs2_arr2)
816            .expect("matmul should succeed for compatible shapes");
817        assert_eq!(result.data.shape(), &[2, 2]);
818    }
819
820    #[test]
821    fn test_scirs2_optimizer() {
822        let mut optimizer = SciRS2Optimizer::new("adam")
823            .with_config("learning_rate", 0.001)
824            .with_config("beta1", 0.9);
825
826        let mut params = HashMap::new();
827        let param_arr = SciRS2Array::with_grad(Array1::from_vec(vec![1.0, 2.0, 3.0]));
828        params.insert("weight".to_string(), param_arr);
829
830        let result = optimizer.step(&mut params);
831        assert!(result.is_ok());
832    }
833
834    #[test]
835    fn test_integration_helpers() {
836        let arr = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0])
837            .expect("valid shape for 2x2 array");
838        let scirs2_arr = integration::from_ndarray(arr.clone());
839
840        let back_to_ndarray: Array2<f64> = integration::to_ndarray(&scirs2_arr)
841            .expect("conversion back to ndarray should succeed");
842        assert_eq!(arr, back_to_ndarray);
843    }
844}