scirs2_sparse/
construct.rs

1// Construction utilities for sparse arrays
2//
3// This module provides functions for constructing sparse arrays,
4// including identity matrices, diagonal matrices, random arrays, etc.
5
6#![allow(unused_variables)]
7#![allow(unused_assignments)]
8#![allow(unused_mut)]
9
10use scirs2_core::ndarray::Array1;
11use scirs2_core::numeric::{Float, SparseElement};
12use scirs2_core::random::seq::SliceRandom;
13use scirs2_core::random::{Rng, SeedableRng};
14use std::fmt::Debug;
15use std::ops::{Add, Div, Mul, Sub};
16
17use crate::coo_array::CooArray;
18use crate::csr_array::CsrArray;
19use crate::dok_array::DokArray;
20use crate::error::{SparseError, SparseResult};
21use crate::lil_array::LilArray;
22use crate::sparray::SparseArray;
23
24// Import parallel operations from scirs2-core
25use scirs2_core::parallel_ops::*;
26
27/// Creates a sparse identity array of size n x n
28///
29/// # Arguments
30/// * `n` - Size of the square array
31/// * `format` - Format of the output array ("csr" or "coo")
32///
33/// # Returns
34/// A sparse array representing the identity matrix
35///
36/// # Examples
37///
38/// ```
39/// use scirs2_sparse::construct::eye_array;
40///
41/// let eye: Box<dyn scirs2_sparse::SparseArray<f64>> = eye_array(3, "csr").unwrap();
42/// assert_eq!(eye.shape(), (3, 3));
43/// assert_eq!(eye.nnz(), 3);
44/// assert_eq!(eye.get(0, 0), 1.0);
45/// assert_eq!(eye.get(1, 1), 1.0);
46/// assert_eq!(eye.get(2, 2), 1.0);
47/// assert_eq!(eye.get(0, 1), 0.0);
48/// ```
49#[allow(dead_code)]
50pub fn eye_array<T>(n: usize, format: &str) -> SparseResult<Box<dyn SparseArray<T>>>
51where
52    T: SparseElement + Div<Output = T> + Float + 'static,
53{
54    if n == 0 {
55        return Err(SparseError::ValueError(
56            "Matrix dimension must be positive".to_string(),
57        ));
58    }
59
60    let mut rows = Vec::with_capacity(n);
61    let mut cols = Vec::with_capacity(n);
62    let mut data = Vec::with_capacity(n);
63
64    for i in 0..n {
65        rows.push(i);
66        cols.push(i);
67        data.push(T::sparse_one());
68    }
69
70    match format.to_lowercase().as_str() {
71        "csr" => CsrArray::from_triplets(&rows, &cols, &data, (n, n), true)
72            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
73        "coo" => CooArray::from_triplets(&rows, &cols, &data, (n, n), true)
74            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
75        "dok" => DokArray::from_triplets(&rows, &cols, &data, (n, n))
76            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
77        "lil" => LilArray::from_triplets(&rows, &cols, &data, (n, n))
78            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
79        _ => Err(SparseError::ValueError(format!(
80            "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
81        ))),
82    }
83}
84
85/// Creates a sparse identity array of size m x n with k-th diagonal filled with ones
86///
87/// # Arguments
88/// * `m` - Number of rows
89/// * `n` - Number of columns
90/// * `k` - Diagonal index (0 = main diagonal, >0 = above main, <0 = below main)
91/// * `format` - Format of the output array ("csr" or "coo")
92///
93/// # Returns
94/// A sparse array with ones on the specified diagonal
95///
96/// # Examples
97///
98/// ```
99/// use scirs2_sparse::construct::eye_array_k;
100///
101/// // Identity with main diagonal (k=0)
102/// let eye: Box<dyn scirs2_sparse::SparseArray<f64>> = eye_array_k(3, 3, 0, "csr").unwrap();
103/// assert_eq!(eye.get(0, 0), 1.0);
104/// assert_eq!(eye.get(1, 1), 1.0);
105/// assert_eq!(eye.get(2, 2), 1.0);
106///
107/// // Superdiagonal (k=1)
108/// let superdiag: Box<dyn scirs2_sparse::SparseArray<f64>> = eye_array_k(3, 4, 1, "csr").unwrap();
109/// assert_eq!(superdiag.get(0, 1), 1.0);
110/// assert_eq!(superdiag.get(1, 2), 1.0);
111/// assert_eq!(superdiag.get(2, 3), 1.0);
112///
113/// // Subdiagonal (k=-1)
114/// let subdiag: Box<dyn scirs2_sparse::SparseArray<f64>> = eye_array_k(4, 3, -1, "csr").unwrap();
115/// assert_eq!(subdiag.get(1, 0), 1.0);
116/// assert_eq!(subdiag.get(2, 1), 1.0);
117/// assert_eq!(subdiag.get(3, 2), 1.0);
118/// ```
119#[allow(dead_code)]
120pub fn eye_array_k<T>(
121    m: usize,
122    n: usize,
123    k: isize,
124    format: &str,
125) -> SparseResult<Box<dyn SparseArray<T>>>
126where
127    T: SparseElement + Div<Output = T> + Float + 'static,
128{
129    if m == 0 || n == 0 {
130        return Err(SparseError::ValueError(
131            "Matrix dimensions must be positive".to_string(),
132        ));
133    }
134
135    let mut rows = Vec::new();
136    let mut cols = Vec::new();
137    let mut data = Vec::new();
138
139    // Calculate diagonal elements
140    if k >= 0 {
141        let k_usize = k as usize;
142        let len = std::cmp::min(m, n.saturating_sub(k_usize));
143
144        for i in 0..len {
145            rows.push(i);
146            cols.push(i + k_usize);
147            data.push(T::sparse_one());
148        }
149    } else {
150        let k_abs = (-k) as usize;
151        let len = std::cmp::min(m.saturating_sub(k_abs), n);
152
153        for i in 0..len {
154            rows.push(i + k_abs);
155            cols.push(i);
156            data.push(T::sparse_one());
157        }
158    }
159
160    match format.to_lowercase().as_str() {
161        "csr" => CsrArray::from_triplets(&rows, &cols, &data, (m, n), true)
162            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
163        "coo" => CooArray::from_triplets(&rows, &cols, &data, (m, n), true)
164            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
165        "dok" => DokArray::from_triplets(&rows, &cols, &data, (m, n))
166            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
167        "lil" => LilArray::from_triplets(&rows, &cols, &data, (m, n))
168            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
169        _ => Err(SparseError::ValueError(format!(
170            "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
171        ))),
172    }
173}
174
175/// Creates a sparse array from the specified diagonals
176///
177/// # Arguments
178/// * `diagonals` - Data for the diagonals
179/// * `offsets` - Offset for each diagonal (0 = main, >0 = above main, <0 = below main)
180/// * `shape` - Shape of the output array (m, n)
181/// * `format` - Format of the output array ("csr" or "coo")
182///
183/// # Returns
184/// A sparse array with the specified diagonals
185///
186/// # Examples
187///
188/// ```
189/// use scirs2_sparse::construct::diags_array;
190/// use scirs2_core::ndarray::Array1;
191///
192/// let diags = vec![
193///     Array1::from_vec(vec![1.0, 2.0, 3.0]), // main diagonal
194///     Array1::from_vec(vec![4.0, 5.0])       // superdiagonal
195/// ];
196/// let offsets = vec![0, 1];
197/// let shape = (3, 3);
198///
199/// let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
200/// assert_eq!(result.shape(), (3, 3));
201/// assert_eq!(result.get(0, 0), 1.0);
202/// assert_eq!(result.get(1, 1), 2.0);
203/// assert_eq!(result.get(2, 2), 3.0);
204/// assert_eq!(result.get(0, 1), 4.0);
205/// assert_eq!(result.get(1, 2), 5.0);
206/// ```
207#[allow(dead_code)]
208pub fn diags_array<T>(
209    diagonals: &[Array1<T>],
210    offsets: &[isize],
211    shape: (usize, usize),
212    format: &str,
213) -> SparseResult<Box<dyn SparseArray<T>>>
214where
215    T: SparseElement + Div<Output = T> + Float + 'static,
216{
217    if diagonals.len() != offsets.len() {
218        return Err(SparseError::InconsistentData {
219            reason: "Number of diagonals must match number of offsets".to_string(),
220        });
221    }
222
223    if shape.0 == 0 || shape.1 == 0 {
224        return Err(SparseError::ValueError(
225            "Matrix dimensions must be positive".to_string(),
226        ));
227    }
228
229    let (m, n) = shape;
230    let mut rows = Vec::new();
231    let mut cols = Vec::new();
232    let mut data = Vec::new();
233
234    for (i, (diag, &offset)) in diagonals.iter().zip(offsets.iter()).enumerate() {
235        if offset >= 0 {
236            let offset_usize = offset as usize;
237            let max_len = std::cmp::min(m, n.saturating_sub(offset_usize));
238
239            if diag.len() > max_len {
240                return Err(SparseError::InconsistentData {
241                    reason: format!("Diagonal {i} is too long ({} > {})", diag.len(), max_len),
242                });
243            }
244
245            for (j, &value) in diag.iter().enumerate() {
246                if !SparseElement::is_zero(&value) {
247                    rows.push(j);
248                    cols.push(j + offset_usize);
249                    data.push(value);
250                }
251            }
252        } else {
253            let offset_abs = (-offset) as usize;
254            let max_len = std::cmp::min(m.saturating_sub(offset_abs), n);
255
256            if diag.len() > max_len {
257                return Err(SparseError::InconsistentData {
258                    reason: format!("Diagonal {i} is too long ({} > {})", diag.len(), max_len),
259                });
260            }
261
262            for (j, &value) in diag.iter().enumerate() {
263                if !SparseElement::is_zero(&value) {
264                    rows.push(j + offset_abs);
265                    cols.push(j);
266                    data.push(value);
267                }
268            }
269        }
270    }
271
272    match format.to_lowercase().as_str() {
273        "csr" => CsrArray::from_triplets(&rows, &cols, &data, shape, false)
274            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
275        "coo" => CooArray::from_triplets(&rows, &cols, &data, shape, false)
276            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
277        "dok" => DokArray::from_triplets(&rows, &cols, &data, shape)
278            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
279        "lil" => LilArray::from_triplets(&rows, &cols, &data, shape)
280            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
281        _ => Err(SparseError::ValueError(format!(
282            "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
283        ))),
284    }
285}
286
287/// Creates a random sparse array with specified density
288///
289/// # Arguments
290/// * `shape` - Shape of the output array (m, n)
291/// * `density` - Density of non-zero elements (between 0.0 and 1.0)
292/// * `seed` - Optional seed for the random number generator
293/// * `format` - Format of the output array ("csr" or "coo")
294///
295/// # Returns
296/// A sparse array with random non-zero elements
297///
298/// # Examples
299///
300/// ```
301/// use scirs2_sparse::construct::random_array;
302///
303/// // Create a 10x10 array with 30% non-zero elements
304/// let random = random_array::<f64>((10, 10), 0.3, None, "csr").unwrap();
305/// assert_eq!(random.shape(), (10, 10));
306///
307/// // Create a random array with a specific seed
308/// let seeded = random_array::<f64>((5, 5), 0.5, Some(42), "coo").unwrap();
309/// assert_eq!(seeded.shape(), (5, 5));
310/// ```
311#[allow(dead_code)]
312pub fn random_array<T>(
313    shape: (usize, usize),
314    density: f64,
315    seed: Option<u64>,
316    format: &str,
317) -> SparseResult<Box<dyn SparseArray<T>>>
318where
319    T: Float + SparseElement + Div<Output = T> + 'static,
320{
321    let (m, n) = shape;
322
323    if !(0.0..=1.0).contains(&density) {
324        return Err(SparseError::ValueError(
325            "Density must be between 0.0 and 1.0".to_string(),
326        ));
327    }
328
329    if m == 0 || n == 0 {
330        return Err(SparseError::ValueError(
331            "Matrix dimensions must be positive".to_string(),
332        ));
333    }
334
335    // Calculate the number of non-zero elements
336    let nnz = (m * n) as f64 * density;
337    let nnz = nnz.round() as usize;
338
339    // Create random indices
340    let mut rows = Vec::with_capacity(nnz);
341    let mut cols = Vec::with_capacity(nnz);
342    let mut data = Vec::with_capacity(nnz);
343
344    // Create RNG
345    let mut rng = if let Some(seed_value) = seed {
346        scirs2_core::random::rngs::StdRng::seed_from_u64(seed_value)
347    } else {
348        // For a random seed, use rng
349        let seed = scirs2_core::random::random::<u64>();
350        scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
351    };
352
353    // Generate random elements
354    let total = m * n;
355
356    if density > 0.4 {
357        // For high densities, more efficient to generate a mask
358        let mut indices: Vec<usize> = (0..total).collect();
359        indices.shuffle(&mut rng);
360
361        for &idx in indices.iter().take(nnz) {
362            let row = idx / n;
363            let col = idx % n;
364
365            rows.push(row);
366            cols.push(col);
367
368            // Generate random non-zero value
369            // For simplicity, using values between -1 and 1
370            let mut val: f64 = rng.random_range(-1.0..1.0);
371            // Make sure the value is not zero
372            while val.abs() < 1e-10 {
373                val = rng.random_range(-1.0..1.0);
374            }
375            data.push(T::from(val).unwrap());
376        }
377    } else {
378        // For low densities..use a set to track already-chosen positions
379        let mut positions = std::collections::HashSet::with_capacity(nnz);
380
381        while positions.len() < nnz {
382            let row = rng.random_range(0..m);
383            let col = rng.random_range(0..n);
384            let pos = row * n + col; // Using row/col as usize indices
385
386            if positions.insert(pos) {
387                rows.push(row);
388                cols.push(col);
389
390                // Generate random non-zero value
391                let mut val: f64 = rng.random_range(-1.0..1.0);
392                // Make sure the value is not zero
393                while val.abs() < 1e-10 {
394                    val = rng.random_range(-1.0..1.0);
395                }
396                data.push(T::from(val).unwrap());
397            }
398        }
399    }
400
401    // Create the output array
402    match format.to_lowercase().as_str() {
403        "csr" => CsrArray::from_triplets(&rows, &cols, &data, shape, false)
404            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
405        "coo" => CooArray::from_triplets(&rows, &cols, &data, shape, false)
406            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
407        "dok" => DokArray::from_triplets(&rows, &cols, &data, shape)
408            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
409        "lil" => LilArray::from_triplets(&rows, &cols, &data, shape)
410            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
411        _ => Err(SparseError::ValueError(format!(
412            "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
413        ))),
414    }
415}
416
417/// Creates a large sparse random array using parallel processing
418///
419/// This function uses parallel construction for improved performance when creating
420/// large sparse arrays with many non-zero elements.
421///
422/// # Arguments
423/// * `shape` - Shape of the array (rows, cols)
424/// * `density` - Density of non-zero elements (0.0 to 1.0)
425/// * `seed` - Optional random seed for reproducibility
426/// * `format` - Format of the output array ("csr" or "coo")
427/// * `parallel_threshold` - Minimum number of elements to use parallel construction
428///
429/// # Returns
430/// A sparse array with randomly distributed non-zero elements
431///
432/// # Examples
433///
434/// ```
435/// use scirs2_sparse::construct::random_array_parallel;
436///
437/// // Create a large random sparse array
438/// let large_random = random_array_parallel::<f64>((1000, 1000), 0.01, Some(42), "csr", 10000).unwrap();
439/// assert_eq!(large_random.shape(), (1000, 1000));
440/// assert!(large_random.nnz() > 5000); // Approximately 10000 non-zeros expected
441/// ```
442#[allow(dead_code)]
443pub fn random_array_parallel<T>(
444    shape: (usize, usize),
445    density: f64,
446    seed: Option<u64>,
447    format: &str,
448    parallel_threshold: usize,
449) -> SparseResult<Box<dyn SparseArray<T>>>
450where
451    T: Float + SparseElement + Div<Output = T> + Send + Sync + 'static,
452{
453    if !(0.0..=1.0).contains(&density) {
454        return Err(SparseError::ValueError(
455            "Density must be between 0.0 and 1.0".to_string(),
456        ));
457    }
458
459    let (rows, cols) = shape;
460    if rows == 0 || cols == 0 {
461        return Err(SparseError::ValueError(
462            "Matrix dimensions must be positive".to_string(),
463        ));
464    }
465
466    let total_elements = rows * cols;
467    let expected_nnz = (total_elements as f64 * density) as usize;
468
469    // Use parallel construction for large matrices
470    if total_elements >= parallel_threshold && expected_nnz >= 1000 {
471        parallel_random_construction(shape, density, seed, format)
472    } else {
473        // Fall back to sequential construction for small matrices
474        random_array(shape, density, seed, format)
475    }
476}
477
478/// Internal parallel construction function
479#[allow(dead_code)]
480fn parallel_random_construction<T>(
481    shape: (usize, usize),
482    density: f64,
483    seed: Option<u64>,
484    format: &str,
485) -> SparseResult<Box<dyn SparseArray<T>>>
486where
487    T: Float + SparseElement + Div<Output = T> + Send + Sync + 'static,
488{
489    let (rows, cols) = shape;
490    let total_elements = rows * cols;
491    let expected_nnz = (total_elements as f64 * density) as usize;
492
493    // Determine number of chunks based on available parallelism
494    let num_chunks = std::cmp::min(scirs2_core::parallel_ops::get_num_threads(), rows.min(cols));
495    let chunk_size = std::cmp::max(1, rows / num_chunks);
496
497    // Create row chunks for parallel processing
498    let row_chunks: Vec<_> = (0..rows)
499        .collect::<Vec<_>>()
500        .chunks(chunk_size)
501        .map(|chunk| chunk.to_vec())
502        .collect();
503
504    // Generate random elements in parallel using enumerate to get chunk index
505    let chunk_data: Vec<_> = row_chunks.iter().enumerate().collect();
506    let results: Vec<_> = parallel_map(&chunk_data, |(chunk_idx, row_chunk)| {
507        let mut local_rows = Vec::new();
508        let mut local_cols = Vec::new();
509        let mut local_data = Vec::new();
510
511        // Use a different seed for each chunk to ensure good randomization
512        let chunk_seed = seed.unwrap_or(42) + *chunk_idx as u64 * 1000007; // Large prime offset
513        let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(chunk_seed);
514
515        for &row in row_chunk.iter() {
516            // Determine how many elements to generate for this row
517            let row_elements = cols;
518            let row_expected_nnz = std::cmp::max(1, (row_elements as f64 * density) as usize);
519
520            // Generate random column indices for this row
521            let mut col_indices: Vec<usize> = (0..cols).collect();
522            col_indices.shuffle(&mut rng);
523
524            // Take the first row_expected_nnz columns
525            for &col in col_indices.iter().take(row_expected_nnz) {
526                // Generate random value
527                let mut val = rng.random_range(-1.0..1.0);
528                // Make sure the value is not zero
529                while val.abs() < 1e-10 {
530                    val = rng.random_range(-1.0..1.0);
531                }
532
533                local_rows.push(row);
534                local_cols.push(col);
535                local_data.push(T::from(val).unwrap());
536            }
537        }
538
539        (local_rows, local_cols, local_data)
540    });
541
542    // Combine results from all chunks
543    let mut all_rows = Vec::new();
544    let mut all_cols = Vec::new();
545    let mut all_data = Vec::new();
546
547    for (mut rowschunk, mut cols_chunk, mut data_chunk) in results {
548        all_rows.extend(rowschunk);
549        all_cols.append(&mut cols_chunk);
550        all_data.append(&mut data_chunk);
551    }
552
553    // Create the output array
554    match format.to_lowercase().as_str() {
555        "csr" => CsrArray::from_triplets(&all_rows, &all_cols, &all_data, shape, false)
556            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
557        "coo" => CooArray::from_triplets(&all_rows, &all_cols, &all_data, shape, false)
558            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
559        "dok" => DokArray::from_triplets(&all_rows, &all_cols, &all_data, shape)
560            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
561        "lil" => LilArray::from_triplets(&all_rows, &all_cols, &all_data, shape)
562            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
563        _ => Err(SparseError::ValueError(format!(
564            "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
565        ))),
566    }
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572
573    #[test]
574    fn test_eye_array() {
575        let eye = eye_array::<f64>(3, "csr").unwrap();
576
577        assert_eq!(eye.shape(), (3, 3));
578        assert_eq!(eye.nnz(), 3);
579        assert_eq!(eye.get(0, 0), 1.0);
580        assert_eq!(eye.get(1, 1), 1.0);
581        assert_eq!(eye.get(2, 2), 1.0);
582        assert_eq!(eye.get(0, 1), 0.0);
583
584        // Try COO format
585        let eye_coo = eye_array::<f64>(3, "coo").unwrap();
586        assert_eq!(eye_coo.shape(), (3, 3));
587        assert_eq!(eye_coo.nnz(), 3);
588
589        // Try DOK format
590        let eye_dok = eye_array::<f64>(3, "dok").unwrap();
591        assert_eq!(eye_dok.shape(), (3, 3));
592        assert_eq!(eye_dok.nnz(), 3);
593        assert_eq!(eye_dok.get(0, 0), 1.0);
594        assert_eq!(eye_dok.get(1, 1), 1.0);
595        assert_eq!(eye_dok.get(2, 2), 1.0);
596
597        // Try LIL format
598        let eye_lil = eye_array::<f64>(3, "lil").unwrap();
599        assert_eq!(eye_lil.shape(), (3, 3));
600        assert_eq!(eye_lil.nnz(), 3);
601        assert_eq!(eye_lil.get(0, 0), 1.0);
602        assert_eq!(eye_lil.get(1, 1), 1.0);
603        assert_eq!(eye_lil.get(2, 2), 1.0);
604    }
605
606    #[test]
607    fn test_eye_array_k() {
608        // Identity with main diagonal (k=0)
609        let eye = eye_array_k::<f64>(3, 3, 0, "csr").unwrap();
610        assert_eq!(eye.get(0, 0), 1.0);
611        assert_eq!(eye.get(1, 1), 1.0);
612        assert_eq!(eye.get(2, 2), 1.0);
613
614        // Superdiagonal (k=1)
615        let superdiag = eye_array_k::<f64>(3, 4, 1, "csr").unwrap();
616        assert_eq!(superdiag.get(0, 1), 1.0);
617        assert_eq!(superdiag.get(1, 2), 1.0);
618        assert_eq!(superdiag.get(2, 3), 1.0);
619
620        // Subdiagonal (k=-1)
621        let subdiag = eye_array_k::<f64>(4, 3, -1, "csr").unwrap();
622        assert_eq!(subdiag.get(1, 0), 1.0);
623        assert_eq!(subdiag.get(2, 1), 1.0);
624        assert_eq!(subdiag.get(3, 2), 1.0);
625
626        // Try LIL format
627        let eye_lil = eye_array_k::<f64>(3, 3, 0, "lil").unwrap();
628        assert_eq!(eye_lil.get(0, 0), 1.0);
629        assert_eq!(eye_lil.get(1, 1), 1.0);
630        assert_eq!(eye_lil.get(2, 2), 1.0);
631    }
632
633    #[test]
634    fn test_diags_array() {
635        let diags = vec![
636            Array1::from_vec(vec![1.0, 2.0, 3.0]), // main diagonal
637            Array1::from_vec(vec![4.0, 5.0]),      // superdiagonal
638        ];
639        let offsets = vec![0, 1];
640        let shape = (3, 3);
641
642        let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
643        assert_eq!(result.shape(), (3, 3));
644        assert_eq!(result.get(0, 0), 1.0);
645        assert_eq!(result.get(1, 1), 2.0);
646        assert_eq!(result.get(2, 2), 3.0);
647        assert_eq!(result.get(0, 1), 4.0);
648        assert_eq!(result.get(1, 2), 5.0);
649
650        // Try with multiple diagonals and subdiagonals
651        let diags = vec![
652            Array1::from_vec(vec![1.0, 2.0, 3.0]), // main diagonal
653            Array1::from_vec(vec![4.0, 5.0]),      // superdiagonal
654            Array1::from_vec(vec![6.0, 7.0]),      // subdiagonal
655        ];
656        let offsets = vec![0, 1, -1];
657        let shape = (3, 3);
658
659        let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
660        assert_eq!(result.shape(), (3, 3));
661        assert_eq!(result.get(0, 0), 1.0);
662        assert_eq!(result.get(1, 1), 2.0);
663        assert_eq!(result.get(2, 2), 3.0);
664        assert_eq!(result.get(0, 1), 4.0);
665        assert_eq!(result.get(1, 2), 5.0);
666        assert_eq!(result.get(1, 0), 6.0);
667        assert_eq!(result.get(2, 1), 7.0);
668
669        // Try LIL format
670        let result_lil = diags_array(&diags, &offsets, shape, "lil").unwrap();
671        assert_eq!(result_lil.shape(), (3, 3));
672        assert_eq!(result_lil.get(0, 0), 1.0);
673        assert_eq!(result_lil.get(1, 1), 2.0);
674        assert_eq!(result_lil.get(2, 2), 3.0);
675        assert_eq!(result_lil.get(0, 1), 4.0);
676        assert_eq!(result_lil.get(1, 2), 5.0);
677        assert_eq!(result_lil.get(1, 0), 6.0);
678        assert_eq!(result_lil.get(2, 1), 7.0);
679    }
680
681    #[test]
682    fn test_random_array() {
683        let shape = (10, 10);
684        let density = 0.3;
685
686        let random = random_array::<f64>(shape, density, None, "csr").unwrap();
687
688        // Check shape and sparsity
689        assert_eq!(random.shape(), shape);
690        let nnz = random.nnz();
691        let expected_nnz = (shape.0 * shape.1) as f64 * density;
692
693        // Allow for some random variation, but should be close to expected density
694        assert!(
695            (nnz as f64) > expected_nnz * 0.7,
696            "Too few non-zeros: {nnz}"
697        );
698        assert!(
699            (nnz as f64) < expected_nnz * 1.3,
700            "Too many non-zeros: {nnz}"
701        );
702
703        // Test with custom RNG seed
704        let random_seeded = random_array::<f64>(shape, density, Some(42), "csr").unwrap();
705        assert_eq!(random_seeded.shape(), shape);
706
707        // Test LIL format
708        let random_lil = random_array::<f64>((5, 5), 0.5, Some(42), "lil").unwrap();
709        assert_eq!(random_lil.shape(), (5, 5));
710        let nnz_lil = random_lil.nnz();
711        let expected_nnz_lil = 25.0 * 0.5;
712        assert!(
713            (nnz_lil as f64) > expected_nnz_lil * 0.7,
714            "Too few non-zeros in LIL: {nnz_lil}"
715        );
716        assert!(
717            (nnz_lil as f64) < expected_nnz_lil * 1.3,
718            "Too many non-zeros in LIL: {nnz_lil}"
719        );
720    }
721}