use crate::{shapes::*, tensor::*};
mod cpu_kernel;
#[cfg(feature = "cuda")]
mod cuda_kernel;
pub trait SliceKernel<E: Unit>: Storage<E> {
fn forward<Src: Shape + SliceShape<Slice>, Slice>(
&self,
inp: &Tensor<Src, E, Self>,
slice: &Slice,
) -> Result<Tensor<Src::Sliced, E, Self>, Self::Err>;
fn backward<Src: Shape + SliceShape<Slice>, Slice>(
&self,
inp: &Tensor<Src, E, Self>,
grad_inp: &mut Self::Vec,
grad_out: &Self::Vec,
slice: &Slice,
) -> Result<(), Self::Err>;
}
pub fn slice<S: SliceShape<Slice>, E: Unit, D: SliceKernel<E>, T: Tape<E, D>, Slice: 'static>(
tensor: Tensor<S, E, D, T>,
slice: Slice,
) -> Tensor<S::Sliced, E, D, T> {
tensor.slice(slice)
}
impl<S: Shape, E: Unit, D: SliceKernel<E>, T: Tape<E, D>> Tensor<S, E, D, T> {
pub fn try_slice<Slice>(self, slice: Slice) -> Result<Tensor<S::Sliced, E, D, T>, D::Err>
where
S: SliceShape<Slice>,
Slice: 'static,
{
let (inp, mut tape) = self.split_tape();
let out = inp.device.forward(&inp, &slice)?;
let inp_ghost = inp.ghost();
let out_ghost = out.ghost();
tape.add_backward_op(move |grads| {
grads.try_alloc_for(&inp_ghost)?;
grads.try_alloc_for(&out_ghost)?;
let (grad_inp, grad_out) = grads.mut_and_ref(&inp_ghost, &out_ghost);
inp.device.backward(&inp, grad_inp, grad_out, &slice)
});
Ok(out.put_tape(tape))
}
pub fn slice<Slice>(self, slice: Slice) -> Tensor<S::Sliced, E, D, T>
where
S: SliceShape<Slice>,
Slice: 'static,
{
self.try_slice(slice).unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{tensor_ops::*, tests::*};
#[test]
fn test_slice() {
let dev = TestDevice::default();
let a = dev
.tensor([
[1., 2., 3., 4.],
[5., 6., 7., 8.],
[9., 10., 11., 12.],
[13., 14., 15., 16.],
])
.to_dtype::<TestDtype>();
let b: Tensor<Rank2<2, 2>, _, _> = a.clone().slice((2.., 2..)).realize();
assert_close_to_literal!(b, [[11., 12.], [15., 16.]]);
let b: Tensor<Rank2<2, 2>, _, _> = a.clone().slice((1..3, 1..3)).realize();
assert_close_to_literal!(b, [[6., 7.], [10., 11.]]);
let b: Tensor<Rank2<1, 3>, _, _> = a.clone().slice((..1, 1..4)).realize();
assert_close_to_literal!(b, [[2., 3., 4.]]);
let b: Tensor<Rank2<2, 3>, _, _> = a.clone().slice((1..3, ..3)).realize();
assert_close_to_literal!(b, [[5., 6., 7.], [9., 10., 11.]]);
let b: Tensor<Rank2<2, 3>, _, _> = a.clone().slice((1..=2, 1..=3)).realize();
assert_close_to_literal!(b, [[6., 7., 8.], [10., 11., 12.]]);
let b: Tensor<Rank2<2, 2>, _, _> = a.clone().slice((0..=1, 2..=3)).realize();
assert_close_to_literal!(b, [[3., 4.], [7., 8.]]);
let b: Tensor<Rank2<3, 2>, _, _> = a.clone().slice((1.., ..2)).realize();
assert_close_to_literal!(b, [[5., 6.], [9., 10.], [13., 14.]]);
let b: Tensor<Rank2<2, 2>, _, _> = a.slice((..2, 2..)).realize();
assert_close_to_literal!(b, [[3., 4.], [7., 8.]]);
}
#[test]
fn test_slice_broadcast_top() {
let dev = TestDevice::default();
let a = dev
.tensor([1., 2., 3., 4.])
.to_dtype::<TestDtype>()
.broadcast::<Rank2<5, 4>, _>();
let b: Tensor<Rank2<3, 4>, _, _> = a.clone().slice((..3, ..)).realize();
assert_close_to_literal!(b, [[1., 2., 3., 4.]; 3]);
let b: Tensor<Rank2<5, 2>, _, _> = a.clone().slice((.., 1..3)).realize();
assert_close_to_literal!(b, [[2., 3.]; 5]);
let b: Tensor<Rank2<2, 2>, _, _> = a.clone().slice((1..3, 1..3)).realize();
assert_close_to_literal!(b, [[2., 3.], [2., 3.]]);
let b: Tensor<Rank2<3, 3>, _, _> = a.slice((1..4, 1..4)).realize();
assert_close_to_literal!(b, [[2., 3., 4.]; 3]);
}
#[test]
fn test_slice_broadcast_bottom() {
let dev = TestDevice::default();
let a: Tensor<Rank2<4, 5>, TestDtype, _> = dev
.tensor([1., 2., 3., 4.])
.to_dtype::<TestDtype>()
.broadcast();
let b: Tensor<Rank2<2, 5>, _, _> = a.clone().slice((1..3, ..)).realize();
assert_close_to_literal!(b, [[2.; 5], [3.; 5]]);
let b: Tensor<Rank2<4, 2>, _, _> = a.clone().slice((.., 1..3)).realize();
assert_close_to_literal!(b, [[1., 1.], [2., 2.], [3., 3.], [4., 4.]]);
let b: Tensor<Rank2<2, 2>, _, _> = a.clone().slice((1..3, 3..)).realize();
assert_close_to_literal!(b, [[2., 2.], [3., 3.]]);
let b: Tensor<Rank2<2, 2>, _, _> = a.slice((..2, 1..3)).realize();
assert_close_to_literal!(b, [[1., 1.], [2., 2.]]);
}
#[test]
fn test_slice_backward() {
let dev = TestDevice::default();
let a = dev
.tensor([
[1., 2., 3., 4.],
[5., 6., 7., 8.],
[9., 10., 11., 12.],
[13., 14., 15., 16.],
])
.to_dtype::<TestDtype>();
let b: Tensor<Rank2<2, 2>, _, _, _> = a.leaky_trace().slice((2.., 2..)).realize();
assert_close_to_literal!(b, [[11., 12.], [15., 16.]]);
let g = b.square().sum().backward();
assert_close_to_literal!(
g.get(&a),
[[0.; 4], [0.; 4], [0., 0., 22., 24.], [0., 0., 30., 32.]]
);
}
}