1use crate::{shape::Dim, Context, Error, Result, Shape, Tensor};
2
3impl Tensor {
4 pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
22 if args.is_empty() {
23 Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
24 }
25 let arg0 = args[0].as_ref();
26 if args.len() == 1 {
27 return Ok(arg0.clone());
28 }
29 let dim = dim.to_index(arg0.shape(), "cat")?;
30 for arg in args {
31 arg.as_ref().check_dim(dim, "cat")?;
32 }
33 for (arg_idx, arg) in args.iter().enumerate() {
34 let arg = arg.as_ref();
35 if arg0.rank() != arg.rank() {
36 Err(Error::UnexpectedNumberOfDims {
37 expected: arg0.rank(),
38 got: arg.rank(),
39 shape: arg.shape().clone(),
40 }
41 .bt())?
42 }
43 for (dim_idx, (v1, v2)) in arg0
44 .shape()
45 .dims()
46 .iter()
47 .zip(arg.shape().dims().iter())
48 .enumerate()
49 {
50 if dim_idx != dim && v1 != v2 {
51 Err(Error::ShapeMismatchCat {
52 dim: dim_idx,
53 first_shape: arg0.shape().clone(),
54 n: arg_idx + 1,
55 nth_shape: arg.shape().clone(),
56 }
57 .bt())?
58 }
59 }
60 }
61 let all_copy2d_compat = args.iter().all(|v| {
62 let t = v.as_ref();
63 t.is_contiguous() || t.layout().outer_stride_for_dim(dim).is_some()
64 });
65 if all_copy2d_compat {
66 Self::cat_contiguous(args, dim)
67 } else if dim == 0 {
68 Self::cat0(args)
69 } else {
70 let args: Vec<Tensor> = args
71 .iter()
72 .map(|a| a.as_ref().transpose(0, dim))
73 .collect::<Result<Vec<_>>>()?;
74 let cat = Self::cat0(&args)?;
75 cat.transpose(0, dim)
76 }
77 }
78
79 fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
80 if args.is_empty() {
81 Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
82 }
83 let arg0 = args[0].as_ref();
84 if args.len() == 1 {
85 return Ok(arg0.clone());
86 }
87 let rank = arg0.rank();
88 let device = arg0.device();
89 let dtype = arg0.dtype();
90 let first_dims = arg0.shape().dims();
91 let mut cat_dims = first_dims.to_vec();
92 cat_dims[0] = 0;
93 let mut offsets = vec![0usize];
94 for (arg_idx, arg) in args.iter().enumerate() {
95 let arg = arg.as_ref();
96 if arg.dtype() != dtype {
97 Err(Error::DTypeMismatchBinaryOp {
98 lhs: dtype,
99 rhs: arg.dtype(),
100 op: "cat",
101 }
102 .bt())?
103 }
104 if arg.device().location() != device.location() {
105 Err(Error::DeviceMismatchBinaryOp {
106 lhs: device.location(),
107 rhs: arg.device().location(),
108 op: "cat",
109 }
110 .bt())?
111 }
112 if rank != arg.rank() {
113 Err(Error::UnexpectedNumberOfDims {
114 expected: rank,
115 got: arg.rank(),
116 shape: arg.shape().clone(),
117 }
118 .bt())?
119 }
120 for (dim_idx, (v1, v2)) in arg0
121 .shape()
122 .dims()
123 .iter()
124 .zip(arg.shape().dims().iter())
125 .enumerate()
126 {
127 if dim_idx == 0 {
128 cat_dims[0] += v2;
129 }
130 if dim_idx != 0 && v1 != v2 {
131 Err(Error::ShapeMismatchCat {
132 dim: dim_idx,
133 first_shape: arg0.shape().clone(),
134 n: arg_idx + 1,
135 nth_shape: arg.shape().clone(),
136 }
137 .bt())?
138 }
139 }
140 let next_offset = offsets.last().context("empty offsets")? + arg.elem_count();
141 offsets.push(next_offset);
142 }
143 let shape = Shape::from(cat_dims);
144 let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
145 let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
146 for (arg, &offset) in args.iter().zip(offsets.iter()) {
147 let arg = arg.as_ref();
148 arg.storage()
149 .copy_strided_src(&mut storage, offset, arg.layout())?;
150 }
151 Ok(crate::tensor::from_storage(storage, shape, op, false))
152 }
153
154 fn cat_contiguous<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
155 if args.is_empty() {
156 Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
157 }
158 let arg0 = args[0].as_ref();
159 if args.len() == 1 {
160 return Ok(arg0.clone());
161 }
162 let rank = arg0.rank();
163 let device = arg0.device();
164 let dtype = arg0.dtype();
165 let first_dims = arg0.shape().dims();
166 let mut cat_dims = first_dims.to_vec();
167 cat_dims[dim] = 0;
168 for (arg_idx, arg) in args.iter().enumerate() {
169 let arg = arg.as_ref();
170 if arg.dtype() != dtype {
171 Err(Error::DTypeMismatchBinaryOp {
172 lhs: dtype,
173 rhs: arg.dtype(),
174 op: "cat",
175 }
176 .bt())?
177 }
178 if arg.device().location() != device.location() {
179 Err(Error::DeviceMismatchBinaryOp {
180 lhs: device.location(),
181 rhs: arg.device().location(),
182 op: "cat",
183 }
184 .bt())?
185 }
186 if rank != arg.rank() {
187 Err(Error::UnexpectedNumberOfDims {
188 expected: rank,
189 got: arg.rank(),
190 shape: arg.shape().clone(),
191 }
192 .bt())?
193 }
194 for (dim_idx, (v1, v2)) in arg0
195 .shape()
196 .dims()
197 .iter()
198 .zip(arg.shape().dims().iter())
199 .enumerate()
200 {
201 if dim_idx == dim {
202 cat_dims[dim] += v2;
203 }
204 if dim_idx != dim && v1 != v2 {
205 Err(Error::ShapeMismatchCat {
206 dim: dim_idx,
207 first_shape: arg0.shape().clone(),
208 n: arg_idx + 1,
209 nth_shape: arg.shape().clone(),
210 }
211 .bt())?
212 }
213 }
214 }
215 let cat_target_dim_len = cat_dims[dim];
216 let block_size: usize = cat_dims.iter().skip(1 + dim).product();
217 let shape = Shape::from(cat_dims);
218 let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
219 let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
220 let mut dst_o = 0;
221 for arg in args.iter() {
222 let arg = arg.as_ref();
223 let arg_dims = arg.shape().dims();
224 let d1: usize = arg_dims.iter().take(dim).product();
225 let d2 = block_size * arg_dims[dim];
226 let dst_s = block_size * cat_target_dim_len;
227 let src_o = arg.layout().start_offset();
228 let src_s = if arg.is_contiguous() {
229 d2
230 } else {
231 arg.layout().outer_stride_for_dim(dim).unwrap_or(d2)
232 };
233 arg.storage()
234 .copy2d(&mut storage, d1, d2, src_s, dst_s, src_o, dst_o)?;
235 dst_o += d2;
236 }
237 Ok(crate::tensor::from_storage(storage, shape, op, false))
238 }
239
240 pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
248 let dim = dim.to_index(self.shape(), "slice-set")?;
249 if !self.is_contiguous() || !src.is_contiguous() {
250 Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
251 }
252 if self.same_storage(src) {
253 crate::bail!("cannot use slice_set when self and src share their storage")
254 }
255 if self.dtype() != src.dtype() {
256 Err(Error::DTypeMismatchBinaryOp {
257 lhs: self.dtype(),
258 rhs: src.dtype(),
259 op: "slice-set",
260 }
261 .bt())?
262 }
263 if self.device().location() != src.device().location() {
264 Err(Error::DeviceMismatchBinaryOp {
265 lhs: self.device().location(),
266 rhs: src.device().location(),
267 op: "slice-set",
268 }
269 .bt())?
270 }
271 if self.rank() != src.rank() {
272 Err(Error::UnexpectedNumberOfDims {
273 expected: self.rank(),
274 got: src.rank(),
275 shape: self.shape().clone(),
276 }
277 .bt())?
278 }
279 for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
280 if dim_idx == dim && *v2 + offset > *v1 {
281 crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
282 }
283 if dim_idx != dim && v1 != v2 {
284 crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
285 }
286 }
287 let block_size: usize = src.dims().iter().skip(1 + dim).product();
288 let d1: usize = src.dims().iter().take(dim).product();
289 let d2 = block_size * src.dims()[dim];
290 let dst_o = self.layout().start_offset() + offset * block_size;
291 let src_o = src.layout().start_offset();
292 src.storage().copy2d(
293 &mut self.storage_mut(),
294 d1,
295 d2,
296 d2,
297 block_size * self.dims()[dim],
298 src_o,
299 dst_o,
300 )?;
301
302 Ok(())
303 }
304}