mod sizes_or_tensor;
pub use sizes_or_tensor::SizesOrTensor;
use crate::{Graph, Tensor, ns_number_array_from_slice};
use objc2::{msg_send, rc::Retained};
use objc2_foundation::{NSArray, NSString};
impl Graph {
pub fn split<'a>(
&self,
tensor: &Tensor,
split_sizes: SizesOrTensor<'a>,
axis: i64,
name: Option<&str>,
) -> Box<[Retained<Tensor>]> {
match split_sizes {
SizesOrTensor::Sizes(split_sizes) => {
let result: Retained<NSArray<Tensor>> = unsafe {
msg_send![
self,
splitTensor: tensor,
splitSizes: &*ns_number_array_from_slice(split_sizes),
axis: axis,
name: name.map(NSString::from_str).as_deref(),
]
};
result.to_vec().into_boxed_slice()
}
SizesOrTensor::Tensor(split_sizes_tensor) => unsafe {
let result: Retained<NSArray<Tensor>> = unsafe {
msg_send![
self,
splitTensor: tensor,
splitSizesTensor: split_sizes_tensor,
axis: axis,
name: name.map(NSString::from_str).as_deref(),
]
};
result.to_vec().into_boxed_slice()
},
}
}
pub fn split_num_splits(
&self,
tensor: &Tensor,
num_splits: u64,
axis: i64,
name: Option<&str>,
) -> Box<[Retained<Tensor>]> {
let result: Retained<NSArray<Tensor>> = unsafe {
msg_send![
self,
splitTensor: tensor,
numSplits: num_splits,
axis: axis,
name: name.map(NSString::from_str).as_deref(),
]
};
result.to_vec().into_boxed_slice()
}
}