use std::ptr;
use flodl_sys::{self as ffi, FlodlTensor};
use super::{Tensor, TensorError, check_err, Result};
impl Tensor {
pub fn reshape(&self, shape: &[i64]) -> Result<Tensor> {
let mut shape = shape.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_reshape(self.handle, shape.as_mut_ptr(), shape.len() as i32, &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn transpose(&self, dim0: i32, dim1: i32) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_transpose(self.handle, dim0, dim1, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn expand(&self, shape: &[i64]) -> Result<Tensor> {
let mut shape = shape.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_expand(self.handle, shape.as_mut_ptr(), shape.len() as i32, &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn permute(&self, dims: &[i64]) -> Result<Tensor> {
let mut dims = dims.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_permute(self.handle, dims.as_mut_ptr(), dims.len() as i32, &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn narrow(&self, dim: i32, start: i64, length: i64) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_narrow(self.handle, dim, start, length, &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn narrow_scatter(&self, src: &Tensor, dim: i32, start: i64) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_narrow_scatter(self.handle, src.handle, dim, start, &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn select(&self, dim: i32, index: i64) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_select(self.handle, dim, index, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn squeeze(&self, dim: i32) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_squeeze(self.handle, dim, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn unsqueeze(&self, dim: i32) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_unsqueeze(self.handle, dim, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn unsqueeze_many(&self, dims: &[i32]) -> Result<Tensor> {
let mut sorted = dims.to_vec();
sorted.sort();
let mut t = self.unsqueeze(sorted[0])?;
for &d in &sorted[1..] {
t = t.unsqueeze(d)?;
}
Ok(t)
}
pub fn flatten(&self, start_dim: i32, end_dim: i32) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_flatten(self.handle, start_dim, end_dim, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn cat(&self, other: &Tensor, dim: i32) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_cat2(self.handle, other.handle, dim, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn cat_many(tensors: &[&Tensor], dim: i32) -> Result<Tensor> {
if tensors.is_empty() {
return Err(TensorError::new("cat_many: empty tensor list"));
}
let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
let mut result: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_cat(handles.as_mut_ptr(), handles.len() as i32, dim, &mut result)
};
check_err(err)?;
Ok(Tensor::from_raw(result))
}
pub fn stack(tensors: &[&Tensor], dim: i32) -> Result<Tensor> {
if tensors.is_empty() {
return Err(TensorError::new("stack: empty tensor list"));
}
let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
let mut result: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_stack(handles.as_mut_ptr(), handles.len() as i32, dim, &mut result)
};
check_err(err)?;
Ok(Tensor::from_raw(result))
}
pub fn repeat(&self, repeats: &[i64]) -> Result<Tensor> {
let mut repeats = repeats.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_repeat(self.handle, repeats.as_mut_ptr(), repeats.len() as i32, &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn pad(&self, padding: &[i64], value: f64) -> Result<Tensor> {
let mut padding = padding.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_pad(
self.handle, padding.as_mut_ptr(), padding.len() as i32,
value, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn pad_mode(&self, padding: &[i64], mode: i32, value: f64) -> Result<Tensor> {
let mut padding = padding.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_pad_mode(
self.handle, padding.as_mut_ptr(), padding.len() as i32,
mode, value, &mut handle,
)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn flip(&self, dims: &[i64]) -> Result<Tensor> {
let mut dims = dims.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe {
ffi::flodl_flip(self.handle, dims.as_mut_ptr(), dims.len() as i32, &mut handle)
};
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn roll(&self, shift: i64, dim: i32) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_roll(self.handle, shift, dim, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn diagonal(&self, offset: i64, dim1: i32, dim2: i32) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_diagonal(self.handle, offset, dim1, dim2, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn movedim(&self, src: i64, dst: i64) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_movedim(self.handle, src, dst, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn tile(&self, reps: &[i64]) -> Result<Tensor> {
let mut reps_buf = reps.to_vec();
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_tile(self.handle, reps_buf.as_mut_ptr(), reps.len() as i32, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn split(&self, split_size: i64, dim: i32) -> Result<Vec<Tensor>> {
let mut results_ptr: *mut FlodlTensor = ptr::null_mut();
let mut count: i32 = 0;
let err = unsafe {
ffi::flodl_split(self.handle, split_size, dim, &mut results_ptr, &mut count)
};
check_err(err)?;
let mut tensors = Vec::with_capacity(count as usize);
for i in 0..count as usize {
let handle = unsafe { *results_ptr.add(i) };
tensors.push(Tensor::from_raw(handle));
}
if !results_ptr.is_null() {
unsafe { ffi::flodl_free_string(results_ptr as *mut i8) };
}
Ok(tensors)
}
pub fn unbind(&self, dim: i32) -> Result<Vec<Tensor>> {
let mut results_ptr: *mut FlodlTensor = ptr::null_mut();
let mut count: i32 = 0;
let err = unsafe {
ffi::flodl_unbind(self.handle, dim, &mut results_ptr, &mut count)
};
check_err(err)?;
let mut tensors = Vec::with_capacity(count as usize);
for i in 0..count as usize {
let handle = unsafe { *results_ptr.add(i) };
tensors.push(Tensor::from_raw(handle));
}
if !results_ptr.is_null() {
unsafe { ffi::flodl_free_string(results_ptr as *mut i8) };
}
Ok(tensors)
}
pub fn contiguous(&self) -> Result<Tensor> {
let mut handle: FlodlTensor = ptr::null_mut();
let err = unsafe { ffi::flodl_contiguous(self.handle, &mut handle) };
check_err(err)?;
Ok(Tensor::from_raw(handle))
}
pub fn chunk(&self, chunks: i32, dim: i32) -> Result<Vec<Tensor>> {
let mut results_ptr: *mut FlodlTensor = ptr::null_mut();
let mut count: i32 = 0;
let err = unsafe {
ffi::flodl_chunk(self.handle, chunks, dim, &mut results_ptr, &mut count)
};
check_err(err)?;
let mut tensors = Vec::with_capacity(count as usize);
for i in 0..count as usize {
let handle = unsafe { *results_ptr.add(i) };
tensors.push(Tensor::from_raw(handle));
}
if !results_ptr.is_null() {
unsafe { ffi::flodl_free_string(results_ptr as *mut i8) };
}
Ok(tensors)
}
pub fn batches(&self, batch_size: i64) -> Result<Vec<Tensor>> {
let n = self.shape()[0];
let mut result = Vec::new();
let mut start = 0i64;
while start < n {
let len = (batch_size).min(n - start);
result.push(self.narrow(0, start, len)?);
start += len;
}
Ok(result)
}
pub fn meshgrid(tensors: &[&Tensor]) -> Result<Vec<Tensor>> {
let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
let mut results_ptr: *mut FlodlTensor = ptr::null_mut();
let mut count: i32 = 0;
let err = unsafe {
ffi::flodl_meshgrid(
handles.as_mut_ptr(), handles.len() as i32,
&mut results_ptr, &mut count,
)
};
check_err(err)?;
let mut out = Vec::with_capacity(count as usize);
for i in 0..count as usize {
let handle = unsafe { *results_ptr.add(i) };
out.push(Tensor::from_raw(handle));
}
if !results_ptr.is_null() {
unsafe { ffi::flodl_free_string(results_ptr as *mut i8) };
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::super::*;
#[test]
fn test_flatten() {
let t = Tensor::ones(&[2, 3, 4], test_opts()).unwrap();
let f = t.flatten(1, 2).unwrap();
assert_eq!(f.shape(), vec![2, 12]);
}
#[test]
fn test_stack() {
let a = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
let b = Tensor::from_f32(&[3.0, 4.0], &[2], test_device()).unwrap();
let c = Tensor::from_f32(&[5.0, 6.0], &[2], test_device()).unwrap();
let s = Tensor::stack(&[&a, &b, &c], 0).unwrap();
assert_eq!(s.shape(), vec![3, 2]);
let data = s.to_f32_vec().unwrap();
assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let s1 = Tensor::stack(&[&a, &b, &c], 1).unwrap();
assert_eq!(s1.shape(), vec![2, 3]);
let data1 = s1.to_f32_vec().unwrap();
assert_eq!(data1, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
}
#[test]
fn test_reshape_transpose_narrow_select() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], test_device()).unwrap();
let r = t.reshape(&[3, 2]).unwrap();
assert_eq!(r.shape(), vec![3, 2]);
assert_eq!(r.to_f32_vec().unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let tr = t.transpose(0, 1).unwrap();
assert_eq!(tr.shape(), vec![3, 2]);
assert_eq!(tr.to_f32_vec().unwrap(), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
let n = t.narrow(1, 0, 2).unwrap();
assert_eq!(n.shape(), vec![2, 2]);
assert_eq!(n.to_f32_vec().unwrap(), vec![1.0, 2.0, 4.0, 5.0]);
let s = t.select(0, 1).unwrap();
assert_eq!(s.shape(), vec![3]);
assert_eq!(s.to_f32_vec().unwrap(), vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_permute_expand() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], test_device()).unwrap();
let p = t.permute(&[1, 0]).unwrap();
assert_eq!(p.shape(), vec![3, 2]);
let s = Tensor::from_f32(&[1.0, 2.0, 3.0], &[1, 3], test_device()).unwrap();
let e = s.expand(&[4, 3]).unwrap();
assert_eq!(e.shape(), vec![4, 3]);
let data = e.to_f32_vec().unwrap();
assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
}
#[test]
fn test_cat_many() {
let a = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
let b = Tensor::from_f32(&[3.0, 4.0, 5.0], &[3], test_device()).unwrap();
let c = Tensor::from_f32(&[6.0], &[1], test_device()).unwrap();
let result = Tensor::cat_many(&[&a, &b, &c], 0).unwrap();
assert_eq!(result.shape(), vec![6]);
assert_eq!(result.to_f32_vec().unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let x = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2], test_device()).unwrap();
let y = Tensor::from_f32(&[5.0, 6.0], &[2, 1], test_device()).unwrap();
let z = Tensor::from_f32(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[2, 3], test_device()).unwrap();
let result2 = Tensor::cat_many(&[&x, &y, &z], 1).unwrap();
assert_eq!(result2.shape(), vec![2, 6]);
assert_eq!(
result2.to_f32_vec().unwrap(),
vec![1.0, 2.0, 5.0, 7.0, 8.0, 9.0, 3.0, 4.0, 6.0, 10.0, 11.0, 12.0]
);
let single = Tensor::cat_many(&[&a], 0).unwrap();
assert_eq!(single.to_f32_vec().unwrap(), vec![1.0, 2.0]);
let empty: Vec<&Tensor> = vec![];
assert!(Tensor::cat_many(&empty, 0).is_err());
}
#[test]
fn test_cat_index_select_index_add() {
let a = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
let b = Tensor::from_f32(&[3.0, 4.0, 5.0], &[3], test_device()).unwrap();
let c = a.cat(&b, 0).unwrap();
assert_eq!(c.to_f32_vec().unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let t = Tensor::from_f32(&[10.0, 20.0, 30.0, 40.0, 50.0], &[5], test_device()).unwrap();
let idx = Tensor::from_i64(&[0, 2, 4], &[3], test_device()).unwrap();
let sel = t.index_select(0, &idx).unwrap();
assert_eq!(sel.to_f32_vec().unwrap(), vec![10.0, 30.0, 50.0]);
let base = Tensor::zeros(&[5], test_opts()).unwrap();
let src = Tensor::from_f32(&[1.0, 1.0, 1.0], &[3], test_device()).unwrap();
let r = base.index_add(0, &idx, &src).unwrap();
let data = r.to_f32_vec().unwrap();
assert!((data[0] - 1.0).abs() < 1e-5);
assert!((data[2] - 1.0).abs() < 1e-5);
assert!((data[4] - 1.0).abs() < 1e-5);
}
#[test]
fn test_narrow_scatter_select_scatter() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[4], test_device()).unwrap();
let src = Tensor::from_f32(&[10.0, 20.0], &[2], test_device()).unwrap();
let ns = t.narrow_scatter(&src, 0, 1).unwrap();
assert_eq!(ns.to_f32_vec().unwrap(), vec![1.0, 10.0, 20.0, 4.0]);
let t2 = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], test_device()).unwrap();
let row = Tensor::from_f32(&[10.0, 20.0, 30.0], &[3], test_device()).unwrap();
let ss = t2.select_scatter(&row, 0, 0).unwrap();
assert_eq!(ss.to_f32_vec().unwrap(), vec![10.0, 20.0, 30.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_chunk_repeat_pad() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6], test_device()).unwrap();
let chunks = t.chunk(3, 0).unwrap();
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0].to_f32_vec().unwrap(), vec![1.0, 2.0]);
assert_eq!(chunks[1].to_f32_vec().unwrap(), vec![3.0, 4.0]);
assert_eq!(chunks[2].to_f32_vec().unwrap(), vec![5.0, 6.0]);
let s = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
let rep = s.repeat(&[3]).unwrap();
assert_eq!(rep.to_f32_vec().unwrap(), vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0]);
let pad = s.pad(&[1, 2], 0.0).unwrap();
assert_eq!(pad.shape(), vec![5]);
assert_eq!(pad.to_f32_vec().unwrap(), vec![0.0, 1.0, 2.0, 0.0, 0.0]);
}
#[test]
fn test_unsqueeze_many() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
let u = t.unsqueeze_many(&[1, 2]).unwrap();
assert_eq!(u.shape(), vec![3, 1, 1]);
let u2 = t.unsqueeze(1).unwrap().unsqueeze(2).unwrap();
assert_eq!(u.shape(), u2.shape());
assert_eq!(u.to_f32_vec().unwrap(), u2.to_f32_vec().unwrap());
}
#[test]
fn test_flip() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2], test_device()).unwrap();
let f = t.flip(&[0]).unwrap().to_f32_vec().unwrap();
assert_eq!(f, vec![3.0, 4.0, 1.0, 2.0]);
let f1 = t.flip(&[1]).unwrap().to_f32_vec().unwrap();
assert_eq!(f1, vec![2.0, 1.0, 4.0, 3.0]);
}
#[test]
fn test_roll() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[4], test_device()).unwrap();
let r = t.roll(1, 0).unwrap().to_f32_vec().unwrap();
assert_eq!(r, vec![4.0, 1.0, 2.0, 3.0]);
let r2 = t.roll(-1, 0).unwrap().to_f32_vec().unwrap();
assert_eq!(r2, vec![2.0, 3.0, 4.0, 1.0]);
}
#[test]
fn test_split() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0], &[5], test_device()).unwrap();
let parts = t.split(2, 0).unwrap();
assert_eq!(parts.len(), 3);
assert_eq!(parts[0].to_f32_vec().unwrap(), vec![1.0, 2.0]);
assert_eq!(parts[1].to_f32_vec().unwrap(), vec![3.0, 4.0]);
assert_eq!(parts[2].to_f32_vec().unwrap(), vec![5.0]);
}
#[test]
fn test_unbind() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2], test_device()).unwrap();
let rows = t.unbind(0).unwrap();
assert_eq!(rows.len(), 3);
assert_eq!(rows[0].to_f32_vec().unwrap(), vec![1.0, 2.0]);
assert_eq!(rows[1].to_f32_vec().unwrap(), vec![3.0, 4.0]);
assert_eq!(rows[2].to_f32_vec().unwrap(), vec![5.0, 6.0]);
}
#[test]
fn test_contiguous() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2], test_device()).unwrap();
let c = t.contiguous().unwrap();
assert_eq!(c.to_f32_vec().unwrap(), t.to_f32_vec().unwrap());
}
#[test]
fn test_meshgrid() {
let a = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
let b = Tensor::from_f32(&[4.0, 5.0], &[2], test_device()).unwrap();
let grids = Tensor::meshgrid(&[&a, &b]).unwrap();
assert_eq!(grids.len(), 2);
assert_eq!(grids[0].shape(), vec![3, 2]);
assert_eq!(grids[1].shape(), vec![3, 2]);
assert_eq!(grids[0].to_f32_vec().unwrap(), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
assert_eq!(grids[1].to_f32_vec().unwrap(), vec![4.0, 5.0, 4.0, 5.0, 4.0, 5.0]);
}
#[test]
fn test_diagonal() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], &[3, 3], test_device()).unwrap();
let d = t.diagonal(0, 0, 1).unwrap().to_f32_vec().unwrap();
assert_eq!(d, vec![1.0, 5.0, 9.0]);
let d1 = t.diagonal(1, 0, 1).unwrap().to_f32_vec().unwrap();
assert_eq!(d1, vec![2.0, 6.0]);
}
#[test]
fn test_movedim() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], test_device()).unwrap();
let m = t.movedim(0, 1).unwrap();
assert_eq!(m.shape(), vec![3, 2]);
}
#[test]
fn test_tile() {
let t = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
let r = t.tile(&[3]).unwrap();
assert_eq!(r.to_f32_vec().unwrap(), vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0]);
let t2 = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2], test_device()).unwrap();
let r2 = t2.tile(&[2, 3]).unwrap();
assert_eq!(r2.shape(), vec![4, 6]);
}
}