Skip to main content

scirs2_sparse/
dok_array.rs

1// Dictionary of Keys (DOK) Array implementation
2//
3// This module provides the DOK (Dictionary of Keys) array format,
4// which is efficient for incremental construction of sparse arrays.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement};
8use std::any::Any;
9use std::collections::HashMap;
10use std::fmt::Debug;
11use std::ops::{Add, Div, Mul, Sub};
12
13use crate::coo_array::CooArray;
14use crate::error::{SparseError, SparseResult};
15use crate::lil_array::LilArray;
16use crate::sparray::{SparseArray, SparseSum};
17
18/// DOK Array format
19///
20/// The DOK (Dictionary of Keys) format stores a sparse array in a dictionary (HashMap)
21/// mapping (row, col) coordinate tuples to values.
22///
23/// # Notes
24///
25/// - Efficient for incremental construction (setting elements one by one)
26/// - Fast random access to individual elements (get/set)
27/// - Slow operations that require iterating over all elements
28/// - Slow arithmetic operations
29/// - Not suitable for large-scale computational operations
30///
31#[derive(Clone)]
32pub struct DokArray<T>
33where
34    T: SparseElement + Div<Output = T> + 'static,
35{
36    /// Dictionary mapping (row, col) to value
37    data: HashMap<(usize, usize), T>,
38    /// Shape of the sparse array
39    shape: (usize, usize),
40}
41
42impl<T> DokArray<T>
43where
44    T: SparseElement + Div<Output = T> + 'static,
45{
46    /// Creates a new DOK array with the given shape
47    ///
48    /// # Arguments
49    /// * `shape` - Shape of the sparse array (rows, cols)
50    ///
51    /// # Returns
52    /// A new empty `DokArray`
53    pub fn new(shape: (usize, usize)) -> Self {
54        Self {
55            data: HashMap::new(),
56            shape,
57        }
58    }
59
60    /// Creates a DOK array from triplet format (COO-like)
61    ///
62    /// # Arguments
63    /// * `rows` - Row indices
64    /// * `cols` - Column indices
65    /// * `data` - Values
66    /// * `shape` - Shape of the sparse array
67    ///
68    /// # Returns
69    /// A new `DokArray`
70    ///
71    /// # Errors
72    /// Returns an error if the data is not consistent
73    pub fn from_triplets(
74        rows: &[usize],
75        cols: &[usize],
76        data: &[T],
77        shape: (usize, usize),
78    ) -> SparseResult<Self> {
79        if rows.len() != cols.len() || rows.len() != data.len() {
80            return Err(SparseError::InconsistentData {
81                reason: "rows, cols, and data must have the same length".to_string(),
82            });
83        }
84
85        let mut dok = Self::new(shape);
86        for i in 0..rows.len() {
87            if rows[i] >= shape.0 || cols[i] >= shape.1 {
88                return Err(SparseError::IndexOutOfBounds {
89                    index: (rows[i], cols[i]),
90                    shape,
91                });
92            }
93            // Only set non-zero values
94            if !SparseElement::is_zero(&data[i]) {
95                dok.data.insert((rows[i], cols[i]), data[i]);
96            }
97        }
98
99        Ok(dok)
100    }
101
102    /// Returns a reference to the internal HashMap
103    pub fn get_data(&self) -> &HashMap<(usize, usize), T> {
104        &self.data
105    }
106
107    /// Returns the triplet representation (row indices, column indices, data)
108    pub fn to_triplets(&self) -> (Array1<usize>, Array1<usize>, Array1<T>)
109    where
110        T: Float + PartialOrd,
111    {
112        let nnz = self.nnz();
113        let mut row_indices = Vec::with_capacity(nnz);
114        let mut col_indices = Vec::with_capacity(nnz);
115        let mut values = Vec::with_capacity(nnz);
116
117        // Sort by row, then column for deterministic output
118        let mut entries: Vec<_> = self.data.iter().collect();
119        entries.sort_by_key(|(&(row, col), _)| (row, col));
120
121        for (&(row, col), &value) in entries {
122            row_indices.push(row);
123            col_indices.push(col);
124            values.push(value);
125        }
126
127        (
128            Array1::from_vec(row_indices),
129            Array1::from_vec(col_indices),
130            Array1::from_vec(values),
131        )
132    }
133
134    /// Creates a DOK array from a dense ndarray
135    ///
136    /// # Arguments
137    /// * `array` - Dense ndarray
138    ///
139    /// # Returns
140    /// A new `DokArray` containing non-zero elements from the input array
141    pub fn from_array(array: &Array2<T>) -> Self {
142        let shape = (array.shape()[0], array.shape()[1]);
143        let mut dok = Self::new(shape);
144
145        for ((i, j), &value) in array.indexed_iter() {
146            if !SparseElement::is_zero(&value) {
147                dok.data.insert((i, j), value);
148            }
149        }
150
151        dok
152    }
153}
154
155impl<T> SparseArray<T> for DokArray<T>
156where
157    T: SparseElement + Div<Output = T> + Float + PartialOrd + 'static,
158{
159    fn shape(&self) -> (usize, usize) {
160        self.shape
161    }
162
163    fn nnz(&self) -> usize {
164        self.data.len()
165    }
166
167    fn dtype(&self) -> &str {
168        "float" // This is a placeholder; ideally, we'd return the actual type
169    }
170
171    fn to_array(&self) -> Array2<T> {
172        let (rows, cols) = self.shape;
173        let mut result = Array2::zeros((rows, cols));
174
175        for (&(row, col), &value) in &self.data {
176            result[[row, col]] = value;
177        }
178
179        result
180    }
181
182    fn toarray(&self) -> Array2<T> {
183        self.to_array()
184    }
185
186    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
187        let (row_indices, col_indices, data) = self.to_triplets();
188        CooArray::new(data, row_indices, col_indices, self.shape, true)
189            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
190    }
191
192    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
193        // First convert to COO, then to CSR
194        match self.to_coo() {
195            Ok(coo) => coo.to_csr(),
196            Err(e) => Err(e),
197        }
198    }
199
200    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
201        // First convert to COO, then to CSC
202        match self.to_coo() {
203            Ok(coo) => coo.to_csc(),
204            Err(e) => Err(e),
205        }
206    }
207
208    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
209        // We're already a DOK array
210        Ok(Box::new(self.clone()))
211    }
212
213    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
214        let (rows_arr, cols_arr, vals_arr) = self.to_triplets();
215        let rows_slice = rows_arr
216            .as_slice()
217            .ok_or_else(|| SparseError::ValueError("non-contiguous row indices".to_string()))?;
218        let cols_slice = cols_arr
219            .as_slice()
220            .ok_or_else(|| SparseError::ValueError("non-contiguous col indices".to_string()))?;
221        let vals_slice = vals_arr
222            .as_slice()
223            .ok_or_else(|| SparseError::ValueError("non-contiguous values".to_string()))?;
224        let lil = LilArray::from_triplets(rows_slice, cols_slice, vals_slice, self.shape)?;
225        Ok(Box::new(lil))
226    }
227
228    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
229        self.to_csr()?.to_dia()
230    }
231
232    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
233        self.to_csr()?.to_bsr()
234    }
235
236    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
237        if self.shape() != other.shape() {
238            return Err(SparseError::DimensionMismatch {
239                expected: self.shape().0,
240                found: other.shape().0,
241            });
242        }
243
244        let mut result = self.clone();
245        let other_array = other.to_array();
246
247        // Add existing values from self
248        for (&(row, col), &value) in &self.data {
249            result.set(row, col, value + other_array[[row, col]])?;
250        }
251
252        // Add values from other that aren't in self
253        for ((row, col), &value) in other_array.indexed_iter() {
254            if !self.data.contains_key(&(row, col)) && !SparseElement::is_zero(&value) {
255                result.set(row, col, value)?;
256            }
257        }
258
259        Ok(Box::new(result))
260    }
261
262    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
263        if self.shape() != other.shape() {
264            return Err(SparseError::DimensionMismatch {
265                expected: self.shape().0,
266                found: other.shape().0,
267            });
268        }
269
270        let mut result = self.clone();
271        let other_array = other.to_array();
272
273        // Subtract existing values from self
274        for (&(row, col), &value) in &self.data {
275            result.set(row, col, value - other_array[[row, col]])?;
276        }
277
278        // Subtract values from other that aren't in self
279        for ((row, col), &value) in other_array.indexed_iter() {
280            if !self.data.contains_key(&(row, col)) && !SparseElement::is_zero(&value) {
281                result.set(row, col, -value)?;
282            }
283        }
284
285        Ok(Box::new(result))
286    }
287
288    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
289        if self.shape() != other.shape() {
290            return Err(SparseError::DimensionMismatch {
291                expected: self.shape().0,
292                found: other.shape().0,
293            });
294        }
295
296        let mut result = DokArray::new(self.shape());
297        let other_array = other.to_array();
298
299        // Only need to process entries in self
300        // since a*0 = 0 for any a
301        for (&(row, col), &value) in &self.data {
302            let product = value * other_array[[row, col]];
303            if !SparseElement::is_zero(&product) {
304                result.set(row, col, product)?;
305            }
306        }
307
308        Ok(Box::new(result))
309    }
310
311    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
312        if self.shape() != other.shape() {
313            return Err(SparseError::DimensionMismatch {
314                expected: self.shape().0,
315                found: other.shape().0,
316            });
317        }
318
319        let mut result = DokArray::new(self.shape());
320        let other_array = other.to_array();
321
322        for (&(row, col), &value) in &self.data {
323            let divisor = other_array[[row, col]];
324            if SparseElement::is_zero(&divisor) {
325                return Err(SparseError::ComputationError(
326                    "Division by zero".to_string(),
327                ));
328            }
329
330            let quotient = value / divisor;
331            if !SparseElement::is_zero(&quotient) {
332                result.set(row, col, quotient)?;
333            }
334        }
335
336        Ok(Box::new(result))
337    }
338
339    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
340        let (_m, n) = self.shape();
341        let (p, q) = other.shape();
342
343        if n != p {
344            return Err(SparseError::DimensionMismatch {
345                expected: n,
346                found: p,
347            });
348        }
349
350        // Convert to CSR for efficient matrix multiplication
351        let csr_self = self.to_csr()?;
352        let csr_other = other.to_csr()?;
353
354        csr_self.dot(&*csr_other)
355    }
356
357    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
358        let (m, n) = self.shape();
359        if n != other.len() {
360            return Err(SparseError::DimensionMismatch {
361                expected: n,
362                found: other.len(),
363            });
364        }
365
366        let mut result = Array1::zeros(m);
367
368        for (&(row, col), &value) in &self.data {
369            result[row] = result[row] + value * other[col];
370        }
371
372        Ok(result)
373    }
374
375    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
376        let (rows, cols) = self.shape;
377        let mut result = DokArray::new((cols, rows));
378
379        for (&(row, col), &value) in &self.data {
380            result.set(col, row, value)?;
381        }
382
383        Ok(Box::new(result))
384    }
385
386    fn copy(&self) -> Box<dyn SparseArray<T>> {
387        Box::new(self.clone())
388    }
389
390    fn get(&self, i: usize, j: usize) -> T {
391        if i >= self.shape.0 || j >= self.shape.1 {
392            return T::sparse_zero();
393        }
394
395        *self.data.get(&(i, j)).unwrap_or(&T::sparse_zero())
396    }
397
398    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
399        if i >= self.shape.0 || j >= self.shape.1 {
400            return Err(SparseError::IndexOutOfBounds {
401                index: (i, j),
402                shape: self.shape,
403            });
404        }
405
406        if SparseElement::is_zero(&value) {
407            // Remove zero entries
408            self.data.remove(&(i, j));
409        } else {
410            // Set non-zero value
411            self.data.insert((i, j), value);
412        }
413
414        Ok(())
415    }
416
417    fn eliminate_zeros(&mut self) {
418        // DOK format already doesn't store zeros, but just in case
419        self.data
420            .retain(|_, &mut value| !SparseElement::is_zero(&value));
421    }
422
423    fn sort_indices(&mut self) {
424        // No-op for DOK format since it's a HashMap
425    }
426
427    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
428        // DOK doesn't have the concept of sorted indices
429        self.copy()
430    }
431
432    fn has_sorted_indices(&self) -> bool {
433        true // DOK format doesn't have the concept of sorted indices
434    }
435
436    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
437        match axis {
438            None => {
439                // Sum all elements
440                let mut sum = T::sparse_zero();
441                for &value in self.data.values() {
442                    sum = sum + value;
443                }
444                Ok(SparseSum::Scalar(sum))
445            }
446            Some(0) => {
447                // Sum along rows
448                let (_, cols) = self.shape();
449                let mut result = DokArray::new((1, cols));
450
451                for (&(_row, col), &value) in &self.data {
452                    let current = result.get(0, col);
453                    result.set(0, col, current + value)?;
454                }
455
456                Ok(SparseSum::SparseArray(Box::new(result)))
457            }
458            Some(1) => {
459                // Sum along columns
460                let (rows, _) = self.shape();
461                let mut result = DokArray::new((rows, 1));
462
463                for (&(row, col), &value) in &self.data {
464                    let current = result.get(row, 0);
465                    result.set(row, 0, current + value)?;
466                }
467
468                Ok(SparseSum::SparseArray(Box::new(result)))
469            }
470            _ => Err(SparseError::InvalidAxis),
471        }
472    }
473
474    fn max(&self) -> T {
475        if self.data.is_empty() {
476            return T::nan();
477        }
478
479        self.data
480            .values()
481            .fold(T::neg_infinity(), |acc, &x| acc.max(x))
482    }
483
484    fn min(&self) -> T {
485        if self.data.is_empty() {
486            return T::nan();
487        }
488
489        self.data
490            .values()
491            .fold(T::sparse_zero(), |acc, &x| acc.min(x))
492    }
493
494    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
495        self.to_triplets()
496    }
497
498    fn slice(
499        &self,
500        row_range: (usize, usize),
501        col_range: (usize, usize),
502    ) -> SparseResult<Box<dyn SparseArray<T>>> {
503        let (start_row, end_row) = row_range;
504        let (start_col, end_col) = col_range;
505        let (rows, cols) = self.shape;
506
507        if start_row >= rows
508            || end_row > rows
509            || start_col >= cols
510            || end_col > cols
511            || start_row >= end_row
512            || start_col >= end_col
513        {
514            return Err(SparseError::InvalidSliceRange);
515        }
516
517        let sliceshape = (end_row - start_row, end_col - start_col);
518        let mut result = DokArray::new(sliceshape);
519
520        for (&(row, col), &value) in &self.data {
521            if row >= start_row && row < end_row && col >= start_col && col < end_col {
522                result.set(row - start_row, col - start_col, value)?;
523            }
524        }
525
526        Ok(Box::new(result))
527    }
528
529    fn as_any(&self) -> &dyn Any {
530        self
531    }
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537    use scirs2_core::ndarray::Array;
538
539    #[test]
540    fn test_dok_array_create_and_access() {
541        // Create a 3x3 sparse array
542        let mut array = DokArray::<f64>::new((3, 3));
543
544        // Set some values
545        array
546            .set(0, 0, 1.0)
547            .expect("Test: failed to set array element");
548        array
549            .set(0, 2, 2.0)
550            .expect("Test: failed to set array element");
551        array
552            .set(1, 2, 3.0)
553            .expect("Test: failed to set array element");
554        array
555            .set(2, 0, 4.0)
556            .expect("Test: failed to set array element");
557        array
558            .set(2, 1, 5.0)
559            .expect("Test: failed to set array element");
560
561        assert_eq!(array.nnz(), 5);
562
563        // Access values
564        assert_eq!(array.get(0, 0), 1.0);
565        assert_eq!(array.get(0, 1), 0.0); // Zero entry
566        assert_eq!(array.get(0, 2), 2.0);
567        assert_eq!(array.get(1, 2), 3.0);
568        assert_eq!(array.get(2, 0), 4.0);
569        assert_eq!(array.get(2, 1), 5.0);
570
571        // Set a value to zero should remove it
572        array
573            .set(0, 0, 0.0)
574            .expect("Test: failed to set array element");
575        assert_eq!(array.nnz(), 4);
576        assert_eq!(array.get(0, 0), 0.0);
577
578        // Out of bounds access should return zero
579        assert_eq!(array.get(3, 0), 0.0);
580        assert_eq!(array.get(0, 3), 0.0);
581    }
582
583    #[test]
584    fn test_dok_array_from_triplets() {
585        let rows = vec![0, 0, 1, 2, 2];
586        let cols = vec![0, 2, 2, 0, 1];
587        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
588
589        let array = DokArray::from_triplets(&rows, &cols, &data, (3, 3))
590            .expect("Test: failed to create DokArray from triplets");
591
592        assert_eq!(array.nnz(), 5);
593        assert_eq!(array.get(0, 0), 1.0);
594        assert_eq!(array.get(0, 2), 2.0);
595        assert_eq!(array.get(1, 2), 3.0);
596        assert_eq!(array.get(2, 0), 4.0);
597        assert_eq!(array.get(2, 1), 5.0);
598    }
599
600    #[test]
601    fn test_dok_array_to_array() {
602        let rows = vec![0, 0, 1, 2, 2];
603        let cols = vec![0, 2, 2, 0, 1];
604        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
605
606        let array = DokArray::from_triplets(&rows, &cols, &data, (3, 3))
607            .expect("Test: failed to create DokArray from triplets");
608        let dense = array.to_array();
609
610        let expected =
611            Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0])
612                .expect("Test: failed to create array from shape vec");
613
614        assert_eq!(dense, expected);
615    }
616
617    #[test]
618    fn test_dok_array_from_array() {
619        let dense =
620            Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0])
621                .expect("Test: failed to create array from shape vec");
622
623        let array = DokArray::from_array(&dense);
624
625        assert_eq!(array.nnz(), 5);
626        assert_eq!(array.get(0, 0), 1.0);
627        assert_eq!(array.get(0, 2), 2.0);
628        assert_eq!(array.get(1, 2), 3.0);
629        assert_eq!(array.get(2, 0), 4.0);
630        assert_eq!(array.get(2, 1), 5.0);
631    }
632
633    #[test]
634    fn test_dok_array_add() {
635        let mut array1 = DokArray::<f64>::new((2, 2));
636        array1
637            .set(0, 0, 1.0)
638            .expect("Test: failed to set array element");
639        array1
640            .set(0, 1, 2.0)
641            .expect("Test: failed to set array element");
642        array1
643            .set(1, 0, 3.0)
644            .expect("Test: failed to set array element");
645
646        let mut array2 = DokArray::<f64>::new((2, 2));
647        array2
648            .set(0, 0, 4.0)
649            .expect("Test: failed to set array element");
650        array2
651            .set(1, 1, 5.0)
652            .expect("Test: failed to set array element");
653
654        let result = array1.add(&array2).expect("Test: array addition failed");
655        let dense_result = result.to_array();
656
657        assert_eq!(dense_result[[0, 0]], 5.0);
658        assert_eq!(dense_result[[0, 1]], 2.0);
659        assert_eq!(dense_result[[1, 0]], 3.0);
660        assert_eq!(dense_result[[1, 1]], 5.0);
661    }
662
663    #[test]
664    fn test_dok_array_mul() {
665        let mut array1 = DokArray::<f64>::new((2, 2));
666        array1
667            .set(0, 0, 1.0)
668            .expect("Test: failed to set array element");
669        array1
670            .set(0, 1, 2.0)
671            .expect("Test: failed to set array element");
672        array1
673            .set(1, 0, 3.0)
674            .expect("Test: failed to set array element");
675        array1
676            .set(1, 1, 4.0)
677            .expect("Test: failed to set array element");
678
679        let mut array2 = DokArray::<f64>::new((2, 2));
680        array2
681            .set(0, 0, 5.0)
682            .expect("Test: failed to set array element");
683        array2
684            .set(0, 1, 6.0)
685            .expect("Test: failed to set array element");
686        array2
687            .set(1, 0, 7.0)
688            .expect("Test: failed to set array element");
689        array2
690            .set(1, 1, 8.0)
691            .expect("Test: failed to set array element");
692
693        // Element-wise multiplication
694        let result = array1
695            .mul(&array2)
696            .expect("Test: array multiplication failed");
697        let dense_result = result.to_array();
698
699        assert_eq!(dense_result[[0, 0]], 5.0);
700        assert_eq!(dense_result[[0, 1]], 12.0);
701        assert_eq!(dense_result[[1, 0]], 21.0);
702        assert_eq!(dense_result[[1, 1]], 32.0);
703    }
704
705    #[test]
706    fn test_dok_array_dot() {
707        let mut array1 = DokArray::<f64>::new((2, 2));
708        array1
709            .set(0, 0, 1.0)
710            .expect("Test: failed to set array element");
711        array1
712            .set(0, 1, 2.0)
713            .expect("Test: failed to set array element");
714        array1
715            .set(1, 0, 3.0)
716            .expect("Test: failed to set array element");
717        array1
718            .set(1, 1, 4.0)
719            .expect("Test: failed to set array element");
720
721        let mut array2 = DokArray::<f64>::new((2, 2));
722        array2
723            .set(0, 0, 5.0)
724            .expect("Test: failed to set array element");
725        array2
726            .set(0, 1, 6.0)
727            .expect("Test: failed to set array element");
728        array2
729            .set(1, 0, 7.0)
730            .expect("Test: failed to set array element");
731        array2
732            .set(1, 1, 8.0)
733            .expect("Test: failed to set array element");
734
735        // Matrix multiplication
736        let result = array1.dot(&array2).expect("Test: array dot product failed");
737        let dense_result = result.to_array();
738
739        // [1 2] [5 6] = [1*5 + 2*7, 1*6 + 2*8] = [19, 22]
740        // [3 4] [7 8]   [3*5 + 4*7, 3*6 + 4*8]   [43, 50]
741        assert_eq!(dense_result[[0, 0]], 19.0);
742        assert_eq!(dense_result[[0, 1]], 22.0);
743        assert_eq!(dense_result[[1, 0]], 43.0);
744        assert_eq!(dense_result[[1, 1]], 50.0);
745    }
746
747    #[test]
748    fn test_dok_array_transpose() {
749        let mut array = DokArray::<f64>::new((2, 3));
750        array
751            .set(0, 0, 1.0)
752            .expect("Test: failed to set array element");
753        array
754            .set(0, 1, 2.0)
755            .expect("Test: failed to set array element");
756        array
757            .set(0, 2, 3.0)
758            .expect("Test: failed to set array element");
759        array
760            .set(1, 0, 4.0)
761            .expect("Test: failed to set array element");
762        array
763            .set(1, 1, 5.0)
764            .expect("Test: failed to set array element");
765        array
766            .set(1, 2, 6.0)
767            .expect("Test: failed to set array element");
768
769        let transposed = array.transpose().expect("Test: array transpose failed");
770
771        assert_eq!(transposed.shape(), (3, 2));
772        assert_eq!(transposed.get(0, 0), 1.0);
773        assert_eq!(transposed.get(1, 0), 2.0);
774        assert_eq!(transposed.get(2, 0), 3.0);
775        assert_eq!(transposed.get(0, 1), 4.0);
776        assert_eq!(transposed.get(1, 1), 5.0);
777        assert_eq!(transposed.get(2, 1), 6.0);
778    }
779
780    #[test]
781    fn test_dok_array_slice() {
782        let mut array = DokArray::<f64>::new((3, 3));
783        array
784            .set(0, 0, 1.0)
785            .expect("Test: failed to set array element");
786        array
787            .set(0, 1, 2.0)
788            .expect("Test: failed to set array element");
789        array
790            .set(0, 2, 3.0)
791            .expect("Test: failed to set array element");
792        array
793            .set(1, 0, 4.0)
794            .expect("Test: failed to set array element");
795        array
796            .set(1, 1, 5.0)
797            .expect("Test: failed to set array element");
798        array
799            .set(1, 2, 6.0)
800            .expect("Test: failed to set array element");
801        array
802            .set(2, 0, 7.0)
803            .expect("Test: failed to set array element");
804        array
805            .set(2, 1, 8.0)
806            .expect("Test: failed to set array element");
807        array
808            .set(2, 2, 9.0)
809            .expect("Test: failed to set array element");
810
811        let slice = array
812            .slice((0, 2), (1, 3))
813            .expect("Test: array slice failed");
814
815        assert_eq!(slice.shape(), (2, 2));
816        assert_eq!(slice.get(0, 0), 2.0);
817        assert_eq!(slice.get(0, 1), 3.0);
818        assert_eq!(slice.get(1, 0), 5.0);
819        assert_eq!(slice.get(1, 1), 6.0);
820    }
821
822    #[test]
823    fn test_dok_array_sum() {
824        let mut array = DokArray::<f64>::new((2, 3));
825        array
826            .set(0, 0, 1.0)
827            .expect("Test: failed to set array element");
828        array
829            .set(0, 1, 2.0)
830            .expect("Test: failed to set array element");
831        array
832            .set(0, 2, 3.0)
833            .expect("Test: failed to set array element");
834        array
835            .set(1, 0, 4.0)
836            .expect("Test: failed to set array element");
837        array
838            .set(1, 1, 5.0)
839            .expect("Test: failed to set array element");
840        array
841            .set(1, 2, 6.0)
842            .expect("Test: failed to set array element");
843
844        // Sum all elements
845        match array.sum(None).expect("Test: array sum failed") {
846            SparseSum::Scalar(sum) => assert_eq!(sum, 21.0),
847            _ => panic!("Expected scalar sum"),
848        }
849
850        // Sum along rows (axis 0)
851        match array.sum(Some(0)).expect("Test: array sum failed") {
852            SparseSum::SparseArray(sum_array) => {
853                assert_eq!(sum_array.shape(), (1, 3));
854                assert_eq!(sum_array.get(0, 0), 5.0);
855                assert_eq!(sum_array.get(0, 1), 7.0);
856                assert_eq!(sum_array.get(0, 2), 9.0);
857            }
858            _ => panic!("Expected sparse array"),
859        }
860
861        // Sum along columns (axis 1)
862        match array.sum(Some(1)).expect("Test: array sum failed") {
863            SparseSum::SparseArray(sum_array) => {
864                assert_eq!(sum_array.shape(), (2, 1));
865                assert_eq!(sum_array.get(0, 0), 6.0);
866                assert_eq!(sum_array.get(1, 0), 15.0);
867            }
868            _ => panic!("Expected sparse array"),
869        }
870    }
871}