Skip to main content

gam_gpu/
driver.rs

1//! Shared CUDA driver presence/loading helpers used by every cuBLAS / cuSPARSE
2//! / cuSOLVER routing module.
3//!
4//! The GPU path uses ONE context model: cudarc's device PRIMARY context
5//! (`cuDevicePrimaryCtxRetain`, bound in `device_runtime::cuda_context_for`).
6//! cuBLAS/cuSOLVER/cuSPARSE handles attach to that current context; there is no
7//! separate user `cuCtxCreate` context (its removal fixed the #1017
8//! NOT_INITIALIZED handle failures). This module keeps only the libcuda
9//! presence probes, byte-size/layout helpers, and the `check_cuda` status wrap.
10
11use libloading::Library;
12#[cfg(target_os = "linux")]
13use libloading::os::unix::{Library as UnixLibrary, RTLD_GLOBAL, RTLD_NOW};
14use ndarray::{Array2, ArrayBase, Data, Ix2};
15use std::borrow::Cow;
16use std::path::Path;
17#[cfg(target_os = "linux")]
18use std::path::PathBuf;
19use std::sync::OnceLock;
20
21use super::gpu_error::GpuError;
22
23pub type CuResult = i32;
24// NOTE (#1017): the `DriverApi` / `CudaWorkingState` / `DeviceAllocation` cluster
25// that lived here was REMOVED. It created a SEPARATE user CUDA context via
26// `cuCtxCreate` — distinct from cudarc's device PRIMARY context (cuDevicePrimaryCtxRetain)
27// that the live GPU path actually uses — which is the documented cause of the
28// cuBLAS/cuSOLVER NOT_INITIALIZED handle failures (handles bind to whichever
29// context is current). The cluster had ZERO consumers once the runtime routed
30// through `cuda_context_for` (the primary context) in `device_runtime.rs`, so it
31// was dead dual-context code. Keep ONE context model: the cudarc primary context.
32// Do not reintroduce `cuCtxCreate` for issuing work.
33
34#[inline]
35pub fn check_cuda(result: CuResult, name: &str) -> Result<(), GpuError> {
36    if result == 0 {
37        Ok(())
38    } else {
39        Err(GpuError::DriverCallFailed {
40            reason: format!("{name} failed with CUDA driver error {result}"),
41        })
42    }
43}
44
45/// Returns whether the platform loader can open a CUDA driver library.
46///
47/// This deliberately uses gam's own `libloading` probe rather than
48/// `cudarc::driver::sys::is_culib_present()`: cudarc 0.19's generated
49/// dynamic-loader helpers are exactly what emit the noisy
50/// `panic_no_lib_found` message when a CPU-only host lacks `libcuda`.
51/// Runtime availability checks need to stay completely outside cudarc until
52/// this function has established that the driver shared library exists.
53#[must_use]
54pub fn cuda_driver_library_present() -> bool {
55    load_library_names(&cuda_library_candidate_names()).is_ok()
56}
57
58fn load_library_names(candidates: &[String]) -> Result<Library, GpuError> {
59    for candidate in candidates {
60        // SAFETY: Library::new runs the library's loader initializer; we
61        // only pass CUDA driver candidates discovered from fixed NVIDIA
62        // driver directories or canonical libcuda sonames.
63        if let Ok(library) = unsafe { Library::new(candidate) } {
64            return Ok(library);
65        }
66    }
67    Err(GpuError::DriverLibraryUnavailable {
68        reason: format!("could not load any of: {}", candidates.join(", ")),
69    })
70}
71
72fn load_static_cuda_driver_library() -> Result<&'static Library, GpuError> {
73    static LIBRARY: OnceLock<Result<Library, GpuError>> = OnceLock::new();
74    LIBRARY
75        .get_or_init(|| load_library_names(&cuda_library_candidate_names()))
76        .as_ref()
77        .map_err(Clone::clone)
78}
79
80pub fn preload_cuda_driver() -> Result<(), String> {
81    static PRELOAD: OnceLock<Result<(), String>> = OnceLock::new();
82    PRELOAD
83        .get_or_init(|| {
84            load_static_cuda_driver_library()
85                .map(|_| ())
86                .map_err(|err| err.to_string())
87        })
88        .clone()
89}
90
91#[cfg(target_os = "linux")]
92fn preload_cuda_userspace_libraries() -> Result<(), String> {
93    static PRELOAD: OnceLock<Result<Vec<UnixLibrary>, String>> = OnceLock::new();
94    PRELOAD
95        .get_or_init(|| {
96            let paths = cuda_userspace_preload_paths();
97            if paths.is_empty() {
98                return Ok(Vec::new());
99            }
100            let mut loaded = Vec::new();
101            for path in paths {
102                // SAFETY: these candidates are CUDA userspace libraries found
103                // in canonical toolkit directories or pip's nvidia-*-cu12
104                // wheel layout. RTLD_GLOBAL is required so transitive deps
105                // such as libcusolver -> libnvJitLink resolve without an
106                // LD_LIBRARY_PATH mutation.
107                match unsafe { UnixLibrary::open(Some(&path), RTLD_NOW | RTLD_GLOBAL) } {
108                    Ok(library) => loaded.push(library),
109                    Err(err) => {
110                        return Err(format!(
111                            "could not preload CUDA userspace library {}: {err}",
112                            path.display()
113                        ));
114                    }
115                }
116            }
117            Ok(loaded)
118        })
119        .as_ref()
120        .map(|_| ())
121        .map_err(Clone::clone)
122}
123
124/// Returns whether the platform loader can open the named CUDA compute
125/// library (`cublas`, `cusolver`, `cusparse`).
126///
127/// cudarc 0.19 attempts to lazy-load these via its own generated
128/// `panic_no_lib_found` helpers the first time `CudaBlas::new` /
129/// `DnHandle::new` / cuSPARSE handle creation is invoked. On a host that
130/// has only the CUDA *driver* (e.g. large-scale workbench images expose
131/// `libcuda.so.1` but no cuBLAS at all), those calls panic out of the
132/// PyO3 FFI boundary instead of returning a typed error.
133///
134/// `GpuRuntime::probe()` calls this for every compute library it depends
135/// on; failure to load any of them downgrades the runtime to CPU with a
136/// `DriverLibraryUnavailable { reason: "lib<name> unavailable" }`, which
137/// keeps the panic completely off the call path.
138#[must_use]
139pub fn cuda_compute_library_present(stem: &str) -> bool {
140    #[cfg(target_os = "linux")]
141    {
142        if preload_cuda_userspace_libraries().is_err() {
143            return false;
144        }
145    }
146    // Cache the probe per stem and KEEP the loaded handle alive for the process
147    // lifetime. Dropping the `Library` here dlclose's it; that dlopen+dlclose
148    // cycle tears down the compute library's global init state, after which
149    // cudarc's own cublasCreate / cusolverDnCreate fail
150    // CUBLAS/CUSOLVER_STATUS_NOT_INITIALIZED on the next handle creation (the GPU
151    // then silently declines and falls back to CPU). Holding the handle keeps the
152    // library mapped and initialized so cudarc reuses it intact.
153    static PROBED: OnceLock<std::sync::Mutex<std::collections::HashMap<String, bool>>> =
154        OnceLock::new();
155    static KEEP_ALIVE: OnceLock<std::sync::Mutex<Vec<Library>>> = OnceLock::new();
156    let probed = PROBED.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new()));
157    if let Ok(cache) = probed.lock() {
158        if let Some(&present) = cache.get(stem) {
159            return present;
160        }
161    }
162    let present = match load_library_names(&cuda_compute_library_candidate_names(stem)) {
163        Ok(library) => {
164            if let Ok(mut keep) = KEEP_ALIVE
165                .get_or_init(|| std::sync::Mutex::new(Vec::new()))
166                .lock()
167            {
168                keep.push(library);
169            }
170            true
171        }
172        Err(_) => false,
173    };
174    if let Ok(mut cache) = probed.lock() {
175        cache.insert(stem.to_string(), present);
176    }
177    present
178}
179
180#[cfg(target_os = "linux")]
181fn cuda_userspace_preload_paths() -> Vec<PathBuf> {
182    let system_dirs = cuda_system_library_dirs();
183    for dir in &system_dirs {
184        if let Some(stack) = complete_system_cuda_stack(dir) {
185            return dedup_paths(stack);
186        }
187        if let Some(stack) = system_cuda_stack_with_packaged_nvjitlink(dir) {
188            return dedup_paths(stack);
189        }
190    }
191    for root in nvidia_package_roots() {
192        if let Some(stack) = complete_nvidia_cuda_stack(&root) {
193            return dedup_paths(stack);
194        }
195    }
196    Vec::new()
197}
198
199fn cuda_compute_library_candidate_names(stem: &str) -> Vec<String> {
200    let base = format!("lib{stem}");
201    let mut out: Vec<String> = Vec::new();
202    // Bare soname forms — exercised by the platform loader against
203    // LD_LIBRARY_PATH and the default search dirs.
204    out.push(format!("{base}.so"));
205    out.push(format!("{base}.so.1"));
206    // Major-version walk mirroring cudarc's own candidate list so the
207    // preflight agrees with whatever cudarc would have tried next.
208    for major in (9..=13).rev() {
209        out.push(format!("{base}.so.{major}"));
210    }
211    #[cfg(target_os = "linux")]
212    {
213        for dir in cuda_system_library_dirs() {
214            out.push(format!("{dir}/{base}.so"));
215            for major in (9..=13).rev() {
216                out.push(format!("{dir}/{base}.so.{major}"));
217            }
218            append_versioned_linux_so_candidates(&mut out, Path::new(dir), &base);
219        }
220        for root in nvidia_package_roots() {
221            let lib_dir = root.join(nvidia_component_for_stem(stem)).join("lib");
222            out.push(format!("{}/{}.so", lib_dir.display(), base));
223            for major in (9..=13).rev() {
224                out.push(format!("{}/{}.so.{major}", lib_dir.display(), base));
225            }
226            append_versioned_linux_so_candidates(&mut out, &lib_dir, &base);
227        }
228    }
229    out
230}
231
232#[cfg(target_os = "linux")]
233fn cuda_system_library_dirs() -> Vec<&'static str> {
234    vec![
235        "/usr/local/cuda/lib64",
236        "/usr/local/cuda/lib",
237        "/usr/local/cuda/targets/x86_64-linux/lib",
238        "/usr/lib/x86_64-linux-gnu",
239        "/usr/lib64",
240        "/usr/lib/wsl/lib",
241        "/opt/cuda/lib64",
242    ]
243}
244
245#[cfg(target_os = "linux")]
246fn complete_system_cuda_stack(dir: &str) -> Option<Vec<PathBuf>> {
247    let dir = Path::new(dir);
248    let stack = vec![
249        first_existing(dir, &["libcudart.so.13", "libcudart.so.12", "libcudart.so"])?,
250        first_existing(
251            dir,
252            &[
253                "libnvJitLink.so.13",
254                "libnvJitLink.so.12",
255                "libnvJitLink.so",
256            ],
257        )?,
258        first_existing(
259            dir,
260            &["libcublasLt.so.13", "libcublasLt.so.12", "libcublasLt.so"],
261        )?,
262        first_existing(dir, &["libcublas.so.13", "libcublas.so.12", "libcublas.so"])?,
263        first_existing(
264            dir,
265            &["libcusparse.so.13", "libcusparse.so.12", "libcusparse.so"],
266        )?,
267        first_existing(
268            dir,
269            &[
270                "libcusolver.so.13",
271                "libcusolver.so.12",
272                "libcusolver.so.11",
273                "libcusolver.so",
274            ],
275        )?,
276    ];
277    Some(stack)
278}
279
280#[cfg(target_os = "linux")]
281fn system_cuda_stack_with_packaged_nvjitlink(dir: &str) -> Option<Vec<PathBuf>> {
282    let dir = Path::new(dir);
283    let nvjitlink = packaged_nvjitlink_library()?;
284    let stack = vec![
285        first_existing(dir, &["libcudart.so.13", "libcudart.so.12", "libcudart.so"])?,
286        nvjitlink,
287        first_existing(
288            dir,
289            &["libcublasLt.so.13", "libcublasLt.so.12", "libcublasLt.so"],
290        )?,
291        first_existing(dir, &["libcublas.so.13", "libcublas.so.12", "libcublas.so"])?,
292        first_existing(
293            dir,
294            &["libcusparse.so.13", "libcusparse.so.12", "libcusparse.so"],
295        )?,
296        first_existing(
297            dir,
298            &[
299                "libcusolver.so.13",
300                "libcusolver.so.12",
301                "libcusolver.so.11",
302                "libcusolver.so",
303            ],
304        )?,
305    ];
306    Some(stack)
307}
308
309#[cfg(target_os = "linux")]
310fn complete_nvidia_cuda_stack(root: &Path) -> Option<Vec<PathBuf>> {
311    let stack = vec![
312        first_existing(
313            &root.join("cuda_runtime").join("lib"),
314            &["libcudart.so.13", "libcudart.so.12", "libcudart.so"],
315        )?,
316        first_existing(
317            &root.join("nvjitlink").join("lib"),
318            &[
319                "libnvJitLink.so.13",
320                "libnvJitLink.so.12",
321                "libnvJitLink.so",
322            ],
323        )?,
324        first_existing(
325            &root.join("cublas").join("lib"),
326            &["libcublasLt.so.13", "libcublasLt.so.12", "libcublasLt.so"],
327        )?,
328        first_existing(
329            &root.join("cublas").join("lib"),
330            &["libcublas.so.13", "libcublas.so.12", "libcublas.so"],
331        )?,
332        first_existing(
333            &root.join("cusparse").join("lib"),
334            &["libcusparse.so.13", "libcusparse.so.12", "libcusparse.so"],
335        )?,
336        first_existing(
337            &root.join("cusolver").join("lib"),
338            &[
339                "libcusolver.so.13",
340                "libcusolver.so.12",
341                "libcusolver.so.11",
342                "libcusolver.so",
343            ],
344        )?,
345    ];
346    Some(stack)
347}
348
349#[cfg(target_os = "linux")]
350fn packaged_nvjitlink_library() -> Option<PathBuf> {
351    for root in nvidia_package_roots() {
352        let lib_dir = root.join("nvjitlink").join("lib");
353        if let Some(path) = first_existing(
354            &lib_dir,
355            &[
356                "libnvJitLink.so.13",
357                "libnvJitLink.so.12",
358                "libnvJitLink.so",
359            ],
360        ) {
361            return Some(path);
362        }
363    }
364    None
365}
366
367#[cfg(target_os = "linux")]
368fn nvidia_component_for_stem(stem: &str) -> String {
369    match stem {
370        "cublas" => "cublas".to_string(),
371        "cusolver" => "cusolver".to_string(),
372        "cusparse" => "cusparse".to_string(),
373        "nvJitLink" | "nvjitlink" => "nvjitlink".to_string(),
374        "cudart" | "cuda_runtime" => "cuda_runtime".to_string(),
375        _ => stem.to_string(),
376    }
377}
378
379#[cfg(target_os = "linux")]
380fn nvidia_package_roots() -> Vec<PathBuf> {
381    let mut roots = Vec::new();
382    if let Some(home) = current_user_home_dir() {
383        collect_python_nvidia_roots(home.join(".local/lib"), &mut roots);
384    }
385    collect_python_nvidia_roots(Path::new("/usr/local/lib").to_path_buf(), &mut roots);
386    collect_python_nvidia_roots(Path::new("/usr/lib").to_path_buf(), &mut roots);
387    dedup_paths(roots)
388}
389
390#[cfg(target_os = "linux")]
391fn current_user_home_dir() -> Option<PathBuf> {
392    let status = std::fs::read_to_string("/proc/self/status").ok()?;
393    let uid = status
394        .lines()
395        .find_map(|line| line.strip_prefix("Uid:"))?
396        .split_whitespace()
397        .next()?;
398    let passwd = std::fs::read_to_string("/etc/passwd").ok()?;
399    for line in passwd.lines() {
400        let mut fields = line.split(':');
401        fields.next()?;
402        fields.next()?;
403        if fields.next()? != uid {
404            continue;
405        }
406        fields.next()?;
407        fields.next()?;
408        return Some(PathBuf::from(fields.next()?));
409    }
410    None
411}
412
413#[cfg(target_os = "linux")]
414fn collect_python_nvidia_roots(base: PathBuf, out: &mut Vec<PathBuf>) {
415    let Ok(entries) = std::fs::read_dir(base) else {
416        return;
417    };
418    for entry in entries.flatten() {
419        let path = entry.path();
420        let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
421            continue;
422        };
423        if !name.starts_with("python") {
424            continue;
425        }
426        for site_dir in ["site-packages", "dist-packages"] {
427            let root = path.join(site_dir).join("nvidia");
428            if root.exists() {
429                out.push(root);
430            }
431        }
432    }
433}
434
435#[cfg(target_os = "linux")]
436fn first_existing(dir: &Path, names: &[&str]) -> Option<PathBuf> {
437    for name in names {
438        let path = dir.join(name);
439        if path.exists() {
440            return Some(path);
441        }
442    }
443    None
444}
445
446#[cfg(target_os = "linux")]
447fn dedup_paths(paths: Vec<PathBuf>) -> Vec<PathBuf> {
448    let mut out = Vec::new();
449    for path in paths {
450        let canonical = path.canonicalize().unwrap_or(path);
451        if !out.iter().any(|existing| existing == &canonical) {
452            out.push(canonical);
453        }
454    }
455    out
456}
457
458#[cfg(target_os = "linux")]
459fn append_versioned_linux_so_candidates(out: &mut Vec<String>, dir: &Path, base: &str) {
460    let Ok(entries) = std::fs::read_dir(dir) else {
461        return;
462    };
463    let prefix = format!("{base}.so.");
464    let mut versioned = Vec::new();
465    for entry in entries.flatten() {
466        let path = entry.path();
467        let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
468            continue;
469        };
470        if name.starts_with(&prefix) {
471            versioned.push(path);
472        }
473    }
474    versioned.sort();
475    for path in versioned {
476        let candidate = path.to_string_lossy().into_owned();
477        if !out.iter().any(|existing| existing == &candidate) {
478            out.push(candidate);
479        }
480    }
481}
482
483fn cuda_library_candidate_names() -> Vec<String> {
484    let mut out: Vec<String> = cuda_library_candidates()
485        .iter()
486        .map(|candidate| (*candidate).to_string())
487        .collect();
488    if cfg!(target_os = "linux") {
489        for dir in [
490            "/usr/local/nvidia/lib64",
491            "/usr/local/nvidia/lib",
492            "/usr/local/cuda/compat",
493            "/usr/lib/x86_64-linux-gnu",
494            "/usr/lib64",
495            "/usr/lib/wsl/lib",
496        ] {
497            append_versioned_linux_libcuda_candidates(&mut out, Path::new(dir));
498        }
499    }
500    out
501}
502
503fn append_versioned_linux_libcuda_candidates(out: &mut Vec<String>, dir: &Path) {
504    let Ok(entries) = std::fs::read_dir(dir) else {
505        return;
506    };
507    let mut versioned = Vec::new();
508    for entry in entries.flatten() {
509        let path = entry.path();
510        let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
511            continue;
512        };
513        if name.starts_with("libcuda.so.") && name != "libcuda.so.1" {
514            versioned.push(path);
515        }
516    }
517    versioned.sort();
518    for path in versioned {
519        let candidate = path.to_string_lossy().into_owned();
520        if !out.iter().any(|existing| existing == &candidate) {
521            out.push(candidate);
522        }
523    }
524}
525
526pub fn cuda_library_candidates() -> &'static [&'static str] {
527    if cfg!(target_os = "windows") {
528        &["nvcuda.dll"]
529    } else if cfg!(target_os = "macos") {
530        &["/usr/local/cuda/lib/libcuda.dylib", "libcuda.dylib"]
531    } else {
532        &[
533            "/usr/local/nvidia/lib64/libcuda.so.1",
534            "/usr/local/nvidia/lib64/libcuda.so",
535            "/usr/local/nvidia/lib/libcuda.so.1",
536            "/usr/local/nvidia/lib/libcuda.so",
537            "/usr/local/cuda/compat/libcuda.so.1",
538            "/usr/local/cuda/compat/libcuda.so",
539            "/usr/lib/x86_64-linux-gnu/libcuda.so.1",
540            "/usr/lib/x86_64-linux-gnu/libcuda.so",
541            "/usr/lib64/libcuda.so.1",
542            "/usr/lib64/libcuda.so",
543            "/usr/lib/wsl/lib/libcuda.so.1",
544            "/usr/lib/wsl/lib/libcuda.so",
545            "libcuda.so.1",
546            "libcuda.so",
547        ]
548    }
549}
550
551#[inline]
552pub fn to_i32(value: usize) -> Option<i32> {
553    i32::try_from(value).ok()
554}
555
556/// Repack a 2D `ndarray::ArrayBase` (row-major) into the column-major
557/// layout expected by every cuBLAS / cuSOLVER entry point.
558///
559/// Walks each column once via ndarray's iter (no per-element bounds checks)
560/// and extends into a pre-sized `Vec`. On large-scale inputs (n≈3×10⁵,
561/// p≈35) this replaces a per-element `a[[row, col]]` indexing loop that
562/// dominated the host side of every GPU dispatch.
563///
564/// Fast path: if the input is already F-order (column-major, contiguous in
565/// memory-order), borrow its raw buffer directly — no allocation, no copy.
566/// Standard row-major ndarrays still go through the permutation path.
567pub fn to_col_major<'a, S: Data<Elem = f64>>(a: &'a ArrayBase<S, Ix2>) -> Cow<'a, [f64]> {
568    let (rows, cols) = a.dim();
569    let strides = a.strides();
570    // F-order contiguous: column stride == 1, row stride == rows.
571    // `as_slice_memory_order` confirms the buffer is contiguous in memory.
572    if rows > 0
573        && cols > 0
574        && strides[0] == 1
575        && strides[1] == rows as isize
576        && let Some(slice) = a.as_slice_memory_order()
577    {
578        return Cow::Borrowed(slice);
579    }
580    let mut out: Vec<f64> = Vec::with_capacity(rows.saturating_mul(cols));
581    for col in 0..cols {
582        out.extend(a.column(col).iter().copied());
583    }
584    Cow::Owned(out)
585}
586
587/// Borrow (or pack) a 2D array's buffer in ROW-major (C) order.
588///
589/// The col-major dual of [`to_col_major`]: when the input is already
590/// C-contiguous its raw buffer IS the row-major flat layout, so this borrows
591/// it with no allocation or copy. Non-contiguous / F-order inputs are packed
592/// row by row.
593///
594/// This is the host-transpose-free upload path. A row-major `(r × c)` buffer,
595/// reinterpreted as a column-major buffer, is exactly the transpose `(c × r)`
596/// of the logical matrix — which is what the swapped-operand cuBLAS GEMM
597/// (`Cᵀ = Bᵀ·Aᵀ`) consumes, letting both the design upload and the result
598/// download skip the O(r·c) scalar permutation that dominated tall-skinny
599/// GEMMs on the host.
600pub fn to_row_major<'a, S: Data<Elem = f64>>(a: &'a ArrayBase<S, Ix2>) -> Cow<'a, [f64]> {
601    let (rows, cols) = a.dim();
602    let strides = a.strides();
603    // C-order contiguous: row stride == cols, column stride == 1.
604    if rows > 0
605        && cols > 0
606        && strides[1] == 1
607        && strides[0] == cols as isize
608        && let Some(slice) = a.as_slice_memory_order()
609    {
610        return Cow::Borrowed(slice);
611    }
612    let mut out: Vec<f64> = Vec::with_capacity(rows.saturating_mul(cols));
613    for row in 0..rows {
614        out.extend(a.row(row).iter().copied());
615    }
616    Cow::Owned(out)
617}
618
619/// Wrap a row-major flat buffer of shape `(rows, cols)` as an `Array2<f64>`
620/// without permutation. The buffer is consumed (no copy when its length
621/// matches). Returns `None` on a length mismatch.
622pub fn array_from_row_major(values: Vec<f64>, rows: usize, cols: usize) -> Option<Array2<f64>> {
623    if values.len() != rows.checked_mul(cols)? {
624        return None;
625    }
626    Array2::from_shape_vec((rows, cols), values).ok()
627}
628
629/// Convert a column-major flat buffer back into row-major `Array2<f64>`.
630pub fn from_col_major_inplace(values: &[f64], out: &mut Array2<f64>) -> Option<()> {
631    let (rows, cols) = out.dim();
632    if values.len() != rows.checked_mul(cols)? {
633        return None;
634    }
635    for col in 0..cols {
636        let src = ndarray::ArrayView1::from(&values[col * rows..(col + 1) * rows]);
637        out.column_mut(col).assign(&src);
638    }
639    Some(())
640}
641
642pub fn from_col_major(values: &[f64], rows: usize, cols: usize) -> Option<Array2<f64>> {
643    let mut out = Array2::<f64>::zeros((rows, cols));
644    from_col_major_inplace(values, &mut out)?;
645    Some(out)
646}
647
648#[cfg(test)]
649mod tests {
650    use super::*;
651    use ndarray::array;
652
653    #[test]
654    fn to_i32_fits_small_value() {
655        assert_eq!(to_i32(0), Some(0));
656        assert_eq!(to_i32(42), Some(42));
657        assert_eq!(to_i32(i32::MAX as usize), Some(i32::MAX));
658    }
659
660    #[test]
661    fn to_i32_overflows_returns_none() {
662        assert_eq!(to_i32(i32::MAX as usize + 1), None);
663    }
664
665    #[test]
666    fn to_col_major_2x3_row_major() {
667        // Row-major [[1,2,3],[4,5,6]] → col-major [1,4,2,5,3,6]
668        let a = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
669        let col = to_col_major(&a);
670        assert_eq!(&*col, &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
671    }
672
673    #[test]
674    fn to_col_major_identity_roundtrip() {
675        let a = array![[1.0_f64, 0.0], [0.0, 1.0]];
676        let col = to_col_major(&a);
677        assert_eq!(&*col, &[1.0, 0.0, 0.0, 1.0]);
678    }
679
680    #[test]
681    fn from_col_major_2x3_roundtrip() {
682        let original = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
683        let col = to_col_major(&original);
684        let recovered = from_col_major(&col, 2, 3).expect("should succeed");
685        assert_eq!(recovered, original);
686    }
687
688    #[test]
689    fn from_col_major_wrong_length_returns_none() {
690        // 2x3 = 6 elements, but only 5 provided
691        assert!(from_col_major(&[1.0, 2.0, 3.0, 4.0, 5.0], 2, 3).is_none());
692    }
693
694    #[test]
695    fn from_col_major_inplace_mismatched_buffer_returns_none() {
696        let mut out = Array2::<f64>::zeros((3, 3));
697        let short = vec![1.0_f64; 8]; // 9 expected, 8 given
698        assert!(from_col_major_inplace(&short, &mut out).is_none());
699    }
700
701    #[test]
702    fn from_col_major_single_element() {
703        let result = from_col_major(&[7.0], 1, 1).expect("should succeed");
704        assert_eq!(result[[0, 0]], 7.0);
705    }
706}