Skip to main content

lumen_core/
storage.rs

1use std::sync::{Arc, RwLock};
2use rand::rng;
3use rand_distr::{Distribution, Uniform};
4use crate::{Error, IntDType, Result};
5use super::{DType, FloatDType, Layout, NumDType, Shape, WithDType};
6
7#[derive(Clone)]
8pub struct Storage<T>(Vec<T>);
9
10impl<T: WithDType> Storage<T> {
11    pub fn zeros(shape: &Shape) -> Self {
12        Self(vec![T::ZERO; shape.element_count()])
13    }
14
15    pub fn ones(shape: &Shape) -> Self {
16        Self(vec![T::ONE; shape.element_count()])
17    }
18
19    pub fn full(value: T, shape: &Shape) -> Self {
20        Self(vec![value; shape.element_count()])
21    }
22
23    pub fn new<D: Into<Vec<T>>>(data: D) -> Self {
24        Self(data.into())
25    }
26
27    #[inline]
28    pub fn data(&self) -> &[T] {
29        &self.0
30    }
31
32    #[inline]
33    pub fn data_mut(&mut self) -> &mut [T] {
34        &mut self.0
35    }
36
37    #[inline]
38    pub fn dtype(&self) -> DType {
39        T::DTYPE
40    }
41
42    #[inline]
43    pub fn copy_data(&self) -> Vec<T> {
44        self.0.clone()
45    }
46
47    #[inline]
48    pub fn get(&self, index: usize) -> Option<T> {
49        self.0.get(index).copied()
50    }
51
52    #[inline]
53    pub fn get_unchecked(&self, index: usize) -> T {
54        self.0[index]
55    }
56
57    #[inline]
58    pub fn set(&mut self, index: usize, value: T) -> Option<()> {
59        if index >= self.len() {
60            None
61        } else {
62            self.0[index] = value;
63            Some(())
64        }
65    }
66
67    #[inline]
68    pub fn set_unchecked(&mut self, index: usize, value: T) {
69        self.0[index] = value;
70    }
71
72    pub fn len(&self) -> usize {
73        self.0.len()
74    }
75
76    pub fn copy(&self, layout: &Layout) -> Self {
77        let output: Vec<_> = layout.storage_indices()
78            .map(|i| self.0[i])
79            .collect();
80        Self(output)
81    }
82
83    pub fn copy_map<F, U>(&self, layout: &Layout, f: F) -> Storage<U> 
84    where 
85        U: WithDType,
86        F: Fn(T) -> U
87    {
88        let output: Vec<_> = layout.storage_indices()
89            .map(|i| f(self.0[i]))
90            .collect();
91        Storage(output)
92    }
93}
94
95impl<T: NumDType> Storage<T> {
96    pub fn rand_uniform(shape: &Shape, min: T, max: T) -> Result<Self> {
97        let elem_count = shape.element_count();
98        let mut rng = rng();
99        let uniform = Uniform::new(min, max).map_err(|e| Error::Rand(e.to_string()))?;
100        let v: Vec<T> = (0..elem_count)
101            .map(|_| uniform.sample(&mut rng))
102            .collect();
103
104        Ok(Self(v))
105    }
106}
107
108impl<F: FloatDType> Storage<F> {
109    pub fn rand_normal(shape: &Shape, mean: F, std: F) -> Result<Self> 
110    {
111        let elem_count = shape.element_count();
112        let v = F::random_normal_vec(elem_count, mean, std)?;
113        Ok(Self(v))
114    }
115}
116
117impl<T: WithDType> Storage<T> {
118    pub(crate) fn index_select<I: IntDType>(
119        &self,
120        self_layout: &Layout,
121        ids: &Storage<I>,
122        ids_layout: &Layout,
123        dim: usize,
124    ) -> Result<Self> {
125        let vec = Self::do_index_select(ids.data(), ids_layout, dim, self.data(), self_layout)?;
126        Ok(Storage::new(vec))
127    }
128
129    fn do_index_select<I: IntDType>(ids: &[I], ids_l: &Layout, dim: usize, src: &[T], layout: &Layout) -> Result<Vec<T>> {
130        if !layout.is_contiguous() {
131            Err(Error::RequiresContiguous { op: "index-select" })?
132        } 
133        let src = &src[layout.start_offset..layout.start_offset+layout.shape.element_count()];
134        let n_ids = ids_l.dims();
135        assert!(n_ids.len() == 1);
136        let n_ids = n_ids[0];
137        let stride_ids = ids_l.stride()[0];
138        let mut dst_dims = layout.dims().to_vec();
139        let src_dim = dst_dims[dim];
140        dst_dims[dim] = n_ids;
141        let dst_len: usize = dst_dims.iter().product();
142        let left_len: usize = dst_dims[..dim].iter().product();
143        let right_len: usize = dst_dims[dim + 1..].iter().product();
144        let mut dst = vec![T::ZERO; dst_len];
145        for left_i in 0..left_len {
146            let start_src_idx = left_i * right_len * src_dim;
147            let start_dst_idx = left_i * right_len * n_ids;
148            for i in 0..n_ids {
149                let start_dst_idx = start_dst_idx + i * right_len;
150                let index = ids[ids_l.start_offset() + stride_ids * i];
151                if index == I::max_value() {
152                    dst[start_dst_idx..start_dst_idx + right_len].fill(T::ZERO);
153                } else {
154                    let index = index.to_usize();
155                    if index >= src_dim {
156                        Err(Error::InvalidIndex {
157                            index,
158                            size: src_dim,
159                            op: "index-select",
160                        })?
161                    }
162                    let start_src_idx = start_src_idx + index * right_len;
163                    dst[start_dst_idx..start_dst_idx + right_len]
164                        .copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
165                }
166            }
167        }
168        Ok(dst)
169    }
170}
171
172impl<T: NumDType> Storage<T> {
173    pub(crate) fn index_add<I: IntDType>(
174        &self,
175        self_layout: &Layout,
176        ids: &Storage<I>,
177        ids_layout: &Layout,
178        source: &Storage<T>,
179        source_layout: &Layout,
180        dim: usize,
181    ) -> Result<Self> {
182        if !self_layout.is_contiguous() || !source_layout.is_contiguous() {
183            return Err(Error::RequiresContiguous { op: "index-add" }.into());
184        }
185
186        let new_data = Self::do_index_add(
187            self.data(),
188            self_layout,
189            ids.data(),
190            ids_layout,
191            source.data(),
192            dim
193        )?;
194        
195        Ok(Storage::new(new_data))
196    }
197
198    fn do_index_add<I: IntDType>(
199        dst_data: &[T],      
200        dst_layout: &Layout,
201        ids: &[I],           
202        ids_layout: &Layout, 
203        src_data: &[T],     
204        dim: usize,
205    ) -> Result<Vec<T>> {
206        // 1. 复制一份 dst_data 作为结果,因为我们要修改它 (或者说是累加到一个新 buffer 上)
207        // 注意:在大张量场景下,这里可能会有性能开销。
208        // 如果你的系统支持 inplace 修改且确信引用计数为1,可以直接修改。
209        // 这里为了安全采用 clone。
210        let mut result = dst_data.to_vec();
211
212        let n_ids = ids_layout.dims()[0];
213        let stride_ids = ids_layout.stride()[0];
214
215        let dst_dims = dst_layout.dims();
216        let src_dim_size = dst_dims[dim]; // 原始张量在该维度的长度
217
218        let left_len: usize = dst_dims[..dim].iter().product();
219        let right_len: usize = dst_dims[dim + 1..].iter().product();
220        
221        // src (grad) 的结构是 [left, n_ids, right]
222        // dst (self) 的结构是 [left, src_dim_size, right]
223
224        for left_i in 0..left_len {
225            let start_src_block = left_i * n_ids * right_len;
226            let start_dst_block = left_i * src_dim_size * right_len;
227
228            for i in 0..n_ids {
229                // 获取索引值
230                let index_val = ids[ids_layout.start_offset() + stride_ids * i];
231                
232                // 处理 mask (如果有 padding index)
233                if index_val == I::max_value() {
234                    continue; 
235                }
236
237                let idx = index_val.to_usize();
238                if idx >= src_dim_size {
239                    return Err(Error::InvalidIndex {
240                        index: idx,
241                        size: src_dim_size,
242                        op: "index-add",
243                    }.into());
244                }
245
246                // 计算偏移量
247                let src_offset = start_src_block + i * right_len;
248                let dst_offset = start_dst_block + idx * right_len;
249
250                // 执行加法: dst[idx] += src[i]
251                for k in 0..right_len {
252                    let s_val = src_data[src_offset + k];
253                    let d_val = result[dst_offset + k];
254                    result[dst_offset + k] = d_val + s_val;
255                }
256            }
257        }
258
259        Ok(result)
260    }
261}
262
263impl<T: WithDType> Storage<T> {
264    pub(crate) fn gather<I: IntDType>(
265        &self,
266        self_layout: &Layout,
267        ids: &Storage<I>,
268        ids_layout: &Layout,
269        dim: usize,
270    ) -> Result<Self> {
271        let new_data = Self::do_gather(self.data(), self_layout, ids.data(), ids_layout, dim)?;
272        Ok(Storage::new(new_data))
273    }
274
275    fn do_gather<I: IntDType>(
276        src: &[T],
277        src_layout: &Layout,
278        ids: &[I],
279        ids_layout: &Layout,
280        dim: usize,
281    ) -> Result<Vec<T>> {
282        // 1. 检查连续性:为了简化多维索引的偏移量计算,强制要求 src 和 ids 均连续。
283        // 如果需要支持非连续,需要根据 stride 手动计算物理偏移量,会显著增加复杂性。
284        if !src_layout.is_contiguous() || !ids_layout.is_contiguous() {
285            return Err(Error::RequiresContiguous { op: "gather" }.into());
286        }
287
288        let src_dims = src_layout.dims();
289        let ids_dims = ids_layout.dims();
290
291        // 2. 检查维度一致性 (Rank check)
292        if src_dims.len() != ids_dims.len() {
293             return Err(Error::ShapeMismatchBinaryOp { 
294                lhs: src_layout.shape().clone(), 
295                rhs: ids_layout.shape().clone(),
296                op: "gather" 
297             }.into());
298        }
299
300        // 3. 准备结果 Buffer
301        let dst_len = ids_layout.shape.element_count();
302        // 类似于 index_select,预填充默认值 (如 0 或 false),方便处理 Padding Mask
303        let mut dst = vec![T::ZERO; dst_len];
304
305        // 4. 计算三段式维度的长度:[Left, Dim, Right]
306        // Left: dim 左边所有维度的乘积
307        let left_len: usize = src_dims[..dim].iter().product();
308        // Right: dim 右边所有维度的乘积
309        let right_len: usize = src_dims[dim + 1..].iter().product();
310        
311        let src_dim_size = src_dims[dim];
312        let ids_dim_size = ids_dims[dim];
313
314        // 5. 执行 Gather 操作
315        // 逻辑视角:src 是 [left, src_dim, right], ids 是 [left, ids_dim, right]
316        // 我们遍历 ids (即遍历 dst) 的所有位置
317        
318        for i in 0..left_len {
319            // 计算 src 在当前 Left 块的起始偏移
320            let src_block_start = i * src_dim_size * right_len;
321            let dst_block_start = i * ids_dim_size * right_len;
322
323            for j in 0..ids_dim_size {
324                for k in 0..right_len {
325                    // 计算 ids 和 dst 的线性索引 (因为是连续内存)
326                    // dst_idx = (i * ids_dim * right) + (j * right) + k
327                    let dst_idx = dst_block_start + j * right_len + k;
328                    
329                    let index_val = ids[dst_idx];
330
331                    // 处理 Mask (Padding Index)
332                    if index_val == I::max_value() {
333                        dst[dst_idx] = T::ZERO;
334                        continue;
335                    }
336
337                    let idx = index_val.to_usize();
338                    if idx >= src_dim_size {
339                        return Err(Error::InvalidIndex {
340                            index: idx,
341                            size: src_dim_size,
342                            op: "gather",
343                        }.into());
344                    }
345
346                    // 计算 src 的线性索引
347                    // src 取值的逻辑:保持 Left(i) 和 Right(k) 不变,Dim 变为 ids 中读出的 idx
348                    let src_idx = src_block_start + idx * right_len + k;
349
350                    dst[dst_idx] = src[src_idx];
351                }
352            }
353        }
354
355        Ok(dst)
356    }
357}
358
359impl<T: NumDType> Storage<T> {
360    pub(crate) fn scatter_add<I: IntDType>(
361        &self,
362        self_layout: &Layout,
363        ids: &Storage<I>,
364        ids_layout: &Layout,
365        source: &Storage<T>,
366        source_layout: &Layout,
367        dim: usize,
368    ) -> Result<Self> {
369        if !self_layout.is_contiguous() || !ids_layout.is_contiguous() || !source_layout.is_contiguous() {
370            return Err(Error::RequiresContiguous { op: "scatter-add" }.into());
371        }
372
373        // 这里的 self 是 destination (arg_grad),source 是 gradient
374        let new_data = Self::do_scatter_add(
375            self.data(),
376            self_layout,
377            ids.data(),
378            ids_layout,
379            source.data(),
380            dim
381        )?;
382        
383        Ok(Storage::new(new_data))
384    }
385
386    fn do_scatter_add<I: IntDType>(
387        dst: &[T],          // arg_grad 的当前数据 (通常是 zero buffer)
388        dst_layout: &Layout,
389        ids: &[I],          // 索引
390        ids_layout: &Layout,
391        src: &[T],          // incoming grad
392        dim: usize,
393    ) -> Result<Vec<T>> {
394        // 1. 复制 dst 数据以进行累加 (inplace 模拟)
395        let mut result = dst.to_vec();
396
397        let dst_dims = dst_layout.dims();
398        let src_dims = ids_layout.dims(); // src 和 ids 的形状应该一致
399
400        // 检查维度一致性
401        if dst_dims.len() != src_dims.len() {
402             return Err(Error::ShapeMismatchBinaryOp { 
403                lhs: dst_layout.shape().clone(), 
404                rhs: ids_layout.shape().clone(),
405                op: "scatter-add" 
406             }.into());
407        }
408
409        // 2. 计算三段式维度
410        // 我们遍历的是 source (grad) 和 ids,它们的形状是一样的
411        let left_len: usize = src_dims[..dim].iter().product();
412        let right_len: usize = src_dims[dim + 1..].iter().product();
413        
414        let src_dim_size = src_dims[dim]; // ids 的 dim 长度
415        let dst_dim_size = dst_dims[dim]; // arg (dst) 的 dim 长度
416
417        // 3. 执行 Scatter Add
418        // 逻辑:遍历 src(grad) 的每一个元素,找到 ids 中对应的索引 idx,
419        // 然后 result[left, idx, right] += src[left, i, right]
420        
421        for i in 0..left_len {
422            let src_block_start = i * src_dim_size * right_len;
423            let dst_block_start = i * dst_dim_size * right_len;
424
425            for j in 0..src_dim_size {
426                for k in 0..right_len {
427                    // src 和 ids 是同步遍历的
428                    let linear_idx = src_block_start + j * right_len + k;
429                    
430                    let index_val = ids[linear_idx];
431
432                    // 处理 Mask (Padding Index)
433                    if index_val == I::max_value() {
434                        continue;
435                    }
436
437                    let idx = index_val.to_usize();
438                    if idx >= dst_dim_size {
439                        return Err(Error::InvalidIndex {
440                            index: idx,
441                            size: dst_dim_size,
442                            op: "scatter-add",
443                        }.into());
444                    }
445
446                    // 计算目标位置的索引
447                    // 保持 Left(i) 和 Right(k) 不变,Dim 变为从 ids 读出的 idx
448                    let dst_idx = dst_block_start + idx * right_len + k;
449
450                    // 累加梯度
451                    result[dst_idx] = result[dst_idx] + src[linear_idx];
452                }
453            }
454        }
455
456        Ok(result)
457    }
458}
459
460#[derive(Clone)]
461pub struct StorageArc<T>(pub(crate) Arc<RwLock<Storage<T>>>);
462
463impl<T: WithDType> StorageArc<T> {
464    pub fn new(storage: Storage<T>) -> Self {
465        Self(Arc::new(RwLock::new(storage)))
466    }
467
468    #[inline]
469    pub fn read(&self) -> std::sync::RwLockReadGuard<'_, Storage<T>> {
470        self.0.read().unwrap()
471    }
472
473    #[inline]
474    pub fn write(&self) -> std::sync::RwLockWriteGuard<'_, Storage<T>> {
475        self.0.write().unwrap()
476    }
477
478    #[inline]
479    pub fn get(&self, index: usize) -> Option<T> {
480        self.read().get(index)
481    }
482
483    #[inline]
484    pub fn set(&mut self, index: usize, val: T) -> Option<()> {
485        self.write().set(index, val)
486    }
487
488    #[inline]
489    pub fn get_unchecked(&self, index: usize) -> T {
490        self.read().get_unchecked(index)
491    }
492
493    #[inline]
494    pub fn set_unchecked(&self, index: usize, val: T) {
495        self.write().set_unchecked(index, val)
496    }
497
498    #[inline]
499    pub fn ptr_eq(this: &Self, other: &Self) -> bool {
500        Arc::ptr_eq(&this.0, &other.0)
501    }
502
503    #[inline]
504    pub fn get_ref(&self, start_offset: usize) -> StorageRef<'_, T> {
505        StorageRef::Guard(std::sync::RwLockReadGuard::map(self.0.read().unwrap(), |s| &s.data()[start_offset..]))
506    }
507
508    #[inline]
509    pub fn get_mut(&self, start_offset: usize) -> StorageMut<'_, T> {
510        StorageMut::Guard(std::sync::RwLockWriteGuard::map(self.0.write().unwrap(), |s| &mut s.data_mut()[start_offset..]))
511    }
512
513    #[inline]
514    pub fn get_ptr(&self, start_offset: usize) -> *mut T {
515        let mut s = self.0.write().unwrap();
516        let data = &mut s.data_mut()[start_offset..];
517        data.as_mut_ptr()
518    }
519}
520
521pub enum StorageRef<'a, T> {
522    Guard(std::sync::MappedRwLockReadGuard<'a, [T]>),
523    Slice(&'a [T]),
524}
525
526// pub struct StorageMut<'a, T>(std::sync::MappedRwLockWriteGuard<'a, [T]>);
527
528pub enum StorageMut<'a, T> {
529    Guard(std::sync::MappedRwLockWriteGuard<'a, [T]>),
530    Slice(&'a mut[T]),
531}
532
533impl<'a, T: WithDType> StorageRef<'a, T> {
534    pub fn clone(&'a self) -> Self {
535        Self::Slice(&self.data())
536    }
537
538    pub fn slice(&'a self, index: usize) -> Self {
539        Self::Slice(&self.data()[index..])
540    }
541
542    #[inline]
543    pub fn get(&self, index: usize) -> Option<T> {
544        self.data().get(index).copied()
545    }
546
547    #[inline]
548    pub fn get_unchecked(&self, index: usize) -> T {
549        self.data()[index]
550    }
551
552    #[inline]
553    pub fn len(&self) -> usize {
554        self.data().len()
555    }
556
557    pub fn data(&self) -> &[T] {
558        match self {
559            Self::Guard(gurad) => &gurad,
560            Self::Slice(s) => s,
561        }
562    }
563}
564
565impl<'a, T: WithDType> StorageMut<'a, T> {
566    pub fn clone(&'a self) -> StorageRef<'a, T> {
567        StorageRef::Slice(self.data())
568    }
569
570    #[inline]
571    pub fn get(&self, index: usize) -> Option<T> {
572        self.data().get(index).copied()
573    }
574
575    #[inline]
576    pub fn get_unchecked(&self, index: usize) -> T {
577        self.data()[index]
578    }
579
580    #[inline]
581    pub fn set(&mut self, index: usize, val: T) -> Option<()> {
582        if index >= self.len() {
583            None
584        } else {
585            self.set_unchecked(index, val);
586            Some(())
587        }
588    }
589
590    #[inline]
591    pub fn set_unchecked(&mut self, index: usize, val: T) {
592        self.data_mut()[index] = val;
593    }
594
595    #[inline]
596    pub fn len(&self) -> usize {
597        self.data().len()
598    }
599
600    pub fn data(&self) -> &[T] {
601        match self {
602            Self::Guard(gurad) => &gurad,
603            Self::Slice(s) => s,
604        }
605    }
606
607    pub fn data_mut(&mut self) -> &mut [T] {
608        match self {
609            Self::Guard(gurad) => &mut gurad[0..],
610            Self::Slice(s) => &mut s[0..],
611        }
612    }
613}