1use std::sync::{Arc, RwLock};
2use rand::rng;
3use rand_distr::{Distribution, Uniform};
4use crate::{Error, IntDType, Result};
5use super::{DType, FloatDType, Layout, NumDType, Shape, WithDType};
6
7#[derive(Clone)]
8pub struct Storage<T>(Vec<T>);
9
10impl<T: WithDType> Storage<T> {
11 pub fn zeros(shape: &Shape) -> Self {
12 Self(vec![T::ZERO; shape.element_count()])
13 }
14
15 pub fn ones(shape: &Shape) -> Self {
16 Self(vec![T::ONE; shape.element_count()])
17 }
18
19 pub fn full(value: T, shape: &Shape) -> Self {
20 Self(vec![value; shape.element_count()])
21 }
22
23 pub fn new<D: Into<Vec<T>>>(data: D) -> Self {
24 Self(data.into())
25 }
26
27 #[inline]
28 pub fn data(&self) -> &[T] {
29 &self.0
30 }
31
32 #[inline]
33 pub fn data_mut(&mut self) -> &mut [T] {
34 &mut self.0
35 }
36
37 #[inline]
38 pub fn dtype(&self) -> DType {
39 T::DTYPE
40 }
41
42 #[inline]
43 pub fn copy_data(&self) -> Vec<T> {
44 self.0.clone()
45 }
46
47 #[inline]
48 pub fn get(&self, index: usize) -> Option<T> {
49 self.0.get(index).copied()
50 }
51
52 #[inline]
53 pub fn get_unchecked(&self, index: usize) -> T {
54 self.0[index]
55 }
56
57 #[inline]
58 pub fn set(&mut self, index: usize, value: T) -> Option<()> {
59 if index >= self.len() {
60 None
61 } else {
62 self.0[index] = value;
63 Some(())
64 }
65 }
66
67 #[inline]
68 pub fn set_unchecked(&mut self, index: usize, value: T) {
69 self.0[index] = value;
70 }
71
72 pub fn len(&self) -> usize {
73 self.0.len()
74 }
75
76 pub fn copy(&self, layout: &Layout) -> Self {
77 let output: Vec<_> = layout.storage_indices()
78 .map(|i| self.0[i])
79 .collect();
80 Self(output)
81 }
82
83 pub fn copy_map<F, U>(&self, layout: &Layout, f: F) -> Storage<U>
84 where
85 U: WithDType,
86 F: Fn(T) -> U
87 {
88 let output: Vec<_> = layout.storage_indices()
89 .map(|i| f(self.0[i]))
90 .collect();
91 Storage(output)
92 }
93}
94
95impl<T: NumDType> Storage<T> {
96 pub fn rand_uniform(shape: &Shape, min: T, max: T) -> Result<Self> {
97 let elem_count = shape.element_count();
98 let mut rng = rng();
99 let uniform = Uniform::new(min, max).map_err(|e| Error::Rand(e.to_string()))?;
100 let v: Vec<T> = (0..elem_count)
101 .map(|_| uniform.sample(&mut rng))
102 .collect();
103
104 Ok(Self(v))
105 }
106}
107
108impl<F: FloatDType> Storage<F> {
109 pub fn rand_normal(shape: &Shape, mean: F, std: F) -> Result<Self>
110 {
111 let elem_count = shape.element_count();
112 let v = F::random_normal_vec(elem_count, mean, std)?;
113 Ok(Self(v))
114 }
115}
116
117impl<T: WithDType> Storage<T> {
118 pub(crate) fn index_select<I: IntDType>(
119 &self,
120 self_layout: &Layout,
121 ids: &Storage<I>,
122 ids_layout: &Layout,
123 dim: usize,
124 ) -> Result<Self> {
125 let vec = Self::do_index_select(ids.data(), ids_layout, dim, self.data(), self_layout)?;
126 Ok(Storage::new(vec))
127 }
128
129 fn do_index_select<I: IntDType>(ids: &[I], ids_l: &Layout, dim: usize, src: &[T], layout: &Layout) -> Result<Vec<T>> {
130 if !layout.is_contiguous() {
131 Err(Error::RequiresContiguous { op: "index-select" })?
132 }
133 let src = &src[layout.start_offset..layout.start_offset+layout.shape.element_count()];
134 let n_ids = ids_l.dims();
135 assert!(n_ids.len() == 1);
136 let n_ids = n_ids[0];
137 let stride_ids = ids_l.stride()[0];
138 let mut dst_dims = layout.dims().to_vec();
139 let src_dim = dst_dims[dim];
140 dst_dims[dim] = n_ids;
141 let dst_len: usize = dst_dims.iter().product();
142 let left_len: usize = dst_dims[..dim].iter().product();
143 let right_len: usize = dst_dims[dim + 1..].iter().product();
144 let mut dst = vec![T::ZERO; dst_len];
145 for left_i in 0..left_len {
146 let start_src_idx = left_i * right_len * src_dim;
147 let start_dst_idx = left_i * right_len * n_ids;
148 for i in 0..n_ids {
149 let start_dst_idx = start_dst_idx + i * right_len;
150 let index = ids[ids_l.start_offset() + stride_ids * i];
151 if index == I::max_value() {
152 dst[start_dst_idx..start_dst_idx + right_len].fill(T::ZERO);
153 } else {
154 let index = index.to_usize();
155 if index >= src_dim {
156 Err(Error::InvalidIndex {
157 index,
158 size: src_dim,
159 op: "index-select",
160 })?
161 }
162 let start_src_idx = start_src_idx + index * right_len;
163 dst[start_dst_idx..start_dst_idx + right_len]
164 .copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
165 }
166 }
167 }
168 Ok(dst)
169 }
170}
171
172impl<T: NumDType> Storage<T> {
173 pub(crate) fn index_add<I: IntDType>(
174 &self,
175 self_layout: &Layout,
176 ids: &Storage<I>,
177 ids_layout: &Layout,
178 source: &Storage<T>,
179 source_layout: &Layout,
180 dim: usize,
181 ) -> Result<Self> {
182 if !self_layout.is_contiguous() || !source_layout.is_contiguous() {
183 return Err(Error::RequiresContiguous { op: "index-add" }.into());
184 }
185
186 let new_data = Self::do_index_add(
187 self.data(),
188 self_layout,
189 ids.data(),
190 ids_layout,
191 source.data(),
192 dim
193 )?;
194
195 Ok(Storage::new(new_data))
196 }
197
198 fn do_index_add<I: IntDType>(
199 dst_data: &[T],
200 dst_layout: &Layout,
201 ids: &[I],
202 ids_layout: &Layout,
203 src_data: &[T],
204 dim: usize,
205 ) -> Result<Vec<T>> {
206 let mut result = dst_data.to_vec();
211
212 let n_ids = ids_layout.dims()[0];
213 let stride_ids = ids_layout.stride()[0];
214
215 let dst_dims = dst_layout.dims();
216 let src_dim_size = dst_dims[dim]; let left_len: usize = dst_dims[..dim].iter().product();
219 let right_len: usize = dst_dims[dim + 1..].iter().product();
220
221 for left_i in 0..left_len {
225 let start_src_block = left_i * n_ids * right_len;
226 let start_dst_block = left_i * src_dim_size * right_len;
227
228 for i in 0..n_ids {
229 let index_val = ids[ids_layout.start_offset() + stride_ids * i];
231
232 if index_val == I::max_value() {
234 continue;
235 }
236
237 let idx = index_val.to_usize();
238 if idx >= src_dim_size {
239 return Err(Error::InvalidIndex {
240 index: idx,
241 size: src_dim_size,
242 op: "index-add",
243 }.into());
244 }
245
246 let src_offset = start_src_block + i * right_len;
248 let dst_offset = start_dst_block + idx * right_len;
249
250 for k in 0..right_len {
252 let s_val = src_data[src_offset + k];
253 let d_val = result[dst_offset + k];
254 result[dst_offset + k] = d_val + s_val;
255 }
256 }
257 }
258
259 Ok(result)
260 }
261}
262
263impl<T: WithDType> Storage<T> {
264 pub(crate) fn gather<I: IntDType>(
265 &self,
266 self_layout: &Layout,
267 ids: &Storage<I>,
268 ids_layout: &Layout,
269 dim: usize,
270 ) -> Result<Self> {
271 let new_data = Self::do_gather(self.data(), self_layout, ids.data(), ids_layout, dim)?;
272 Ok(Storage::new(new_data))
273 }
274
275 fn do_gather<I: IntDType>(
276 src: &[T],
277 src_layout: &Layout,
278 ids: &[I],
279 ids_layout: &Layout,
280 dim: usize,
281 ) -> Result<Vec<T>> {
282 if !src_layout.is_contiguous() || !ids_layout.is_contiguous() {
285 return Err(Error::RequiresContiguous { op: "gather" }.into());
286 }
287
288 let src_dims = src_layout.dims();
289 let ids_dims = ids_layout.dims();
290
291 if src_dims.len() != ids_dims.len() {
293 return Err(Error::ShapeMismatchBinaryOp {
294 lhs: src_layout.shape().clone(),
295 rhs: ids_layout.shape().clone(),
296 op: "gather"
297 }.into());
298 }
299
300 let dst_len = ids_layout.shape.element_count();
302 let mut dst = vec![T::ZERO; dst_len];
304
305 let left_len: usize = src_dims[..dim].iter().product();
308 let right_len: usize = src_dims[dim + 1..].iter().product();
310
311 let src_dim_size = src_dims[dim];
312 let ids_dim_size = ids_dims[dim];
313
314 for i in 0..left_len {
319 let src_block_start = i * src_dim_size * right_len;
321 let dst_block_start = i * ids_dim_size * right_len;
322
323 for j in 0..ids_dim_size {
324 for k in 0..right_len {
325 let dst_idx = dst_block_start + j * right_len + k;
328
329 let index_val = ids[dst_idx];
330
331 if index_val == I::max_value() {
333 dst[dst_idx] = T::ZERO;
334 continue;
335 }
336
337 let idx = index_val.to_usize();
338 if idx >= src_dim_size {
339 return Err(Error::InvalidIndex {
340 index: idx,
341 size: src_dim_size,
342 op: "gather",
343 }.into());
344 }
345
346 let src_idx = src_block_start + idx * right_len + k;
349
350 dst[dst_idx] = src[src_idx];
351 }
352 }
353 }
354
355 Ok(dst)
356 }
357}
358
359impl<T: NumDType> Storage<T> {
360 pub(crate) fn scatter_add<I: IntDType>(
361 &self,
362 self_layout: &Layout,
363 ids: &Storage<I>,
364 ids_layout: &Layout,
365 source: &Storage<T>,
366 source_layout: &Layout,
367 dim: usize,
368 ) -> Result<Self> {
369 if !self_layout.is_contiguous() || !ids_layout.is_contiguous() || !source_layout.is_contiguous() {
370 return Err(Error::RequiresContiguous { op: "scatter-add" }.into());
371 }
372
373 let new_data = Self::do_scatter_add(
375 self.data(),
376 self_layout,
377 ids.data(),
378 ids_layout,
379 source.data(),
380 dim
381 )?;
382
383 Ok(Storage::new(new_data))
384 }
385
386 fn do_scatter_add<I: IntDType>(
387 dst: &[T], dst_layout: &Layout,
389 ids: &[I], ids_layout: &Layout,
391 src: &[T], dim: usize,
393 ) -> Result<Vec<T>> {
394 let mut result = dst.to_vec();
396
397 let dst_dims = dst_layout.dims();
398 let src_dims = ids_layout.dims(); if dst_dims.len() != src_dims.len() {
402 return Err(Error::ShapeMismatchBinaryOp {
403 lhs: dst_layout.shape().clone(),
404 rhs: ids_layout.shape().clone(),
405 op: "scatter-add"
406 }.into());
407 }
408
409 let left_len: usize = src_dims[..dim].iter().product();
412 let right_len: usize = src_dims[dim + 1..].iter().product();
413
414 let src_dim_size = src_dims[dim]; let dst_dim_size = dst_dims[dim]; for i in 0..left_len {
422 let src_block_start = i * src_dim_size * right_len;
423 let dst_block_start = i * dst_dim_size * right_len;
424
425 for j in 0..src_dim_size {
426 for k in 0..right_len {
427 let linear_idx = src_block_start + j * right_len + k;
429
430 let index_val = ids[linear_idx];
431
432 if index_val == I::max_value() {
434 continue;
435 }
436
437 let idx = index_val.to_usize();
438 if idx >= dst_dim_size {
439 return Err(Error::InvalidIndex {
440 index: idx,
441 size: dst_dim_size,
442 op: "scatter-add",
443 }.into());
444 }
445
446 let dst_idx = dst_block_start + idx * right_len + k;
449
450 result[dst_idx] = result[dst_idx] + src[linear_idx];
452 }
453 }
454 }
455
456 Ok(result)
457 }
458}
459
460#[derive(Clone)]
461pub struct StorageArc<T>(pub(crate) Arc<RwLock<Storage<T>>>);
462
463impl<T: WithDType> StorageArc<T> {
464 pub fn new(storage: Storage<T>) -> Self {
465 Self(Arc::new(RwLock::new(storage)))
466 }
467
468 #[inline]
469 pub fn read(&self) -> std::sync::RwLockReadGuard<'_, Storage<T>> {
470 self.0.read().unwrap()
471 }
472
473 #[inline]
474 pub fn write(&self) -> std::sync::RwLockWriteGuard<'_, Storage<T>> {
475 self.0.write().unwrap()
476 }
477
478 #[inline]
479 pub fn get(&self, index: usize) -> Option<T> {
480 self.read().get(index)
481 }
482
483 #[inline]
484 pub fn set(&mut self, index: usize, val: T) -> Option<()> {
485 self.write().set(index, val)
486 }
487
488 #[inline]
489 pub fn get_unchecked(&self, index: usize) -> T {
490 self.read().get_unchecked(index)
491 }
492
493 #[inline]
494 pub fn set_unchecked(&self, index: usize, val: T) {
495 self.write().set_unchecked(index, val)
496 }
497
498 #[inline]
499 pub fn ptr_eq(this: &Self, other: &Self) -> bool {
500 Arc::ptr_eq(&this.0, &other.0)
501 }
502
503 #[inline]
504 pub fn get_ref(&self, start_offset: usize) -> StorageRef<'_, T> {
505 StorageRef::Guard(std::sync::RwLockReadGuard::map(self.0.read().unwrap(), |s| &s.data()[start_offset..]))
506 }
507
508 #[inline]
509 pub fn get_mut(&self, start_offset: usize) -> StorageMut<'_, T> {
510 StorageMut::Guard(std::sync::RwLockWriteGuard::map(self.0.write().unwrap(), |s| &mut s.data_mut()[start_offset..]))
511 }
512
513 #[inline]
514 pub fn get_ptr(&self, start_offset: usize) -> *mut T {
515 let mut s = self.0.write().unwrap();
516 let data = &mut s.data_mut()[start_offset..];
517 data.as_mut_ptr()
518 }
519}
520
521pub enum StorageRef<'a, T> {
522 Guard(std::sync::MappedRwLockReadGuard<'a, [T]>),
523 Slice(&'a [T]),
524}
525
526pub enum StorageMut<'a, T> {
529 Guard(std::sync::MappedRwLockWriteGuard<'a, [T]>),
530 Slice(&'a mut[T]),
531}
532
533impl<'a, T: WithDType> StorageRef<'a, T> {
534 pub fn clone(&'a self) -> Self {
535 Self::Slice(&self.data())
536 }
537
538 pub fn slice(&'a self, index: usize) -> Self {
539 Self::Slice(&self.data()[index..])
540 }
541
542 #[inline]
543 pub fn get(&self, index: usize) -> Option<T> {
544 self.data().get(index).copied()
545 }
546
547 #[inline]
548 pub fn get_unchecked(&self, index: usize) -> T {
549 self.data()[index]
550 }
551
552 #[inline]
553 pub fn len(&self) -> usize {
554 self.data().len()
555 }
556
557 pub fn data(&self) -> &[T] {
558 match self {
559 Self::Guard(gurad) => &gurad,
560 Self::Slice(s) => s,
561 }
562 }
563}
564
565impl<'a, T: WithDType> StorageMut<'a, T> {
566 pub fn clone(&'a self) -> StorageRef<'a, T> {
567 StorageRef::Slice(self.data())
568 }
569
570 #[inline]
571 pub fn get(&self, index: usize) -> Option<T> {
572 self.data().get(index).copied()
573 }
574
575 #[inline]
576 pub fn get_unchecked(&self, index: usize) -> T {
577 self.data()[index]
578 }
579
580 #[inline]
581 pub fn set(&mut self, index: usize, val: T) -> Option<()> {
582 if index >= self.len() {
583 None
584 } else {
585 self.set_unchecked(index, val);
586 Some(())
587 }
588 }
589
590 #[inline]
591 pub fn set_unchecked(&mut self, index: usize, val: T) {
592 self.data_mut()[index] = val;
593 }
594
595 #[inline]
596 pub fn len(&self) -> usize {
597 self.data().len()
598 }
599
600 pub fn data(&self) -> &[T] {
601 match self {
602 Self::Guard(gurad) => &gurad,
603 Self::Slice(s) => s,
604 }
605 }
606
607 pub fn data_mut(&mut self) -> &mut [T] {
608 match self {
609 Self::Guard(gurad) => &mut gurad[0..],
610 Self::Slice(s) => &mut s[0..],
611 }
612 }
613}