rstsr_core/tensor/
assignment.rs

1use crate::prelude_dev::*;
2
3/* #region assign */
4
5pub trait TensorAssignAPI<TRB> {
6    fn assign_f(a: &mut Self, b: TRB) -> Result<()>;
7    fn assign(a: &mut Self, b: TRB) {
8        Self::assign_f(a, b).rstsr_unwrap()
9    }
10}
11
12pub fn assign_f<TRA, TRB>(a: &mut TRA, b: TRB) -> Result<()>
13where
14    TRA: TensorAssignAPI<TRB>,
15{
16    TRA::assign_f(a, b)
17}
18
19pub fn assign<TRA, TRB>(a: &mut TRA, b: TRB)
20where
21    TRA: TensorAssignAPI<TRB>,
22{
23    TRA::assign(a, b)
24}
25
26impl<RA, DA, RB, DB, TA, TB, B> TensorAssignAPI<TensorAny<RB, TB, B, DB>> for TensorAny<RA, TA, B, DA>
27where
28    RA: DataMutAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
29    RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
30    DA: DimAPI,
31    DB: DimAPI,
32    B: DeviceAPI<TA> + DeviceAPI<TB> + OpAssignAPI<TA, DA, TB>,
33{
34    fn assign_f(a: &mut Self, b: TensorAny<RB, TB, B, DB>) -> Result<()> {
35        // get tensor views
36        let mut a = a.view_mut();
37        let b = b.view();
38        // check device
39        rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
40        let device = a.device().clone();
41        // check layout
42        rstsr_assert!(!a.layout().is_broadcasted(), InvalidLayout, "cannot assign to broadcasted tensor")?;
43        let la = a.layout().to_dim::<IxD>()?;
44        let lb = b.layout().to_dim::<IxD>()?;
45        let default_order = a.device().default_order();
46        let (la_b, lb_b) = broadcast_layout_to_first(&la, &lb, default_order)?;
47        let la_b = la_b.into_dim::<DA>()?;
48        let lb_b = lb_b.into_dim::<DA>()?;
49        // assign
50        device.assign(a.raw_mut(), &la_b, b.raw(), &lb_b)
51    }
52}
53
54impl<RA, DA, RB, DB, TA, TB, B> TensorAssignAPI<&TensorAny<RB, TB, B, DB>> for TensorAny<RA, TA, B, DA>
55where
56    RA: DataMutAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
57    RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
58    DA: DimAPI,
59    DB: DimAPI,
60    B: DeviceAPI<TA> + DeviceAPI<TB> + OpAssignAPI<TA, DA, TB>,
61{
62    fn assign_f(a: &mut Self, b: &TensorAny<RB, TB, B, DB>) -> Result<()> {
63        TensorAssignAPI::assign_f(a, b.view())
64    }
65}
66
67impl<S, D> TensorBase<S, D>
68where
69    D: DimAPI,
70{
71    pub fn assign_f<TRB>(&mut self, b: TRB) -> Result<()>
72    where
73        Self: TensorAssignAPI<TRB>,
74    {
75        assign_f(self, b)
76    }
77
78    pub fn assign<TRB>(&mut self, b: TRB)
79    where
80        Self: TensorAssignAPI<TRB>,
81    {
82        assign(self, b)
83    }
84}
85
86/* #endregion */
87
88/* #region fill */
89
90pub trait TensorFillAPI<T> {
91    fn fill_f(a: &mut Self, b: T) -> Result<()>;
92    fn fill(a: &mut Self, b: T) {
93        Self::fill_f(a, b).rstsr_unwrap()
94    }
95}
96
97pub fn fill_f<TRA, T>(a: &mut TRA, b: T) -> Result<()>
98where
99    TRA: TensorFillAPI<T>,
100{
101    TRA::fill_f(a, b)
102}
103
104pub fn fill<TRA, T>(a: &mut TRA, b: T)
105where
106    TRA: TensorFillAPI<T>,
107{
108    TRA::fill(a, b)
109}
110
111impl<RA, DA, TA, TB, B> TensorFillAPI<TB> for TensorAny<RA, TA, B, DA>
112where
113    RA: DataMutAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
114    DA: DimAPI,
115    B: DeviceAPI<TA> + OpAssignAPI<TA, DA, TB>,
116{
117    fn fill_f(a: &mut Self, b: TB) -> Result<()> {
118        // check layout
119        rstsr_assert!(!a.layout().is_broadcasted(), InvalidLayout, "cannot fill broadcasted tensor")?;
120        let la = a.layout().clone();
121        let device = a.device().clone();
122        device.fill(a.raw_mut(), &la, b)
123    }
124}
125
126impl<S, D> TensorBase<S, D>
127where
128    D: DimAPI,
129{
130    pub fn fill_f<T>(&mut self, b: T) -> Result<()>
131    where
132        Self: TensorFillAPI<T>,
133    {
134        fill_f(self, b)
135    }
136
137    pub fn fill<T>(&mut self, b: T)
138    where
139        Self: TensorFillAPI<T>,
140    {
141        fill(self, b)
142    }
143}
144
145/* #endregion */
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn test_assign_with_cast() {
153        let mut device = DeviceCpuSerial::default();
154        device.set_default_order(RowMajor);
155        let mut a: Tensor<f32, _> = zeros(([2, 3], &device));
156        let b = arange((6i32, &device)).into_shape((2, 3));
157        a.assign(&b);
158        assert_eq!(a.raw(), &vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]);
159
160        let c: i32 = 10;
161        a.fill(c);
162        assert_eq!(a.raw(), &vec![10.0f32; 6]);
163    }
164
165    #[test]
166    #[cfg(feature = "faer")]
167    fn test_assign_with_cast_faer() {
168        let mut device = DeviceFaer::default();
169        device.set_default_order(RowMajor);
170        let mut a: Tensor<f32, _> = zeros(([2, 3], &device));
171        let b = arange((6i32, &device)).into_shape((2, 3));
172        a.assign(&b);
173        assert_eq!(a.raw(), &vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]);
174
175        let c: i32 = 10;
176        a.fill(c);
177        assert_eq!(a.raw(), &vec![10.0f32; 6]);
178    }
179}