1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement, Zero};
8use std::fmt::Debug;
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::error::{SparseError, SparseResult};
12
13pub trait SparseArray<T>: std::any::Any
35where
36 T: SparseElement + Div<Output = T> + 'static,
37{
38 fn shape(&self) -> (usize, usize);
40
41 fn nnz(&self) -> usize;
43
44 fn dtype(&self) -> &str;
46
47 fn to_array(&self) -> Array2<T>;
49
50 fn toarray(&self) -> Array2<T>;
52
53 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
55
56 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
58
59 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
61
62 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
64
65 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
67
68 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
70
71 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
73
74 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
76
77 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
79
80 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
82
83 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
85
86 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
88
89 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>>;
91
92 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
94
95 fn copy(&self) -> Box<dyn SparseArray<T>>;
97
98 fn get(&self, i: usize, j: usize) -> T;
100
101 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()>;
103
104 fn eliminate_zeros(&mut self);
106
107 fn sort_indices(&mut self);
109
110 fn sorted_indices(&self) -> Box<dyn SparseArray<T>>;
112
113 fn has_sorted_indices(&self) -> bool;
115
116 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>>;
123
124 fn max(&self) -> T;
126
127 fn min(&self) -> T;
129
130 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>);
132
133 fn slice(
135 &self,
136 row_range: (usize, usize),
137 col_range: (usize, usize),
138 ) -> SparseResult<Box<dyn SparseArray<T>>>;
139
140 fn as_any(&self) -> &dyn std::any::Any;
142
143 fn get_indptr(&self) -> Option<&Array1<usize>> {
146 None
147 }
148
149 fn indptr(&self) -> Option<&Array1<usize>> {
152 None
153 }
154}
155
156pub enum SparseSum<T>
159where
160 T: SparseElement + Div<Output = T> + 'static,
161{
162 SparseArray(Box<dyn SparseArray<T>>),
164
165 Scalar(T),
167}
168
169impl<T> Debug for SparseSum<T>
170where
171 T: SparseElement + Div<Output = T> + 'static,
172{
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 match self {
175 SparseSum::SparseArray(_) => write!(f, "SparseSum::SparseArray(...)"),
176 SparseSum::Scalar(value) => write!(f, "SparseSum::Scalar({value:?})"),
177 }
178 }
179}
180
181impl<T> Clone for SparseSum<T>
182where
183 T: SparseElement + Div<Output = T> + 'static,
184{
185 fn clone(&self) -> Self {
186 match self {
187 SparseSum::SparseArray(array) => SparseSum::SparseArray(array.copy()),
188 SparseSum::Scalar(value) => SparseSum::Scalar(*value),
189 }
190 }
191}
192
193#[allow(dead_code)]
195pub fn is_sparse<T>(obj: &dyn SparseArray<T>) -> bool
196where
197 T: SparseElement + Div<Output = T> + 'static,
198{
199 true }
201
202pub struct SparseArrayBase<T>
204where
205 T: SparseElement + Div<Output = T> + 'static,
206{
207 data: Array2<T>,
208}
209
210impl<T> SparseArrayBase<T>
211where
212 T: SparseElement + Div<Output = T> + Zero + 'static,
213{
214 pub fn new(data: Array2<T>) -> Self {
216 Self { data }
217 }
218}
219
220impl<T> SparseArray<T> for SparseArrayBase<T>
221where
222 T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
223{
224 fn shape(&self) -> (usize, usize) {
225 let shape = self.data.shape();
226 (shape[0], shape[1])
227 }
228
229 fn nnz(&self) -> usize {
230 self.data.iter().filter(|&&x| x != T::sparse_zero()).count()
231 }
232
233 fn dtype(&self) -> &str {
234 "float" }
236
237 fn to_array(&self) -> Array2<T> {
238 self.data.clone()
239 }
240
241 fn toarray(&self) -> Array2<T> {
242 self.data.clone()
243 }
244
245 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
246 Ok(Box::new(self.clone()))
248 }
249
250 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
251 Ok(Box::new(self.clone()))
253 }
254
255 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
256 Ok(Box::new(self.clone()))
258 }
259
260 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
261 Ok(Box::new(self.clone()))
263 }
264
265 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
266 Ok(Box::new(self.clone()))
268 }
269
270 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
271 Ok(Box::new(self.clone()))
273 }
274
275 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
276 Ok(Box::new(self.clone()))
278 }
279
280 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
281 let other_array = other.to_array();
282 let result = &self.data + &other_array;
283 Ok(Box::new(SparseArrayBase::new(result)))
284 }
285
286 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
287 let other_array = other.to_array();
288 let result = &self.data - &other_array;
289 Ok(Box::new(SparseArrayBase::new(result)))
290 }
291
292 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
293 let other_array = other.to_array();
294 let result = &self.data * &other_array;
295 Ok(Box::new(SparseArrayBase::new(result)))
296 }
297
298 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
299 let other_array = other.to_array();
300 let result = &self.data / &other_array;
301 Ok(Box::new(SparseArrayBase::new(result)))
302 }
303
304 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
305 let other_array = other.to_array();
306 let (m, n) = self.shape();
307 let (p, q) = other.shape();
308
309 if n != p {
310 return Err(SparseError::DimensionMismatch {
311 expected: n,
312 found: p,
313 });
314 }
315
316 let mut result = Array2::zeros((m, q));
317 for i in 0..m {
318 for j in 0..q {
319 let mut sum = T::sparse_zero();
320 for k in 0..n {
321 sum = sum + self.data[[i, k]] * other_array[[k, j]];
322 }
323 result[[i, j]] = sum;
324 }
325 }
326
327 Ok(Box::new(SparseArrayBase::new(result)))
328 }
329
330 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
331 let (m, n) = self.shape();
332 if n != other.len() {
333 return Err(SparseError::DimensionMismatch {
334 expected: n,
335 found: other.len(),
336 });
337 }
338
339 let mut result = Array1::zeros(m);
340 for i in 0..m {
341 let mut sum = T::sparse_zero();
342 for j in 0..n {
343 sum = sum + self.data[[i, j]] * other[j];
344 }
345 result[i] = sum;
346 }
347
348 Ok(result)
349 }
350
351 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
352 Ok(Box::new(SparseArrayBase::new(self.data.t().to_owned())))
353 }
354
355 fn copy(&self) -> Box<dyn SparseArray<T>> {
356 Box::new(self.clone())
357 }
358
359 fn get(&self, i: usize, j: usize) -> T {
360 self.data[[i, j]]
361 }
362
363 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
364 let (m, n) = self.shape();
365 if i >= m || j >= n {
366 return Err(SparseError::IndexOutOfBounds {
367 index: (i, j),
368 shape: (m, n),
369 });
370 }
371 self.data[[i, j]] = value;
372 Ok(())
373 }
374
375 fn eliminate_zeros(&mut self) {
376 }
378
379 fn sort_indices(&mut self) {
380 }
382
383 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
384 self.copy()
385 }
386
387 fn has_sorted_indices(&self) -> bool {
388 true }
390
391 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
392 match axis {
393 None => {
394 let mut sum = T::sparse_zero();
395 for &val in self.data.iter() {
396 sum = sum + val;
397 }
398 Ok(SparseSum::Scalar(sum))
399 }
400 Some(0) => {
401 let (_, n) = self.shape();
402 let mut result = Array2::zeros((1, n));
403 for j in 0..n {
404 let mut sum = T::sparse_zero();
405 for i in 0..self.data.shape()[0] {
406 sum = sum + self.data[[i, j]];
407 }
408 result[[0, j]] = sum;
409 }
410 Ok(SparseSum::SparseArray(Box::new(SparseArrayBase::new(
411 result,
412 ))))
413 }
414 Some(1) => {
415 let (m_, _) = self.shape();
416 let mut result = Array2::zeros((m_, 1));
417 for i in 0..m_ {
418 let mut sum = T::sparse_zero();
419 for j in 0..self.data.shape()[1] {
420 sum = sum + self.data[[i, j]];
421 }
422 result[[i, 0]] = sum;
423 }
424 Ok(SparseSum::SparseArray(Box::new(SparseArrayBase::new(
425 result,
426 ))))
427 }
428 _ => Err(SparseError::InvalidAxis),
429 }
430 }
431
432 fn max(&self) -> T {
433 if self.data.is_empty() {
434 return T::sparse_zero();
435 }
436 let mut max_val = self.data[[0, 0]];
437 for &val in self.data.iter() {
438 if val > max_val {
439 max_val = val;
440 }
441 }
442 max_val
443 }
444
445 fn min(&self) -> T {
446 if self.data.is_empty() {
447 return T::sparse_zero();
448 }
449 let mut min_val = self.data[[0, 0]];
450 for &val in self.data.iter() {
451 if val < min_val {
452 min_val = val;
453 }
454 }
455 min_val
456 }
457
458 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
459 let (m, n) = self.shape();
460 let nnz = self.nnz();
461 let mut rows = Vec::with_capacity(nnz);
462 let mut cols = Vec::with_capacity(nnz);
463 let mut values = Vec::with_capacity(nnz);
464
465 for i in 0..m {
466 for j in 0..n {
467 let value = self.data[[i, j]];
468 if value != T::sparse_zero() {
469 rows.push(i);
470 cols.push(j);
471 values.push(value);
472 }
473 }
474 }
475
476 (
477 Array1::from_vec(rows),
478 Array1::from_vec(cols),
479 Array1::from_vec(values),
480 )
481 }
482
483 fn slice(
484 &self,
485 row_range: (usize, usize),
486 col_range: (usize, usize),
487 ) -> SparseResult<Box<dyn SparseArray<T>>> {
488 let (start_row, end_row) = row_range;
489 let (start_col, end_col) = col_range;
490 let (m, n) = self.shape();
491
492 if start_row >= m
493 || end_row > m
494 || start_col >= n
495 || end_col > n
496 || start_row >= end_row
497 || start_col >= end_col
498 {
499 return Err(SparseError::InvalidSliceRange);
500 }
501
502 let view = self.data.slice(scirs2_core::ndarray::s![
503 start_row..end_row,
504 start_col..end_col
505 ]);
506 Ok(Box::new(SparseArrayBase::new(view.to_owned())))
507 }
508
509 fn as_any(&self) -> &dyn std::any::Any {
510 self
511 }
512}
513
514impl<T> Clone for SparseArrayBase<T>
515where
516 T: SparseElement + Div<Output = T> + 'static,
517{
518 fn clone(&self) -> Self {
519 Self {
520 data: self.data.clone(),
521 }
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528 use scirs2_core::ndarray::Array;
529
530 #[test]
531 fn test_sparse_array_base() {
532 let data = Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0])
533 .unwrap();
534 let sparse = SparseArrayBase::new(data);
535
536 assert_eq!(sparse.shape(), (3, 3));
537 assert_eq!(sparse.nnz(), 5);
538 assert_eq!(sparse.get(0, 0), 1.0);
539 assert_eq!(sparse.get(1, 1), 3.0);
540 assert_eq!(sparse.get(2, 2), 5.0);
541 assert_eq!(sparse.get(0, 1), 0.0);
542 }
543
544 #[test]
545 fn test_sparse_array_operations() {
546 let data1 = Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
547 let data2 = Array::from_shape_vec((2, 2), vec![5.0, 6.0, 7.0, 8.0]).unwrap();
548
549 let sparse1 = SparseArrayBase::new(data1);
550 let sparse2 = SparseArrayBase::new(data2);
551
552 let result = sparse1.add(&sparse2).unwrap();
554 let result_array = result.to_array();
555 assert_eq!(result_array[[0, 0]], 6.0);
556 assert_eq!(result_array[[0, 1]], 8.0);
557 assert_eq!(result_array[[1, 0]], 10.0);
558 assert_eq!(result_array[[1, 1]], 12.0);
559
560 let result = sparse1.dot(&sparse2).unwrap();
562 let result_array = result.to_array();
563 assert_eq!(result_array[[0, 0]], 19.0);
564 assert_eq!(result_array[[0, 1]], 22.0);
565 assert_eq!(result_array[[1, 0]], 43.0);
566 assert_eq!(result_array[[1, 1]], 50.0);
567 }
568}