1use crate::backend::BackendDevice;
2use crate::cpu_backend::CpuDevice;
3use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
4
5#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
8pub enum DeviceLocation {
9 Cpu,
10 Cuda { gpu_id: usize },
11 Metal { gpu_id: usize },
12}
13
14#[derive(Debug, Clone)]
16pub enum Device {
17 Cpu,
18 Cuda(crate::CudaDevice),
19 Metal(crate::MetalDevice),
20}
21
22pub trait NdArray {
23 fn shape(&self) -> Result<Shape>;
24
25 fn to_cpu_storage(&self) -> CpuStorage;
26}
27
28impl<S: WithDType> NdArray for S {
29 fn shape(&self) -> Result<Shape> {
30 Ok(Shape::from(()))
31 }
32
33 fn to_cpu_storage(&self) -> CpuStorage {
34 S::to_cpu_storage(&[*self])
35 }
36}
37
38impl<S: WithDType, const N: usize> NdArray for &[S; N] {
39 fn shape(&self) -> Result<Shape> {
40 Ok(Shape::from(self.len()))
41 }
42
43 fn to_cpu_storage(&self) -> CpuStorage {
44 S::to_cpu_storage(self.as_slice())
45 }
46}
47
48impl<S: WithDType> NdArray for &[S] {
49 fn shape(&self) -> Result<Shape> {
50 Ok(Shape::from(self.len()))
51 }
52
53 fn to_cpu_storage(&self) -> CpuStorage {
54 S::to_cpu_storage(self)
55 }
56}
57
58impl<S: WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
59 fn shape(&self) -> Result<Shape> {
60 Ok(Shape::from((M, N)))
61 }
62
63 fn to_cpu_storage(&self) -> CpuStorage {
64 S::to_cpu_storage_owned(self.concat())
65 }
66}
67
68impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
69 for &[[[S; N3]; N2]; N1]
70{
71 fn shape(&self) -> Result<Shape> {
72 Ok(Shape::from((N1, N2, N3)))
73 }
74
75 fn to_cpu_storage(&self) -> CpuStorage {
76 let mut vec = Vec::with_capacity(N1 * N2 * N3);
77 for i1 in 0..N1 {
78 for i2 in 0..N2 {
79 vec.extend(self[i1][i2])
80 }
81 }
82 S::to_cpu_storage_owned(vec)
83 }
84}
85
86impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
87 for &[[[[S; N4]; N3]; N2]; N1]
88{
89 fn shape(&self) -> Result<Shape> {
90 Ok(Shape::from((N1, N2, N3, N4)))
91 }
92
93 fn to_cpu_storage(&self) -> CpuStorage {
94 let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
95 for i1 in 0..N1 {
96 for i2 in 0..N2 {
97 for i3 in 0..N3 {
98 vec.extend(self[i1][i2][i3])
99 }
100 }
101 }
102 S::to_cpu_storage_owned(vec)
103 }
104}
105
106impl<S: NdArray> NdArray for Vec<S> {
107 fn shape(&self) -> Result<Shape> {
108 if self.is_empty() {
109 crate::bail!("empty array")
110 }
111 let shape0 = self[0].shape()?;
112 let n = self.len();
113 for v in self.iter() {
114 let shape = v.shape()?;
115 if shape != shape0 {
116 crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
117 }
118 }
119 Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
120 }
121
122 fn to_cpu_storage(&self) -> CpuStorage {
123 let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
125 CpuStorage::concat(storages.as_slice()).unwrap()
126 }
127}
128
129impl Device {
130 pub fn new_cuda(ordinal: usize) -> Result<Self> {
131 Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
132 }
133
134 pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> {
135 match self {
136 Self::Cuda(d) => Ok(d),
137 Self::Cpu => crate::bail!("expected a cuda device, got cpu"),
138 Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"),
139 }
140 }
141
142 pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
143 match self {
144 Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
145 Self::Cpu => crate::bail!("expected a metal device, got cpu"),
146 Self::Metal(d) => Ok(d),
147 }
148 }
149
150 pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
151 Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
152 }
153
154 pub fn new_metal(ordinal: usize) -> Result<Self> {
155 Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
156 }
157
158 pub fn set_seed(&self, seed: u64) -> Result<()> {
159 match self {
160 Self::Cpu => CpuDevice.set_seed(seed),
161 Self::Cuda(c) => c.set_seed(seed),
162 Self::Metal(m) => m.set_seed(seed),
163 }
164 }
165
166 pub fn same_device(&self, rhs: &Self) -> bool {
167 match (self, rhs) {
168 (Self::Cpu, Self::Cpu) => true,
169 (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
170 (Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
171 _ => false,
172 }
173 }
174
175 pub fn location(&self) -> DeviceLocation {
176 match self {
177 Self::Cpu => DeviceLocation::Cpu,
178 Self::Cuda(device) => device.location(),
179 Device::Metal(device) => device.location(),
180 }
181 }
182
183 pub fn is_cpu(&self) -> bool {
184 matches!(self, Self::Cpu)
185 }
186
187 pub fn is_cuda(&self) -> bool {
188 matches!(self, Self::Cuda(_))
189 }
190
191 pub fn is_metal(&self) -> bool {
192 matches!(self, Self::Metal(_))
193 }
194
195 pub fn supports_bf16(&self) -> bool {
196 match self {
197 Self::Cuda(_) | Self::Metal(_) => true,
198 Self::Cpu => false,
199 }
200 }
201
202 pub fn bf16_default_to_f32(&self) -> DType {
204 if self.supports_bf16() {
205 DType::BF16
206 } else {
207 DType::F32
208 }
209 }
210
211 pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
212 if crate::utils::cuda_is_available() {
213 Self::new_cuda(ordinal)
214 } else {
215 Ok(Self::Cpu)
216 }
217 }
218
219 pub(crate) fn rand_uniform_f64(
220 &self,
221 lo: f64,
222 up: f64,
223 shape: &Shape,
224 dtype: DType,
225 ) -> Result<Storage> {
226 match self {
227 Device::Cpu => {
228 let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
229 Ok(Storage::Cpu(storage))
230 }
231 Device::Cuda(device) => {
232 if dtype == DType::F16 || dtype == DType::BF16 {
234 let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
235 Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
236 } else {
237 let storage = device.rand_uniform(shape, dtype, lo, up)?;
238 Ok(Storage::Cuda(storage))
239 }
240 }
241 Device::Metal(device) => {
242 let storage = device.rand_uniform(shape, dtype, lo, up)?;
243 Ok(Storage::Metal(storage))
244 }
245 }
246 }
247
248 pub(crate) fn rand_uniform<T: crate::FloatDType>(
249 &self,
250 lo: T,
251 up: T,
252 shape: &Shape,
253 ) -> Result<Storage> {
254 self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
255 }
256
257 pub(crate) fn rand_normal_f64(
258 &self,
259 mean: f64,
260 std: f64,
261 shape: &Shape,
262 dtype: DType,
263 ) -> Result<Storage> {
264 match self {
265 Device::Cpu => {
266 let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
267 Ok(Storage::Cpu(storage))
268 }
269 Device::Cuda(device) => {
270 if dtype == DType::F16 || dtype == DType::BF16 {
272 let storage = device.rand_normal(shape, DType::F32, mean, std)?;
273 Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
274 } else {
275 let storage = device.rand_normal(shape, dtype, mean, std)?;
276 Ok(Storage::Cuda(storage))
277 }
278 }
279 Device::Metal(device) => {
280 let storage = device.rand_normal(shape, dtype, mean, std)?;
281 Ok(Storage::Metal(storage))
282 }
283 }
284 }
285
286 pub(crate) fn rand_normal<T: crate::FloatDType>(
287 &self,
288 mean: T,
289 std: T,
290 shape: &Shape,
291 ) -> Result<Storage> {
292 self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
293 }
294
295 pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
296 match self {
297 Device::Cpu => {
298 let storage = CpuDevice.ones_impl(shape, dtype)?;
299 Ok(Storage::Cpu(storage))
300 }
301 Device::Cuda(device) => {
302 let storage = device.ones_impl(shape, dtype)?;
303 Ok(Storage::Cuda(storage))
304 }
305 Device::Metal(device) => {
306 let storage = device.ones_impl(shape, dtype)?;
307 Ok(Storage::Metal(storage))
308 }
309 }
310 }
311
312 pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
313 match self {
314 Device::Cpu => {
315 let storage = CpuDevice.zeros_impl(shape, dtype)?;
316 Ok(Storage::Cpu(storage))
317 }
318 Device::Cuda(device) => {
319 let storage = device.zeros_impl(shape, dtype)?;
320 Ok(Storage::Cuda(storage))
321 }
322 Device::Metal(device) => {
323 let storage = device.zeros_impl(shape, dtype)?;
324 Ok(Storage::Metal(storage))
325 }
326 }
327 }
328
329 pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
330 match self {
331 Device::Cpu => {
332 let storage = CpuDevice.alloc_uninit(shape, dtype)?;
333 Ok(Storage::Cpu(storage))
334 }
335 Device::Cuda(device) => {
336 let storage = device.alloc_uninit(shape, dtype)?;
337 Ok(Storage::Cuda(storage))
338 }
339 Device::Metal(device) => {
340 let storage = device.alloc_uninit(shape, dtype)?;
341 Ok(Storage::Metal(storage))
342 }
343 }
344 }
345
346 pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
347 match self {
348 Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
349 Device::Cuda(device) => {
350 let storage = device.storage_from_slice(data)?;
351 Ok(Storage::Cuda(storage))
352 }
353 Device::Metal(device) => {
354 let storage = device.storage_from_slice(data)?;
355 Ok(Storage::Metal(storage))
356 }
357 }
358 }
359
360 pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
361 match self {
362 Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
363 Device::Cuda(device) => {
364 let storage = array.to_cpu_storage();
365 let storage = device.storage_from_cpu_storage_owned(storage)?;
366 Ok(Storage::Cuda(storage))
367 }
368 Device::Metal(device) => {
369 let storage = array.to_cpu_storage();
370 let storage = device.storage_from_cpu_storage_owned(storage)?;
371 Ok(Storage::Metal(storage))
372 }
373 }
374 }
375
376 pub(crate) fn storage_owned<S: WithDType>(&self, data: Vec<S>) -> Result<Storage> {
377 match self {
378 Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
379 Device::Cuda(device) => {
380 let storage = S::to_cpu_storage_owned(data);
381 let storage = device.storage_from_cpu_storage_owned(storage)?;
382 Ok(Storage::Cuda(storage))
383 }
384 Device::Metal(device) => {
385 let storage = S::to_cpu_storage_owned(data);
386 let storage = device.storage_from_cpu_storage_owned(storage)?;
387 Ok(Storage::Metal(storage))
388 }
389 }
390 }
391
392 pub fn synchronize(&self) -> Result<()> {
393 match self {
394 Self::Cpu => Ok(()),
395 Self::Cuda(d) => d.synchronize(),
396 Self::Metal(d) => d.synchronize(),
397 }
398 }
399}