use crate::error::{KernelError, Result};
use crate::types::Kernel;
#[derive(Debug, Clone)]
pub struct TaskInput {
pub features: Vec<f64>,
pub task: usize,
}
impl TaskInput {
pub fn new(features: Vec<f64>, task: usize) -> Self {
Self { features, task }
}
pub fn from_slice(features: &[f64], task: usize) -> Self {
Self {
features: features.to_vec(),
task,
}
}
}
#[derive(Debug, Clone)]
pub struct MultiTaskConfig {
pub num_tasks: usize,
pub normalize: bool,
}
impl MultiTaskConfig {
pub fn new(num_tasks: usize) -> Self {
Self {
num_tasks,
normalize: false,
}
}
pub fn with_normalization(mut self) -> Self {
self.normalize = true;
self
}
}
#[derive(Debug, Clone)]
pub struct IndexKernel {
task_covariance: Vec<Vec<f64>>,
num_tasks: usize,
}
impl IndexKernel {
pub fn new(task_covariance: Vec<Vec<f64>>) -> Result<Self> {
let num_tasks = task_covariance.len();
if num_tasks == 0 {
return Err(KernelError::InvalidParameter {
parameter: "task_covariance".to_string(),
value: "empty".to_string(),
reason: "must have at least one task".to_string(),
});
}
for (i, row) in task_covariance.iter().enumerate() {
if row.len() != num_tasks {
return Err(KernelError::InvalidParameter {
parameter: "task_covariance".to_string(),
value: format!("row {} has {} elements", i, row.len()),
reason: format!("expected {} elements (square matrix)", num_tasks),
});
}
}
Ok(Self {
task_covariance,
num_tasks,
})
}
pub fn identity(num_tasks: usize) -> Result<Self> {
let mut cov = vec![vec![0.0; num_tasks]; num_tasks];
for (i, row) in cov.iter_mut().enumerate() {
row[i] = 1.0;
}
Self::new(cov)
}
pub fn uniform(num_tasks: usize, correlation: f64) -> Result<Self> {
if !(0.0..=1.0).contains(&correlation) {
return Err(KernelError::InvalidParameter {
parameter: "correlation".to_string(),
value: correlation.to_string(),
reason: "must be in [0, 1]".to_string(),
});
}
let mut cov = vec![vec![correlation; num_tasks]; num_tasks];
for (i, row) in cov.iter_mut().enumerate() {
row[i] = 1.0;
}
Self::new(cov)
}
pub fn get_task_covariance(&self, task_i: usize, task_j: usize) -> Result<f64> {
if task_i >= self.num_tasks || task_j >= self.num_tasks {
return Err(KernelError::ComputationError(format!(
"Task index out of bounds: ({}, {}) for {} tasks",
task_i, task_j, self.num_tasks
)));
}
Ok(self.task_covariance[task_i][task_j])
}
pub fn num_tasks(&self) -> usize {
self.num_tasks
}
pub fn covariance_matrix(&self) -> &Vec<Vec<f64>> {
&self.task_covariance
}
}
pub struct ICMKernel {
base_kernel: Box<dyn Kernel>,
task_covariance: Vec<Vec<f64>>,
num_tasks: usize,
}
impl ICMKernel {
pub fn new(base_kernel: Box<dyn Kernel>, task_covariance: Vec<Vec<f64>>) -> Result<Self> {
let num_tasks = task_covariance.len();
if num_tasks == 0 {
return Err(KernelError::InvalidParameter {
parameter: "task_covariance".to_string(),
value: "empty".to_string(),
reason: "must have at least one task".to_string(),
});
}
for (i, row) in task_covariance.iter().enumerate() {
if row.len() != num_tasks {
return Err(KernelError::InvalidParameter {
parameter: "task_covariance".to_string(),
value: format!("row {} has {} elements", i, row.len()),
reason: format!("expected {} elements", num_tasks),
});
}
}
Ok(Self {
base_kernel,
task_covariance,
num_tasks,
})
}
pub fn independent(base_kernel: Box<dyn Kernel>, num_tasks: usize) -> Result<Self> {
let mut cov = vec![vec![0.0; num_tasks]; num_tasks];
for (i, row) in cov.iter_mut().enumerate() {
row[i] = 1.0;
}
Self::new(base_kernel, cov)
}
pub fn uniform(
base_kernel: Box<dyn Kernel>,
num_tasks: usize,
correlation: f64,
) -> Result<Self> {
if !(0.0..=1.0).contains(&correlation) {
return Err(KernelError::InvalidParameter {
parameter: "correlation".to_string(),
value: correlation.to_string(),
reason: "must be in [0, 1]".to_string(),
});
}
let mut cov = vec![vec![correlation; num_tasks]; num_tasks];
for (i, row) in cov.iter_mut().enumerate() {
row[i] = 1.0;
}
Self::new(base_kernel, cov)
}
pub fn from_rank1(base_kernel: Box<dyn Kernel>, task_variances: Vec<f64>) -> Result<Self> {
let num_tasks = task_variances.len();
let mut cov = vec![vec![0.0; num_tasks]; num_tasks];
for i in 0..num_tasks {
for j in 0..num_tasks {
cov[i][j] = task_variances[i].sqrt() * task_variances[j].sqrt();
}
}
Self::new(base_kernel, cov)
}
pub fn compute_tasks(&self, x: &TaskInput, y: &TaskInput) -> Result<f64> {
if x.task >= self.num_tasks || y.task >= self.num_tasks {
return Err(KernelError::ComputationError(format!(
"Task index out of bounds: ({}, {}) for {} tasks",
x.task, y.task, self.num_tasks
)));
}
let k_features = self.base_kernel.compute(&x.features, &y.features)?;
let b_tasks = self.task_covariance[x.task][y.task];
Ok(b_tasks * k_features)
}
pub fn num_tasks(&self) -> usize {
self.num_tasks
}
pub fn task_covariance(&self) -> &Vec<Vec<f64>> {
&self.task_covariance
}
pub fn compute_task_matrix(&self, inputs: &[TaskInput]) -> Result<Vec<Vec<f64>>> {
let n = inputs.len();
let mut matrix = vec![vec![0.0; n]; n];
for i in 0..n {
for j in i..n {
let k = self.compute_tasks(&inputs[i], &inputs[j])?;
matrix[i][j] = k;
matrix[j][i] = k;
}
}
Ok(matrix)
}
}
struct LMCComponent {
kernel: Box<dyn Kernel>,
task_covariance: Vec<Vec<f64>>,
}
pub struct LMCKernel {
components: Vec<LMCComponent>,
num_tasks: usize,
}
impl LMCKernel {
pub fn new(num_tasks: usize) -> Self {
Self {
components: Vec::new(),
num_tasks,
}
}
pub fn add_component(
&mut self,
kernel: Box<dyn Kernel>,
task_covariance: Vec<Vec<f64>>,
) -> Result<()> {
if task_covariance.len() != self.num_tasks {
return Err(KernelError::InvalidParameter {
parameter: "task_covariance".to_string(),
value: format!("{} rows", task_covariance.len()),
reason: format!("expected {} rows", self.num_tasks),
});
}
for (i, row) in task_covariance.iter().enumerate() {
if row.len() != self.num_tasks {
return Err(KernelError::InvalidParameter {
parameter: "task_covariance".to_string(),
value: format!("row {} has {} elements", i, row.len()),
reason: format!("expected {} elements", self.num_tasks),
});
}
}
self.components.push(LMCComponent {
kernel,
task_covariance,
});
Ok(())
}
pub fn compute_tasks(&self, x: &TaskInput, y: &TaskInput) -> Result<f64> {
if x.task >= self.num_tasks || y.task >= self.num_tasks {
return Err(KernelError::ComputationError(format!(
"Task index out of bounds: ({}, {}) for {} tasks",
x.task, y.task, self.num_tasks
)));
}
let mut result = 0.0;
for component in &self.components {
let k_features = component.kernel.compute(&x.features, &y.features)?;
let b_tasks = component.task_covariance[x.task][y.task];
result += b_tasks * k_features;
}
Ok(result)
}
pub fn num_components(&self) -> usize {
self.components.len()
}
pub fn num_tasks(&self) -> usize {
self.num_tasks
}
pub fn compute_task_matrix(&self, inputs: &[TaskInput]) -> Result<Vec<Vec<f64>>> {
let n = inputs.len();
let mut matrix = vec![vec![0.0; n]; n];
for i in 0..n {
for j in i..n {
let k = self.compute_tasks(&inputs[i], &inputs[j])?;
matrix[i][j] = k;
matrix[j][i] = k;
}
}
Ok(matrix)
}
}
pub struct ICMKernelWrapper {
inner: ICMKernel,
}
impl ICMKernelWrapper {
pub fn new(inner: ICMKernel) -> Self {
Self { inner }
}
pub fn inner(&self) -> &ICMKernel {
&self.inner
}
}
impl Kernel for ICMKernelWrapper {
fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
if x.is_empty() || y.is_empty() {
return Err(KernelError::ComputationError(
"Input must have at least task index".to_string(),
));
}
let task_x = x[0] as usize;
let task_y = y[0] as usize;
let features_x = &x[1..];
let features_y = &y[1..];
let input_x = TaskInput::from_slice(features_x, task_x);
let input_y = TaskInput::from_slice(features_y, task_y);
self.inner.compute_tasks(&input_x, &input_y)
}
fn name(&self) -> &str {
"ICM"
}
}
pub struct LMCKernelWrapper {
inner: LMCKernel,
}
impl LMCKernelWrapper {
pub fn new(inner: LMCKernel) -> Self {
Self { inner }
}
pub fn inner(&self) -> &LMCKernel {
&self.inner
}
}
impl Kernel for LMCKernelWrapper {
fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
if x.is_empty() || y.is_empty() {
return Err(KernelError::ComputationError(
"Input must have at least task index".to_string(),
));
}
let task_x = x[0] as usize;
let task_y = y[0] as usize;
let features_x = &x[1..];
let features_y = &y[1..];
let input_x = TaskInput::from_slice(features_x, task_x);
let input_y = TaskInput::from_slice(features_y, task_y);
self.inner.compute_tasks(&input_x, &input_y)
}
fn name(&self) -> &str {
"LMC"
}
}
pub struct HadamardTaskKernel {
kernels: Vec<ICMKernel>,
}
impl HadamardTaskKernel {
pub fn new() -> Self {
Self {
kernels: Vec::new(),
}
}
pub fn add_kernel(&mut self, kernel: ICMKernel) -> Result<()> {
if !self.kernels.is_empty() && kernel.num_tasks() != self.kernels[0].num_tasks() {
return Err(KernelError::InvalidParameter {
parameter: "num_tasks".to_string(),
value: kernel.num_tasks().to_string(),
reason: format!("expected {}", self.kernels[0].num_tasks()),
});
}
self.kernels.push(kernel);
Ok(())
}
pub fn compute_tasks(&self, x: &TaskInput, y: &TaskInput) -> Result<f64> {
if self.kernels.is_empty() {
return Err(KernelError::ComputationError(
"No component kernels added".to_string(),
));
}
let mut result = 1.0;
for kernel in &self.kernels {
result *= kernel.compute_tasks(x, y)?;
}
Ok(result)
}
pub fn num_tasks(&self) -> Option<usize> {
self.kernels.first().map(|k| k.num_tasks())
}
}
impl Default for HadamardTaskKernel {
fn default() -> Self {
Self::new()
}
}
pub struct MultiTaskKernelBuilder {
num_tasks: usize,
base_kernels: Vec<Box<dyn Kernel>>,
task_covariances: Vec<Vec<Vec<f64>>>,
}
impl MultiTaskKernelBuilder {
pub fn new(num_tasks: usize) -> Self {
Self {
num_tasks,
base_kernels: Vec::new(),
task_covariances: Vec::new(),
}
}
pub fn add_component(
mut self,
kernel: Box<dyn Kernel>,
task_covariance: Vec<Vec<f64>>,
) -> Self {
self.base_kernels.push(kernel);
self.task_covariances.push(task_covariance);
self
}
pub fn build_icm(self) -> Result<ICMKernel> {
if self.base_kernels.len() != 1 {
return Err(KernelError::InvalidParameter {
parameter: "components".to_string(),
value: self.base_kernels.len().to_string(),
reason: "ICM requires exactly one component".to_string(),
});
}
let kernel = self
.base_kernels
.into_iter()
.next()
.expect("validated non-empty");
let cov = self
.task_covariances
.into_iter()
.next()
.expect("validated non-empty");
ICMKernel::new(kernel, cov)
}
pub fn build_lmc(self) -> Result<LMCKernel> {
let mut lmc = LMCKernel::new(self.num_tasks);
for (kernel, cov) in self.base_kernels.into_iter().zip(self.task_covariances) {
lmc.add_component(kernel, cov)?;
}
Ok(lmc)
}
}
#[cfg(test)]
#[allow(clippy::needless_range_loop)]
mod tests {
use super::*;
use crate::{LinearKernel, RbfKernel, RbfKernelConfig};
#[test]
fn test_index_kernel_basic() {
let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
let kernel = IndexKernel::new(cov).expect("unwrap");
assert_eq!(kernel.num_tasks(), 2);
assert!((kernel.get_task_covariance(0, 1).expect("unwrap") - 0.5).abs() < 1e-10);
assert!((kernel.get_task_covariance(1, 1).expect("unwrap") - 1.0).abs() < 1e-10);
}
#[test]
fn test_index_kernel_identity() {
let kernel = IndexKernel::identity(3).expect("unwrap");
assert!((kernel.get_task_covariance(0, 0).expect("unwrap") - 1.0).abs() < 1e-10);
assert!((kernel.get_task_covariance(0, 1).expect("unwrap")).abs() < 1e-10);
assert!((kernel.get_task_covariance(1, 2).expect("unwrap")).abs() < 1e-10);
}
#[test]
fn test_index_kernel_uniform() {
let kernel = IndexKernel::uniform(3, 0.5).expect("unwrap");
assert!((kernel.get_task_covariance(0, 0).expect("unwrap") - 1.0).abs() < 1e-10);
assert!((kernel.get_task_covariance(0, 1).expect("unwrap") - 0.5).abs() < 1e-10);
assert!((kernel.get_task_covariance(1, 2).expect("unwrap") - 0.5).abs() < 1e-10);
}
#[test]
fn test_index_kernel_invalid() {
let result = IndexKernel::new(vec![]);
assert!(result.is_err());
let result = IndexKernel::new(vec![vec![1.0, 0.5]]);
assert!(result.is_err());
let result = IndexKernel::uniform(3, 1.5);
assert!(result.is_err());
}
#[test]
fn test_index_kernel_out_of_bounds() {
let kernel = IndexKernel::identity(2).expect("unwrap");
assert!(kernel.get_task_covariance(2, 0).is_err());
}
#[test]
fn test_icm_kernel_basic() {
let base = LinearKernel::new();
let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
let icm = ICMKernel::new(Box::new(base), cov).expect("unwrap");
assert_eq!(icm.num_tasks(), 2);
}
#[test]
fn test_icm_kernel_compute() {
let base = LinearKernel::new();
let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
let icm = ICMKernel::new(Box::new(base), cov).expect("unwrap");
let x = TaskInput::new(vec![1.0, 2.0], 0);
let y = TaskInput::new(vec![3.0, 4.0], 1);
let k = icm.compute_tasks(&x, &y).expect("unwrap");
assert!((k - 5.5).abs() < 1e-10);
}
#[test]
fn test_icm_kernel_same_task() {
let base = LinearKernel::new();
let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
let icm = ICMKernel::new(Box::new(base), cov).expect("unwrap");
let x = TaskInput::new(vec![1.0, 2.0], 0);
let y = TaskInput::new(vec![3.0, 4.0], 0);
let k = icm.compute_tasks(&x, &y).expect("unwrap");
assert!((k - 11.0).abs() < 1e-10);
}
#[test]
fn test_icm_kernel_independent() {
let base = LinearKernel::new();
let icm = ICMKernel::independent(Box::new(base), 3).expect("unwrap");
let x = TaskInput::new(vec![1.0], 0);
let y = TaskInput::new(vec![1.0], 1);
let k = icm.compute_tasks(&x, &y).expect("unwrap");
assert!(k.abs() < 1e-10);
let z = TaskInput::new(vec![1.0], 0);
let k = icm.compute_tasks(&x, &z).expect("unwrap");
assert!((k - 1.0).abs() < 1e-10);
}
#[test]
fn test_icm_kernel_uniform() {
let base = LinearKernel::new();
let icm = ICMKernel::uniform(Box::new(base), 2, 0.8).expect("unwrap");
let x = TaskInput::new(vec![1.0], 0);
let y = TaskInput::new(vec![1.0], 1);
let k = icm.compute_tasks(&x, &y).expect("unwrap");
assert!((k - 0.8).abs() < 1e-10);
}
#[test]
fn test_icm_kernel_rank1() {
let base = LinearKernel::new();
let variances = vec![1.0, 4.0]; let icm = ICMKernel::from_rank1(Box::new(base), variances).expect("unwrap");
let x = TaskInput::new(vec![1.0], 0);
let y = TaskInput::new(vec![1.0], 1);
let k = icm.compute_tasks(&x, &y).expect("unwrap");
assert!((k - 2.0).abs() < 1e-10);
}
#[test]
fn test_icm_kernel_matrix() {
let base = LinearKernel::new();
let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
let icm = ICMKernel::new(Box::new(base), cov).expect("unwrap");
let inputs = vec![
TaskInput::new(vec![1.0], 0),
TaskInput::new(vec![1.0], 1),
TaskInput::new(vec![2.0], 0),
];
let matrix = icm.compute_task_matrix(&inputs).expect("unwrap");
assert_eq!(matrix.len(), 3);
for i in 0..3 {
for j in 0..3 {
assert!(
(matrix[i][j] - matrix[j][i]).abs() < 1e-10,
"Matrix not symmetric at ({}, {})",
i,
j
);
}
}
}
#[test]
fn test_icm_kernel_invalid_task() {
let base = LinearKernel::new();
let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
let icm = ICMKernel::new(Box::new(base), cov).expect("unwrap");
let x = TaskInput::new(vec![1.0], 0);
let y = TaskInput::new(vec![1.0], 5);
assert!(icm.compute_tasks(&x, &y).is_err());
}
#[test]
fn test_lmc_kernel_basic() {
let mut lmc = LMCKernel::new(2);
let base1 = LinearKernel::new();
let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
lmc.add_component(Box::new(base1), cov1).expect("unwrap");
assert_eq!(lmc.num_tasks(), 2);
assert_eq!(lmc.num_components(), 1);
}
#[test]
fn test_lmc_kernel_compute() {
let mut lmc = LMCKernel::new(2);
let base1 = LinearKernel::new();
let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
lmc.add_component(Box::new(base1), cov1).expect("unwrap");
let base2 = RbfKernel::new(RbfKernelConfig::new(1.0)).expect("unwrap");
let cov2 = vec![vec![2.0, 1.0], vec![1.0, 2.0]];
lmc.add_component(Box::new(base2), cov2).expect("unwrap");
let x = TaskInput::new(vec![1.0, 0.0], 0);
let y = TaskInput::new(vec![1.0, 0.0], 1);
let k = lmc.compute_tasks(&x, &y).expect("unwrap");
assert!((k - 1.5).abs() < 1e-10);
}
#[test]
fn test_lmc_kernel_matrix() {
let mut lmc = LMCKernel::new(2);
let base = LinearKernel::new();
let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
lmc.add_component(Box::new(base), cov).expect("unwrap");
let inputs = vec![TaskInput::new(vec![1.0], 0), TaskInput::new(vec![1.0], 1)];
let matrix = lmc.compute_task_matrix(&inputs).expect("unwrap");
assert_eq!(matrix.len(), 2);
assert!((matrix[0][1] - matrix[1][0]).abs() < 1e-10);
}
#[test]
fn test_lmc_kernel_invalid_dimensions() {
let mut lmc = LMCKernel::new(2);
let base = LinearKernel::new();
let cov = vec![
vec![1.0, 0.5, 0.3],
vec![0.5, 1.0, 0.4],
vec![0.3, 0.4, 1.0],
];
assert!(lmc.add_component(Box::new(base), cov).is_err());
}
#[test]
fn test_icm_wrapper() {
let base = LinearKernel::new();
let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
let icm = ICMKernel::new(Box::new(base), cov).expect("unwrap");
let wrapper = ICMKernelWrapper::new(icm);
let x = vec![0.0, 1.0, 2.0]; let y = vec![1.0, 3.0, 4.0];
let k = wrapper.compute(&x, &y).expect("unwrap");
assert!((k - 5.5).abs() < 1e-10);
assert_eq!(wrapper.name(), "ICM");
}
#[test]
fn test_lmc_wrapper() {
let mut lmc = LMCKernel::new(2);
let base = LinearKernel::new();
let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
lmc.add_component(Box::new(base), cov).expect("unwrap");
let wrapper = LMCKernelWrapper::new(lmc);
let x = vec![0.0, 1.0]; let y = vec![1.0, 1.0];
let k = wrapper.compute(&x, &y).expect("unwrap");
assert!((k - 0.5).abs() < 1e-10);
assert_eq!(wrapper.name(), "LMC");
}
#[test]
fn test_wrapper_empty_input() {
let base = LinearKernel::new();
let cov = vec![vec![1.0]];
let icm = ICMKernel::new(Box::new(base), cov).expect("unwrap");
let wrapper = ICMKernelWrapper::new(icm);
assert!(wrapper.compute(&[], &[0.0, 1.0]).is_err());
}
#[test]
fn test_hadamard_task_kernel() {
let mut hadamard = HadamardTaskKernel::new();
let base1 = LinearKernel::new();
let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
let icm1 = ICMKernel::new(Box::new(base1), cov1).expect("unwrap");
hadamard.add_kernel(icm1).expect("unwrap");
let base2 = LinearKernel::new();
let cov2 = vec![vec![2.0, 1.0], vec![1.0, 2.0]];
let icm2 = ICMKernel::new(Box::new(base2), cov2).expect("unwrap");
hadamard.add_kernel(icm2).expect("unwrap");
let x = TaskInput::new(vec![1.0], 0);
let y = TaskInput::new(vec![1.0], 1);
let k = hadamard.compute_tasks(&x, &y).expect("unwrap");
assert!((k - 0.5).abs() < 1e-10);
}
#[test]
fn test_hadamard_task_kernel_empty() {
let hadamard = HadamardTaskKernel::new();
let x = TaskInput::new(vec![1.0], 0);
let y = TaskInput::new(vec![1.0], 0);
assert!(hadamard.compute_tasks(&x, &y).is_err());
}
#[test]
fn test_hadamard_mismatched_tasks() {
let mut hadamard = HadamardTaskKernel::new();
let base1 = LinearKernel::new();
let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
let icm1 = ICMKernel::new(Box::new(base1), cov1).expect("unwrap");
hadamard.add_kernel(icm1).expect("unwrap");
let base2 = LinearKernel::new();
let cov2 = vec![
vec![1.0, 0.5, 0.3],
vec![0.5, 1.0, 0.4],
vec![0.3, 0.4, 1.0],
];
let icm2 = ICMKernel::new(Box::new(base2), cov2).expect("unwrap");
assert!(hadamard.add_kernel(icm2).is_err());
}
#[test]
fn test_builder_icm() {
let base = LinearKernel::new();
let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
let icm = MultiTaskKernelBuilder::new(2)
.add_component(Box::new(base), cov)
.build_icm()
.expect("unwrap");
assert_eq!(icm.num_tasks(), 2);
}
#[test]
fn test_builder_lmc() {
let base1 = LinearKernel::new();
let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
let base2 = RbfKernel::new(RbfKernelConfig::new(1.0)).expect("unwrap");
let cov2 = vec![vec![2.0, 1.0], vec![1.0, 2.0]];
let lmc = MultiTaskKernelBuilder::new(2)
.add_component(Box::new(base1), cov1)
.add_component(Box::new(base2), cov2)
.build_lmc()
.expect("unwrap");
assert_eq!(lmc.num_tasks(), 2);
assert_eq!(lmc.num_components(), 2);
}
#[test]
fn test_builder_icm_wrong_components() {
let base1 = LinearKernel::new();
let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
let base2 = LinearKernel::new();
let cov2 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
let result = MultiTaskKernelBuilder::new(2)
.add_component(Box::new(base1), cov1)
.add_component(Box::new(base2), cov2)
.build_icm();
assert!(result.is_err());
}
#[test]
fn test_multitask_with_rbf() {
let base = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("unwrap");
let cov = vec![
vec![1.0, 0.8, 0.6],
vec![0.8, 1.0, 0.7],
vec![0.6, 0.7, 1.0],
];
let icm = ICMKernel::new(Box::new(base), cov).expect("unwrap");
let x = TaskInput::new(vec![1.0, 2.0], 0);
let k = icm.compute_tasks(&x, &x).expect("unwrap");
assert!((k - 1.0).abs() < 1e-10);
let y = TaskInput::new(vec![1.0, 2.0], 1);
let k = icm.compute_tasks(&x, &y).expect("unwrap");
assert!((k - 0.8).abs() < 1e-10);
let z = TaskInput::new(vec![1.0, 3.0], 0);
let k = icm.compute_tasks(&x, &z).expect("unwrap");
assert!(k > 0.5 && k < 0.7);
}
#[test]
fn test_task_input_creation() {
let input = TaskInput::new(vec![1.0, 2.0, 3.0], 0);
assert_eq!(input.features, vec![1.0, 2.0, 3.0]);
assert_eq!(input.task, 0);
let input = TaskInput::from_slice(&[4.0, 5.0], 2);
assert_eq!(input.features, vec![4.0, 5.0]);
assert_eq!(input.task, 2);
}
}