Skip to main content

sparse_ir_capi/
utils.rs

1//! Utility functions for C API
2//!
3//! This module provides helper functions for order conversion and dimension handling.
4
5#[allow(unused_imports)] // Used in test code
6use crate::{
7    SPIR_COMPUTATION_SUCCESS, SPIR_INTERNAL_ERROR, SPIR_ORDER_COLUMN_MAJOR, SPIR_ORDER_ROW_MAJOR,
8    SPIR_TWORK_FLOAT64, SPIR_TWORK_FLOAT64X2,
9};
10#[allow(unused_imports)]
11use mdarray::Shape;
12use sparse_ir::numeric::CustomNumeric; // Used in test code for with_dims
13
14/// Check if SPARSEIR_DEBUG environment variable is set
15///
16/// Returns true if SPARSEIR_DEBUG is set to any non-empty value.
17pub fn is_debug_enabled() -> bool {
18    std::env::var("SPARSEIR_DEBUG").is_ok()
19}
20
21/// Memory layout order
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum MemoryOrder {
24    RowMajor,    // Rightmost dimension varies fastest (C, Python)
25    ColumnMajor, // Leftmost dimension varies fastest (Fortran, Julia, MATLAB)
26}
27
28impl MemoryOrder {
29    /// Convert from C int to MemoryOrder
30    pub fn from_c_int(order: libc::c_int) -> Result<Self, ()> {
31        match order {
32            SPIR_ORDER_ROW_MAJOR => Ok(Self::RowMajor),
33            SPIR_ORDER_COLUMN_MAJOR => Ok(Self::ColumnMajor),
34            _ => Err(()),
35        }
36    }
37}
38
39/// Convert dimensions and target_dim for row-major mdarray
40///
41/// mdarray uses row-major (C order) by default. When the C-API caller
42/// specifies column-major (Fortran/Julia order), we need to reverse
43/// dimensions and adjust target_dim to match mdarray's row-major layout.
44///
45/// This follows libsparseir's pattern for order handling.
46///
47/// # Arguments
48/// * `dims` - Original dimensions from C-API
49/// * `target_dim` - Original target dimension from C-API
50/// * `order` - Memory order specified by caller
51///
52/// # Returns
53/// (mdarray_dims, mdarray_target_dim) - Dimensions and target_dim for row-major mdarray
54///
55/// # Example
56/// ```text
57/// // Julia: dims=[5, 3], target_dim=0, order=COLUMN_MAJOR
58/// convert_dims_for_row_major(&[5, 3], 0, MemoryOrder::ColumnMajor)
59/// → ([3, 5], 1)  // For row-major mdarray
60/// ```
61pub fn convert_dims_for_row_major(
62    dims: &[usize],
63    target_dim: usize,
64    order: MemoryOrder,
65) -> (Vec<usize>, usize) {
66    match order {
67        MemoryOrder::RowMajor => {
68            // Already row-major, use as-is
69            (dims.to_vec(), target_dim)
70        }
71        MemoryOrder::ColumnMajor => {
72            // Convert column-major to row-major:
73            // Reverse dims and flip target_dim
74            let mut rev_dims = dims.to_vec();
75            rev_dims.reverse();
76            let rev_target_dim = dims.len() - 1 - target_dim;
77            (rev_dims, rev_target_dim)
78        }
79    }
80}
81
82/// Read N-dimensional tensor from raw pointer (row-major layout)
83///
84/// Reads a tensor from a raw pointer assuming row-major (C order) memory layout.
85/// The buffer is interpreted as a flat array and reshaped according to `dims`.
86///
87/// # Arguments
88/// * `ptr` - Raw pointer to the data buffer
89/// * `dims` - Dimensions of the tensor (e.g., `[num_points, basis_size]`)
90///
91/// # Returns
92/// A `Tensor<T, DynRank>` with the specified dimensions
93///
94/// # Safety
95/// Caller must ensure `ptr` is valid and points to at least `product(dims)` elements.
96pub(crate) unsafe fn _read_tensor_nd_row_major<T: Copy>(
97    ptr: *const T,
98    dims: &[usize],
99) -> sparse_ir::Tensor<T, sparse_ir::DynRank> {
100    assert!(!dims.is_empty(), "dims must not be empty");
101    let total: usize = dims.iter().product();
102
103    // Read buffer as slice
104    let slice = unsafe { std::slice::from_raw_parts(ptr, total) };
105    let data: Vec<T> = slice.to_vec();
106
107    // Create 1D tensor and reshape to specified dimensions
108    let flat = sparse_ir::Tensor::<T, (usize,)>::from(data);
109    flat.into_dyn().reshape(dims).to_tensor()
110}
111
112/// Read N-dimensional tensor from raw pointer (column-major layout)
113///
114/// Reads a tensor from a raw pointer assuming column-major (Fortran/Julia order) memory layout.
115/// The buffer is interpreted as a flat array with reversed dimensions, then permuted
116/// to restore the original axis order.
117///
118/// # Arguments
119/// * `ptr` - Raw pointer to the data buffer
120/// * `dims` - Dimensions of the tensor (e.g., `[num_points, basis_size]`)
121///
122/// # Returns
123/// A `Tensor<T, DynRank>` with the specified dimensions and correct axis order
124///
125/// # Safety
126/// Caller must ensure `ptr` is valid and points to at least `product(dims)` elements.
127pub(crate) unsafe fn _read_tensor_nd_column_major<T: Copy>(
128    ptr: *const T,
129    dims: &[usize],
130) -> sparse_ir::Tensor<T, sparse_ir::DynRank> {
131    assert!(!dims.is_empty(), "dims must not be empty");
132
133    // 1. Reverse dimensions to read as row-major
134    let mut rev_dims = dims.to_vec();
135    rev_dims.reverse();
136    let tmp = unsafe { _read_tensor_nd_row_major(ptr, &rev_dims) };
137
138    // 2. Permute axes to restore original order
139    // For example: if dims=[5, 3], we read as [3, 5] (row-major),
140    // then permute [0, 1] -> [1, 0] to get back [5, 3]
141    let rank = dims.len();
142    let perm: Vec<usize> = (0..rank).rev().collect();
143
144    // Tensor implements Borrow<Slice>, so &tmp can be used as &Slice
145    use mdarray::Slice;
146    (&tmp as &Slice<T, sparse_ir::DynRank>)
147        .permute(&perm[..])
148        .to_tensor()
149}
150
151/// Read N-dimensional tensor from raw pointer
152///
153/// Reads a tensor from a raw pointer, handling both row-major and column-major memory layouts.
154/// This is a convenience wrapper that dispatches to the appropriate internal function
155/// based on the memory order.
156///
157/// # Arguments
158/// * `ptr` - Raw pointer to the data buffer
159/// * `dims` - Dimensions of the tensor (e.g., `[num_points, basis_size]`)
160/// * `order` - Memory layout order (RowMajor or ColumnMajor)
161///
162/// # Returns
163/// A `Tensor<T, DynRank>` with the specified dimensions
164///
165/// # Safety
166/// Caller must ensure `ptr` is valid and points to at least `product(dims)` elements.
167pub(crate) unsafe fn read_tensor_nd<T: Copy>(
168    ptr: *const T,
169    dims: &[usize],
170    order: MemoryOrder,
171) -> sparse_ir::Tensor<T, sparse_ir::DynRank> {
172    match order {
173        MemoryOrder::RowMajor => unsafe { _read_tensor_nd_row_major(ptr, dims) },
174        MemoryOrder::ColumnMajor => unsafe { _read_tensor_nd_column_major(ptr, dims) },
175    }
176}
177
178/// Copy N-dimensional tensor to C array
179///
180/// Flattens the tensor and copies all elements to the output pointer.
181/// For column-major order, the tensor dimensions are permuted before flattening
182/// to match the expected memory layout.
183///
184/// # Arguments
185/// * `tensor` - Source tensor (any rank)
186/// * `out` - Destination C array pointer
187/// * `order` - Memory layout order for output (RowMajor or ColumnMajor)
188///
189/// # Safety
190/// Caller must ensure `out` has space for `tensor.len()` elements
191pub(crate) unsafe fn copy_tensor_to_c_array<T: Copy>(
192    tensor: sparse_ir::Tensor<T, sparse_ir::DynRank>,
193    out: *mut T,
194    order: MemoryOrder,
195) {
196    let total = tensor.len();
197
198    // For column-major, permute dimensions to reverse order before flattening
199    let flat = match order {
200        MemoryOrder::RowMajor => {
201            // Row-major: flatten directly
202            tensor.into_dyn().reshape(&[total]).to_tensor()
203        }
204        MemoryOrder::ColumnMajor => {
205            // Column-major: permute dimensions to reverse order, then flatten
206            // This is the inverse of read_tensor_nd_column_major
207            use mdarray::Slice;
208            let rank = tensor.rank();
209            let perm: Vec<usize> = (0..rank).rev().collect();
210            let permuted = (&tensor as &Slice<T, sparse_ir::DynRank>)
211                .permute(&perm[..])
212                .to_tensor();
213            permuted.into_dyn().reshape(&[total]).to_tensor()
214        }
215    };
216
217    for i in 0..total {
218        unsafe {
219            *out.add(i) = flat[i];
220        }
221    }
222}
223
224/// Build output dimensions by replacing target_dim with new_size
225pub(crate) fn build_output_dims(
226    input_dims: &[usize],
227    target_dim: usize,
228    new_size: usize,
229) -> Vec<usize> {
230    let mut out_dims = input_dims.to_vec();
231    out_dims[target_dim] = new_size;
232    out_dims
233}
234
235/// Create a DView (immutable) from raw pointer with DynRank dimensions
236///
237/// Zero-copy: directly interprets the buffer as a tensor with the given dimensions.
238/// For column-major data, pass reversed dimensions (via convert_dims_for_row_major).
239///
240/// # Safety
241/// - `ptr` must be valid and point to at least `product(dims)` elements
242/// - The memory must remain valid for the lifetime of the returned view
243pub(crate) unsafe fn create_dview_from_ptr<'a, T>(
244    ptr: *const T,
245    dims: &[usize],
246) -> mdarray::View<'a, T, sparse_ir::DynRank, mdarray::Dense> {
247    use mdarray::Shape;
248    let shape = sparse_ir::DynRank::from_dims(dims);
249    let mapping = mdarray::DenseMapping::new(shape);
250    unsafe { mdarray::View::new_unchecked(ptr, mapping) }
251}
252
253/// Create a DViewMut (mutable) from raw pointer with DynRank dimensions
254///
255/// Zero-copy: directly interprets the buffer as a mutable tensor with the given dimensions.
256/// For column-major data, pass reversed dimensions (via convert_dims_for_row_major).
257///
258/// # Safety
259/// - `ptr` must be valid and point to at least `product(dims)` elements
260/// - The memory must remain valid for the lifetime of the returned view
261/// - The caller must ensure no aliasing occurs
262pub(crate) unsafe fn create_dviewmut_from_ptr<'a, T>(
263    ptr: *mut T,
264    dims: &[usize],
265) -> mdarray::ViewMut<'a, T, sparse_ir::DynRank> {
266    use mdarray::Shape;
267    let shape = sparse_ir::DynRank::from_dims(dims);
268    let mapping = mdarray::DenseMapping::new(shape);
269    unsafe { mdarray::ViewMut::new_unchecked(ptr, mapping) }
270}
271
272/// Choose the working type (Twork) based on epsilon value
273///
274/// This function determines the appropriate working precision type based on the
275/// target accuracy epsilon. It follows the same logic as SPIR_TWORK_AUTO:
276/// - Returns SPIR_TWORK_FLOAT64X2 if epsilon < 1e-8 or epsilon is NaN
277/// - Returns SPIR_TWORK_FLOAT64 otherwise
278///
279/// # Arguments
280/// * `epsilon` - Target accuracy (must be non-negative, or NaN for auto-selection)
281///
282/// # Returns
283/// Working type constant:
284/// - SPIR_TWORK_FLOAT64 (0): Use double precision (64-bit)
285/// - SPIR_TWORK_FLOAT64X2 (1): Use extended precision (128-bit)
286#[unsafe(no_mangle)]
287pub extern "C" fn spir_choose_working_type(epsilon: f64) -> libc::c_int {
288    if epsilon.is_nan() || epsilon < 1e-8 {
289        SPIR_TWORK_FLOAT64X2
290    } else {
291        SPIR_TWORK_FLOAT64
292    }
293}
294
295/// Compute piecewise Gauss-Legendre quadrature rule (double precision)
296///
297/// Generates a piecewise Gauss-Legendre quadrature rule with n points per segment.
298/// The rule is concatenated across all segments, with points and weights properly
299/// scaled for each segment interval.
300///
301/// # Arguments
302/// * `n` - Number of Gauss points per segment (must be >= 1)
303/// * `segments` - Array of segment boundaries (n_segments + 1 elements).
304///                Must be monotonically increasing.
305/// * `n_segments` - Number of segments (must be >= 1)
306/// * `x` - Output array for Gauss points (size n * n_segments). Must be pre-allocated.
307/// * `w` - Output array for Gauss weights (size n * n_segments). Must be pre-allocated.
308/// * `status` - Pointer to store the status code
309///
310/// # Returns
311/// Status code:
312/// - SPIR_COMPUTATION_SUCCESS (0) on success
313/// - Non-zero error code on failure
314#[unsafe(no_mangle)]
315pub extern "C" fn spir_gauss_legendre_rule_piecewise_double(
316    n: libc::c_int,
317    segments: *const f64,
318    n_segments: libc::c_int,
319    x: *mut f64,
320    w: *mut f64,
321    status: *mut crate::StatusCode,
322) -> crate::StatusCode {
323    use crate::{SPIR_COMPUTATION_SUCCESS, SPIR_INTERNAL_ERROR, SPIR_INVALID_ARGUMENT};
324    use sparse_ir::legendre;
325    use std::panic::catch_unwind;
326
327    if status.is_null() {
328        return SPIR_INVALID_ARGUMENT;
329    }
330
331    if segments.is_null() || x.is_null() || w.is_null() {
332        unsafe {
333            *status = SPIR_INVALID_ARGUMENT;
334        }
335        return SPIR_INVALID_ARGUMENT;
336    }
337
338    if n < 1 || n_segments < 1 {
339        unsafe {
340            *status = SPIR_INVALID_ARGUMENT;
341        }
342        return SPIR_INVALID_ARGUMENT;
343    }
344
345    let result = catch_unwind(|| {
346        // Convert segments to Vec
347        let segments_slice =
348            unsafe { std::slice::from_raw_parts(segments, (n_segments + 1) as usize) };
349        let segs_vec = segments_slice.to_vec();
350
351        // Verify segments are monotonically increasing
352        for i in 1..segs_vec.len() {
353            if segs_vec[i] <= segs_vec[i - 1] {
354                unsafe {
355                    *status = SPIR_INVALID_ARGUMENT;
356                }
357                return SPIR_INVALID_ARGUMENT;
358            }
359        }
360
361        // Generate base rule with DDouble precision, then convert to double
362        let rule_dd = legendre::<sparse_ir::Df64>(n as usize);
363        let rule = sparse_ir::gauss::Rule::from_vectors(
364            rule_dd.x.iter().map(|&x| x.to_f64()).collect(),
365            rule_dd.w.iter().map(|&w| w.to_f64()).collect(),
366            rule_dd.a.to_f64(),
367            rule_dd.b.to_f64(),
368        );
369
370        // Create piecewise rule
371        let piecewise_rule = rule.piecewise(&segs_vec);
372
373        // Copy to output arrays
374        for i in 0..piecewise_rule.x.len() {
375            unsafe {
376                *x.add(i) = piecewise_rule.x[i];
377                *w.add(i) = piecewise_rule.w[i];
378            }
379        }
380
381        unsafe {
382            *status = SPIR_COMPUTATION_SUCCESS;
383        }
384        SPIR_COMPUTATION_SUCCESS
385    });
386
387    result.unwrap_or_else(|_| {
388        unsafe {
389            *status = SPIR_INTERNAL_ERROR;
390        }
391        SPIR_INTERNAL_ERROR
392    })
393}
394
395/// Compute piecewise Gauss-Legendre quadrature rule (DDouble precision)
396///
397/// Generates a piecewise Gauss-Legendre quadrature rule with n points per segment,
398/// computed using extended precision (DDouble). Returns high and low parts separately
399/// for maximum precision.
400///
401/// # Arguments
402/// * `n` - Number of Gauss points per segment (must be >= 1)
403/// * `segments` - Array of segment boundaries (n_segments + 1 elements).
404///                Must be monotonically increasing.
405/// * `n_segments` - Number of segments (must be >= 1)
406/// * `x_high` - Output array for high part of Gauss points (size n * n_segments).
407///              Must be pre-allocated.
408/// * `x_low` - Output array for low part of Gauss points (size n * n_segments).
409///             Must be pre-allocated.
410/// * `w_high` - Output array for high part of Gauss weights (size n * n_segments).
411///              Must be pre-allocated.
412/// * `w_low` - Output array for low part of Gauss weights (size n * n_segments).
413///            Must be pre-allocated.
414/// * `status` - Pointer to store the status code
415///
416/// # Returns
417/// Status code:
418/// - SPIR_COMPUTATION_SUCCESS (0) on success
419/// - Non-zero error code on failure
420#[unsafe(no_mangle)]
421pub extern "C" fn spir_gauss_legendre_rule_piecewise_ddouble(
422    n: libc::c_int,
423    segments: *const f64,
424    n_segments: libc::c_int,
425    x_high: *mut f64,
426    x_low: *mut f64,
427    w_high: *mut f64,
428    w_low: *mut f64,
429    status: *mut crate::StatusCode,
430) -> crate::StatusCode {
431    use crate::{SPIR_COMPUTATION_SUCCESS, SPIR_INTERNAL_ERROR, SPIR_INVALID_ARGUMENT};
432    use sparse_ir::legendre;
433    use std::panic::catch_unwind;
434
435    if status.is_null() {
436        return SPIR_INVALID_ARGUMENT;
437    }
438
439    if segments.is_null()
440        || x_high.is_null()
441        || x_low.is_null()
442        || w_high.is_null()
443        || w_low.is_null()
444    {
445        unsafe {
446            *status = SPIR_INVALID_ARGUMENT;
447        }
448        return SPIR_INVALID_ARGUMENT;
449    }
450
451    if n < 1 || n_segments < 1 {
452        unsafe {
453            *status = SPIR_INVALID_ARGUMENT;
454        }
455        return SPIR_INVALID_ARGUMENT;
456    }
457
458    let result = catch_unwind(|| {
459        // Convert segments to Vec
460        let segments_slice =
461            unsafe { std::slice::from_raw_parts(segments, (n_segments + 1) as usize) };
462        let segs_vec: Vec<sparse_ir::Df64> = segments_slice
463            .iter()
464            .map(|&x| sparse_ir::Df64::new(x))
465            .collect();
466
467        // Verify segments are monotonically increasing
468        for i in 1..segs_vec.len() {
469            if segs_vec[i] <= segs_vec[i - 1] {
470                unsafe {
471                    *status = SPIR_INVALID_ARGUMENT;
472                }
473                return SPIR_INVALID_ARGUMENT;
474            }
475        }
476
477        // Generate base rule with DDouble precision
478        let rule_dd = legendre::<sparse_ir::Df64>(n as usize);
479
480        // Create piecewise rule
481        let piecewise_rule = rule_dd.piecewise(&segs_vec);
482
483        // Extract high and low parts
484        for i in 0..piecewise_rule.x.len() {
485            unsafe {
486                *x_high.add(i) = piecewise_rule.x[i].hi();
487                *x_low.add(i) = piecewise_rule.x[i].lo();
488                *w_high.add(i) = piecewise_rule.w[i].hi();
489                *w_low.add(i) = piecewise_rule.w[i].lo();
490            }
491        }
492
493        unsafe {
494            *status = SPIR_COMPUTATION_SUCCESS;
495        }
496        SPIR_COMPUTATION_SUCCESS
497    });
498
499    result.unwrap_or_else(|_| {
500        unsafe {
501            *status = SPIR_INTERNAL_ERROR;
502        }
503        SPIR_INTERNAL_ERROR
504    })
505}
506
507#[cfg(test)]
508mod tests {
509    use super::*;
510
511    #[test]
512    fn test_memory_order_conversion() {
513        assert_eq!(
514            MemoryOrder::from_c_int(SPIR_ORDER_ROW_MAJOR),
515            Ok(MemoryOrder::RowMajor)
516        );
517        assert_eq!(
518            MemoryOrder::from_c_int(SPIR_ORDER_COLUMN_MAJOR),
519            Ok(MemoryOrder::ColumnMajor)
520        );
521        assert_eq!(MemoryOrder::from_c_int(99), Err(()));
522    }
523
524    #[test]
525    fn test_choose_working_type() {
526        // Test with epsilon >= 1e-8 -> should return FLOAT64
527        {
528            let twork = spir_choose_working_type(1e-6);
529            assert_eq!(twork, SPIR_TWORK_FLOAT64);
530        }
531
532        {
533            let twork = spir_choose_working_type(1e-8);
534            assert_eq!(twork, SPIR_TWORK_FLOAT64);
535        }
536
537        // Test with epsilon < 1e-8 -> should return FLOAT64X2
538        {
539            let twork = spir_choose_working_type(1e-10);
540            assert_eq!(twork, SPIR_TWORK_FLOAT64X2);
541        }
542
543        {
544            let twork = spir_choose_working_type(1e-15);
545            assert_eq!(twork, SPIR_TWORK_FLOAT64X2);
546        }
547
548        // Test with NaN -> should return FLOAT64X2
549        {
550            let twork = spir_choose_working_type(f64::NAN);
551            assert_eq!(twork, SPIR_TWORK_FLOAT64X2);
552        }
553
554        // Test boundary case: epsilon = 1e-8 exactly
555        {
556            let twork = spir_choose_working_type(1e-8);
557            assert_eq!(twork, SPIR_TWORK_FLOAT64);
558        }
559
560        // Test boundary case: epsilon just below 1e-8
561        {
562            let twork = spir_choose_working_type(0.99e-8);
563            assert_eq!(twork, SPIR_TWORK_FLOAT64X2);
564        }
565    }
566
567    #[test]
568    fn test_gauss_legendre_rule_piecewise_double() {
569        // Test with single segment [-1, 1]
570        {
571            let n = 5;
572            let segments = [-1.0, 1.0];
573            let n_segments = 1;
574            let mut x = vec![0.0; n as usize];
575            let mut w = vec![0.0; n as usize];
576            let mut status = SPIR_INTERNAL_ERROR;
577
578            let result = spir_gauss_legendre_rule_piecewise_double(
579                n,
580                segments.as_ptr(),
581                n_segments,
582                x.as_mut_ptr(),
583                w.as_mut_ptr(),
584                &mut status,
585            );
586            assert_eq!(result, SPIR_COMPUTATION_SUCCESS);
587            assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
588
589            // Verify we got n points
590            // Points should be in [-1, 1] and sorted
591            assert!(x[0] >= -1.0);
592            assert!(x[(n - 1) as usize] <= 1.0);
593            for i in 1..(n as usize) {
594                assert!(x[i] > x[i - 1]);
595            }
596
597            // Weights should be positive
598            for i in 0..(n as usize) {
599                assert!(w[i] > 0.0);
600            }
601
602            // Weight sum should be approximately 2.0 (integral over [-1, 1] is 2)
603            let weight_sum: f64 = w.iter().sum();
604            assert!((weight_sum - 2.0).abs() < 1e-10);
605        }
606
607        // Test with two segments [-1, 0, 1]
608        {
609            let n = 3;
610            let segments = [-1.0, 0.0, 1.0];
611            let n_segments = 2;
612            let mut x = vec![0.0; (n * n_segments) as usize];
613            let mut w = vec![0.0; (n * n_segments) as usize];
614            let mut status = SPIR_INTERNAL_ERROR;
615
616            let result = spir_gauss_legendre_rule_piecewise_double(
617                n,
618                segments.as_ptr(),
619                n_segments,
620                x.as_mut_ptr(),
621                w.as_mut_ptr(),
622                &mut status,
623            );
624            assert_eq!(result, SPIR_COMPUTATION_SUCCESS);
625            assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
626
627            // Verify we got n * n_segments points
628            // Points should be sorted across segments
629            assert!(x[0] >= -1.0);
630            assert!(x[5] <= 1.0);
631            for i in 1..6 {
632                assert!(x[i] > x[i - 1]);
633            }
634
635            // Weights should be positive
636            for i in 0..6 {
637                assert!(w[i] > 0.0);
638            }
639
640            // Weight sum should be approximately 2.0 (integral over [-1, 1])
641            let weight_sum: f64 = w.iter().sum();
642            assert!((weight_sum - 2.0).abs() < 1e-10);
643        }
644
645        // Test error handling
646        {
647            let mut status = SPIR_INTERNAL_ERROR;
648            let result = spir_gauss_legendre_rule_piecewise_double(
649                5,
650                std::ptr::null(),
651                1,
652                std::ptr::null_mut(),
653                std::ptr::null_mut(),
654                &mut status,
655            );
656            assert_ne!(result, SPIR_COMPUTATION_SUCCESS);
657        }
658
659        {
660            let segments = [-1.0, 1.0];
661            let mut x = vec![0.0; 5];
662            let mut w = vec![0.0; 5];
663            let mut status = SPIR_INTERNAL_ERROR;
664            let result = spir_gauss_legendre_rule_piecewise_double(
665                0,
666                segments.as_ptr(),
667                1,
668                x.as_mut_ptr(),
669                w.as_mut_ptr(),
670                &mut status,
671            );
672            assert_ne!(result, SPIR_COMPUTATION_SUCCESS);
673        }
674
675        {
676            let segments = [1.0, -1.0]; // Wrong order
677            let mut x = vec![0.0; 5];
678            let mut w = vec![0.0; 5];
679            let mut status = SPIR_INTERNAL_ERROR;
680            let result = spir_gauss_legendre_rule_piecewise_double(
681                5,
682                segments.as_ptr(),
683                1,
684                x.as_mut_ptr(),
685                w.as_mut_ptr(),
686                &mut status,
687            );
688            assert_ne!(result, SPIR_COMPUTATION_SUCCESS);
689        }
690    }
691
692    #[test]
693    fn test_read_tensor_nd_row_major() {
694        use num_complex::Complex64;
695
696        // Test 2D tensor: 3x4 matrix
697        {
698            // Create test data: row-major order
699            // [[1, 2, 3, 4],
700            //  [5, 6, 7, 8],
701            //  [9, 10, 11, 12]]
702            let data = vec![
703                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
704            ];
705            let tensor = unsafe { _read_tensor_nd_row_major(data.as_ptr(), &[3, 4]) };
706
707            let shape_dims = tensor.shape().with_dims(|dims| dims.to_vec());
708            assert_eq!(shape_dims, &[3, 4]);
709            assert_eq!(tensor[[0, 0]], 1.0);
710            assert_eq!(tensor[[0, 3]], 4.0);
711            assert_eq!(tensor[[1, 0]], 5.0);
712            assert_eq!(tensor[[2, 3]], 12.0);
713        }
714
715        // Test 3D tensor: 2x3x4
716        {
717            let data: Vec<f64> = (1..=24).map(|x| x as f64).collect();
718            let tensor = unsafe { _read_tensor_nd_row_major(data.as_ptr(), &[2, 3, 4]) };
719
720            let shape_dims = tensor.shape().with_dims(|dims| dims.to_vec());
721            assert_eq!(shape_dims, &[2, 3, 4]);
722            assert_eq!(tensor[[0, 0, 0]], 1.0);
723            assert_eq!(tensor[[0, 0, 3]], 4.0);
724            assert_eq!(tensor[[0, 1, 0]], 5.0);
725            assert_eq!(tensor[[1, 2, 3]], 24.0);
726        }
727
728        // Test complex numbers
729        {
730            let data = vec![
731                Complex64::new(1.0, 2.0),
732                Complex64::new(3.0, 4.0),
733                Complex64::new(5.0, 6.0),
734                Complex64::new(7.0, 8.0),
735            ];
736            let tensor = unsafe { _read_tensor_nd_row_major(data.as_ptr(), &[2, 2]) };
737
738            let shape_dims = tensor.shape().with_dims(|dims| dims.to_vec());
739            assert_eq!(shape_dims, &[2, 2]);
740            assert_eq!(tensor[[0, 0]], Complex64::new(1.0, 2.0));
741            assert_eq!(tensor[[1, 1]], Complex64::new(7.0, 8.0));
742        }
743    }
744
745    #[test]
746    fn test_read_tensor_nd_column_major() {
747        use num_complex::Complex64;
748
749        // Test 2D tensor: 3x4 matrix
750        // Column-major order means:
751        // [[1, 4, 7, 10],
752        //  [2, 5, 8, 11],
753        //  [3, 6, 9, 12]]
754        // But we want to read it as [3, 4] shape
755        {
756            // Create test data: column-major order
757            // First column: [1, 2, 3]
758            // Second column: [4, 5, 6]
759            // Third column: [7, 8, 9]
760            // Fourth column: [10, 11, 12]
761            let data = vec![
762                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
763            ];
764            let tensor = unsafe { _read_tensor_nd_column_major(data.as_ptr(), &[3, 4]) };
765
766            let shape_dims = tensor.shape().with_dims(|dims| dims.to_vec());
767            assert_eq!(shape_dims, &[3, 4]);
768            // After reading as [4, 3] (reversed) and permuting back, we should get:
769            // [[1, 4, 7, 10],
770            //  [2, 5, 8, 11],
771            //  [3, 6, 9, 12]]
772            assert_eq!(tensor[[0, 0]], 1.0);
773            assert_eq!(tensor[[0, 1]], 4.0);
774            assert_eq!(tensor[[0, 3]], 10.0);
775            assert_eq!(tensor[[1, 0]], 2.0);
776            assert_eq!(tensor[[2, 3]], 12.0);
777        }
778
779        // Test 3D tensor: 2x3x4
780        // Column-major: first all elements with index [0,0,0], [1,0,0], then [0,1,0], [1,1,0], etc.
781        {
782            // For 2x3x4, column-major order:
783            // [0,0,0]=1, [1,0,0]=2, [0,1,0]=3, [1,1,0]=4, [0,2,0]=5, [1,2,0]=6,
784            // [0,0,1]=7, [1,0,1]=8, ...
785            let data: Vec<f64> = (1..=24).map(|x| x as f64).collect();
786            let tensor = unsafe { _read_tensor_nd_column_major(data.as_ptr(), &[2, 3, 4]) };
787
788            let shape_dims = tensor.shape().with_dims(|dims| dims.to_vec());
789            assert_eq!(shape_dims, &[2, 3, 4]);
790            // Verify first few elements
791            assert_eq!(tensor[[0, 0, 0]], 1.0);
792            assert_eq!(tensor[[1, 0, 0]], 2.0);
793            assert_eq!(tensor[[0, 1, 0]], 3.0);
794        }
795
796        // Test complex numbers
797        {
798            // Column-major: [1+2i, 3+4i] in first column, [5+6i, 7+8i] in second column
799            let data = vec![
800                Complex64::new(1.0, 2.0),
801                Complex64::new(3.0, 4.0),
802                Complex64::new(5.0, 6.0),
803                Complex64::new(7.0, 8.0),
804            ];
805            let tensor = unsafe { _read_tensor_nd_column_major(data.as_ptr(), &[2, 2]) };
806
807            let shape_dims = tensor.shape().with_dims(|dims| dims.to_vec());
808            assert_eq!(shape_dims, &[2, 2]);
809            // After permute: [[1+2i, 5+6i], [3+4i, 7+8i]]
810            assert_eq!(tensor[[0, 0]], Complex64::new(1.0, 2.0));
811            assert_eq!(tensor[[1, 0]], Complex64::new(3.0, 4.0));
812            assert_eq!(tensor[[0, 1]], Complex64::new(5.0, 6.0));
813            assert_eq!(tensor[[1, 1]], Complex64::new(7.0, 8.0));
814        }
815    }
816
817    #[test]
818    fn test_read_tensor_nd_roundtrip() {
819        // Test that row-major and column-major produce consistent results
820        // when the data is transposed appropriately
821
822        // Create a 3x4 matrix in row-major
823        let row_major_data = vec![
824            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
825        ];
826        let row_tensor = unsafe { _read_tensor_nd_row_major(row_major_data.as_ptr(), &[3, 4]) };
827
828        // Create the same matrix in column-major (transposed storage)
829        // [[1, 4, 7, 10], [2, 5, 8, 11], [3, 6, 9, 12]] stored as:
830        // [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] (column-major)
831        let col_major_data = vec![
832            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
833        ];
834        let col_tensor = unsafe { _read_tensor_nd_column_major(col_major_data.as_ptr(), &[3, 4]) };
835
836        // They should have the same shape
837        let row_shape = row_tensor.shape().with_dims(|dims| dims.to_vec());
838        let col_shape = col_tensor.shape().with_dims(|dims| dims.to_vec());
839        assert_eq!(row_shape, col_shape);
840
841        // But different values (because storage order is different)
842        // row_tensor: [[1,2,3,4], [5,6,7,8], [9,10,11,12]]
843        // col_tensor: [[1,4,7,10], [2,5,8,11], [3,6,9,12]]
844        assert_eq!(row_tensor[[0, 0]], 1.0);
845        assert_eq!(col_tensor[[0, 0]], 1.0);
846        assert_eq!(row_tensor[[0, 1]], 2.0);
847        assert_eq!(col_tensor[[0, 1]], 4.0); // Different!
848    }
849
850    #[test]
851    fn test_gauss_legendre_rule_piecewise_ddouble() {
852        // Test with single segment [-1, 1]
853        {
854            let n = 5;
855            let segments = [-1.0, 1.0];
856            let n_segments = 1;
857            let mut x_high = vec![0.0; n as usize];
858            let mut x_low = vec![0.0; n as usize];
859            let mut w_high = vec![0.0; n as usize];
860            let mut w_low = vec![0.0; n as usize];
861            let mut status = SPIR_INTERNAL_ERROR;
862
863            let result = spir_gauss_legendre_rule_piecewise_ddouble(
864                n,
865                segments.as_ptr(),
866                n_segments,
867                x_high.as_mut_ptr(),
868                x_low.as_mut_ptr(),
869                w_high.as_mut_ptr(),
870                w_low.as_mut_ptr(),
871                &mut status,
872            );
873            assert_eq!(result, SPIR_COMPUTATION_SUCCESS);
874            assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
875
876            // Verify we got n points
877            // Points should be in [-1, 1] and sorted
878            let x0 = x_high[0] + x_low[0];
879            let x_last = x_high[(n - 1) as usize] + x_low[(n - 1) as usize];
880            assert!(x0 >= -1.0);
881            assert!(x_last <= 1.0);
882            for i in 1..(n as usize) {
883                let x_val = x_high[i] + x_low[i];
884                let x_prev = x_high[i - 1] + x_low[i - 1];
885                assert!(x_val > x_prev);
886            }
887
888            // Weights should be positive
889            let mut weight_sum = 0.0;
890            for i in 0..(n as usize) {
891                let w_val = w_high[i] + w_low[i];
892                assert!(w_val > 0.0);
893                weight_sum += w_val;
894            }
895            // Weight sum should be approximately 2.0 (integral over [-1, 1])
896            assert!((weight_sum - 2.0).abs() < 1e-10);
897        }
898
899        // Test with two segments [-1, 0, 1]
900        {
901            let n = 3;
902            let segments = [-1.0, 0.0, 1.0];
903            let n_segments = 2;
904            let mut x_high = vec![0.0; (n * n_segments) as usize];
905            let mut x_low = vec![0.0; (n * n_segments) as usize];
906            let mut w_high = vec![0.0; (n * n_segments) as usize];
907            let mut w_low = vec![0.0; (n * n_segments) as usize];
908            let mut status = SPIR_INTERNAL_ERROR;
909
910            let result = spir_gauss_legendre_rule_piecewise_ddouble(
911                n,
912                segments.as_ptr(),
913                n_segments,
914                x_high.as_mut_ptr(),
915                x_low.as_mut_ptr(),
916                w_high.as_mut_ptr(),
917                w_low.as_mut_ptr(),
918                &mut status,
919            );
920            assert_eq!(result, SPIR_COMPUTATION_SUCCESS);
921            assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
922
923            // Verify points are sorted
924            for i in 1..6 {
925                let x_val = x_high[i] + x_low[i];
926                let x_prev = x_high[i - 1] + x_low[i - 1];
927                assert!(x_val > x_prev);
928            }
929
930            // Weight sum should be approximately 2.0 (integral over [-1, 1])
931            let mut weight_sum = 0.0;
932            for i in 0..6 {
933                let w_val = w_high[i] + w_low[i];
934                weight_sum += w_val;
935            }
936            assert!((weight_sum - 2.0).abs() < 1e-10);
937        }
938
939        // Test error handling
940        {
941            let mut status = SPIR_INTERNAL_ERROR;
942            let result = spir_gauss_legendre_rule_piecewise_ddouble(
943                5,
944                std::ptr::null(),
945                1,
946                std::ptr::null_mut(),
947                std::ptr::null_mut(),
948                std::ptr::null_mut(),
949                std::ptr::null_mut(),
950                &mut status,
951            );
952            assert_ne!(result, SPIR_COMPUTATION_SUCCESS);
953        }
954    }
955}