pardiso_wrapper/mkl/
interface.rs

1use super::loader::*;
2use crate::{MKLPardisoError, PardisoData, PardisoError, PardisoInterface};
3use std::ffi::c_void;
4
5// as defined in mkl_types.h: #define MKL_DOMAIN_PARDISO  4
6pub(crate) const MKL_DOMAIN_PARDISO: i32 = 4;
7
8pub struct MKLPardisoSolver {
9    _data: PardisoData,
10}
11
12impl PardisoInterface for MKLPardisoSolver {
13    fn data(&self) -> &PardisoData {
14        &self._data
15    }
16    fn data_mut(&mut self) -> &mut PardisoData {
17        &mut self._data
18    }
19
20    fn new() -> Result<Self, PardisoError> {
21        if !MKLPardisoSolver::is_loaded() {
22            return Err(MKLPardisoError::LibraryLoadFailure)?;
23        }
24        let data = PardisoData::default();
25        Ok(Self { _data: data })
26    }
27
28    fn pardisoinit(&mut self) -> Result<(), PardisoError> {
29        let ptrs = mkl_ptrs()?;
30
31        let pt = self.data_mut().pt.as_mut_ptr() as *mut c_void;
32        let mtype = self.get_matrix_type() as i32;
33        let iparm = self.data_mut().iparm.as_mut_ptr();
34
35        (ptrs.pardisoinit)(pt, &mtype, iparm);
36
37        Ok(())
38    }
39
40    fn pardiso(
41        &mut self,
42        a: &[f64],
43        ia: &[i32],
44        ja: &[i32],
45        b: &mut [f64],
46        x: &mut [f64],
47        n: i32,
48        nrhs: i32,
49    ) -> Result<(), PardisoError> {
50        let ptrs = mkl_ptrs()?;
51
52        let mut error = 0;
53        let pt = self.data_mut().pt.as_mut_ptr() as *mut c_void;
54        let maxfct = self.data().maxfct;
55        let mnum = self.data().mnum;
56        let mtype = self.get_matrix_type() as i32;
57        let phase = self.data().phase as i32;
58        let a = a.as_ptr();
59        let ia = ia.as_ptr();
60        let ja = ja.as_ptr();
61        let b = b.as_mut_ptr();
62        let x = x.as_mut_ptr();
63        let perm = self.data_mut().perm.as_mut_ptr();
64        let iparm = self.data_mut().iparm.as_mut_ptr();
65        let msglvl = self.data().msglvl as i32;
66
67        (ptrs.pardiso)(
68            pt, &maxfct, &mnum, &mtype, &phase, &n, a, ia, ja, perm, &nrhs, iparm, &msglvl, b, x,
69            &mut error,
70        );
71
72        if error != 0 {
73            let error = MKLPardisoError::from(error);
74            return Err(PardisoError::from(error));
75        }
76        Ok(())
77    }
78
79    fn name(&self) -> &'static str {
80        "mkl"
81    }
82
83    fn is_licensed() -> bool {
84        true //MKL doesn't do license checks
85    }
86
87    fn is_loaded() -> bool {
88        mkl_ptrs().is_ok()
89    }
90
91    fn get_num_threads(&self) -> Result<i32, PardisoError> {
92        Ok(MKLPardisoSolver::mkl_get_max_threads()?)
93    }
94}
95
96// additional MKL specific functions
97impl MKLPardisoSolver {
98    pub fn set_num_threads(&mut self, num_threads: i32) -> Result<i32, PardisoError> {
99        Ok(MKLPardisoSolver::mkl_set_num_threads_local(num_threads)?)
100    }
101    pub fn mkl_set_num_threads(num_threads: i32) -> Result<i32, MKLPardisoError> {
102        Ok((mkl_ptrs()?.mkl_set_num_threads)(&num_threads))
103    }
104    // sets threads for the current execution thread
105    // overrides global settings, so we use this for the
106    // the default set_num_threads above.  It should
107    // be reported correctly by mkl_get_max_threads
108    pub fn mkl_set_num_threads_local(num_threads: i32) -> Result<i32, MKLPardisoError> {
109        Ok((mkl_ptrs()?.mkl_set_num_threads_local)(&num_threads))
110    }
111    // sets the number of threads in MKL_DOMAIN_PARDISO only
112    pub fn mkl_set_num_threads_pardiso(num_threads: i32) -> Result<i32, MKLPardisoError> {
113        Ok((mkl_ptrs()?.mkl_domain_set_num_threads)(
114            &num_threads,
115            &MKL_DOMAIN_PARDISO,
116        ))
117    }
118    // max threads available to MKL
119    pub fn mkl_get_max_threads() -> Result<i32, MKLPardisoError> {
120        Ok((mkl_ptrs()?.mkl_get_max_threads)())
121    }
122    // max threads available to MKL_DOMAIN_PARDISO, possibly limited
123    // by environment variables or thread local settings
124    pub fn mkl_get_max_threads_pardiso() -> Result<i32, MKLPardisoError> {
125        Ok((mkl_ptrs()?.mkl_domain_get_max_threads)(
126            &MKL_DOMAIN_PARDISO,
127        ))
128    }
129    pub fn mkl_set_dynamic(dynamic: i32) -> Result<(), MKLPardisoError> {
130        (mkl_ptrs()?.mkl_set_dynamic)(&dynamic);
131        Ok(())
132    }
133}
134
135impl Drop for MKLPardisoSolver {
136    fn drop(&mut self) {
137        self.release();
138    }
139}