1use crate::backend::BackendDevice;
2use crate::cpu_backend::CpuDevice;
3use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
4
5#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
8pub enum DeviceLocation {
9 Cpu,
10 Cuda { gpu_id: usize },
11 Metal { gpu_id: usize },
12}
13
14#[derive(Debug, Clone)]
16pub enum Device {
17 Cpu,
18 Cuda(crate::CudaDevice),
19 Metal(crate::MetalDevice),
20}
21
22pub trait NdArray {
23 fn shape(&self) -> Result<Shape>;
24
25 fn to_cpu_storage(&self) -> CpuStorage;
26}
27
28impl<S: WithDType> NdArray for S {
29 fn shape(&self) -> Result<Shape> {
30 Ok(Shape::from(()))
31 }
32
33 fn to_cpu_storage(&self) -> CpuStorage {
34 S::to_cpu_storage(&[*self])
35 }
36}
37
38impl<S: WithDType, const N: usize> NdArray for &[S; N] {
39 fn shape(&self) -> Result<Shape> {
40 Ok(Shape::from(self.len()))
41 }
42
43 fn to_cpu_storage(&self) -> CpuStorage {
44 S::to_cpu_storage(self.as_slice())
45 }
46}
47
48impl<S: WithDType> NdArray for &[S] {
49 fn shape(&self) -> Result<Shape> {
50 Ok(Shape::from(self.len()))
51 }
52
53 fn to_cpu_storage(&self) -> CpuStorage {
54 S::to_cpu_storage(self)
55 }
56}
57
58impl<S: WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
59 fn shape(&self) -> Result<Shape> {
60 Ok(Shape::from((M, N)))
61 }
62
63 fn to_cpu_storage(&self) -> CpuStorage {
64 S::to_cpu_storage_owned(self.concat())
65 }
66}
67
68impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
69 for &[[[S; N3]; N2]; N1]
70{
71 fn shape(&self) -> Result<Shape> {
72 Ok(Shape::from((N1, N2, N3)))
73 }
74
75 fn to_cpu_storage(&self) -> CpuStorage {
76 let mut vec = Vec::with_capacity(N1 * N2 * N3);
77 for i1 in 0..N1 {
78 for i2 in 0..N2 {
79 vec.extend(self[i1][i2])
80 }
81 }
82 S::to_cpu_storage_owned(vec)
83 }
84}
85
86impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
87 for &[[[[S; N4]; N3]; N2]; N1]
88{
89 fn shape(&self) -> Result<Shape> {
90 Ok(Shape::from((N1, N2, N3, N4)))
91 }
92
93 fn to_cpu_storage(&self) -> CpuStorage {
94 let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
95 for i1 in 0..N1 {
96 for i2 in 0..N2 {
97 for i3 in 0..N3 {
98 vec.extend(self[i1][i2][i3])
99 }
100 }
101 }
102 S::to_cpu_storage_owned(vec)
103 }
104}
105
106impl<S: WithDType> NdArray for Vec<S> {
107 fn shape(&self) -> Result<Shape> {
108 Ok(Shape::from(self.len()))
109 }
110
111 fn to_cpu_storage(&self) -> CpuStorage {
112 S::to_cpu_storage(self.as_slice())
113 }
114}
115
116impl<S: WithDType> NdArray for Vec<&[S]> {
117 fn shape(&self) -> Result<Shape> {
118 if self.is_empty() {
119 crate::bail!("empty array")
120 }
121 let n = self.len();
122 let m = self[0].len();
123 for v in self.iter() {
124 if v.len() != m {
125 crate::bail!("two elements have different len {m} {}", v.len())
126 }
127 }
128 Ok(Shape::from((n, m)))
129 }
130
131 fn to_cpu_storage(&self) -> CpuStorage {
132 let data = self.iter().copied().flatten().copied().collect::<Vec<_>>();
133 S::to_cpu_storage_owned(data)
134 }
135}
136
137impl<S: WithDType> NdArray for Vec<Vec<S>> {
138 fn shape(&self) -> Result<Shape> {
139 if self.is_empty() {
140 crate::bail!("empty array")
141 }
142 let n = self.len();
143 let m = self[0].len();
144 for v in self.iter() {
145 if v.len() != m {
146 crate::bail!("two elements have different len {m} {}", v.len())
147 }
148 }
149 Ok(Shape::from((n, m)))
150 }
151
152 fn to_cpu_storage(&self) -> CpuStorage {
153 let len: usize = self.iter().map(|v| v.len()).sum();
154 let mut dst = Vec::with_capacity(len);
155 for v in self.iter() {
156 dst.extend(v.iter().copied());
157 }
158 S::to_cpu_storage_owned(dst)
159 }
160}
161
162impl<S: WithDType> NdArray for Vec<Vec<Vec<S>>> {
163 fn shape(&self) -> Result<Shape> {
164 if self.is_empty() {
165 crate::bail!("empty array")
166 }
167 let shape0 = self[0].shape()?;
168 let n = self.len();
169 for v in self.iter() {
170 let shape = v.shape()?;
171 if shape != shape0 {
172 crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
173 }
174 }
175 Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
176 }
177
178 fn to_cpu_storage(&self) -> CpuStorage {
179 if self.is_empty() {
180 return S::to_cpu_storage_owned(vec![]);
181 }
182 let len: usize = self
183 .iter()
184 .map(|v| v.iter().map(|v| v.len()).sum::<usize>())
185 .sum();
186 let mut dst = Vec::with_capacity(len);
187 for v1 in self.iter() {
188 for v2 in v1.iter() {
189 dst.extend(v2.iter().copied());
190 }
191 }
192 S::to_cpu_storage_owned(dst)
193 }
194}
195
196impl<S: WithDType> NdArray for Vec<Vec<Vec<Vec<S>>>> {
197 fn shape(&self) -> Result<Shape> {
198 if self.is_empty() {
199 crate::bail!("empty array")
200 }
201 let shape0 = self[0].shape()?;
202 let n = self.len();
203 for v in self.iter() {
204 let shape = v.shape()?;
205 if shape != shape0 {
206 crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
207 }
208 }
209 Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
210 }
211
212 fn to_cpu_storage(&self) -> CpuStorage {
213 let len: usize = self
214 .iter()
215 .map(|v| {
216 v.iter()
217 .map(|v| v.iter().map(|v| v.len()).sum::<usize>())
218 .sum::<usize>()
219 })
220 .sum();
221 let mut dst = Vec::with_capacity(len);
222 for v1 in self.iter() {
223 for v2 in v1.iter() {
224 for v3 in v2.iter() {
225 dst.extend(v3.iter().copied());
226 }
227 }
228 }
229 S::to_cpu_storage_owned(dst)
230 }
231}
232
233impl Device {
234 pub fn new_cuda(ordinal: usize) -> Result<Self> {
235 Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
236 }
237
238 pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> {
239 match self {
240 Self::Cuda(d) => Ok(d),
241 Self::Cpu => crate::bail!("expected a cuda device, got cpu"),
242 Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"),
243 }
244 }
245
246 pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
247 match self {
248 Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
249 Self::Cpu => crate::bail!("expected a metal device, got cpu"),
250 Self::Metal(d) => Ok(d),
251 }
252 }
253
254 pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
255 Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
256 }
257
258 pub fn new_metal(ordinal: usize) -> Result<Self> {
259 Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
260 }
261
262 pub fn set_seed(&self, seed: u64) -> Result<()> {
263 match self {
264 Self::Cpu => CpuDevice.set_seed(seed),
265 Self::Cuda(c) => c.set_seed(seed),
266 Self::Metal(m) => m.set_seed(seed),
267 }
268 }
269
270 pub fn get_current_seed(&self) -> Result<u64> {
271 match self {
272 Self::Cpu => CpuDevice.get_current_seed(),
273 Self::Cuda(c) => c.get_current_seed(),
274 Self::Metal(m) => m.get_current_seed(),
275 }
276 }
277
278 pub fn same_device(&self, rhs: &Self) -> bool {
279 match (self, rhs) {
280 (Self::Cpu, Self::Cpu) => true,
281 (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
282 (Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
283 _ => false,
284 }
285 }
286
287 pub fn location(&self) -> DeviceLocation {
288 match self {
289 Self::Cpu => DeviceLocation::Cpu,
290 Self::Cuda(device) => device.location(),
291 Device::Metal(device) => device.location(),
292 }
293 }
294
295 pub fn is_cpu(&self) -> bool {
296 matches!(self, Self::Cpu)
297 }
298
299 pub fn is_cuda(&self) -> bool {
300 matches!(self, Self::Cuda(_))
301 }
302
303 pub fn is_metal(&self) -> bool {
304 matches!(self, Self::Metal(_))
305 }
306
307 pub fn supports_bf16(&self) -> bool {
308 match self {
309 Self::Cuda(_) | Self::Metal(_) => true,
310 Self::Cpu => false,
311 }
312 }
313
314 pub fn bf16_default_to_f32(&self) -> DType {
316 if self.supports_bf16() {
317 DType::BF16
318 } else {
319 DType::F32
320 }
321 }
322
323 pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
324 if crate::utils::cuda_is_available() {
325 Self::new_cuda(ordinal)
326 } else {
327 Ok(Self::Cpu)
328 }
329 }
330
331 pub fn metal_if_available(ordinal: usize) -> Result<Self> {
332 if crate::utils::metal_is_available() {
333 Self::new_metal(ordinal)
334 } else {
335 Ok(Self::Cpu)
336 }
337 }
338
339 pub(crate) fn rand_uniform_f64(
340 &self,
341 lo: f64,
342 up: f64,
343 shape: &Shape,
344 dtype: DType,
345 ) -> Result<Storage> {
346 match self {
347 Device::Cpu => {
348 let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
349 Ok(Storage::Cpu(storage))
350 }
351 Device::Cuda(device) => {
352 if dtype == DType::F16 || dtype == DType::BF16 {
354 let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
355 Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
356 } else {
357 let storage = device.rand_uniform(shape, dtype, lo, up)?;
358 Ok(Storage::Cuda(storage))
359 }
360 }
361 Device::Metal(device) => {
362 let storage = device.rand_uniform(shape, dtype, lo, up)?;
363 Ok(Storage::Metal(storage))
364 }
365 }
366 }
367
368 pub(crate) fn rand_uniform<T: crate::FloatDType>(
369 &self,
370 lo: T,
371 up: T,
372 shape: &Shape,
373 ) -> Result<Storage> {
374 self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
375 }
376
377 pub(crate) fn rand_normal_f64(
378 &self,
379 mean: f64,
380 std: f64,
381 shape: &Shape,
382 dtype: DType,
383 ) -> Result<Storage> {
384 match self {
385 Device::Cpu => {
386 let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
387 Ok(Storage::Cpu(storage))
388 }
389 Device::Cuda(device) => {
390 if dtype == DType::F16 || dtype == DType::BF16 {
392 let storage = device.rand_normal(shape, DType::F32, mean, std)?;
393 Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
394 } else {
395 let storage = device.rand_normal(shape, dtype, mean, std)?;
396 Ok(Storage::Cuda(storage))
397 }
398 }
399 Device::Metal(device) => {
400 let storage = device.rand_normal(shape, dtype, mean, std)?;
401 Ok(Storage::Metal(storage))
402 }
403 }
404 }
405
406 pub(crate) fn rand_normal<T: crate::FloatDType>(
407 &self,
408 mean: T,
409 std: T,
410 shape: &Shape,
411 ) -> Result<Storage> {
412 self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
413 }
414
415 pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
416 match self {
417 Device::Cpu => {
418 let storage = CpuDevice.zeros_impl(shape, dtype)?;
419 Ok(Storage::Cpu(storage))
420 }
421 Device::Cuda(device) => {
422 let storage = device.zeros_impl(shape, dtype)?;
423 Ok(Storage::Cuda(storage))
424 }
425 Device::Metal(device) => {
426 let storage = device.zeros_impl(shape, dtype)?;
427 Ok(Storage::Metal(storage))
428 }
429 }
430 }
431
432 pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
433 match self {
434 Device::Cpu => {
435 let storage = CpuDevice.alloc_uninit(shape, dtype)?;
436 Ok(Storage::Cpu(storage))
437 }
438 Device::Cuda(device) => {
439 let storage = device.alloc_uninit(shape, dtype)?;
440 Ok(Storage::Cuda(storage))
441 }
442 Device::Metal(device) => {
443 let storage = device.alloc_uninit(shape, dtype)?;
444 Ok(Storage::Metal(storage))
445 }
446 }
447 }
448
449 pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
450 match self {
451 Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
452 Device::Cuda(device) => {
453 let storage = device.storage_from_slice(data)?;
454 Ok(Storage::Cuda(storage))
455 }
456 Device::Metal(device) => {
457 let storage = device.storage_from_slice(data)?;
458 Ok(Storage::Metal(storage))
459 }
460 }
461 }
462
463 pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
464 match self {
465 Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
466 Device::Cuda(device) => {
467 let storage = array.to_cpu_storage();
468 let storage = device.storage_from_cpu_storage_owned(storage)?;
469 Ok(Storage::Cuda(storage))
470 }
471 Device::Metal(device) => {
472 let storage = array.to_cpu_storage();
473 let storage = device.storage_from_cpu_storage_owned(storage)?;
474 Ok(Storage::Metal(storage))
475 }
476 }
477 }
478
479 pub(crate) fn storage_owned<S: WithDType>(&self, data: Vec<S>) -> Result<Storage> {
480 match self {
481 Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
482 Device::Cuda(device) => {
483 let storage = S::to_cpu_storage_owned(data);
484 let storage = device.storage_from_cpu_storage_owned(storage)?;
485 Ok(Storage::Cuda(storage))
486 }
487 Device::Metal(device) => {
488 let storage = S::to_cpu_storage_owned(data);
489 let storage = device.storage_from_cpu_storage_owned(storage)?;
490 Ok(Storage::Metal(storage))
491 }
492 }
493 }
494
495 pub fn synchronize(&self) -> Result<()> {
496 match self {
497 Self::Cpu => Ok(()),
498 Self::Cuda(d) => d.synchronize(),
499 Self::Metal(d) => d.synchronize(),
500 }
501 }
502}