Skip to main content

hanzo_ml/
tensor_cat.rs

1use crate::{shape::Dim, Context, Error, Result, Shape, Tensor};
2
3impl Tensor {
4    /// Concatenates two or more tensors along a particular dimension.
5    ///
6    /// All tensors must of the same rank, and the output will have
7    /// the same rank
8    ///
9    /// ```rust
10    /// # use hanzo_ml::{Tensor, DType, Device};
11    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
12    /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
13    ///
14    /// let c = Tensor::cat(&[&a, &b], 0)?;
15    /// assert_eq!(c.shape().dims(), &[4, 3]);
16    ///
17    /// let c = Tensor::cat(&[&a, &b], 1)?;
18    /// assert_eq!(c.shape().dims(), &[2, 6]);
19    /// # Ok::<(), hanzo_ml::Error>(())
20    /// ```
21    pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
22        if args.is_empty() {
23            Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
24        }
25        let arg0 = args[0].as_ref();
26        if args.len() == 1 {
27            return Ok(arg0.clone());
28        }
29        let dim = dim.to_index(arg0.shape(), "cat")?;
30        for arg in args {
31            arg.as_ref().check_dim(dim, "cat")?;
32        }
33        for (arg_idx, arg) in args.iter().enumerate() {
34            let arg = arg.as_ref();
35            if arg0.rank() != arg.rank() {
36                Err(Error::UnexpectedNumberOfDims {
37                    expected: arg0.rank(),
38                    got: arg.rank(),
39                    shape: arg.shape().clone(),
40                }
41                .bt())?
42            }
43            for (dim_idx, (v1, v2)) in arg0
44                .shape()
45                .dims()
46                .iter()
47                .zip(arg.shape().dims().iter())
48                .enumerate()
49            {
50                if dim_idx != dim && v1 != v2 {
51                    Err(Error::ShapeMismatchCat {
52                        dim: dim_idx,
53                        first_shape: arg0.shape().clone(),
54                        n: arg_idx + 1,
55                        nth_shape: arg.shape().clone(),
56                    }
57                    .bt())?
58                }
59            }
60        }
61        let all_copy2d_compat = args.iter().all(|v| {
62            let t = v.as_ref();
63            t.is_contiguous() || t.layout().outer_stride_for_dim(dim).is_some()
64        });
65        if all_copy2d_compat {
66            Self::cat_contiguous(args, dim)
67        } else if dim == 0 {
68            Self::cat0(args)
69        } else {
70            let args: Vec<Tensor> = args
71                .iter()
72                .map(|a| a.as_ref().transpose(0, dim))
73                .collect::<Result<Vec<_>>>()?;
74            let cat = Self::cat0(&args)?;
75            cat.transpose(0, dim)
76        }
77    }
78
79    fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
80        if args.is_empty() {
81            Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
82        }
83        let arg0 = args[0].as_ref();
84        if args.len() == 1 {
85            return Ok(arg0.clone());
86        }
87        let rank = arg0.rank();
88        let device = arg0.device();
89        let dtype = arg0.dtype();
90        let first_dims = arg0.shape().dims();
91        let mut cat_dims = first_dims.to_vec();
92        cat_dims[0] = 0;
93        let mut offsets = vec![0usize];
94        for (arg_idx, arg) in args.iter().enumerate() {
95            let arg = arg.as_ref();
96            if arg.dtype() != dtype {
97                Err(Error::DTypeMismatchBinaryOp {
98                    lhs: dtype,
99                    rhs: arg.dtype(),
100                    op: "cat",
101                }
102                .bt())?
103            }
104            if arg.device().location() != device.location() {
105                Err(Error::DeviceMismatchBinaryOp {
106                    lhs: device.location(),
107                    rhs: arg.device().location(),
108                    op: "cat",
109                }
110                .bt())?
111            }
112            if rank != arg.rank() {
113                Err(Error::UnexpectedNumberOfDims {
114                    expected: rank,
115                    got: arg.rank(),
116                    shape: arg.shape().clone(),
117                }
118                .bt())?
119            }
120            for (dim_idx, (v1, v2)) in arg0
121                .shape()
122                .dims()
123                .iter()
124                .zip(arg.shape().dims().iter())
125                .enumerate()
126            {
127                if dim_idx == 0 {
128                    cat_dims[0] += v2;
129                }
130                if dim_idx != 0 && v1 != v2 {
131                    Err(Error::ShapeMismatchCat {
132                        dim: dim_idx,
133                        first_shape: arg0.shape().clone(),
134                        n: arg_idx + 1,
135                        nth_shape: arg.shape().clone(),
136                    }
137                    .bt())?
138                }
139            }
140            let next_offset = offsets.last().context("empty offsets")? + arg.elem_count();
141            offsets.push(next_offset);
142        }
143        let shape = Shape::from(cat_dims);
144        let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
145        let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
146        for (arg, &offset) in args.iter().zip(offsets.iter()) {
147            let arg = arg.as_ref();
148            arg.storage()
149                .copy_strided_src(&mut storage, offset, arg.layout())?;
150        }
151        Ok(crate::tensor::from_storage(storage, shape, op, false))
152    }
153
154    fn cat_contiguous<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
155        if args.is_empty() {
156            Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
157        }
158        let arg0 = args[0].as_ref();
159        if args.len() == 1 {
160            return Ok(arg0.clone());
161        }
162        let rank = arg0.rank();
163        let device = arg0.device();
164        let dtype = arg0.dtype();
165        let first_dims = arg0.shape().dims();
166        let mut cat_dims = first_dims.to_vec();
167        cat_dims[dim] = 0;
168        for (arg_idx, arg) in args.iter().enumerate() {
169            let arg = arg.as_ref();
170            if arg.dtype() != dtype {
171                Err(Error::DTypeMismatchBinaryOp {
172                    lhs: dtype,
173                    rhs: arg.dtype(),
174                    op: "cat",
175                }
176                .bt())?
177            }
178            if arg.device().location() != device.location() {
179                Err(Error::DeviceMismatchBinaryOp {
180                    lhs: device.location(),
181                    rhs: arg.device().location(),
182                    op: "cat",
183                }
184                .bt())?
185            }
186            if rank != arg.rank() {
187                Err(Error::UnexpectedNumberOfDims {
188                    expected: rank,
189                    got: arg.rank(),
190                    shape: arg.shape().clone(),
191                }
192                .bt())?
193            }
194            for (dim_idx, (v1, v2)) in arg0
195                .shape()
196                .dims()
197                .iter()
198                .zip(arg.shape().dims().iter())
199                .enumerate()
200            {
201                if dim_idx == dim {
202                    cat_dims[dim] += v2;
203                }
204                if dim_idx != dim && v1 != v2 {
205                    Err(Error::ShapeMismatchCat {
206                        dim: dim_idx,
207                        first_shape: arg0.shape().clone(),
208                        n: arg_idx + 1,
209                        nth_shape: arg.shape().clone(),
210                    }
211                    .bt())?
212                }
213            }
214        }
215        let cat_target_dim_len = cat_dims[dim];
216        let block_size: usize = cat_dims.iter().skip(1 + dim).product();
217        let shape = Shape::from(cat_dims);
218        let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
219        let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
220        let mut dst_o = 0;
221        for arg in args.iter() {
222            let arg = arg.as_ref();
223            let arg_dims = arg.shape().dims();
224            let d1: usize = arg_dims.iter().take(dim).product();
225            let d2 = block_size * arg_dims[dim];
226            let dst_s = block_size * cat_target_dim_len;
227            let src_o = arg.layout().start_offset();
228            let src_s = if arg.is_contiguous() {
229                d2
230            } else {
231                arg.layout().outer_stride_for_dim(dim).unwrap_or(d2)
232            };
233            arg.storage()
234                .copy2d(&mut storage, d1, d2, src_s, dst_s, src_o, dst_o)?;
235            dst_o += d2;
236        }
237        Ok(crate::tensor::from_storage(storage, shape, op, false))
238    }
239
240    /// Set the values on `self` using values from `src`. The copy starts at the specified
241    /// `offset` for the target dimension `dim` on `self`.
242    /// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
243    /// has to be greater than or equal to `offset` plus the `src` size.
244    ///
245    /// Note that this modifies `self` in place and as such is not compatible with
246    /// back-propagation.
247    pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
248        let dim = dim.to_index(self.shape(), "slice-set")?;
249        if !self.is_contiguous() || !src.is_contiguous() {
250            Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
251        }
252        if self.same_storage(src) {
253            crate::bail!("cannot use slice_set when self and src share their storage")
254        }
255        if self.dtype() != src.dtype() {
256            Err(Error::DTypeMismatchBinaryOp {
257                lhs: self.dtype(),
258                rhs: src.dtype(),
259                op: "slice-set",
260            }
261            .bt())?
262        }
263        if self.device().location() != src.device().location() {
264            Err(Error::DeviceMismatchBinaryOp {
265                lhs: self.device().location(),
266                rhs: src.device().location(),
267                op: "slice-set",
268            }
269            .bt())?
270        }
271        if self.rank() != src.rank() {
272            Err(Error::UnexpectedNumberOfDims {
273                expected: self.rank(),
274                got: src.rank(),
275                shape: self.shape().clone(),
276            }
277            .bt())?
278        }
279        for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
280            if dim_idx == dim && *v2 + offset > *v1 {
281                crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
282            }
283            if dim_idx != dim && v1 != v2 {
284                crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
285            }
286        }
287        let block_size: usize = src.dims().iter().skip(1 + dim).product();
288        let d1: usize = src.dims().iter().take(dim).product();
289        let d2 = block_size * src.dims()[dim];
290        let dst_o = self.layout().start_offset() + offset * block_size;
291        let src_o = src.layout().start_offset();
292        src.storage().copy2d(
293            &mut self.storage_mut(),
294            d1,
295            d2,
296            /* src_s */ d2,
297            /* dst_s */ block_size * self.dims()[dim],
298            src_o,
299            dst_o,
300        )?;
301
302        Ok(())
303    }
304}