quantrs2_ml/
scirs2_integration.rs

1//! SciRS2 integration layer for quantum machine learning
2//!
3//! This module provides integration with the SciRS2 scientific computing framework,
4//! enabling quantum ML models to leverage SciRS2's optimized tensor operations,
5//! distributed training capabilities, and serialization formats.
6
7use crate::error::{MLError, Result};
8use scirs2_core::ndarray::{Array, Array1, Array2, Array3, ArrayD, ArrayViewD, Dimension, IxDyn};
9use std::collections::HashMap;
10
11/// Trait for tensor operations compatible with SciRS2
12pub trait SciRS2Tensor {
13    /// Get tensor shape
14    fn shape(&self) -> &[usize];
15
16    /// Get tensor data as ArrayViewD
17    fn view(&self) -> ArrayViewD<f64>;
18
19    /// Convert to SciRS2 format (placeholder)
20    fn to_scirs2(&self) -> Result<SciRS2Array>;
21
22    /// Perform tensor operations using SciRS2 backend
23    fn matmul(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array>;
24
25    /// Element-wise operations
26    fn add(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array>;
27    fn mul(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array>;
28    fn sub(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array>;
29
30    /// Reduction operations
31    fn sum(&self, axis: Option<usize>) -> Result<SciRS2Array>;
32    fn mean(&self, axis: Option<usize>) -> Result<SciRS2Array>;
33    fn max(&self, axis: Option<usize>) -> Result<SciRS2Array>;
34    fn min(&self, axis: Option<usize>) -> Result<SciRS2Array>;
35}
36
37/// SciRS2 array wrapper for quantum ML operations
38pub struct SciRS2Array {
39    /// Array data
40    pub data: ArrayD<f64>,
41    /// Whether gradients are required
42    pub requires_grad: bool,
43    /// Gradient accumulator
44    pub grad: Option<ArrayD<f64>>,
45    /// Operation history for backpropagation
46    pub grad_fn: Option<Box<dyn GradFunction>>,
47}
48
49impl std::fmt::Debug for SciRS2Array {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        f.debug_struct("SciRS2Array")
52            .field("data", &self.data)
53            .field("requires_grad", &self.requires_grad)
54            .field("grad", &self.grad)
55            .field("grad_fn", &"<gradient_function>")
56            .finish()
57    }
58}
59
60impl Clone for SciRS2Array {
61    fn clone(&self) -> Self {
62        Self {
63            data: self.data.clone(),
64            requires_grad: self.requires_grad,
65            grad: self.grad.clone(),
66            grad_fn: None, // Cannot clone trait objects
67        }
68    }
69}
70
71impl SciRS2Array {
72    /// Create a new SciRS2Array
73    pub fn new(data: ArrayD<f64>, requires_grad: bool) -> Self {
74        let grad = if requires_grad {
75            Some(ArrayD::zeros(data.raw_dim()))
76        } else {
77            None
78        };
79        Self {
80            data,
81            requires_grad,
82            grad,
83            grad_fn: None,
84        }
85    }
86
87    /// Create from ndarray
88    pub fn from_array<D: Dimension>(arr: Array<f64, D>) -> Self {
89        let data = arr.into_dyn();
90        Self::new(data, false)
91    }
92
93    /// Create with gradient tracking
94    pub fn with_grad<D: Dimension>(arr: Array<f64, D>) -> Self {
95        let data = arr.into_dyn();
96        Self::new(data, true)
97    }
98
99    /// Zero gradients
100    pub fn zero_grad(&mut self) {
101        if let Some(ref mut grad) = self.grad {
102            grad.fill(0.0);
103        }
104    }
105
106    /// Backward pass
107    pub fn backward(&mut self) -> Result<()> {
108        // Extract grad_fn to avoid borrow conflicts
109        if let Some(grad_fn) = self.grad_fn.take() {
110            grad_fn.backward(self)?;
111            self.grad_fn = Some(grad_fn);
112        }
113        Ok(())
114    }
115
116    /// Matrix multiplication using SciRS2 backend
117    pub fn matmul(&self, other: &SciRS2Array) -> Result<SciRS2Array> {
118        // Placeholder - would use SciRS2 linalg operations
119        let result_data = if self.data.ndim() == 2 && other.data.ndim() == 2 {
120            let self_2d = self
121                .data
122                .view()
123                .into_dimensionality::<scirs2_core::ndarray::Ix2>()
124                .map_err(|e| MLError::ComputationError(format!("Shape error: {}", e)))?;
125            let other_2d = other
126                .data
127                .view()
128                .into_dimensionality::<scirs2_core::ndarray::Ix2>()
129                .map_err(|e| MLError::ComputationError(format!("Shape error: {}", e)))?;
130            self_2d.dot(&other_2d).into_dyn()
131        } else {
132            return Err(MLError::InvalidConfiguration(
133                "Matrix multiplication requires 2D arrays".to_string(),
134            ));
135        };
136
137        let requires_grad = self.requires_grad || other.requires_grad;
138        let mut result = SciRS2Array::new(result_data, requires_grad);
139
140        if requires_grad {
141            result.grad_fn = Some(Box::new(MatmulGradFn {
142                left_shape: self.data.raw_dim(),
143                right_shape: other.data.raw_dim(),
144            }));
145        }
146
147        Ok(result)
148    }
149
150    /// Element-wise addition
151    pub fn add(&self, other: &SciRS2Array) -> Result<SciRS2Array> {
152        let result_data = &self.data + &other.data;
153        let requires_grad = self.requires_grad || other.requires_grad;
154        let mut result = SciRS2Array::new(result_data, requires_grad);
155
156        if requires_grad {
157            result.grad_fn = Some(Box::new(AddGradFn));
158        }
159
160        Ok(result)
161    }
162
163    /// Element-wise multiplication
164    pub fn mul(&self, other: &SciRS2Array) -> Result<SciRS2Array> {
165        let result_data = &self.data * &other.data;
166        let requires_grad = self.requires_grad || other.requires_grad;
167        let mut result = SciRS2Array::new(result_data, requires_grad);
168
169        if requires_grad {
170            result.grad_fn = Some(Box::new(MulGradFn {
171                left_data: self.data.clone(),
172                right_data: other.data.clone(),
173            }));
174        }
175
176        Ok(result)
177    }
178
179    /// Reduction sum
180    pub fn sum(&self, axis: Option<usize>) -> Result<SciRS2Array> {
181        let result_data = match axis {
182            Some(ax) => self.data.sum_axis(scirs2_core::ndarray::Axis(ax)).into_dyn(),
183            None => {
184                let sum_val = self.data.sum();
185                ArrayD::from_elem(IxDyn(&[]), sum_val)
186            }
187        };
188
189        let mut result = SciRS2Array::new(result_data, self.requires_grad);
190
191        if self.requires_grad {
192            result.grad_fn = Some(Box::new(SumGradFn { axis }));
193        }
194
195        Ok(result)
196    }
197}
198
199impl SciRS2Tensor for SciRS2Array {
200    fn shape(&self) -> &[usize] {
201        self.data.shape()
202    }
203
204    fn view(&self) -> ArrayViewD<f64> {
205        self.data.view()
206    }
207
208    fn to_scirs2(&self) -> Result<SciRS2Array> {
209        Ok(self.clone())
210    }
211
212    fn matmul(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array> {
213        // Convert other to SciRS2Array for computation
214        let other_array = other.to_scirs2()?;
215        self.matmul(&other_array)
216    }
217
218    fn add(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array> {
219        let other_array = other.to_scirs2()?;
220        self.add(&other_array)
221    }
222
223    fn mul(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array> {
224        let other_array = other.to_scirs2()?;
225        self.mul(&other_array)
226    }
227
228    fn sub(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array> {
229        let result_data = &self.data - &other.to_scirs2()?.data;
230        let requires_grad = self.requires_grad || other.to_scirs2()?.requires_grad;
231        Ok(SciRS2Array::new(result_data, requires_grad))
232    }
233
234    fn sum(&self, axis: Option<usize>) -> Result<SciRS2Array> {
235        self.sum(axis)
236    }
237
238    fn mean(&self, axis: Option<usize>) -> Result<SciRS2Array> {
239        let result_data = match axis {
240            Some(ax) => self.data.mean_axis(scirs2_core::ndarray::Axis(ax)).unwrap().into_dyn(),
241            None => {
242                let mean_val = self.data.mean().unwrap();
243                ArrayD::from_elem(IxDyn(&[]), mean_val)
244            }
245        };
246        Ok(SciRS2Array::new(result_data, self.requires_grad))
247    }
248
249    fn max(&self, axis: Option<usize>) -> Result<SciRS2Array> {
250        let result_data = match axis {
251            Some(ax) => self
252                .data
253                .map_axis(scirs2_core::ndarray::Axis(ax), |view| {
254                    *view
255                        .iter()
256                        .max_by(|a, b| a.partial_cmp(b).unwrap())
257                        .unwrap()
258                })
259                .into_dyn(),
260            None => {
261                let max_val = *self
262                    .data
263                    .iter()
264                    .max_by(|a, b| a.partial_cmp(b).unwrap())
265                    .unwrap();
266                ArrayD::from_elem(IxDyn(&[]), max_val)
267            }
268        };
269        Ok(SciRS2Array::new(result_data, self.requires_grad))
270    }
271
272    fn min(&self, axis: Option<usize>) -> Result<SciRS2Array> {
273        let result_data = match axis {
274            Some(ax) => self
275                .data
276                .map_axis(scirs2_core::ndarray::Axis(ax), |view| {
277                    *view
278                        .iter()
279                        .min_by(|a, b| a.partial_cmp(b).unwrap())
280                        .unwrap()
281                })
282                .into_dyn(),
283            None => {
284                let min_val = *self
285                    .data
286                    .iter()
287                    .min_by(|a, b| a.partial_cmp(b).unwrap())
288                    .unwrap();
289                ArrayD::from_elem(IxDyn(&[]), min_val)
290            }
291        };
292        Ok(SciRS2Array::new(result_data, self.requires_grad))
293    }
294}
295
296/// Trait for gradient functions
297pub trait GradFunction: Send + Sync {
298    fn backward(&self, output: &mut SciRS2Array) -> Result<()>;
299}
300
301/// Gradient function for matrix multiplication
302#[derive(Debug)]
303struct MatmulGradFn {
304    left_shape: IxDyn,
305    right_shape: IxDyn,
306}
307
308impl GradFunction for MatmulGradFn {
309    fn backward(&self, _output: &mut SciRS2Array) -> Result<()> {
310        // Placeholder - would compute gradients for matmul inputs
311        Ok(())
312    }
313}
314
315/// Gradient function for addition
316#[derive(Debug)]
317struct AddGradFn;
318
319impl GradFunction for AddGradFn {
320    fn backward(&self, _output: &mut SciRS2Array) -> Result<()> {
321        // Gradient flows through unchanged for addition
322        Ok(())
323    }
324}
325
326/// Gradient function for multiplication
327#[derive(Debug)]
328struct MulGradFn {
329    left_data: ArrayD<f64>,
330    right_data: ArrayD<f64>,
331}
332
333impl GradFunction for MulGradFn {
334    fn backward(&self, _output: &mut SciRS2Array) -> Result<()> {
335        // Placeholder - would compute gradients for element-wise multiplication
336        Ok(())
337    }
338}
339
340/// Gradient function for sum reduction
341#[derive(Debug)]
342struct SumGradFn {
343    axis: Option<usize>,
344}
345
346impl GradFunction for SumGradFn {
347    fn backward(&self, _output: &mut SciRS2Array) -> Result<()> {
348        // Placeholder - would broadcast gradients for sum reduction
349        Ok(())
350    }
351}
352
353/// SciRS2 optimization interface
354pub struct SciRS2Optimizer {
355    /// Optimizer type
356    pub optimizer_type: String,
357    /// Configuration parameters
358    pub config: HashMap<String, f64>,
359    /// Parameter state (for stateful optimizers like Adam)
360    pub state: HashMap<String, ArrayD<f64>>,
361}
362
363impl SciRS2Optimizer {
364    /// Create a new SciRS2 optimizer
365    pub fn new(optimizer_type: impl Into<String>) -> Self {
366        Self {
367            optimizer_type: optimizer_type.into(),
368            config: HashMap::new(),
369            state: HashMap::new(),
370        }
371    }
372
373    /// Set optimizer configuration
374    pub fn with_config(mut self, key: impl Into<String>, value: f64) -> Self {
375        self.config.insert(key.into(), value);
376        self
377    }
378
379    /// Update parameters using computed gradients
380    pub fn step(&mut self, params: &mut HashMap<String, SciRS2Array>) -> Result<()> {
381        match self.optimizer_type.as_str() {
382            "adam" => self.adam_step(params),
383            "sgd" => self.sgd_step(params),
384            "lbfgs" => self.lbfgs_step(params),
385            _ => Err(MLError::InvalidConfiguration(format!(
386                "Unknown optimizer type: {}",
387                self.optimizer_type
388            ))),
389        }
390    }
391
392    /// Adam optimizer step
393    fn adam_step(&mut self, params: &mut HashMap<String, SciRS2Array>) -> Result<()> {
394        let learning_rate = self.config.get("learning_rate").unwrap_or(&0.001);
395        let beta1 = self.config.get("beta1").unwrap_or(&0.9);
396        let beta2 = self.config.get("beta2").unwrap_or(&0.999);
397        let epsilon = self.config.get("epsilon").unwrap_or(&1e-8);
398
399        for (name, param) in params.iter_mut() {
400            if let Some(ref grad) = param.grad {
401                // Initialize momentum and velocity if not present
402                let m_key = format!("{}_m", name);
403                let v_key = format!("{}_v", name);
404
405                if !self.state.contains_key(&m_key) {
406                    self.state
407                        .insert(m_key.clone(), ArrayD::zeros(grad.raw_dim()));
408                    self.state
409                        .insert(v_key.clone(), ArrayD::zeros(grad.raw_dim()));
410                }
411
412                // Update first moment estimate
413                {
414                    let m = self.state.get_mut(&m_key).unwrap();
415                    *m = *beta1 * &*m + (1.0 - *beta1) * grad;
416                }
417
418                // Update second moment estimate
419                {
420                    let v = self.state.get_mut(&v_key).unwrap();
421                    *v = *beta2 * &*v + (1.0 - *beta2) * grad * grad;
422                }
423
424                // Get references for bias correction
425                let m_hat = self.state.get(&m_key).unwrap().clone();
426                let v_hat = self.state.get(&v_key).unwrap().clone();
427
428                // Update parameters
429                param.data =
430                    &param.data - *learning_rate * &m_hat / (v_hat.mapv(|x| x.sqrt()) + *epsilon);
431            }
432        }
433
434        Ok(())
435    }
436
437    /// SGD optimizer step
438    fn sgd_step(&mut self, params: &mut HashMap<String, SciRS2Array>) -> Result<()> {
439        let learning_rate = self.config.get("learning_rate").unwrap_or(&0.01);
440        let momentum = self.config.get("momentum").unwrap_or(&0.0);
441
442        for (name, param) in params.iter_mut() {
443            if let Some(ref grad) = param.grad {
444                if *momentum > 0.0 {
445                    let v_key = format!("{}_v", name);
446                    if !self.state.contains_key(&v_key) {
447                        self.state
448                            .insert(v_key.clone(), ArrayD::zeros(grad.raw_dim()));
449                    }
450
451                    let v = self.state.get_mut(&v_key).unwrap();
452                    *v = *momentum * &*v + *learning_rate * grad;
453                    param.data = &param.data - &*v;
454                } else {
455                    param.data = &param.data - *learning_rate * grad;
456                }
457            }
458        }
459
460        Ok(())
461    }
462
463    /// L-BFGS optimizer step (placeholder)
464    fn lbfgs_step(&mut self, _params: &mut HashMap<String, SciRS2Array>) -> Result<()> {
465        // Placeholder - would implement L-BFGS using SciRS2
466        Ok(())
467    }
468}
469
470/// SciRS2 distributed training support
471pub struct SciRS2DistributedTrainer {
472    /// World size (number of processes)
473    pub world_size: usize,
474    /// Local rank
475    pub rank: usize,
476    /// Backend for communication
477    pub backend: String,
478}
479
480impl SciRS2DistributedTrainer {
481    /// Create a new distributed trainer
482    pub fn new(world_size: usize, rank: usize) -> Self {
483        Self {
484            world_size,
485            rank,
486            backend: "nccl".to_string(),
487        }
488    }
489
490    /// All-reduce operation for gradient synchronization
491    pub fn all_reduce(&self, tensor: &mut SciRS2Array) -> Result<()> {
492        // Placeholder - would use SciRS2 distributed operations
493        Ok(())
494    }
495
496    /// All-reduce scalar operation for metrics synchronization
497    pub fn all_reduce_scalar(&self, value: f64) -> Result<f64> {
498        // Placeholder - would use SciRS2 distributed operations
499        // For now, just return the value unchanged (single process behavior)
500        Ok(value)
501    }
502
503    /// Broadcast operation
504    pub fn broadcast(&self, tensor: &mut SciRS2Array, root: usize) -> Result<()> {
505        // Placeholder - would use SciRS2 distributed operations
506        Ok(())
507    }
508
509    /// All-gather operation
510    pub fn all_gather(&self, tensor: &SciRS2Array) -> Result<Vec<SciRS2Array>> {
511        // Placeholder - would use SciRS2 distributed operations
512        Ok(vec![tensor.clone(); self.world_size])
513    }
514
515    /// Wrap a model for distributed training
516    pub fn wrap_model<T>(&self, model: T) -> Result<T> {
517        // Placeholder - would wrap the model with distributed training capabilities
518        // For now, just return the model unchanged
519        Ok(model)
520    }
521}
522
523/// SciRS2 model serialization interface
524pub struct SciRS2Serializer;
525
526impl SciRS2Serializer {
527    /// Serialize model parameters to SciRS2 format
528    pub fn save_model(params: &HashMap<String, SciRS2Array>, path: &str) -> Result<()> {
529        // Placeholder - would use SciRS2 serialization
530        Ok(())
531    }
532
533    /// Load model parameters from SciRS2 format
534    pub fn load_model(path: &str) -> Result<HashMap<String, SciRS2Array>> {
535        // Placeholder - would use SciRS2 deserialization
536        Ok(HashMap::new())
537    }
538
539    /// Save checkpoint with optimizer state
540    pub fn save_checkpoint(
541        params: &HashMap<String, SciRS2Array>,
542        optimizer: &SciRS2Optimizer,
543        epoch: usize,
544        path: &str,
545    ) -> Result<()> {
546        // Placeholder - would use SciRS2 checkpoint format
547        Ok(())
548    }
549
550    /// Load checkpoint with optimizer state
551    pub fn load_checkpoint(
552        path: &str,
553    ) -> Result<(HashMap<String, SciRS2Array>, SciRS2Optimizer, usize)> {
554        // Placeholder - would use SciRS2 checkpoint format
555        Ok((HashMap::new(), SciRS2Optimizer::new("adam"), 0))
556    }
557}
558
559/// SciRS2 Dataset wrapper for quantum ML
560pub struct SciRS2Dataset {
561    /// Training data
562    pub data: ArrayD<f64>,
563    /// Labels
564    pub labels: ArrayD<f64>,
565    /// Dataset size
566    pub size: usize,
567}
568
569impl SciRS2Dataset {
570    /// Create a new dataset
571    pub fn new(data: ArrayD<f64>, labels: ArrayD<f64>) -> Result<Self> {
572        let size = data.shape()[0];
573        if labels.shape()[0] != size {
574            return Err(MLError::InvalidConfiguration(
575                "Data and labels must have same number of samples".to_string(),
576            ));
577        }
578
579        Ok(Self { data, labels, size })
580    }
581}
582
583/// SciRS2 DataLoader for batch processing
584pub struct SciRS2DataLoader {
585    /// Dataset reference
586    pub dataset: SciRS2Dataset,
587    /// Batch size
588    pub batch_size: usize,
589    /// Current index
590    pub current_index: usize,
591}
592
593impl SciRS2DataLoader {
594    /// Create a new data loader
595    pub fn new(dataset: SciRS2Dataset, batch_size: usize) -> Self {
596        Self {
597            dataset,
598            batch_size,
599            current_index: 0,
600        }
601    }
602
603    /// Iterator-like enumeration support
604    pub fn enumerate(&mut self) -> DataLoaderIterator {
605        DataLoaderIterator {
606            loader: self,
607            batch_idx: 0,
608        }
609    }
610}
611
612/// Iterator for DataLoader
613pub struct DataLoaderIterator<'a> {
614    loader: &'a mut SciRS2DataLoader,
615    batch_idx: usize,
616}
617
618impl<'a> Iterator for DataLoaderIterator<'a> {
619    type Item = (usize, (SciRS2Array, SciRS2Array));
620
621    fn next(&mut self) -> Option<Self::Item> {
622        if self.loader.current_index >= self.loader.dataset.size {
623            return None;
624        }
625
626        let start = self.loader.current_index;
627        let end = (start + self.loader.batch_size).min(self.loader.dataset.size);
628
629        // Extract batch data and labels
630        let batch_data = self
631            .loader
632            .dataset
633            .data
634            .slice(scirs2_core::ndarray::s![start..end, ..])
635            .to_owned();
636        let batch_labels = self
637            .loader
638            .dataset
639            .labels
640            .slice(scirs2_core::ndarray::s![start..end, ..])
641            .to_owned();
642
643        let data_array = SciRS2Array::from_array(batch_data);
644        let label_array = SciRS2Array::from_array(batch_labels);
645
646        self.loader.current_index = end;
647        let batch_idx = self.batch_idx;
648        self.batch_idx += 1;
649
650        Some((batch_idx, (data_array, label_array)))
651    }
652}
653
654/// SciRS2 Device enumeration
655#[derive(Debug, Clone, Copy)]
656pub enum SciRS2Device {
657    CPU,
658    GPU,
659    Quantum,
660}
661
662/// Additional SciRS2Array methods for compatibility
663impl SciRS2Array {
664    /// Create array with specified device
665    pub fn randn(shape: Vec<usize>, device: SciRS2Device) -> Result<Self> {
666        use scirs2_core::random::prelude::*;
667        let total_size = shape.iter().product();
668        let mut rng = thread_rng();
669        let data: Vec<f64> = (0..total_size).map(|_| rng.gen_range(-1.0..1.0)).collect();
670        let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
671            .map_err(|e| MLError::ComputationError(format!("Shape error: {}", e)))?;
672        Ok(Self::new(array, false))
673    }
674
675    /// Create ones_like array
676    pub fn ones_like(&self) -> Result<Self> {
677        let ones = ArrayD::ones(self.data.raw_dim());
678        Ok(Self::new(ones, false))
679    }
680
681    /// Create random integers
682    pub fn randint(low: i32, high: i32, shape: Vec<usize>, device: SciRS2Device) -> Result<Self> {
683        use scirs2_core::random::prelude::*;
684        let total_size = shape.iter().product();
685        let mut rng = thread_rng();
686        let data: Vec<f64> = (0..total_size)
687            .map(|_| rng.gen_range(low..high) as f64)
688            .collect();
689        let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
690            .map_err(|e| MLError::ComputationError(format!("Shape error: {}", e)))?;
691        Ok(Self::new(array, false))
692    }
693
694    /// Create quantum observable
695    pub fn quantum_observable(name: &str, num_qubits: usize) -> Result<Self> {
696        match name {
697            "pauli_z_all" => {
698                let size = 1 << num_qubits;
699                let mut data = ArrayD::zeros(IxDyn(&[size, size]));
700                for i in 0..size {
701                    let parity = i.count_ones() % 2;
702                    data[[i, i]] = if parity == 0 { 1.0 } else { -1.0 };
703                }
704                Ok(Self::new(data, false))
705            }
706            _ => Err(MLError::InvalidConfiguration(format!(
707                "Unknown observable: {}",
708                name
709            ))),
710        }
711    }
712}
713
714/// Integration helper functions
715pub mod integration {
716    use super::*;
717
718    /// Convert ndarray to SciRS2Array
719    pub fn from_ndarray<D: Dimension>(arr: Array<f64, D>) -> SciRS2Array {
720        SciRS2Array::from_array(arr)
721    }
722
723    /// Convert SciRS2Array to ndarray
724    pub fn to_ndarray<D: Dimension>(arr: &SciRS2Array) -> Result<Array<f64, D>> {
725        arr.data
726            .view()
727            .into_dimensionality::<D>()
728            .map(|v| v.to_owned())
729            .map_err(|e| MLError::ComputationError(format!("Dimension error: {}", e)))
730    }
731
732    /// Create SciRS2 optimizer from configuration
733    pub fn create_optimizer(optimizer_type: &str, config: HashMap<String, f64>) -> SciRS2Optimizer {
734        let mut optimizer = SciRS2Optimizer::new(optimizer_type);
735        for (key, value) in config {
736            optimizer = optimizer.with_config(key, value);
737        }
738        optimizer
739    }
740
741    /// Setup distributed training
742    pub fn setup_distributed(world_size: usize, rank: usize) -> SciRS2DistributedTrainer {
743        SciRS2DistributedTrainer::new(world_size, rank)
744    }
745}
746
747#[cfg(test)]
748mod tests {
749    use super::*;
750    use scirs2_core::ndarray::Array2;
751
752    #[test]
753    fn test_scirs2_array_creation() {
754        let arr = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
755        let scirs2_arr = SciRS2Array::from_array(arr);
756
757        assert_eq!(scirs2_arr.data.shape(), &[2, 2]);
758        assert!(!scirs2_arr.requires_grad);
759    }
760
761    #[test]
762    fn test_scirs2_array_with_grad() {
763        let arr = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
764        let scirs2_arr = SciRS2Array::with_grad(arr);
765
766        assert!(scirs2_arr.requires_grad);
767        assert!(scirs2_arr.grad.is_some());
768    }
769
770    #[test]
771    fn test_scirs2_matmul() {
772        let arr1 = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
773        let arr2 = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
774
775        let scirs2_arr1 = SciRS2Array::from_array(arr1);
776        let scirs2_arr2 = SciRS2Array::from_array(arr2);
777
778        let result = scirs2_arr1.matmul(&scirs2_arr2).unwrap();
779        assert_eq!(result.data.shape(), &[2, 2]);
780    }
781
782    #[test]
783    fn test_scirs2_optimizer() {
784        let mut optimizer = SciRS2Optimizer::new("adam")
785            .with_config("learning_rate", 0.001)
786            .with_config("beta1", 0.9);
787
788        let mut params = HashMap::new();
789        let param_arr = SciRS2Array::with_grad(Array1::from_vec(vec![1.0, 2.0, 3.0]));
790        params.insert("weight".to_string(), param_arr);
791
792        let result = optimizer.step(&mut params);
793        assert!(result.is_ok());
794    }
795
796    #[test]
797    fn test_integration_helpers() {
798        let arr = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
799        let scirs2_arr = integration::from_ndarray(arr.clone());
800
801        let back_to_ndarray: Array2<f64> = integration::to_ndarray(&scirs2_arr).unwrap();
802        assert_eq!(arr, back_to_ndarray);
803    }
804}