1use crate::backend::BackendDevice;
2use crate::cpu_backend::CpuDevice;
3use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
4
5#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
8pub enum DeviceLocation {
9 Cpu,
10 Cuda { gpu_id: usize },
11 Metal { gpu_id: usize },
12}
13
14#[derive(Debug, Clone)]
16pub enum Device {
17 Cpu,
18 Cuda(crate::CudaDevice),
19 Metal(crate::MetalDevice),
20}
21
22pub trait NdArray {
23 fn shape(&self) -> Result<Shape>;
24
25 fn to_cpu_storage(&self) -> CpuStorage;
26}
27
28impl<S: WithDType> NdArray for S {
29 fn shape(&self) -> Result<Shape> {
30 Ok(Shape::from(()))
31 }
32
33 fn to_cpu_storage(&self) -> CpuStorage {
34 S::to_cpu_storage(&[*self])
35 }
36}
37
38impl<S: WithDType, const N: usize> NdArray for &[S; N] {
39 fn shape(&self) -> Result<Shape> {
40 Ok(Shape::from(self.len()))
41 }
42
43 fn to_cpu_storage(&self) -> CpuStorage {
44 S::to_cpu_storage(self.as_slice())
45 }
46}
47
48impl<S: WithDType> NdArray for &[S] {
49 fn shape(&self) -> Result<Shape> {
50 Ok(Shape::from(self.len()))
51 }
52
53 fn to_cpu_storage(&self) -> CpuStorage {
54 S::to_cpu_storage(self)
55 }
56}
57
58impl<S: WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
59 fn shape(&self) -> Result<Shape> {
60 Ok(Shape::from((M, N)))
61 }
62
63 fn to_cpu_storage(&self) -> CpuStorage {
64 S::to_cpu_storage_owned(self.concat())
65 }
66}
67
68impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
69 for &[[[S; N3]; N2]; N1]
70{
71 fn shape(&self) -> Result<Shape> {
72 Ok(Shape::from((N1, N2, N3)))
73 }
74
75 fn to_cpu_storage(&self) -> CpuStorage {
76 let mut vec = Vec::with_capacity(N1 * N2 * N3);
77 for i1 in 0..N1 {
78 for i2 in 0..N2 {
79 vec.extend(self[i1][i2])
80 }
81 }
82 S::to_cpu_storage_owned(vec)
83 }
84}
85
86impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
87 for &[[[[S; N4]; N3]; N2]; N1]
88{
89 fn shape(&self) -> Result<Shape> {
90 Ok(Shape::from((N1, N2, N3, N4)))
91 }
92
93 fn to_cpu_storage(&self) -> CpuStorage {
94 let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
95 for i1 in 0..N1 {
96 for i2 in 0..N2 {
97 for i3 in 0..N3 {
98 vec.extend(self[i1][i2][i3])
99 }
100 }
101 }
102 S::to_cpu_storage_owned(vec)
103 }
104}
105
106impl<S: WithDType> NdArray for Vec<S> {
107 fn shape(&self) -> Result<Shape> {
108 Ok(Shape::from(self.len()))
109 }
110
111 fn to_cpu_storage(&self) -> CpuStorage {
112 S::to_cpu_storage(self.as_slice())
113 }
114}
115
116impl<S: WithDType> NdArray for Vec<&[S]> {
117 fn shape(&self) -> Result<Shape> {
118 if self.is_empty() {
119 crate::bail!("empty array")
120 }
121 let n = self.len();
122 let m = self[0].len();
123 for v in self.iter() {
124 if v.len() != m {
125 crate::bail!("two elements have different len {m} {}", v.len())
126 }
127 }
128 Ok(Shape::from((n, m)))
129 }
130
131 fn to_cpu_storage(&self) -> CpuStorage {
132 let data = self.iter().copied().flatten().copied().collect::<Vec<_>>();
133 S::to_cpu_storage_owned(data)
134 }
135}
136
137impl<S: WithDType> NdArray for Vec<Vec<S>> {
138 fn shape(&self) -> Result<Shape> {
139 if self.is_empty() {
140 crate::bail!("empty array")
141 }
142 let n = self.len();
143 let m = self[0].len();
144 for v in self.iter() {
145 if v.len() != m {
146 crate::bail!("two elements have different len {m} {}", v.len())
147 }
148 }
149 Ok(Shape::from((n, m)))
150 }
151
152 fn to_cpu_storage(&self) -> CpuStorage {
153 let len: usize = self.iter().map(|v| v.len()).sum();
154 let mut dst = Vec::with_capacity(len);
155 for v in self.iter() {
156 dst.extend(v.iter().copied());
157 }
158 S::to_cpu_storage_owned(dst)
159 }
160}
161
162impl<S: WithDType> NdArray for Vec<Vec<Vec<S>>> {
163 fn shape(&self) -> Result<Shape> {
164 if self.is_empty() {
165 crate::bail!("empty array")
166 }
167 let shape0 = self[0].shape()?;
168 let n = self.len();
169 for v in self.iter() {
170 let shape = v.shape()?;
171 if shape != shape0 {
172 crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
173 }
174 }
175 Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
176 }
177
178 fn to_cpu_storage(&self) -> CpuStorage {
179 if self.is_empty() {
180 return S::to_cpu_storage_owned(vec![]);
181 }
182 let len: usize = self
183 .iter()
184 .map(|v| v.iter().map(|v| v.len()).sum::<usize>())
185 .sum();
186 let mut dst = Vec::with_capacity(len);
187 for v1 in self.iter() {
188 for v2 in v1.iter() {
189 dst.extend(v2.iter().copied());
190 }
191 }
192 S::to_cpu_storage_owned(dst)
193 }
194}
195
196impl<S: WithDType> NdArray for Vec<Vec<Vec<Vec<S>>>> {
197 fn shape(&self) -> Result<Shape> {
198 if self.is_empty() {
199 crate::bail!("empty array")
200 }
201 let shape0 = self[0].shape()?;
202 let n = self.len();
203 for v in self.iter() {
204 let shape = v.shape()?;
205 if shape != shape0 {
206 crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
207 }
208 }
209 Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
210 }
211
212 fn to_cpu_storage(&self) -> CpuStorage {
213 let len: usize = self
214 .iter()
215 .map(|v| {
216 v.iter()
217 .map(|v| v.iter().map(|v| v.len()).sum::<usize>())
218 .sum::<usize>()
219 })
220 .sum();
221 let mut dst = Vec::with_capacity(len);
222 for v1 in self.iter() {
223 for v2 in v1.iter() {
224 for v3 in v2.iter() {
225 dst.extend(v3.iter().copied());
226 }
227 }
228 }
229 S::to_cpu_storage_owned(dst)
230 }
231}
232
233impl Device {
234 pub fn new_cuda(ordinal: usize) -> Result<Self> {
235 Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
236 }
237
238 pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> {
239 match self {
240 Self::Cuda(d) => Ok(d),
241 Self::Cpu => crate::bail!("expected a cuda device, got cpu"),
242 Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"),
243 }
244 }
245
246 pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
247 match self {
248 Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
249 Self::Cpu => crate::bail!("expected a metal device, got cpu"),
250 Self::Metal(d) => Ok(d),
251 }
252 }
253
254 pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
255 Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
256 }
257
258 pub fn new_metal(ordinal: usize) -> Result<Self> {
259 Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
260 }
261
262 pub fn with_context<F, R>(&self, f: F) -> R
268 where
269 F: FnOnce() -> R + Send,
270 R: Send,
271 {
272 match self {
273 Self::Cpu => crate::utils::with_threadpool(f),
274 _ => f(),
275 }
276 }
277
278 pub fn set_seed(&self, seed: u64) -> Result<()> {
279 match self {
280 Self::Cpu => CpuDevice.set_seed(seed),
281 Self::Cuda(c) => c.set_seed(seed),
282 Self::Metal(m) => m.set_seed(seed),
283 }
284 }
285
286 pub fn get_current_seed(&self) -> Result<u64> {
287 match self {
288 Self::Cpu => CpuDevice.get_current_seed(),
289 Self::Cuda(c) => c.get_current_seed(),
290 Self::Metal(m) => m.get_current_seed(),
291 }
292 }
293
294 pub fn same_device(&self, rhs: &Self) -> bool {
295 match (self, rhs) {
296 (Self::Cpu, Self::Cpu) => true,
297 (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
298 (Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
299 _ => false,
300 }
301 }
302
303 pub fn location(&self) -> DeviceLocation {
304 match self {
305 Self::Cpu => DeviceLocation::Cpu,
306 Self::Cuda(device) => device.location(),
307 Device::Metal(device) => device.location(),
308 }
309 }
310
311 pub fn is_cpu(&self) -> bool {
312 matches!(self, Self::Cpu)
313 }
314
315 pub fn is_cuda(&self) -> bool {
316 matches!(self, Self::Cuda(_))
317 }
318
319 pub fn is_metal(&self) -> bool {
320 matches!(self, Self::Metal(_))
321 }
322
323 pub fn supports_bf16(&self) -> bool {
324 match self {
325 Self::Cuda(_) | Self::Metal(_) => true,
326 Self::Cpu => false,
327 }
328 }
329
330 pub fn bf16_default_to_f32(&self) -> DType {
332 if self.supports_bf16() {
333 DType::BF16
334 } else {
335 DType::F32
336 }
337 }
338
339 pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
340 if crate::utils::cuda_is_available() {
341 Self::new_cuda(ordinal)
342 } else {
343 Ok(Self::Cpu)
344 }
345 }
346
347 pub fn metal_if_available(ordinal: usize) -> Result<Self> {
348 if crate::utils::metal_is_available() {
349 Self::new_metal(ordinal)
350 } else {
351 Ok(Self::Cpu)
352 }
353 }
354
355 pub(crate) fn rand_uniform_f64(
356 &self,
357 lo: f64,
358 up: f64,
359 shape: &Shape,
360 dtype: DType,
361 ) -> Result<Storage> {
362 match self {
363 Device::Cpu => {
364 let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
365 Ok(Storage::Cpu(storage))
366 }
367 Device::Cuda(device) => {
368 if dtype == DType::F16 || dtype == DType::BF16 {
370 let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
371 Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
372 } else {
373 let storage = device.rand_uniform(shape, dtype, lo, up)?;
374 Ok(Storage::Cuda(storage))
375 }
376 }
377 Device::Metal(device) => {
378 let storage = device.rand_uniform(shape, dtype, lo, up)?;
379 Ok(Storage::Metal(storage))
380 }
381 }
382 }
383
384 pub(crate) fn rand_uniform<T: crate::FloatDType>(
385 &self,
386 lo: T,
387 up: T,
388 shape: &Shape,
389 ) -> Result<Storage> {
390 self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
391 }
392
393 pub(crate) fn rand_normal_f64(
394 &self,
395 mean: f64,
396 std: f64,
397 shape: &Shape,
398 dtype: DType,
399 ) -> Result<Storage> {
400 match self {
401 Device::Cpu => {
402 let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
403 Ok(Storage::Cpu(storage))
404 }
405 Device::Cuda(device) => {
406 if dtype == DType::F16 || dtype == DType::BF16 {
408 let storage = device.rand_normal(shape, DType::F32, mean, std)?;
409 Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
410 } else {
411 let storage = device.rand_normal(shape, dtype, mean, std)?;
412 Ok(Storage::Cuda(storage))
413 }
414 }
415 Device::Metal(device) => {
416 let storage = device.rand_normal(shape, dtype, mean, std)?;
417 Ok(Storage::Metal(storage))
418 }
419 }
420 }
421
422 pub(crate) fn rand_normal<T: crate::FloatDType>(
423 &self,
424 mean: T,
425 std: T,
426 shape: &Shape,
427 ) -> Result<Storage> {
428 self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
429 }
430
431 pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
432 match self {
433 Device::Cpu => {
434 let storage = CpuDevice.zeros_impl(shape, dtype)?;
435 Ok(Storage::Cpu(storage))
436 }
437 Device::Cuda(device) => {
438 let storage = device.zeros_impl(shape, dtype)?;
439 Ok(Storage::Cuda(storage))
440 }
441 Device::Metal(device) => {
442 let storage = device.zeros_impl(shape, dtype)?;
443 Ok(Storage::Metal(storage))
444 }
445 }
446 }
447
448 pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
449 match self {
450 Device::Cpu => {
451 let storage = CpuDevice.alloc_uninit(shape, dtype)?;
452 Ok(Storage::Cpu(storage))
453 }
454 Device::Cuda(device) => {
455 let storage = device.alloc_uninit(shape, dtype)?;
456 Ok(Storage::Cuda(storage))
457 }
458 Device::Metal(device) => {
459 let storage = device.alloc_uninit(shape, dtype)?;
460 Ok(Storage::Metal(storage))
461 }
462 }
463 }
464
465 pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
466 match self {
467 Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
468 Device::Cuda(device) => {
469 let storage = device.storage_from_slice(data)?;
470 Ok(Storage::Cuda(storage))
471 }
472 Device::Metal(device) => {
473 let storage = device.storage_from_slice(data)?;
474 Ok(Storage::Metal(storage))
475 }
476 }
477 }
478
479 pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
480 match self {
481 Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
482 Device::Cuda(device) => {
483 let storage = array.to_cpu_storage();
484 let storage = device.storage_from_cpu_storage_owned(storage)?;
485 Ok(Storage::Cuda(storage))
486 }
487 Device::Metal(device) => {
488 let storage = array.to_cpu_storage();
489 let storage = device.storage_from_cpu_storage_owned(storage)?;
490 Ok(Storage::Metal(storage))
491 }
492 }
493 }
494
495 pub(crate) fn storage_owned<S: WithDType>(&self, data: Vec<S>) -> Result<Storage> {
496 match self {
497 Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
498 Device::Cuda(device) => {
499 let storage = S::to_cpu_storage_owned(data);
500 let storage = device.storage_from_cpu_storage_owned(storage)?;
501 Ok(Storage::Cuda(storage))
502 }
503 Device::Metal(device) => {
504 let storage = S::to_cpu_storage_owned(data);
505 let storage = device.storage_from_cpu_storage_owned(storage)?;
506 Ok(Storage::Metal(storage))
507 }
508 }
509 }
510
511 pub fn synchronize(&self) -> Result<()> {
512 match self {
513 Self::Cpu => Ok(()),
514 Self::Cuda(d) => d.synchronize(),
515 Self::Metal(d) => d.synchronize(),
516 }
517 }
518}