1use crate::backend::BackendDevice;
2use crate::cpu_backend::CpuDevice;
3use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
4
5#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
8pub enum DeviceLocation {
9 Cpu,
10 Cuda { gpu_id: usize },
11 Metal { gpu_id: usize },
12}
13
14#[derive(Debug, Clone)]
16pub enum Device {
17 Cpu,
18 Cuda(crate::CudaDevice),
19 Metal(crate::MetalDevice),
20}
21
22pub trait NdArray {
23 fn shape(&self) -> Result<Shape>;
24
25 fn to_cpu_storage(&self) -> CpuStorage;
26}
27
28impl<S: WithDType> NdArray for S {
29 fn shape(&self) -> Result<Shape> {
30 Ok(Shape::from(()))
31 }
32
33 fn to_cpu_storage(&self) -> CpuStorage {
34 S::to_cpu_storage(&[*self])
35 }
36}
37
38impl<S: WithDType, const N: usize> NdArray for &[S; N] {
39 fn shape(&self) -> Result<Shape> {
40 Ok(Shape::from(self.len()))
41 }
42
43 fn to_cpu_storage(&self) -> CpuStorage {
44 S::to_cpu_storage(self.as_slice())
45 }
46}
47
48impl<S: WithDType> NdArray for &[S] {
49 fn shape(&self) -> Result<Shape> {
50 Ok(Shape::from(self.len()))
51 }
52
53 fn to_cpu_storage(&self) -> CpuStorage {
54 S::to_cpu_storage(self)
55 }
56}
57
58impl<S: WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
59 fn shape(&self) -> Result<Shape> {
60 Ok(Shape::from((M, N)))
61 }
62
63 fn to_cpu_storage(&self) -> CpuStorage {
64 S::to_cpu_storage_owned(self.concat())
65 }
66}
67
68impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
69 for &[[[S; N3]; N2]; N1]
70{
71 fn shape(&self) -> Result<Shape> {
72 Ok(Shape::from((N1, N2, N3)))
73 }
74
75 fn to_cpu_storage(&self) -> CpuStorage {
76 let mut vec = Vec::with_capacity(N1 * N2 * N3);
77 for i1 in 0..N1 {
78 for i2 in 0..N2 {
79 vec.extend(self[i1][i2])
80 }
81 }
82 S::to_cpu_storage_owned(vec)
83 }
84}
85
86impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
87 for &[[[[S; N4]; N3]; N2]; N1]
88{
89 fn shape(&self) -> Result<Shape> {
90 Ok(Shape::from((N1, N2, N3, N4)))
91 }
92
93 fn to_cpu_storage(&self) -> CpuStorage {
94 let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
95 for i1 in 0..N1 {
96 for i2 in 0..N2 {
97 for i3 in 0..N3 {
98 vec.extend(self[i1][i2][i3])
99 }
100 }
101 }
102 S::to_cpu_storage_owned(vec)
103 }
104}
105
106impl<S: WithDType> NdArray for Vec<S> {
107 fn shape(&self) -> Result<Shape> {
108 Ok(Shape::from(self.len()))
109 }
110
111 fn to_cpu_storage(&self) -> CpuStorage {
112 S::to_cpu_storage(self.as_slice())
113 }
114}
115
116impl<S: WithDType> NdArray for Vec<&[S]> {
117 fn shape(&self) -> Result<Shape> {
118 if self.is_empty() {
119 crate::bail!("empty array")
120 }
121 let n = self.len();
122 let m = self[0].len();
123 for v in self.iter() {
124 if v.len() != m {
125 crate::bail!("two elements have different len {m} {}", v.len())
126 }
127 }
128 Ok(Shape::from((n, m)))
129 }
130
131 fn to_cpu_storage(&self) -> CpuStorage {
132 let data = self.iter().copied().flatten().copied().collect::<Vec<_>>();
133 S::to_cpu_storage_owned(data)
134 }
135}
136
137impl<S: WithDType> NdArray for Vec<Vec<S>> {
138 fn shape(&self) -> Result<Shape> {
139 if self.is_empty() {
140 crate::bail!("empty array")
141 }
142 let n = self.len();
143 let m = self[0].len();
144 for v in self.iter() {
145 if v.len() != m {
146 crate::bail!("two elements have different len {m} {}", v.len())
147 }
148 }
149 Ok(Shape::from((n, m)))
150 }
151
152 fn to_cpu_storage(&self) -> CpuStorage {
153 let len: usize = self.iter().map(|v| v.len()).sum();
154 let mut dst = Vec::with_capacity(len);
155 for v in self.iter() {
156 dst.extend(v.iter().copied());
157 }
158 S::to_cpu_storage_owned(dst)
159 }
160}
161
162impl<S: WithDType> NdArray for Vec<Vec<Vec<S>>> {
163 fn shape(&self) -> Result<Shape> {
164 if self.is_empty() {
165 crate::bail!("empty array")
166 }
167 let shape0 = self[0].shape()?;
168 let n = self.len();
169 for v in self.iter() {
170 let shape = v.shape()?;
171 if shape != shape0 {
172 crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
173 }
174 }
175 Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
176 }
177
178 fn to_cpu_storage(&self) -> CpuStorage {
179 if self.is_empty() {
180 return S::to_cpu_storage_owned(vec![]);
181 }
182 let len: usize = self
183 .iter()
184 .map(|v| v.iter().map(|v| v.len()).sum::<usize>())
185 .sum();
186 let mut dst = Vec::with_capacity(len);
187 for v1 in self.iter() {
188 for v2 in v1.iter() {
189 dst.extend(v2.iter().copied());
190 }
191 }
192 S::to_cpu_storage_owned(dst)
193 }
194}
195
196impl<S: WithDType> NdArray for Vec<Vec<Vec<Vec<S>>>> {
197 fn shape(&self) -> Result<Shape> {
198 if self.is_empty() {
199 crate::bail!("empty array")
200 }
201 let shape0 = self[0].shape()?;
202 let n = self.len();
203 for v in self.iter() {
204 let shape = v.shape()?;
205 if shape != shape0 {
206 crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
207 }
208 }
209 Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
210 }
211
212 fn to_cpu_storage(&self) -> CpuStorage {
213 let len: usize = self
214 .iter()
215 .map(|v| {
216 v.iter()
217 .map(|v| v.iter().map(|v| v.len()).sum::<usize>())
218 .sum::<usize>()
219 })
220 .sum();
221 let mut dst = Vec::with_capacity(len);
222 for v1 in self.iter() {
223 for v2 in v1.iter() {
224 for v3 in v2.iter() {
225 dst.extend(v3.iter().copied());
226 }
227 }
228 }
229 S::to_cpu_storage_owned(dst)
230 }
231}
232
233impl Device {
234 pub fn new_cuda(ordinal: usize) -> Result<Self> {
235 Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
236 }
237
238 pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> {
239 match self {
240 Self::Cuda(d) => Ok(d),
241 Self::Cpu => crate::bail!("expected a cuda device, got cpu"),
242 Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"),
243 }
244 }
245
246 pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
247 match self {
248 Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
249 Self::Cpu => crate::bail!("expected a metal device, got cpu"),
250 Self::Metal(d) => Ok(d),
251 }
252 }
253
254 pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
255 Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
256 }
257
258 pub fn new_metal(ordinal: usize) -> Result<Self> {
259 Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
260 }
261
262 pub fn set_seed(&self, seed: u64) -> Result<()> {
263 match self {
264 Self::Cpu => CpuDevice.set_seed(seed),
265 Self::Cuda(c) => c.set_seed(seed),
266 Self::Metal(m) => m.set_seed(seed),
267 }
268 }
269
270 pub fn same_device(&self, rhs: &Self) -> bool {
271 match (self, rhs) {
272 (Self::Cpu, Self::Cpu) => true,
273 (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
274 (Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
275 _ => false,
276 }
277 }
278
279 pub fn location(&self) -> DeviceLocation {
280 match self {
281 Self::Cpu => DeviceLocation::Cpu,
282 Self::Cuda(device) => device.location(),
283 Device::Metal(device) => device.location(),
284 }
285 }
286
287 pub fn is_cpu(&self) -> bool {
288 matches!(self, Self::Cpu)
289 }
290
291 pub fn is_cuda(&self) -> bool {
292 matches!(self, Self::Cuda(_))
293 }
294
295 pub fn is_metal(&self) -> bool {
296 matches!(self, Self::Metal(_))
297 }
298
299 pub fn supports_bf16(&self) -> bool {
300 match self {
301 Self::Cuda(_) | Self::Metal(_) => true,
302 Self::Cpu => false,
303 }
304 }
305
306 pub fn bf16_default_to_f32(&self) -> DType {
308 if self.supports_bf16() {
309 DType::BF16
310 } else {
311 DType::F32
312 }
313 }
314
315 pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
316 if crate::utils::cuda_is_available() {
317 Self::new_cuda(ordinal)
318 } else {
319 Ok(Self::Cpu)
320 }
321 }
322
323 pub(crate) fn rand_uniform_f64(
324 &self,
325 lo: f64,
326 up: f64,
327 shape: &Shape,
328 dtype: DType,
329 ) -> Result<Storage> {
330 match self {
331 Device::Cpu => {
332 let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
333 Ok(Storage::Cpu(storage))
334 }
335 Device::Cuda(device) => {
336 if dtype == DType::F16 || dtype == DType::BF16 {
338 let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
339 Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
340 } else {
341 let storage = device.rand_uniform(shape, dtype, lo, up)?;
342 Ok(Storage::Cuda(storage))
343 }
344 }
345 Device::Metal(device) => {
346 let storage = device.rand_uniform(shape, dtype, lo, up)?;
347 Ok(Storage::Metal(storage))
348 }
349 }
350 }
351
352 pub(crate) fn rand_uniform<T: crate::FloatDType>(
353 &self,
354 lo: T,
355 up: T,
356 shape: &Shape,
357 ) -> Result<Storage> {
358 self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
359 }
360
361 pub(crate) fn rand_normal_f64(
362 &self,
363 mean: f64,
364 std: f64,
365 shape: &Shape,
366 dtype: DType,
367 ) -> Result<Storage> {
368 match self {
369 Device::Cpu => {
370 let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
371 Ok(Storage::Cpu(storage))
372 }
373 Device::Cuda(device) => {
374 if dtype == DType::F16 || dtype == DType::BF16 {
376 let storage = device.rand_normal(shape, DType::F32, mean, std)?;
377 Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
378 } else {
379 let storage = device.rand_normal(shape, dtype, mean, std)?;
380 Ok(Storage::Cuda(storage))
381 }
382 }
383 Device::Metal(device) => {
384 let storage = device.rand_normal(shape, dtype, mean, std)?;
385 Ok(Storage::Metal(storage))
386 }
387 }
388 }
389
390 pub(crate) fn rand_normal<T: crate::FloatDType>(
391 &self,
392 mean: T,
393 std: T,
394 shape: &Shape,
395 ) -> Result<Storage> {
396 self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
397 }
398
399 pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
400 match self {
401 Device::Cpu => {
402 let storage = CpuDevice.zeros_impl(shape, dtype)?;
403 Ok(Storage::Cpu(storage))
404 }
405 Device::Cuda(device) => {
406 let storage = device.zeros_impl(shape, dtype)?;
407 Ok(Storage::Cuda(storage))
408 }
409 Device::Metal(device) => {
410 let storage = device.zeros_impl(shape, dtype)?;
411 Ok(Storage::Metal(storage))
412 }
413 }
414 }
415
416 pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
417 match self {
418 Device::Cpu => {
419 let storage = CpuDevice.alloc_uninit(shape, dtype)?;
420 Ok(Storage::Cpu(storage))
421 }
422 Device::Cuda(device) => {
423 let storage = device.alloc_uninit(shape, dtype)?;
424 Ok(Storage::Cuda(storage))
425 }
426 Device::Metal(device) => {
427 let storage = device.alloc_uninit(shape, dtype)?;
428 Ok(Storage::Metal(storage))
429 }
430 }
431 }
432
433 pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
434 match self {
435 Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
436 Device::Cuda(device) => {
437 let storage = device.storage_from_slice(data)?;
438 Ok(Storage::Cuda(storage))
439 }
440 Device::Metal(device) => {
441 let storage = device.storage_from_slice(data)?;
442 Ok(Storage::Metal(storage))
443 }
444 }
445 }
446
447 pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
448 match self {
449 Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
450 Device::Cuda(device) => {
451 let storage = array.to_cpu_storage();
452 let storage = device.storage_from_cpu_storage_owned(storage)?;
453 Ok(Storage::Cuda(storage))
454 }
455 Device::Metal(device) => {
456 let storage = array.to_cpu_storage();
457 let storage = device.storage_from_cpu_storage_owned(storage)?;
458 Ok(Storage::Metal(storage))
459 }
460 }
461 }
462
463 pub(crate) fn storage_owned<S: WithDType>(&self, data: Vec<S>) -> Result<Storage> {
464 match self {
465 Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
466 Device::Cuda(device) => {
467 let storage = S::to_cpu_storage_owned(data);
468 let storage = device.storage_from_cpu_storage_owned(storage)?;
469 Ok(Storage::Cuda(storage))
470 }
471 Device::Metal(device) => {
472 let storage = S::to_cpu_storage_owned(data);
473 let storage = device.storage_from_cpu_storage_owned(storage)?;
474 Ok(Storage::Metal(storage))
475 }
476 }
477 }
478
479 pub fn synchronize(&self) -> Result<()> {
480 match self {
481 Self::Cpu => Ok(()),
482 Self::Cuda(d) => d.synchronize(),
483 Self::Metal(d) => d.synchronize(),
484 }
485 }
486}