candle_core/
tensor_cat.rs

1use crate::{shape::Dim, Context, Error, Result, Shape, Tensor};
2
3impl Tensor {
4    /// Concatenates two or more tensors along a particular dimension.
5    ///
6    /// All tensors must of the same rank, and the output will have
7    /// the same rank
8    ///
9    /// ```rust
10    /// # use candle_core::{Tensor, DType, Device};
11    /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
12    /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
13    ///
14    /// let c = Tensor::cat(&[&a, &b], 0)?;
15    /// assert_eq!(c.shape().dims(), &[4, 3]);
16    ///
17    /// let c = Tensor::cat(&[&a, &b], 1)?;
18    /// assert_eq!(c.shape().dims(), &[2, 6]);
19    /// # Ok::<(), candle_core::Error>(())
20    /// ```
21    pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
22        if args.is_empty() {
23            Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
24        }
25        let arg0 = args[0].as_ref();
26        if args.len() == 1 {
27            return Ok(arg0.clone());
28        }
29        let dim = dim.to_index(arg0.shape(), "cat")?;
30        for arg in args {
31            arg.as_ref().check_dim(dim, "cat")?;
32        }
33        for (arg_idx, arg) in args.iter().enumerate() {
34            let arg = arg.as_ref();
35            if arg0.rank() != arg.rank() {
36                Err(Error::UnexpectedNumberOfDims {
37                    expected: arg0.rank(),
38                    got: arg.rank(),
39                    shape: arg.shape().clone(),
40                }
41                .bt())?
42            }
43            for (dim_idx, (v1, v2)) in arg0
44                .shape()
45                .dims()
46                .iter()
47                .zip(arg.shape().dims().iter())
48                .enumerate()
49            {
50                if dim_idx != dim && v1 != v2 {
51                    Err(Error::ShapeMismatchCat {
52                        dim: dim_idx,
53                        first_shape: arg0.shape().clone(),
54                        n: arg_idx + 1,
55                        nth_shape: arg.shape().clone(),
56                    }
57                    .bt())?
58                }
59            }
60        }
61        let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
62        if all_contiguous {
63            Self::cat_contiguous(args, dim)
64        } else if dim == 0 {
65            Self::cat0(args)
66        } else {
67            let args: Vec<Tensor> = args
68                .iter()
69                .map(|a| a.as_ref().transpose(0, dim))
70                .collect::<Result<Vec<_>>>()?;
71            let cat = Self::cat0(&args)?;
72            cat.transpose(0, dim)
73        }
74    }
75
76    fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
77        if args.is_empty() {
78            Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
79        }
80        let arg0 = args[0].as_ref();
81        if args.len() == 1 {
82            return Ok(arg0.clone());
83        }
84        let rank = arg0.rank();
85        let device = arg0.device();
86        let dtype = arg0.dtype();
87        let first_dims = arg0.shape().dims();
88        let mut cat_dims = first_dims.to_vec();
89        cat_dims[0] = 0;
90        let mut offsets = vec![0usize];
91        for (arg_idx, arg) in args.iter().enumerate() {
92            let arg = arg.as_ref();
93            if arg.dtype() != dtype {
94                Err(Error::DTypeMismatchBinaryOp {
95                    lhs: dtype,
96                    rhs: arg.dtype(),
97                    op: "cat",
98                }
99                .bt())?
100            }
101            if arg.device().location() != device.location() {
102                Err(Error::DeviceMismatchBinaryOp {
103                    lhs: device.location(),
104                    rhs: arg.device().location(),
105                    op: "cat",
106                }
107                .bt())?
108            }
109            if rank != arg.rank() {
110                Err(Error::UnexpectedNumberOfDims {
111                    expected: rank,
112                    got: arg.rank(),
113                    shape: arg.shape().clone(),
114                }
115                .bt())?
116            }
117            for (dim_idx, (v1, v2)) in arg0
118                .shape()
119                .dims()
120                .iter()
121                .zip(arg.shape().dims().iter())
122                .enumerate()
123            {
124                if dim_idx == 0 {
125                    cat_dims[0] += v2;
126                }
127                if dim_idx != 0 && v1 != v2 {
128                    Err(Error::ShapeMismatchCat {
129                        dim: dim_idx,
130                        first_shape: arg0.shape().clone(),
131                        n: arg_idx + 1,
132                        nth_shape: arg.shape().clone(),
133                    }
134                    .bt())?
135                }
136            }
137            let next_offset = offsets.last().context("empty offsets")? + arg.elem_count();
138            offsets.push(next_offset);
139        }
140        let shape = Shape::from(cat_dims);
141        let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
142        let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
143        for (arg, &offset) in args.iter().zip(offsets.iter()) {
144            let arg = arg.as_ref();
145            arg.storage()
146                .copy_strided_src(&mut storage, offset, arg.layout())?;
147        }
148        Ok(crate::tensor::from_storage(storage, shape, op, false))
149    }
150
151    fn cat_contiguous<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
152        if args.is_empty() {
153            Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
154        }
155        let arg0 = args[0].as_ref();
156        if args.len() == 1 {
157            return Ok(arg0.clone());
158        }
159        let rank = arg0.rank();
160        let device = arg0.device();
161        let dtype = arg0.dtype();
162        let first_dims = arg0.shape().dims();
163        let mut cat_dims = first_dims.to_vec();
164        cat_dims[dim] = 0;
165        for (arg_idx, arg) in args.iter().enumerate() {
166            let arg = arg.as_ref();
167            if arg.dtype() != dtype {
168                Err(Error::DTypeMismatchBinaryOp {
169                    lhs: dtype,
170                    rhs: arg.dtype(),
171                    op: "cat",
172                }
173                .bt())?
174            }
175            if arg.device().location() != device.location() {
176                Err(Error::DeviceMismatchBinaryOp {
177                    lhs: device.location(),
178                    rhs: arg.device().location(),
179                    op: "cat",
180                }
181                .bt())?
182            }
183            if rank != arg.rank() {
184                Err(Error::UnexpectedNumberOfDims {
185                    expected: rank,
186                    got: arg.rank(),
187                    shape: arg.shape().clone(),
188                }
189                .bt())?
190            }
191            for (dim_idx, (v1, v2)) in arg0
192                .shape()
193                .dims()
194                .iter()
195                .zip(arg.shape().dims().iter())
196                .enumerate()
197            {
198                if dim_idx == dim {
199                    cat_dims[dim] += v2;
200                }
201                if dim_idx != dim && v1 != v2 {
202                    Err(Error::ShapeMismatchCat {
203                        dim: dim_idx,
204                        first_shape: arg0.shape().clone(),
205                        n: arg_idx + 1,
206                        nth_shape: arg.shape().clone(),
207                    }
208                    .bt())?
209                }
210            }
211        }
212        let cat_target_dim_len = cat_dims[dim];
213        let block_size: usize = cat_dims.iter().skip(1 + dim).product();
214        let shape = Shape::from(cat_dims);
215        let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
216        let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
217        let mut dst_o = 0;
218        for arg in args.iter() {
219            let arg = arg.as_ref();
220            let arg_dims = arg.shape().dims();
221            let d1: usize = arg_dims.iter().take(dim).product();
222            let d2 = block_size * arg_dims[dim];
223            let dst_s = block_size * cat_target_dim_len;
224            let src_o = arg.layout().start_offset();
225            arg.storage().copy2d(
226                &mut storage,
227                d1,
228                d2,
229                /* src_s */ d2,
230                dst_s,
231                src_o,
232                dst_o,
233            )?;
234            dst_o += d2;
235        }
236        Ok(crate::tensor::from_storage(storage, shape, op, false))
237    }
238
239    /// Set the values on `self` using values from `src`. The copy starts at the specified
240    /// `offset` for the target dimension `dim` on `self`.
241    /// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
242    /// has to be greater than or equal to `offset` plus the `src` size.
243    ///
244    /// Note that this modifies `self` in place and as such is not compatible with
245    /// back-propagation.  
246    pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
247        let dim = dim.to_index(self.shape(), "slice-set")?;
248        if !self.is_contiguous() || !src.is_contiguous() {
249            Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
250        }
251        if self.same_storage(src) {
252            crate::bail!("cannot use slice_set when self and src share their storage")
253        }
254        if self.dtype() != src.dtype() {
255            Err(Error::DTypeMismatchBinaryOp {
256                lhs: self.dtype(),
257                rhs: src.dtype(),
258                op: "slice-set",
259            }
260            .bt())?
261        }
262        if self.device().location() != src.device().location() {
263            Err(Error::DeviceMismatchBinaryOp {
264                lhs: self.device().location(),
265                rhs: src.device().location(),
266                op: "slice-set",
267            }
268            .bt())?
269        }
270        if self.rank() != src.rank() {
271            Err(Error::UnexpectedNumberOfDims {
272                expected: self.rank(),
273                got: src.rank(),
274                shape: self.shape().clone(),
275            }
276            .bt())?
277        }
278        for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
279            if dim_idx == dim && *v2 + offset > *v1 {
280                crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
281            }
282            if dim_idx != dim && v1 != v2 {
283                crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
284            }
285        }
286        let block_size: usize = src.dims().iter().skip(1 + dim).product();
287        let d1: usize = src.dims().iter().take(dim).product();
288        let d2 = block_size * src.dims()[dim];
289        let dst_o = self.layout().start_offset() + offset * block_size;
290        let src_o = src.layout().start_offset();
291        src.storage().copy2d(
292            &mut self.storage_mut(),
293            d1,
294            d2,
295            /* src_s */ d2,
296            /* dst_s */ block_size * self.dims()[dim],
297            src_o,
298            dst_o,
299        )?;
300
301        Ok(())
302    }
303}