1use crate::error::{MLError, Result};
8use scirs2_core::ndarray::{Array, Array1, Array2, Array3, ArrayD, ArrayViewD, Dimension, IxDyn};
9use std::collections::HashMap;
10
11pub trait SciRS2Tensor {
13 fn shape(&self) -> &[usize];
15
16 fn view(&self) -> ArrayViewD<f64>;
18
19 fn to_scirs2(&self) -> Result<SciRS2Array>;
21
22 fn matmul(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array>;
24
25 fn add(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array>;
27 fn mul(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array>;
28 fn sub(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array>;
29
30 fn sum(&self, axis: Option<usize>) -> Result<SciRS2Array>;
32 fn mean(&self, axis: Option<usize>) -> Result<SciRS2Array>;
33 fn max(&self, axis: Option<usize>) -> Result<SciRS2Array>;
34 fn min(&self, axis: Option<usize>) -> Result<SciRS2Array>;
35}
36
37pub struct SciRS2Array {
39 pub data: ArrayD<f64>,
41 pub requires_grad: bool,
43 pub grad: Option<ArrayD<f64>>,
45 pub grad_fn: Option<Box<dyn GradFunction>>,
47}
48
49impl std::fmt::Debug for SciRS2Array {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 f.debug_struct("SciRS2Array")
52 .field("data", &self.data)
53 .field("requires_grad", &self.requires_grad)
54 .field("grad", &self.grad)
55 .field("grad_fn", &"<gradient_function>")
56 .finish()
57 }
58}
59
60impl Clone for SciRS2Array {
61 fn clone(&self) -> Self {
62 Self {
63 data: self.data.clone(),
64 requires_grad: self.requires_grad,
65 grad: self.grad.clone(),
66 grad_fn: None, }
68 }
69}
70
71impl SciRS2Array {
72 pub fn new(data: ArrayD<f64>, requires_grad: bool) -> Self {
74 let grad = if requires_grad {
75 Some(ArrayD::zeros(data.raw_dim()))
76 } else {
77 None
78 };
79 Self {
80 data,
81 requires_grad,
82 grad,
83 grad_fn: None,
84 }
85 }
86
87 pub fn from_array<D: Dimension>(arr: Array<f64, D>) -> Self {
89 let data = arr.into_dyn();
90 Self::new(data, false)
91 }
92
93 pub fn with_grad<D: Dimension>(arr: Array<f64, D>) -> Self {
95 let data = arr.into_dyn();
96 Self::new(data, true)
97 }
98
99 pub fn zero_grad(&mut self) {
101 if let Some(ref mut grad) = self.grad {
102 grad.fill(0.0);
103 }
104 }
105
106 pub fn backward(&mut self) -> Result<()> {
108 if let Some(grad_fn) = self.grad_fn.take() {
110 grad_fn.backward(self)?;
111 self.grad_fn = Some(grad_fn);
112 }
113 Ok(())
114 }
115
116 pub fn matmul(&self, other: &SciRS2Array) -> Result<SciRS2Array> {
118 let result_data = if self.data.ndim() == 2 && other.data.ndim() == 2 {
120 let self_2d = self
121 .data
122 .view()
123 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
124 .map_err(|e| MLError::ComputationError(format!("Shape error: {}", e)))?;
125 let other_2d = other
126 .data
127 .view()
128 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
129 .map_err(|e| MLError::ComputationError(format!("Shape error: {}", e)))?;
130 self_2d.dot(&other_2d).into_dyn()
131 } else {
132 return Err(MLError::InvalidConfiguration(
133 "Matrix multiplication requires 2D arrays".to_string(),
134 ));
135 };
136
137 let requires_grad = self.requires_grad || other.requires_grad;
138 let mut result = SciRS2Array::new(result_data, requires_grad);
139
140 if requires_grad {
141 result.grad_fn = Some(Box::new(MatmulGradFn {
142 left_shape: self.data.raw_dim(),
143 right_shape: other.data.raw_dim(),
144 }));
145 }
146
147 Ok(result)
148 }
149
150 pub fn add(&self, other: &SciRS2Array) -> Result<SciRS2Array> {
152 let result_data = &self.data + &other.data;
153 let requires_grad = self.requires_grad || other.requires_grad;
154 let mut result = SciRS2Array::new(result_data, requires_grad);
155
156 if requires_grad {
157 result.grad_fn = Some(Box::new(AddGradFn));
158 }
159
160 Ok(result)
161 }
162
163 pub fn mul(&self, other: &SciRS2Array) -> Result<SciRS2Array> {
165 let result_data = &self.data * &other.data;
166 let requires_grad = self.requires_grad || other.requires_grad;
167 let mut result = SciRS2Array::new(result_data, requires_grad);
168
169 if requires_grad {
170 result.grad_fn = Some(Box::new(MulGradFn {
171 left_data: self.data.clone(),
172 right_data: other.data.clone(),
173 }));
174 }
175
176 Ok(result)
177 }
178
179 pub fn sum(&self, axis: Option<usize>) -> Result<SciRS2Array> {
181 let result_data = match axis {
182 Some(ax) => self.data.sum_axis(scirs2_core::ndarray::Axis(ax)).into_dyn(),
183 None => {
184 let sum_val = self.data.sum();
185 ArrayD::from_elem(IxDyn(&[]), sum_val)
186 }
187 };
188
189 let mut result = SciRS2Array::new(result_data, self.requires_grad);
190
191 if self.requires_grad {
192 result.grad_fn = Some(Box::new(SumGradFn { axis }));
193 }
194
195 Ok(result)
196 }
197}
198
199impl SciRS2Tensor for SciRS2Array {
200 fn shape(&self) -> &[usize] {
201 self.data.shape()
202 }
203
204 fn view(&self) -> ArrayViewD<f64> {
205 self.data.view()
206 }
207
208 fn to_scirs2(&self) -> Result<SciRS2Array> {
209 Ok(self.clone())
210 }
211
212 fn matmul(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array> {
213 let other_array = other.to_scirs2()?;
215 self.matmul(&other_array)
216 }
217
218 fn add(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array> {
219 let other_array = other.to_scirs2()?;
220 self.add(&other_array)
221 }
222
223 fn mul(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array> {
224 let other_array = other.to_scirs2()?;
225 self.mul(&other_array)
226 }
227
228 fn sub(&self, other: &dyn SciRS2Tensor) -> Result<SciRS2Array> {
229 let result_data = &self.data - &other.to_scirs2()?.data;
230 let requires_grad = self.requires_grad || other.to_scirs2()?.requires_grad;
231 Ok(SciRS2Array::new(result_data, requires_grad))
232 }
233
234 fn sum(&self, axis: Option<usize>) -> Result<SciRS2Array> {
235 self.sum(axis)
236 }
237
238 fn mean(&self, axis: Option<usize>) -> Result<SciRS2Array> {
239 let result_data = match axis {
240 Some(ax) => self.data.mean_axis(scirs2_core::ndarray::Axis(ax)).unwrap().into_dyn(),
241 None => {
242 let mean_val = self.data.mean().unwrap();
243 ArrayD::from_elem(IxDyn(&[]), mean_val)
244 }
245 };
246 Ok(SciRS2Array::new(result_data, self.requires_grad))
247 }
248
249 fn max(&self, axis: Option<usize>) -> Result<SciRS2Array> {
250 let result_data = match axis {
251 Some(ax) => self
252 .data
253 .map_axis(scirs2_core::ndarray::Axis(ax), |view| {
254 *view
255 .iter()
256 .max_by(|a, b| a.partial_cmp(b).unwrap())
257 .unwrap()
258 })
259 .into_dyn(),
260 None => {
261 let max_val = *self
262 .data
263 .iter()
264 .max_by(|a, b| a.partial_cmp(b).unwrap())
265 .unwrap();
266 ArrayD::from_elem(IxDyn(&[]), max_val)
267 }
268 };
269 Ok(SciRS2Array::new(result_data, self.requires_grad))
270 }
271
272 fn min(&self, axis: Option<usize>) -> Result<SciRS2Array> {
273 let result_data = match axis {
274 Some(ax) => self
275 .data
276 .map_axis(scirs2_core::ndarray::Axis(ax), |view| {
277 *view
278 .iter()
279 .min_by(|a, b| a.partial_cmp(b).unwrap())
280 .unwrap()
281 })
282 .into_dyn(),
283 None => {
284 let min_val = *self
285 .data
286 .iter()
287 .min_by(|a, b| a.partial_cmp(b).unwrap())
288 .unwrap();
289 ArrayD::from_elem(IxDyn(&[]), min_val)
290 }
291 };
292 Ok(SciRS2Array::new(result_data, self.requires_grad))
293 }
294}
295
296pub trait GradFunction: Send + Sync {
298 fn backward(&self, output: &mut SciRS2Array) -> Result<()>;
299}
300
301#[derive(Debug)]
303struct MatmulGradFn {
304 left_shape: IxDyn,
305 right_shape: IxDyn,
306}
307
308impl GradFunction for MatmulGradFn {
309 fn backward(&self, _output: &mut SciRS2Array) -> Result<()> {
310 Ok(())
312 }
313}
314
315#[derive(Debug)]
317struct AddGradFn;
318
319impl GradFunction for AddGradFn {
320 fn backward(&self, _output: &mut SciRS2Array) -> Result<()> {
321 Ok(())
323 }
324}
325
326#[derive(Debug)]
328struct MulGradFn {
329 left_data: ArrayD<f64>,
330 right_data: ArrayD<f64>,
331}
332
333impl GradFunction for MulGradFn {
334 fn backward(&self, _output: &mut SciRS2Array) -> Result<()> {
335 Ok(())
337 }
338}
339
340#[derive(Debug)]
342struct SumGradFn {
343 axis: Option<usize>,
344}
345
346impl GradFunction for SumGradFn {
347 fn backward(&self, _output: &mut SciRS2Array) -> Result<()> {
348 Ok(())
350 }
351}
352
353pub struct SciRS2Optimizer {
355 pub optimizer_type: String,
357 pub config: HashMap<String, f64>,
359 pub state: HashMap<String, ArrayD<f64>>,
361}
362
363impl SciRS2Optimizer {
364 pub fn new(optimizer_type: impl Into<String>) -> Self {
366 Self {
367 optimizer_type: optimizer_type.into(),
368 config: HashMap::new(),
369 state: HashMap::new(),
370 }
371 }
372
373 pub fn with_config(mut self, key: impl Into<String>, value: f64) -> Self {
375 self.config.insert(key.into(), value);
376 self
377 }
378
379 pub fn step(&mut self, params: &mut HashMap<String, SciRS2Array>) -> Result<()> {
381 match self.optimizer_type.as_str() {
382 "adam" => self.adam_step(params),
383 "sgd" => self.sgd_step(params),
384 "lbfgs" => self.lbfgs_step(params),
385 _ => Err(MLError::InvalidConfiguration(format!(
386 "Unknown optimizer type: {}",
387 self.optimizer_type
388 ))),
389 }
390 }
391
392 fn adam_step(&mut self, params: &mut HashMap<String, SciRS2Array>) -> Result<()> {
394 let learning_rate = self.config.get("learning_rate").unwrap_or(&0.001);
395 let beta1 = self.config.get("beta1").unwrap_or(&0.9);
396 let beta2 = self.config.get("beta2").unwrap_or(&0.999);
397 let epsilon = self.config.get("epsilon").unwrap_or(&1e-8);
398
399 for (name, param) in params.iter_mut() {
400 if let Some(ref grad) = param.grad {
401 let m_key = format!("{}_m", name);
403 let v_key = format!("{}_v", name);
404
405 if !self.state.contains_key(&m_key) {
406 self.state
407 .insert(m_key.clone(), ArrayD::zeros(grad.raw_dim()));
408 self.state
409 .insert(v_key.clone(), ArrayD::zeros(grad.raw_dim()));
410 }
411
412 {
414 let m = self.state.get_mut(&m_key).unwrap();
415 *m = *beta1 * &*m + (1.0 - *beta1) * grad;
416 }
417
418 {
420 let v = self.state.get_mut(&v_key).unwrap();
421 *v = *beta2 * &*v + (1.0 - *beta2) * grad * grad;
422 }
423
424 let m_hat = self.state.get(&m_key).unwrap().clone();
426 let v_hat = self.state.get(&v_key).unwrap().clone();
427
428 param.data =
430 ¶m.data - *learning_rate * &m_hat / (v_hat.mapv(|x| x.sqrt()) + *epsilon);
431 }
432 }
433
434 Ok(())
435 }
436
437 fn sgd_step(&mut self, params: &mut HashMap<String, SciRS2Array>) -> Result<()> {
439 let learning_rate = self.config.get("learning_rate").unwrap_or(&0.01);
440 let momentum = self.config.get("momentum").unwrap_or(&0.0);
441
442 for (name, param) in params.iter_mut() {
443 if let Some(ref grad) = param.grad {
444 if *momentum > 0.0 {
445 let v_key = format!("{}_v", name);
446 if !self.state.contains_key(&v_key) {
447 self.state
448 .insert(v_key.clone(), ArrayD::zeros(grad.raw_dim()));
449 }
450
451 let v = self.state.get_mut(&v_key).unwrap();
452 *v = *momentum * &*v + *learning_rate * grad;
453 param.data = ¶m.data - &*v;
454 } else {
455 param.data = ¶m.data - *learning_rate * grad;
456 }
457 }
458 }
459
460 Ok(())
461 }
462
463 fn lbfgs_step(&mut self, _params: &mut HashMap<String, SciRS2Array>) -> Result<()> {
465 Ok(())
467 }
468}
469
470pub struct SciRS2DistributedTrainer {
472 pub world_size: usize,
474 pub rank: usize,
476 pub backend: String,
478}
479
480impl SciRS2DistributedTrainer {
481 pub fn new(world_size: usize, rank: usize) -> Self {
483 Self {
484 world_size,
485 rank,
486 backend: "nccl".to_string(),
487 }
488 }
489
490 pub fn all_reduce(&self, tensor: &mut SciRS2Array) -> Result<()> {
492 Ok(())
494 }
495
496 pub fn all_reduce_scalar(&self, value: f64) -> Result<f64> {
498 Ok(value)
501 }
502
503 pub fn broadcast(&self, tensor: &mut SciRS2Array, root: usize) -> Result<()> {
505 Ok(())
507 }
508
509 pub fn all_gather(&self, tensor: &SciRS2Array) -> Result<Vec<SciRS2Array>> {
511 Ok(vec![tensor.clone(); self.world_size])
513 }
514
515 pub fn wrap_model<T>(&self, model: T) -> Result<T> {
517 Ok(model)
520 }
521}
522
523pub struct SciRS2Serializer;
525
526impl SciRS2Serializer {
527 pub fn save_model(params: &HashMap<String, SciRS2Array>, path: &str) -> Result<()> {
529 Ok(())
531 }
532
533 pub fn load_model(path: &str) -> Result<HashMap<String, SciRS2Array>> {
535 Ok(HashMap::new())
537 }
538
539 pub fn save_checkpoint(
541 params: &HashMap<String, SciRS2Array>,
542 optimizer: &SciRS2Optimizer,
543 epoch: usize,
544 path: &str,
545 ) -> Result<()> {
546 Ok(())
548 }
549
550 pub fn load_checkpoint(
552 path: &str,
553 ) -> Result<(HashMap<String, SciRS2Array>, SciRS2Optimizer, usize)> {
554 Ok((HashMap::new(), SciRS2Optimizer::new("adam"), 0))
556 }
557}
558
559pub struct SciRS2Dataset {
561 pub data: ArrayD<f64>,
563 pub labels: ArrayD<f64>,
565 pub size: usize,
567}
568
569impl SciRS2Dataset {
570 pub fn new(data: ArrayD<f64>, labels: ArrayD<f64>) -> Result<Self> {
572 let size = data.shape()[0];
573 if labels.shape()[0] != size {
574 return Err(MLError::InvalidConfiguration(
575 "Data and labels must have same number of samples".to_string(),
576 ));
577 }
578
579 Ok(Self { data, labels, size })
580 }
581}
582
583pub struct SciRS2DataLoader {
585 pub dataset: SciRS2Dataset,
587 pub batch_size: usize,
589 pub current_index: usize,
591}
592
593impl SciRS2DataLoader {
594 pub fn new(dataset: SciRS2Dataset, batch_size: usize) -> Self {
596 Self {
597 dataset,
598 batch_size,
599 current_index: 0,
600 }
601 }
602
603 pub fn enumerate(&mut self) -> DataLoaderIterator {
605 DataLoaderIterator {
606 loader: self,
607 batch_idx: 0,
608 }
609 }
610}
611
612pub struct DataLoaderIterator<'a> {
614 loader: &'a mut SciRS2DataLoader,
615 batch_idx: usize,
616}
617
618impl<'a> Iterator for DataLoaderIterator<'a> {
619 type Item = (usize, (SciRS2Array, SciRS2Array));
620
621 fn next(&mut self) -> Option<Self::Item> {
622 if self.loader.current_index >= self.loader.dataset.size {
623 return None;
624 }
625
626 let start = self.loader.current_index;
627 let end = (start + self.loader.batch_size).min(self.loader.dataset.size);
628
629 let batch_data = self
631 .loader
632 .dataset
633 .data
634 .slice(scirs2_core::ndarray::s![start..end, ..])
635 .to_owned();
636 let batch_labels = self
637 .loader
638 .dataset
639 .labels
640 .slice(scirs2_core::ndarray::s![start..end, ..])
641 .to_owned();
642
643 let data_array = SciRS2Array::from_array(batch_data);
644 let label_array = SciRS2Array::from_array(batch_labels);
645
646 self.loader.current_index = end;
647 let batch_idx = self.batch_idx;
648 self.batch_idx += 1;
649
650 Some((batch_idx, (data_array, label_array)))
651 }
652}
653
654#[derive(Debug, Clone, Copy)]
656pub enum SciRS2Device {
657 CPU,
658 GPU,
659 Quantum,
660}
661
662impl SciRS2Array {
664 pub fn randn(shape: Vec<usize>, device: SciRS2Device) -> Result<Self> {
666 use scirs2_core::random::prelude::*;
667 let total_size = shape.iter().product();
668 let mut rng = thread_rng();
669 let data: Vec<f64> = (0..total_size).map(|_| rng.gen_range(-1.0..1.0)).collect();
670 let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
671 .map_err(|e| MLError::ComputationError(format!("Shape error: {}", e)))?;
672 Ok(Self::new(array, false))
673 }
674
675 pub fn ones_like(&self) -> Result<Self> {
677 let ones = ArrayD::ones(self.data.raw_dim());
678 Ok(Self::new(ones, false))
679 }
680
681 pub fn randint(low: i32, high: i32, shape: Vec<usize>, device: SciRS2Device) -> Result<Self> {
683 use scirs2_core::random::prelude::*;
684 let total_size = shape.iter().product();
685 let mut rng = thread_rng();
686 let data: Vec<f64> = (0..total_size)
687 .map(|_| rng.gen_range(low..high) as f64)
688 .collect();
689 let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
690 .map_err(|e| MLError::ComputationError(format!("Shape error: {}", e)))?;
691 Ok(Self::new(array, false))
692 }
693
694 pub fn quantum_observable(name: &str, num_qubits: usize) -> Result<Self> {
696 match name {
697 "pauli_z_all" => {
698 let size = 1 << num_qubits;
699 let mut data = ArrayD::zeros(IxDyn(&[size, size]));
700 for i in 0..size {
701 let parity = i.count_ones() % 2;
702 data[[i, i]] = if parity == 0 { 1.0 } else { -1.0 };
703 }
704 Ok(Self::new(data, false))
705 }
706 _ => Err(MLError::InvalidConfiguration(format!(
707 "Unknown observable: {}",
708 name
709 ))),
710 }
711 }
712}
713
714pub mod integration {
716 use super::*;
717
718 pub fn from_ndarray<D: Dimension>(arr: Array<f64, D>) -> SciRS2Array {
720 SciRS2Array::from_array(arr)
721 }
722
723 pub fn to_ndarray<D: Dimension>(arr: &SciRS2Array) -> Result<Array<f64, D>> {
725 arr.data
726 .view()
727 .into_dimensionality::<D>()
728 .map(|v| v.to_owned())
729 .map_err(|e| MLError::ComputationError(format!("Dimension error: {}", e)))
730 }
731
732 pub fn create_optimizer(optimizer_type: &str, config: HashMap<String, f64>) -> SciRS2Optimizer {
734 let mut optimizer = SciRS2Optimizer::new(optimizer_type);
735 for (key, value) in config {
736 optimizer = optimizer.with_config(key, value);
737 }
738 optimizer
739 }
740
741 pub fn setup_distributed(world_size: usize, rank: usize) -> SciRS2DistributedTrainer {
743 SciRS2DistributedTrainer::new(world_size, rank)
744 }
745}
746
747#[cfg(test)]
748mod tests {
749 use super::*;
750 use scirs2_core::ndarray::Array2;
751
752 #[test]
753 fn test_scirs2_array_creation() {
754 let arr = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
755 let scirs2_arr = SciRS2Array::from_array(arr);
756
757 assert_eq!(scirs2_arr.data.shape(), &[2, 2]);
758 assert!(!scirs2_arr.requires_grad);
759 }
760
761 #[test]
762 fn test_scirs2_array_with_grad() {
763 let arr = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
764 let scirs2_arr = SciRS2Array::with_grad(arr);
765
766 assert!(scirs2_arr.requires_grad);
767 assert!(scirs2_arr.grad.is_some());
768 }
769
770 #[test]
771 fn test_scirs2_matmul() {
772 let arr1 = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
773 let arr2 = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
774
775 let scirs2_arr1 = SciRS2Array::from_array(arr1);
776 let scirs2_arr2 = SciRS2Array::from_array(arr2);
777
778 let result = scirs2_arr1.matmul(&scirs2_arr2).unwrap();
779 assert_eq!(result.data.shape(), &[2, 2]);
780 }
781
782 #[test]
783 fn test_scirs2_optimizer() {
784 let mut optimizer = SciRS2Optimizer::new("adam")
785 .with_config("learning_rate", 0.001)
786 .with_config("beta1", 0.9);
787
788 let mut params = HashMap::new();
789 let param_arr = SciRS2Array::with_grad(Array1::from_vec(vec![1.0, 2.0, 3.0]));
790 params.insert("weight".to_string(), param_arr);
791
792 let result = optimizer.step(&mut params);
793 assert!(result.is_ok());
794 }
795
796 #[test]
797 fn test_integration_helpers() {
798 let arr = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
799 let scirs2_arr = integration::from_ndarray(arr.clone());
800
801 let back_to_ndarray: Array2<f64> = integration::to_ndarray(&scirs2_arr).unwrap();
802 assert_eq!(arr, back_to_ndarray);
803 }
804}