1use crate::prelude_dev::*;
2use core::mem::transmute;
3
4pub fn vecdot<TA, TB, DA, DB, B>(
123 a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
124 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
125 axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
126) -> Tensor<TA::Output, B, IxD>
127where
128 TA: Mul<TB>,
129 DA: DimAPI,
130 DB: DimAPI,
131 B: DeviceVecdotAPI<TA, TB, TA::Output, DA, DB, IxD>
132 + DeviceAPI<TA>
133 + DeviceAPI<TB>
134 + DeviceAPI<TA::Output>
135 + DeviceCreationAnyAPI<TA::Output>,
136{
137 vecdot_f(a, b, axes_pair).rstsr_unwrap()
138}
139
140pub fn vecdot_from<TA, TB, TC, DA, DB, DC, B>(
144 c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
145 a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
146 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
147 axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
148) where
149 DA: DimAPI,
150 DB: DimAPI,
151 DC: DimAPI,
152 B: DeviceVecdotAPI<TA, TB, TC, DA, DB, DC> + DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
153{
154 vecdot_from_f(c, a, b, axes_pair).rstsr_unwrap()
155}
156
157pub fn vecdot_f<TA, TB, DA, DB, B>(
161 a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
162 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
163 axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
164) -> Result<Tensor<TA::Output, B, IxD>>
165where
166 TA: Mul<TB>,
167 DA: DimAPI,
168 DB: DimAPI,
169 B: DeviceVecdotAPI<TA, TB, TA::Output, DA, DB, IxD>
170 + DeviceAPI<TA>
171 + DeviceAPI<TB>
172 + DeviceAPI<TA::Output>
173 + DeviceCreationAnyAPI<TA::Output>,
174{
175 let (a, b) = (a.view(), b.view());
176
177 let device = a.device().clone();
179 rstsr_assert!(device.same_device(b.device()), DeviceMismatch)?;
180
181 let mut axes_pair = axes_pair.try_into().map_err(Into::into)?;
183 if axes_pair == AxesPairIndex::None {
184 axes_pair = AxesPairIndex::Val(-1);
185 }
186
187 let (axes_a, axes_b) = match axes_pair {
188 AxesPairIndex::None => unreachable!("already handled above"),
189 AxesPairIndex::Val(axis) => {
190 if axis < 0 {
191 rstsr_pattern!(
192 axis,
193 -(a.ndim().min(b.ndim()) as isize)..=-1,
194 InvalidValue,
195 "axis should be [-N, -1] where N is min(a.ndim, b.ndim)"
196 )?;
197 let axis_a = axis + a.ndim() as isize;
198 let axis_b = axis + b.ndim() as isize;
199 (vec![axis_a], vec![axis_b])
200 } else {
201 rstsr_pattern!(
202 axis,
203 0..(a.ndim().min(b.ndim()) as isize),
204 InvalidValue,
205 "axis should be [0, N) where N is min(a.ndim, b.ndim)"
206 )?;
207 (vec![axis], vec![axis])
208 }
209 },
210 AxesPairIndex::Pair(axes_a, axes_b) => {
211 let axes_a = normalize_axes_index(axes_a, a.ndim(), false, false)?;
212 let axes_b = normalize_axes_index(axes_b, b.ndim(), false, false)?;
213 rstsr_assert_eq!(
214 axes_a.len(),
215 axes_b.len(),
216 InvalidValue,
217 "axes_a and axes_b should have the same length"
218 )?;
219 (axes_a, axes_b)
220 },
221 };
222
223 let (las, lam) = a.layout().dim_split_axes(&axes_a)?;
224 let (lbs, lbm) = b.layout().dim_split_axes(&axes_b)?;
225
226 rstsr_assert_eq!(
227 las.shape(),
228 lbs.shape(),
229 InvalidLayout,
230 "the dimensions of a and b along the contracted axis should be the same"
231 )?;
232
233 let default_order = a.device().default_order();
234 let (lam_b, lbm_b) = broadcast_layout(&lam, &lbm, default_order)?;
235 let layout_c = match TensorIterOrder::default() {
237 TensorIterOrder::C => lam_b.shape().c(),
238 TensorIterOrder::F => lam_b.shape().f(),
239 _ => get_layout_for_binary_op(&lam_b, &lbm_b, default_order)?,
240 };
241 let mut storage_c = device.uninit_impl(layout_c.bounds_index()?.1)?;
242 device.vecdot(storage_c.raw_mut(), &layout_c, a.raw(), a.layout(), b.raw(), b.layout(), &axes_a, &axes_b)?;
243 unsafe { Tensor::new_f(B::assume_init_impl(storage_c)?, layout_c) }
244}
245
246pub fn vecdot_from_f<TA, TB, TC, DA, DB, DC, B>(
250 mut c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
251 a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
252 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
253 axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
254) -> Result<()>
255where
256 DA: DimAPI,
257 DB: DimAPI,
258 DC: DimAPI,
259 B: DeviceVecdotAPI<TA, TB, TC, DA, DB, DC> + DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
260{
261 let (a, b, mut c) = (a.view(), b.view(), c.view_mut());
262
263 let device = c.device().clone();
265 rstsr_assert!(device.same_device(a.device()), DeviceMismatch)?;
266 rstsr_assert!(device.same_device(b.device()), DeviceMismatch)?;
267
268 let mut axes_pair = axes_pair.try_into().map_err(Into::into)?;
270 if axes_pair == AxesPairIndex::None {
271 axes_pair = AxesPairIndex::Val(-1);
272 }
273
274 let (axes_a, axes_b) = match axes_pair {
275 AxesPairIndex::None => unreachable!("already handled above"),
276 AxesPairIndex::Val(axis) => {
277 if axis < 0 {
278 rstsr_pattern!(
279 axis,
280 -(a.ndim().min(b.ndim()) as isize)..=-1,
281 InvalidValue,
282 "axis should be [-N, -1] where N is min(a.ndim, b.ndim)"
283 )?;
284 let axis_a = axis + a.ndim() as isize;
285 let axis_b = axis + b.ndim() as isize;
286 (vec![axis_a], vec![axis_b])
287 } else {
288 rstsr_pattern!(
289 axis,
290 0..(a.ndim().min(b.ndim()) as isize),
291 InvalidValue,
292 "axis should be [0, N) where N is min(a.ndim, b.ndim)"
293 )?;
294 (vec![axis], vec![axis])
295 }
296 },
297 AxesPairIndex::Pair(axes_a, axes_b) => {
298 let axes_a = normalize_axes_index(axes_a, a.ndim(), false, false)?;
299 let axes_b = normalize_axes_index(axes_b, b.ndim(), false, false)?;
300 rstsr_assert_eq!(
301 axes_a.len(),
302 axes_b.len(),
303 InvalidValue,
304 "axes_a and axes_b should have the same length"
305 )?;
306 (axes_a, axes_b)
307 },
308 };
309
310 let (las, lam) = a.layout().dim_split_axes(&axes_a)?;
311 let (lbs, lbm) = b.layout().dim_split_axes(&axes_b)?;
312
313 rstsr_assert_eq!(
314 las.shape(),
315 lbs.shape(),
316 InvalidLayout,
317 "the dimensions of a and b along the contracted axis should be the same"
318 )?;
319
320 let shape_c_expect = broadcast_shapes_f(&[lam.shape().to_vec(), lbm.shape().to_vec()], device.default_order())?;
321 let shape_c = c.shape();
322 rstsr_assert_eq!(shape_c_expect, shape_c.as_ref(), InvalidLayout, "incompatible shapes in vecdot")?;
323
324 let c_layout = c.layout().clone();
325 let c_raw_mut = unsafe {
326 transmute::<&mut <B as DeviceRawAPI<TC>>::Raw, &mut <B as DeviceRawAPI<MaybeUninit<TC>>>::Raw>(c.raw_mut())
327 };
328 device.vecdot(c_raw_mut, &c_layout, a.raw(), a.layout(), b.raw(), b.layout(), &axes_a, &axes_b)
329}
330
331impl<R, T, B, D> TensorAny<R, T, B, D>
332where
333 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
334 B: DeviceAPI<T>,
335 D: DimAPI,
336{
337 pub fn vecdot<TB, DB, TC>(
341 &self,
342 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
343 axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
344 ) -> Tensor<TC, B, IxD>
345 where
346 T: Mul<TB, Output = TC>,
347 DB: DimAPI,
348 B: DeviceVecdotAPI<T, TB, TC, D, DB, IxD>
349 + DeviceAPI<T>
350 + DeviceAPI<TB>
351 + DeviceAPI<TC>
352 + DeviceCreationAnyAPI<TC>,
353 {
354 vecdot_f(self.view(), b, axes_pair).rstsr_unwrap()
355 }
356
357 pub fn vecdot_f<TB, DB, TC>(
361 &self,
362 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
363 axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
364 ) -> Result<Tensor<TC, B, IxD>>
365 where
366 T: Mul<TB, Output = TC>,
367 DB: DimAPI,
368 B: DeviceVecdotAPI<T, TB, TC, D, DB, IxD>
369 + DeviceAPI<T>
370 + DeviceAPI<TB>
371 + DeviceAPI<TC>
372 + DeviceCreationAnyAPI<TC>,
373 {
374 vecdot_f(self.view(), b, axes_pair)
375 }
376
377 pub fn vecdot_from<TA, TB, DA, DB>(
381 &mut self,
382 a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
383 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
384 axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
385 ) where
386 DA: DimAPI,
387 DB: DimAPI,
388 B: DeviceVecdotAPI<TA, TB, T, DA, DB, D> + DeviceAPI<TA> + DeviceAPI<TB>,
389 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
390 {
391 vecdot_from_f(self, a, b, axes_pair).rstsr_unwrap()
392 }
393
394 pub fn vecdot_from_f<TA, TB, DA, DB>(
398 &mut self,
399 a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
400 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
401 axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
402 ) -> Result<()>
403 where
404 DA: DimAPI,
405 DB: DimAPI,
406 B: DeviceVecdotAPI<TA, TB, T, DA, DB, D> + DeviceAPI<TA> + DeviceAPI<TB>,
407 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
408 {
409 vecdot_from_f(self, a, b, axes_pair)
410 }
411}
412
413#[cfg(test)]
414mod test {
415 use rstsr::prelude::*;
416
417 #[test]
418 fn test_vecdot() {
419 let mut device = DeviceCpuSerial::default();
420 device.set_default_order(RowMajor);
421 let a = rt::arange((6, &device)).into_shape((2, 3));
422 let b = rt::arange((6, 12, &device)).into_shape((2, 3));
423 let c = rt::vecdot(&a, &b, None);
424 println!("Result c: {c}");
425 let target = rt::tensor_from_nested!([23, 122], &device);
426 assert!(rt::allclose(&c, &target, None));
427
428 let a = rt::tensor_from_nested!([[0., 5., 0.], [0., 0., 10.], [0., 6., 8.]], &device);
429 let b = rt::tensor_from_nested!([0., 0.6, 0.8], &device);
430 let c = rt::vecdot(&a, &b, None);
431 println!("Result c: {c}");
432 let target = rt::tensor_from_nested!([3., 8., 10.], &device);
433 assert!(rt::allclose(&c, &target, None));
434 }
435}