1use crate::{shape::Dim, Context, Error, Result, Shape, Tensor};
2
3impl Tensor {
4 pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
22 if args.is_empty() {
23 Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
24 }
25 let arg0 = args[0].as_ref();
26 if args.len() == 1 {
27 return Ok(arg0.clone());
28 }
29 let dim = dim.to_index(arg0.shape(), "cat")?;
30 for arg in args {
31 arg.as_ref().check_dim(dim, "cat")?;
32 }
33 for (arg_idx, arg) in args.iter().enumerate() {
34 let arg = arg.as_ref();
35 if arg0.rank() != arg.rank() {
36 Err(Error::UnexpectedNumberOfDims {
37 expected: arg0.rank(),
38 got: arg.rank(),
39 shape: arg.shape().clone(),
40 }
41 .bt())?
42 }
43 for (dim_idx, (v1, v2)) in arg0
44 .shape()
45 .dims()
46 .iter()
47 .zip(arg.shape().dims().iter())
48 .enumerate()
49 {
50 if dim_idx != dim && v1 != v2 {
51 Err(Error::ShapeMismatchCat {
52 dim: dim_idx,
53 first_shape: arg0.shape().clone(),
54 n: arg_idx + 1,
55 nth_shape: arg.shape().clone(),
56 }
57 .bt())?
58 }
59 }
60 }
61 let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
62 if all_contiguous {
63 Self::cat_contiguous(args, dim)
64 } else if dim == 0 {
65 Self::cat0(args)
66 } else {
67 let args: Vec<Tensor> = args
68 .iter()
69 .map(|a| a.as_ref().transpose(0, dim))
70 .collect::<Result<Vec<_>>>()?;
71 let cat = Self::cat0(&args)?;
72 cat.transpose(0, dim)
73 }
74 }
75
76 fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
77 if args.is_empty() {
78 Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
79 }
80 let arg0 = args[0].as_ref();
81 if args.len() == 1 {
82 return Ok(arg0.clone());
83 }
84 let rank = arg0.rank();
85 let device = arg0.device();
86 let dtype = arg0.dtype();
87 let first_dims = arg0.shape().dims();
88 let mut cat_dims = first_dims.to_vec();
89 cat_dims[0] = 0;
90 let mut offsets = vec![0usize];
91 for (arg_idx, arg) in args.iter().enumerate() {
92 let arg = arg.as_ref();
93 if arg.dtype() != dtype {
94 Err(Error::DTypeMismatchBinaryOp {
95 lhs: dtype,
96 rhs: arg.dtype(),
97 op: "cat",
98 }
99 .bt())?
100 }
101 if arg.device().location() != device.location() {
102 Err(Error::DeviceMismatchBinaryOp {
103 lhs: device.location(),
104 rhs: arg.device().location(),
105 op: "cat",
106 }
107 .bt())?
108 }
109 if rank != arg.rank() {
110 Err(Error::UnexpectedNumberOfDims {
111 expected: rank,
112 got: arg.rank(),
113 shape: arg.shape().clone(),
114 }
115 .bt())?
116 }
117 for (dim_idx, (v1, v2)) in arg0
118 .shape()
119 .dims()
120 .iter()
121 .zip(arg.shape().dims().iter())
122 .enumerate()
123 {
124 if dim_idx == 0 {
125 cat_dims[0] += v2;
126 }
127 if dim_idx != 0 && v1 != v2 {
128 Err(Error::ShapeMismatchCat {
129 dim: dim_idx,
130 first_shape: arg0.shape().clone(),
131 n: arg_idx + 1,
132 nth_shape: arg.shape().clone(),
133 }
134 .bt())?
135 }
136 }
137 let next_offset = offsets.last().context("empty offsets")? + arg.elem_count();
138 offsets.push(next_offset);
139 }
140 let shape = Shape::from(cat_dims);
141 let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
142 let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
143 for (arg, &offset) in args.iter().zip(offsets.iter()) {
144 let arg = arg.as_ref();
145 arg.storage()
146 .copy_strided_src(&mut storage, offset, arg.layout())?;
147 }
148 Ok(crate::tensor::from_storage(storage, shape, op, false))
149 }
150
151 fn cat_contiguous<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
152 if args.is_empty() {
153 Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
154 }
155 let arg0 = args[0].as_ref();
156 if args.len() == 1 {
157 return Ok(arg0.clone());
158 }
159 let rank = arg0.rank();
160 let device = arg0.device();
161 let dtype = arg0.dtype();
162 let first_dims = arg0.shape().dims();
163 let mut cat_dims = first_dims.to_vec();
164 cat_dims[dim] = 0;
165 for (arg_idx, arg) in args.iter().enumerate() {
166 let arg = arg.as_ref();
167 if arg.dtype() != dtype {
168 Err(Error::DTypeMismatchBinaryOp {
169 lhs: dtype,
170 rhs: arg.dtype(),
171 op: "cat",
172 }
173 .bt())?
174 }
175 if arg.device().location() != device.location() {
176 Err(Error::DeviceMismatchBinaryOp {
177 lhs: device.location(),
178 rhs: arg.device().location(),
179 op: "cat",
180 }
181 .bt())?
182 }
183 if rank != arg.rank() {
184 Err(Error::UnexpectedNumberOfDims {
185 expected: rank,
186 got: arg.rank(),
187 shape: arg.shape().clone(),
188 }
189 .bt())?
190 }
191 for (dim_idx, (v1, v2)) in arg0
192 .shape()
193 .dims()
194 .iter()
195 .zip(arg.shape().dims().iter())
196 .enumerate()
197 {
198 if dim_idx == dim {
199 cat_dims[dim] += v2;
200 }
201 if dim_idx != dim && v1 != v2 {
202 Err(Error::ShapeMismatchCat {
203 dim: dim_idx,
204 first_shape: arg0.shape().clone(),
205 n: arg_idx + 1,
206 nth_shape: arg.shape().clone(),
207 }
208 .bt())?
209 }
210 }
211 }
212 let cat_target_dim_len = cat_dims[dim];
213 let block_size: usize = cat_dims.iter().skip(1 + dim).product();
214 let shape = Shape::from(cat_dims);
215 let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
216 let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
217 let mut dst_o = 0;
218 for arg in args.iter() {
219 let arg = arg.as_ref();
220 let arg_dims = arg.shape().dims();
221 let d1: usize = arg_dims.iter().take(dim).product();
222 let d2 = block_size * arg_dims[dim];
223 let dst_s = block_size * cat_target_dim_len;
224 let src_o = arg.layout().start_offset();
225 arg.storage().copy2d(
226 &mut storage,
227 d1,
228 d2,
229 d2,
230 dst_s,
231 src_o,
232 dst_o,
233 )?;
234 dst_o += d2;
235 }
236 Ok(crate::tensor::from_storage(storage, shape, op, false))
237 }
238
239 pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
247 let dim = dim.to_index(self.shape(), "slice-set")?;
248 if !self.is_contiguous() || !src.is_contiguous() {
249 Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
250 }
251 if self.same_storage(src) {
252 crate::bail!("cannot use slice_set when self and src share their storage")
253 }
254 if self.dtype() != src.dtype() {
255 Err(Error::DTypeMismatchBinaryOp {
256 lhs: self.dtype(),
257 rhs: src.dtype(),
258 op: "slice-set",
259 }
260 .bt())?
261 }
262 if self.device().location() != src.device().location() {
263 Err(Error::DeviceMismatchBinaryOp {
264 lhs: self.device().location(),
265 rhs: src.device().location(),
266 op: "slice-set",
267 }
268 .bt())?
269 }
270 if self.rank() != src.rank() {
271 Err(Error::UnexpectedNumberOfDims {
272 expected: self.rank(),
273 got: src.rank(),
274 shape: self.shape().clone(),
275 }
276 .bt())?
277 }
278 for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
279 if dim_idx == dim && *v2 + offset > *v1 {
280 crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
281 }
282 if dim_idx != dim && v1 != v2 {
283 crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
284 }
285 }
286 let block_size: usize = src.dims().iter().skip(1 + dim).product();
287 let d1: usize = src.dims().iter().take(dim).product();
288 let d2 = block_size * src.dims()[dim];
289 let dst_o = self.layout().start_offset() + offset * block_size;
290 let src_o = src.layout().start_offset();
291 src.storage().copy2d(
292 &mut self.storage_mut(),
293 d1,
294 d2,
295 d2,
296 block_size * self.dims()[dim],
297 src_o,
298 dst_o,
299 )?;
300
301 Ok(())
302 }
303}