use scirs2_core::ndarray::Array2;
use scirs2_core::numeric::Float;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct TaskData<F: Float + Debug> {
pub support_x: Array2<F>,
pub support_y: Array2<F>,
pub query_x: Array2<F>,
pub query_y: Array2<F>,
}
impl<F: Float + Debug> TaskData<F> {
pub fn new(
support_x: Array2<F>,
support_y: Array2<F>,
query_x: Array2<F>,
query_y: Array2<F>,
) -> Self {
Self {
support_x,
support_y,
query_x,
query_y,
}
}
pub fn support_size(&self) -> usize {
self.support_x.nrows()
}
pub fn query_size(&self) -> usize {
self.query_x.nrows()
}
pub fn input_dim(&self) -> usize {
self.support_x.ncols()
}
pub fn output_dim(&self) -> usize {
self.support_y.ncols()
}
}