1use crate::prelude_dev::*;
2
3pub trait TensorAssignAPI<TRB> {
6 fn assign_f(a: &mut Self, b: TRB) -> Result<()>;
7 fn assign(a: &mut Self, b: TRB) {
8 Self::assign_f(a, b).rstsr_unwrap()
9 }
10}
11
12pub fn assign_f<TRA, TRB>(a: &mut TRA, b: TRB) -> Result<()>
13where
14 TRA: TensorAssignAPI<TRB>,
15{
16 TRA::assign_f(a, b)
17}
18
19pub fn assign<TRA, TRB>(a: &mut TRA, b: TRB)
20where
21 TRA: TensorAssignAPI<TRB>,
22{
23 TRA::assign(a, b)
24}
25
26impl<RA, DA, RB, DB, TA, TB, B> TensorAssignAPI<TensorAny<RB, TB, B, DB>> for TensorAny<RA, TA, B, DA>
27where
28 RA: DataMutAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
29 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
30 DA: DimAPI,
31 DB: DimAPI,
32 B: DeviceAPI<TA> + DeviceAPI<TB> + OpAssignAPI<TA, DA, TB>,
33{
34 fn assign_f(a: &mut Self, b: TensorAny<RB, TB, B, DB>) -> Result<()> {
35 let mut a = a.view_mut();
37 let b = b.view();
38 rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
40 let device = a.device().clone();
41 rstsr_assert!(!a.layout().is_broadcasted(), InvalidLayout, "cannot assign to broadcasted tensor")?;
43 let la = a.layout().to_dim::<IxD>()?;
44 let lb = b.layout().to_dim::<IxD>()?;
45 let default_order = a.device().default_order();
46 let (la_b, lb_b) = broadcast_layout_to_first(&la, &lb, default_order)?;
47 let la_b = la_b.into_dim::<DA>()?;
48 let lb_b = lb_b.into_dim::<DA>()?;
49 device.assign(a.raw_mut(), &la_b, b.raw(), &lb_b)
51 }
52}
53
54impl<RA, DA, RB, DB, TA, TB, B> TensorAssignAPI<&TensorAny<RB, TB, B, DB>> for TensorAny<RA, TA, B, DA>
55where
56 RA: DataMutAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
57 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
58 DA: DimAPI,
59 DB: DimAPI,
60 B: DeviceAPI<TA> + DeviceAPI<TB> + OpAssignAPI<TA, DA, TB>,
61{
62 fn assign_f(a: &mut Self, b: &TensorAny<RB, TB, B, DB>) -> Result<()> {
63 TensorAssignAPI::assign_f(a, b.view())
64 }
65}
66
67impl<S, D> TensorBase<S, D>
68where
69 D: DimAPI,
70{
71 pub fn assign_f<TRB>(&mut self, b: TRB) -> Result<()>
72 where
73 Self: TensorAssignAPI<TRB>,
74 {
75 assign_f(self, b)
76 }
77
78 pub fn assign<TRB>(&mut self, b: TRB)
79 where
80 Self: TensorAssignAPI<TRB>,
81 {
82 assign(self, b)
83 }
84}
85
86pub trait TensorFillAPI<T> {
91 fn fill_f(a: &mut Self, b: T) -> Result<()>;
92 fn fill(a: &mut Self, b: T) {
93 Self::fill_f(a, b).rstsr_unwrap()
94 }
95}
96
97pub fn fill_f<TRA, T>(a: &mut TRA, b: T) -> Result<()>
98where
99 TRA: TensorFillAPI<T>,
100{
101 TRA::fill_f(a, b)
102}
103
104pub fn fill<TRA, T>(a: &mut TRA, b: T)
105where
106 TRA: TensorFillAPI<T>,
107{
108 TRA::fill(a, b)
109}
110
111impl<RA, DA, TA, TB, B> TensorFillAPI<TB> for TensorAny<RA, TA, B, DA>
112where
113 RA: DataMutAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
114 DA: DimAPI,
115 B: DeviceAPI<TA> + OpAssignAPI<TA, DA, TB>,
116{
117 fn fill_f(a: &mut Self, b: TB) -> Result<()> {
118 rstsr_assert!(!a.layout().is_broadcasted(), InvalidLayout, "cannot fill broadcasted tensor")?;
120 let la = a.layout().clone();
121 let device = a.device().clone();
122 device.fill(a.raw_mut(), &la, b)
123 }
124}
125
126impl<S, D> TensorBase<S, D>
127where
128 D: DimAPI,
129{
130 pub fn fill_f<T>(&mut self, b: T) -> Result<()>
131 where
132 Self: TensorFillAPI<T>,
133 {
134 fill_f(self, b)
135 }
136
137 pub fn fill<T>(&mut self, b: T)
138 where
139 Self: TensorFillAPI<T>,
140 {
141 fill(self, b)
142 }
143}
144
145#[cfg(test)]
148mod tests {
149 use super::*;
150
151 #[test]
152 fn test_assign_with_cast() {
153 let mut device = DeviceCpuSerial::default();
154 device.set_default_order(RowMajor);
155 let mut a: Tensor<f32, _> = zeros(([2, 3], &device));
156 let b = arange((6i32, &device)).into_shape((2, 3));
157 a.assign(&b);
158 assert_eq!(a.raw(), &vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]);
159
160 let c: i32 = 10;
161 a.fill(c);
162 assert_eq!(a.raw(), &vec![10.0f32; 6]);
163 }
164
165 #[test]
166 #[cfg(feature = "faer")]
167 fn test_assign_with_cast_faer() {
168 let mut device = DeviceFaer::default();
169 device.set_default_order(RowMajor);
170 let mut a: Tensor<f32, _> = zeros(([2, 3], &device));
171 let b = arange((6i32, &device)).into_shape((2, 3));
172 a.assign(&b);
173 assert_eq!(a.raw(), &vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]);
174
175 let c: i32 = 10;
176 a.fill(c);
177 assert_eq!(a.raw(), &vec![10.0f32; 6]);
178 }
179}