1use std::fs::File;
10use std::io::ErrorKind;
11use std::io::Read;
12use std::path::Path;
13#[cfg(any(feature = "opencl", feature = "cuda"))]
14use std::sync::Arc;
15#[cfg(feature = "opencl")]
16use crate::matrix;
17#[cfg(feature = "opencl")]
18use crate::matrix::opencl::CL_DEVICE_TYPE_ALL;
19#[cfg(feature = "opencl")]
20use crate::matrix::opencl::ClBackend;
21#[cfg(feature = "opencl")]
22use crate::matrix::opencl::Context;
23#[cfg(feature = "opencl")]
24use crate::matrix::opencl::Device;
25#[cfg(feature = "opencl")]
26use crate::matrix::opencl::get_platforms;
27#[cfg(feature = "cuda")]
28use crate::matrix::cuda::CudaBackend;
29#[cfg(any(feature = "opencl", feature = "cuda"))]
30use crate::matrix::set_default_backend;
31use crate::matrix::unset_default_backend;
32use crate::serde::Deserialize;
33use crate::serde::Serialize;
34use crate::toml;
35use crate::error::*;
36
37#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Serialize, Deserialize)]
39pub enum Backend
40{
41 #[serde(rename = "OpenCL")]
43 OpenCl,
44 #[serde(rename = "CUDA")]
46 Cuda,
47}
48
49#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
51pub struct BackendConfig
52{
53 pub backend: Option<Backend>,
55 pub ordinal: Option<usize>,
57 pub platform: Option<usize>,
59 pub device: Option<usize>,
61 pub cublas: Option<bool>,
63 pub mma: Option<bool>,
65}
66
67impl BackendConfig
68{
69 pub fn read(r: &mut dyn Read) -> Result<Self>
71 {
72 let mut s = String::new();
73 match r.read_to_string(&mut s) {
74 Ok(_) => {
75 match toml::from_str(s.as_str()) {
76 Ok(config) => Ok(config),
77 Err(err) => Err(Error::TomlDe(err)),
78 }
79 },
80 Err(err) => Err(Error::Io(err)),
81 }
82 }
83
84 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self>
86 {
87 match File::open(path) {
88 Ok(mut file) => Self::read(&mut file),
89 Err(err) => Err(Error::Io(err)),
90 }
91 }
92
93 pub fn load_opt<P: AsRef<Path>>(path: P) -> Result<Option<Self>>
96 {
97 match File::open(path) {
98 Ok(mut file) => Ok(Some(Self::read(&mut file)?)),
99 Err(err) if err.kind() == ErrorKind::NotFound => Ok(None),
100 Err(err) => Err(Error::Io(err)),
101 }
102 }
103}
104
105#[cfg(feature = "opencl")]
106fn initialize_opencl_backend(platform_idx: usize, device_idx: usize) -> Result<()>
107{
108 let platforms = match get_platforms() {
109 Ok(tmp_platforms) => tmp_platforms,
110 Err(err) => return Err(Error::Matrix(matrix::Error::OpenCl(err))),
111 };
112 let platform = match platforms.get(platform_idx) {
113 Some(tmp_platform) => tmp_platform,
114 None => return Err(Error::Matrix(matrix::Error::NoPlatform)),
115 };
116 let device_ids = match platform.get_devices(CL_DEVICE_TYPE_ALL) {
117 Ok(tmp_device_ids) => tmp_device_ids,
118 Err(err) => return Err(Error::Matrix(matrix::Error::OpenCl(err))),
119 };
120 let device = match device_ids.get(device_idx) {
121 Some(device_id) => Device::new(*device_id),
122 None => return Err(Error::Matrix(matrix::Error::NoDevice)),
123 };
124 let context = match Context::from_device(&device) {
125 Ok(tmp_context) => tmp_context,
126 Err(err) => return Err(Error::Matrix(matrix::Error::OpenCl(err))),
127 };
128 match ClBackend::new_with_context(context) {
129 Ok(backend) => {
130 match set_default_backend(Arc::new(backend)) {
131 Ok(()) => Ok(()),
132 Err(err) => Err(Error::Matrix(err)),
133 }
134 },
135 Err(err) => Err(Error::Matrix(err)),
136 }
137}
138
139#[cfg(not(feature = "opencl"))]
140fn initialize_opencl_backend(_platform_idx: usize, _device_idx: usize) -> Result<()>
141{ Err(Error::NoOpenClBackend) }
142
143#[cfg(feature = "cuda")]
144fn initialize_cuda_backend(ordinal: usize, is_cublas: bool, is_mma: bool) -> Result<()>
145{
146 match CudaBackend::new_with_ordinal_and_flags(ordinal, is_cublas, is_mma) {
147 Ok(backend) => {
148 match set_default_backend(Arc::new(backend)) {
149 Ok(()) => Ok(()),
150 Err(err) => Err(Error::Matrix(err)),
151 }
152 },
153 Err(err) => Err(Error::Matrix(err)),
154 }
155}
156
157#[cfg(not(feature = "cuda"))]
158fn initialize_cuda_backend(_ordinal: usize, _is_cublas: bool, _is_mma: bool) -> Result<()>
159{ Err(Error::NoCudaBackend) }
160
161pub fn initialize_backend_with_config(config: &Option<BackendConfig>) -> Result<()>
166{
167 #[cfg(feature = "cuda")]
168 let mut backend = Backend::Cuda;
169 #[cfg(not(feature = "cuda"))]
170 let mut backend = Backend::OpenCl;
171 let mut ordinal = 0usize;
172 let mut platform_idx = 0usize;
173 let mut device_idx = 0usize;
174 let mut is_cublas = true;
175 let mut is_mma = false;
176 match config {
177 Some(config) => {
178 backend = config.backend.unwrap_or(backend);
179 ordinal = config.ordinal.unwrap_or(ordinal);
180 platform_idx = config.platform.unwrap_or(platform_idx);
181 device_idx = config.device.unwrap_or(device_idx);
182 is_cublas = config.cublas.unwrap_or(is_cublas);
183 is_mma = config.mma.unwrap_or(is_mma);
184 },
185 None => (),
186 }
187 match backend {
188 Backend::OpenCl => initialize_opencl_backend(platform_idx, device_idx),
189 Backend::Cuda => initialize_cuda_backend(ordinal, is_cublas, is_mma),
190 }
191}
192
193pub fn initialize_backend<P: AsRef<Path>>(path: P) -> Result<()>
198{
199 let config = BackendConfig::load_opt(path)?;
200 initialize_backend_with_config(&config)
201}
202
203pub fn finalize_backend() -> Result<()>
205{
206 match unset_default_backend() {
207 Ok(()) => Ok(()),
208 Err(err) => Err(Error::Matrix(err)),
209 }
210}
211
212#[cfg(test)]
213mod tests;