Skip to main content

rstsr_core/tensor/linalg/
vecdot.rs

1use crate::prelude_dev::*;
2use core::mem::transmute;
3
4/// Vector dot product of two arrays.
5///
6/// Let $\mathbf{a}$ be a vector in `a` and $\mathbf{b}$ be
7/// a corresponding vector in `b`. The dot product is defined as:
8///
9/// $$\mathbf{a} \cdot \mathbf{b} = \sum_{i=0}^{n-1} \overline{a_i}b_i$$
10///
11/// where the sum is over the dimension specified by `axis` (default: last axis)
12/// and where $\overline{a_i}$ denotes the complex conjugate if $a_i$
13/// is complex and the identity otherwise.
14///
15/// # Parameters
16///
17/// - `a`: impl [`TensorViewAPI`]
18///
19///   - The first input array. Note this array is conjugated if it has a complex data type.
20///   - Scalar not allowed.
21///
22/// - `b`: impl [`TensorViewAPI`]
23///
24///   - The second input array.
25///   - Scalar not allowed.
26///
27/// - `axis`: `impl Into<Option<isize>>`
28///
29///   - The axis over which to compute the dot product.
30///   - Default: `-1` (the last axis).
31///   - If negative, the axis is counted from the last axis of each input array.
32///
33/// # Returns
34///
35/// [`Tensor<T::Output, B, DA::Max>`]
36///
37/// - The result shape is the broadcast of the input shapes with the contracted axis removed.
38///
39/// # Examples
40///
41/// Basic vector dot product:
42///
43/// ```rust
44/// # use rstsr::prelude::*;
45/// # let mut device = DeviceCpu::default();
46/// # device.set_default_order(RowMajor);
47/// let a = rt::tensor_from_nested!([1, 2, 3], &device);
48/// let b = rt::tensor_from_nested!([4, 5, 6], &device);
49/// let result = rt::vecdot(&a, &b, None);
50/// println!("{result}");
51/// // 32
52/// ```
53///
54/// 2-dim dot product:
55///
56/// ```rust
57/// # use rstsr::prelude::*;
58/// # let mut device = DeviceCpu::default();
59/// # device.set_default_order(RowMajor);
60/// let a = rt::tensor_from_nested!([[1, 2], [3, 4]], &device);
61/// let b = rt::tensor_from_nested!([[5, 6], [7, 8]], &device);
62/// let result = rt::vecdot(&a, &b, None);
63/// println!("{result}");
64/// // [ 17 53]
65/// ```
66///
67/// 2-dim broadcasted dot product (note in this case, the following two tensors only can be
68/// broadcasted row-major):
69///
70/// ```rust
71/// # use rstsr::prelude::*;
72/// # let mut device = DeviceCpu::default();
73/// device.set_default_order(RowMajor);
74/// let a = rt::tensor_from_nested!([[0., 5., 0.], [0., 0., 10.], [0., 6., 8.]], &device);
75/// let b = rt::tensor_from_nested!([0., 0.6, 0.8], &device);
76/// let result = rt::vecdot(&a, &b, None);
77/// println!("{result}");
78/// // [ 3 8 10]
79/// ```
80///
81/// Complex vector dot product (conjugates first argument):
82///
83/// ```rust
84/// # use rstsr::prelude::*;
85/// # let mut device = DeviceCpu::default();
86/// # device.set_default_order(RowMajor);
87/// use num::complex::c64;
88/// let a = rt::tensor_from_nested!([c64(1., 0.), c64(2., 2.), c64(3., 0.)], &device);
89/// let b = rt::tensor_from_nested!([c64(1., 0.), c64(2., 0.), c64(3., 3.)], &device);
90/// let result = rt::vecdot(&a, &b, None);
91/// println!("{result}");
92/// // 14+5i
93/// ```
94///
95/// # Notes of API accordance
96///
97/// - Array-API: `vecdot(x1, x2, /, *, axis=-1)` ([`vecdot`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.vecdot.html))
98/// - NumPy: `vecdot(x1, x2, /, out=None, *, casting='same_kind', order='K', dtype=None, subok=True[, signature, axes, axis])` ([`numpy.vecdot`](https://numpy.org/doc/stable/reference/generated/numpy.vecdot.html))
99/// - RSTSR: `rt::vecdot(a, b, axis)`
100///
101/// # Panics
102///
103/// - The contracted axis dimensions do not match.
104/// - The input tensors cannot be broadcast together.
105///
106/// For a fallible version, use [`vecdot_f`].
107///
108/// # See Also
109///
110/// ## Related functions in RSTSR
111///
112/// - [`matmul`] - Matrix-matrix product.
113/// - [`rt::tblis::tensordot`](https://docs.rs/rstsr-tblis/latest/rstsr_tblis/tensordot_impl/fn.tensordot.html)
114///   - Tensor dot product along specified axes.
115/// - [`rt::tblis::einsum`](https://docs.rs/rstsr-tblis/latest/rstsr_tblis/einsum_impl/fn.einsum.html)
116///   - Einstein summation for tensors.
117///
118/// ## Variants of this function
119///
120/// - [`vecdot`] / [`vecdot_f`]: Returning a new tensor.
121/// - [`vecdot_from`] / [`vecdot_from_f`]: Writing result to existing tensor.
122pub fn vecdot<TA, TB, DA, DB, B>(
123    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
124    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
125    axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
126) -> Tensor<TA::Output, B, IxD>
127where
128    TA: Mul<TB>,
129    DA: DimAPI,
130    DB: DimAPI,
131    B: DeviceVecdotAPI<TA, TB, TA::Output, DA, DB, IxD>
132        + DeviceAPI<TA>
133        + DeviceAPI<TB>
134        + DeviceAPI<TA::Output>
135        + DeviceCreationAnyAPI<TA::Output>,
136{
137    vecdot_f(a, b, axes_pair).rstsr_unwrap()
138}
139
140/// Vector dot product of two arrays.
141///
142/// See also [`vecdot`].
143pub fn vecdot_from<TA, TB, TC, DA, DB, DC, B>(
144    c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
145    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
146    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
147    axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
148) where
149    DA: DimAPI,
150    DB: DimAPI,
151    DC: DimAPI,
152    B: DeviceVecdotAPI<TA, TB, TC, DA, DB, DC> + DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
153{
154    vecdot_from_f(c, a, b, axes_pair).rstsr_unwrap()
155}
156
157/// Vector dot product of two arrays.
158///
159/// See also [`vecdot`].
160pub fn vecdot_f<TA, TB, DA, DB, B>(
161    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
162    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
163    axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
164) -> Result<Tensor<TA::Output, B, IxD>>
165where
166    TA: Mul<TB>,
167    DA: DimAPI,
168    DB: DimAPI,
169    B: DeviceVecdotAPI<TA, TB, TA::Output, DA, DB, IxD>
170        + DeviceAPI<TA>
171        + DeviceAPI<TB>
172        + DeviceAPI<TA::Output>
173        + DeviceCreationAnyAPI<TA::Output>,
174{
175    let (a, b) = (a.view(), b.view());
176
177    // check devices
178    let device = a.device().clone();
179    rstsr_assert!(device.same_device(b.device()), DeviceMismatch)?;
180
181    // check axis
182    let mut axes_pair = axes_pair.try_into().map_err(Into::into)?;
183    if axes_pair == AxesPairIndex::None {
184        axes_pair = AxesPairIndex::Val(-1);
185    }
186
187    let (axes_a, axes_b) = match axes_pair {
188        AxesPairIndex::None => unreachable!("already handled above"),
189        AxesPairIndex::Val(axis) => {
190            if axis < 0 {
191                rstsr_pattern!(
192                    axis,
193                    -(a.ndim().min(b.ndim()) as isize)..=-1,
194                    InvalidValue,
195                    "axis should be [-N, -1] where N is min(a.ndim, b.ndim)"
196                )?;
197                let axis_a = axis + a.ndim() as isize;
198                let axis_b = axis + b.ndim() as isize;
199                (vec![axis_a], vec![axis_b])
200            } else {
201                rstsr_pattern!(
202                    axis,
203                    0..(a.ndim().min(b.ndim()) as isize),
204                    InvalidValue,
205                    "axis should be [0, N) where N is min(a.ndim, b.ndim)"
206                )?;
207                (vec![axis], vec![axis])
208            }
209        },
210        AxesPairIndex::Pair(axes_a, axes_b) => {
211            let axes_a = normalize_axes_index(axes_a, a.ndim(), false, false)?;
212            let axes_b = normalize_axes_index(axes_b, b.ndim(), false, false)?;
213            rstsr_assert_eq!(
214                axes_a.len(),
215                axes_b.len(),
216                InvalidValue,
217                "axes_a and axes_b should have the same length"
218            )?;
219            (axes_a, axes_b)
220        },
221    };
222
223    let (las, lam) = a.layout().dim_split_axes(&axes_a)?;
224    let (lbs, lbm) = b.layout().dim_split_axes(&axes_b)?;
225
226    rstsr_assert_eq!(
227        las.shape(),
228        lbs.shape(),
229        InvalidLayout,
230        "the dimensions of a and b along the contracted axis should be the same"
231    )?;
232
233    let default_order = a.device().default_order();
234    let (lam_b, lbm_b) = broadcast_layout(&lam, &lbm, default_order)?;
235    // generate output layout
236    let layout_c = match TensorIterOrder::default() {
237        TensorIterOrder::C => lam_b.shape().c(),
238        TensorIterOrder::F => lam_b.shape().f(),
239        _ => get_layout_for_binary_op(&lam_b, &lbm_b, default_order)?,
240    };
241    let mut storage_c = device.uninit_impl(layout_c.bounds_index()?.1)?;
242    device.vecdot(storage_c.raw_mut(), &layout_c, a.raw(), a.layout(), b.raw(), b.layout(), &axes_a, &axes_b)?;
243    unsafe { Tensor::new_f(B::assume_init_impl(storage_c)?, layout_c) }
244}
245
246/// Vector dot product of two arrays.
247///
248/// See also [`vecdot`].
249pub fn vecdot_from_f<TA, TB, TC, DA, DB, DC, B>(
250    mut c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
251    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
252    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
253    axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
254) -> Result<()>
255where
256    DA: DimAPI,
257    DB: DimAPI,
258    DC: DimAPI,
259    B: DeviceVecdotAPI<TA, TB, TC, DA, DB, DC> + DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
260{
261    let (a, b, mut c) = (a.view(), b.view(), c.view_mut());
262
263    // check devices
264    let device = c.device().clone();
265    rstsr_assert!(device.same_device(a.device()), DeviceMismatch)?;
266    rstsr_assert!(device.same_device(b.device()), DeviceMismatch)?;
267
268    // check axis
269    let mut axes_pair = axes_pair.try_into().map_err(Into::into)?;
270    if axes_pair == AxesPairIndex::None {
271        axes_pair = AxesPairIndex::Val(-1);
272    }
273
274    let (axes_a, axes_b) = match axes_pair {
275        AxesPairIndex::None => unreachable!("already handled above"),
276        AxesPairIndex::Val(axis) => {
277            if axis < 0 {
278                rstsr_pattern!(
279                    axis,
280                    -(a.ndim().min(b.ndim()) as isize)..=-1,
281                    InvalidValue,
282                    "axis should be [-N, -1] where N is min(a.ndim, b.ndim)"
283                )?;
284                let axis_a = axis + a.ndim() as isize;
285                let axis_b = axis + b.ndim() as isize;
286                (vec![axis_a], vec![axis_b])
287            } else {
288                rstsr_pattern!(
289                    axis,
290                    0..(a.ndim().min(b.ndim()) as isize),
291                    InvalidValue,
292                    "axis should be [0, N) where N is min(a.ndim, b.ndim)"
293                )?;
294                (vec![axis], vec![axis])
295            }
296        },
297        AxesPairIndex::Pair(axes_a, axes_b) => {
298            let axes_a = normalize_axes_index(axes_a, a.ndim(), false, false)?;
299            let axes_b = normalize_axes_index(axes_b, b.ndim(), false, false)?;
300            rstsr_assert_eq!(
301                axes_a.len(),
302                axes_b.len(),
303                InvalidValue,
304                "axes_a and axes_b should have the same length"
305            )?;
306            (axes_a, axes_b)
307        },
308    };
309
310    let (las, lam) = a.layout().dim_split_axes(&axes_a)?;
311    let (lbs, lbm) = b.layout().dim_split_axes(&axes_b)?;
312
313    rstsr_assert_eq!(
314        las.shape(),
315        lbs.shape(),
316        InvalidLayout,
317        "the dimensions of a and b along the contracted axis should be the same"
318    )?;
319
320    let shape_c_expect = broadcast_shapes_f(&[lam.shape().to_vec(), lbm.shape().to_vec()], device.default_order())?;
321    let shape_c = c.shape();
322    rstsr_assert_eq!(shape_c_expect, shape_c.as_ref(), InvalidLayout, "incompatible shapes in vecdot")?;
323
324    let c_layout = c.layout().clone();
325    let c_raw_mut = unsafe {
326        transmute::<&mut <B as DeviceRawAPI<TC>>::Raw, &mut <B as DeviceRawAPI<MaybeUninit<TC>>>::Raw>(c.raw_mut())
327    };
328    device.vecdot(c_raw_mut, &c_layout, a.raw(), a.layout(), b.raw(), b.layout(), &axes_a, &axes_b)
329}
330
331impl<R, T, B, D> TensorAny<R, T, B, D>
332where
333    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
334    B: DeviceAPI<T>,
335    D: DimAPI,
336{
337    /// Vector dot product of two arrays.
338    ///
339    /// See also [`vecdot`].
340    pub fn vecdot<TB, DB, TC>(
341        &self,
342        b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
343        axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
344    ) -> Tensor<TC, B, IxD>
345    where
346        T: Mul<TB, Output = TC>,
347        DB: DimAPI,
348        B: DeviceVecdotAPI<T, TB, TC, D, DB, IxD>
349            + DeviceAPI<T>
350            + DeviceAPI<TB>
351            + DeviceAPI<TC>
352            + DeviceCreationAnyAPI<TC>,
353    {
354        vecdot_f(self.view(), b, axes_pair).rstsr_unwrap()
355    }
356
357    /// Vector dot product of two arrays.
358    ///
359    /// See also [`vecdot`].
360    pub fn vecdot_f<TB, DB, TC>(
361        &self,
362        b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
363        axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
364    ) -> Result<Tensor<TC, B, IxD>>
365    where
366        T: Mul<TB, Output = TC>,
367        DB: DimAPI,
368        B: DeviceVecdotAPI<T, TB, TC, D, DB, IxD>
369            + DeviceAPI<T>
370            + DeviceAPI<TB>
371            + DeviceAPI<TC>
372            + DeviceCreationAnyAPI<TC>,
373    {
374        vecdot_f(self.view(), b, axes_pair)
375    }
376
377    /// Vector dot product of two arrays.
378    ///
379    /// See also [`vecdot`].
380    pub fn vecdot_from<TA, TB, DA, DB>(
381        &mut self,
382        a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
383        b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
384        axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
385    ) where
386        DA: DimAPI,
387        DB: DimAPI,
388        B: DeviceVecdotAPI<TA, TB, T, DA, DB, D> + DeviceAPI<TA> + DeviceAPI<TB>,
389        R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
390    {
391        vecdot_from_f(self, a, b, axes_pair).rstsr_unwrap()
392    }
393
394    /// Vector dot product of two arrays.
395    ///
396    /// See also [`vecdot`].
397    pub fn vecdot_from_f<TA, TB, DA, DB>(
398        &mut self,
399        a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
400        b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
401        axes_pair: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
402    ) -> Result<()>
403    where
404        DA: DimAPI,
405        DB: DimAPI,
406        B: DeviceVecdotAPI<TA, TB, T, DA, DB, D> + DeviceAPI<TA> + DeviceAPI<TB>,
407        R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
408    {
409        vecdot_from_f(self, a, b, axes_pair)
410    }
411}
412
413#[cfg(test)]
414mod test {
415    use rstsr::prelude::*;
416
417    #[test]
418    fn test_vecdot() {
419        let mut device = DeviceCpuSerial::default();
420        device.set_default_order(RowMajor);
421        let a = rt::arange((6, &device)).into_shape((2, 3));
422        let b = rt::arange((6, 12, &device)).into_shape((2, 3));
423        let c = rt::vecdot(&a, &b, None);
424        println!("Result c: {c}");
425        let target = rt::tensor_from_nested!([23, 122], &device);
426        assert!(rt::allclose(&c, &target, None));
427
428        let a = rt::tensor_from_nested!([[0., 5., 0.], [0., 0., 10.], [0., 6., 8.]], &device);
429        let b = rt::tensor_from_nested!([0., 0.6, 0.8], &device);
430        let c = rt::vecdot(&a, &b, None);
431        println!("Result c: {c}");
432        let target = rt::tensor_from_nested!([3., 8., 10.], &device);
433        assert!(rt::allclose(&c, &target, None));
434    }
435}