use crate::Tensor;
use block2::{Block, IntoBlock, RcBlock};
use objc2::rc::Retained;
use objc2_foundation::{NSArray, NSMutableArray};
use std::{ops::Deref, ptr::NonNull};
pub struct WhileBeforeBlock(
RcBlock<dyn Fn(NonNull<NSArray<Tensor>>, NonNull<NSMutableArray<Tensor>>) -> NonNull<Tensor>>,
);
impl WhileBeforeBlock {
pub fn new<F>(while_before_ops: F) -> Self
where
F: Fn(&[&Tensor], &mut [&Tensor]) -> Retained<Tensor> + 'static,
{
Self(RcBlock::new(
move |input_tensors: NonNull<NSArray<Tensor>>,
result_tensors: NonNull<NSMutableArray<Tensor>>| {
let inputs = unsafe { input_tensors.as_ref().to_vec_unchecked() };
let mut results = unsafe { result_tensors.as_ref().to_vec_unchecked() };
let tensor = while_before_ops(&inputs, &mut results);
let raw = Retained::autorelease_return(tensor);
unsafe { NonNull::new_unchecked(raw) }
},
))
}
}
impl Deref for WhileBeforeBlock {
type Target =
Block<dyn Fn(NonNull<NSArray<Tensor>>, NonNull<NSMutableArray<Tensor>>) -> NonNull<Tensor>>;
fn deref(&self) -> &Self::Target {
&*self.0
}
}