opensrdk_kernel_method/
instant.rs

1use std::{
2    marker::PhantomData,
3    ops::{Add, Mul},
4};
5
6use crate::KernelError;
7
8use std::fmt::Debug;
9
10use super::{KernelAdd, KernelMul, PositiveDefiniteKernel};
11use opensrdk_symbolic_computation::Expression;
12
13#[derive(Clone)]
14pub struct InstantKernel<F>
15where
16    F: Fn(&Expression, &Expression, &[Expression]) -> Result<Expression, KernelError>
17        + Clone
18        + Send
19        + Sync,
20{
21    params_len: usize,
22    function: F,
23    phantom: PhantomData<Expression>,
24}
25
26impl<F> Debug for InstantKernel<F>
27where
28    F: Fn(&Expression, &Expression, &[Expression]) -> Result<Expression, KernelError>
29        + Clone
30        + Send
31        + Sync,
32{
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        write!(f, "InstantKernel {{ params_len: {} }}", self.params_len)
35    }
36}
37
38impl<F> PositiveDefiniteKernel for InstantKernel<F>
39where
40    F: Fn(&Expression, &Expression, &[Expression]) -> Result<Expression, KernelError>
41        + Clone
42        + Send
43        + Sync,
44{
45    fn expression(
46        &self,
47        x: Expression,
48        x_prime: Expression,
49        params: &[Expression],
50    ) -> Result<Expression, KernelError> {
51        (self.function)(&x, &x_prime, params)
52    }
53
54    fn params_len(&self) -> usize {
55        self.params_len
56    }
57}
58
59impl<R, F> Add<R> for InstantKernel<F>
60where
61    R: PositiveDefiniteKernel,
62    F: Fn(&Expression, &Expression, &[Expression]) -> Result<Expression, KernelError>
63        + Clone
64        + Send
65        + Sync,
66{
67    type Output = KernelAdd<Self, R>;
68
69    fn add(self, rhs: R) -> Self::Output {
70        KernelAdd::new(self, rhs)
71    }
72}
73
74impl<R, F> Mul<R> for InstantKernel<F>
75where
76    F: Fn(&Expression, &Expression, &[Expression]) -> Result<Expression, KernelError>
77        + Clone
78        + Send
79        + Sync,
80    R: PositiveDefiniteKernel,
81{
82    type Output = KernelMul<Self, R>;
83
84    fn mul(self, rhs: R) -> Self::Output {
85        KernelMul::new(self, rhs)
86    }
87}
88
89// use super::PositiveDefiniteKernel;
90// use crate::KernelError;
91// use crate::Value;
92// use crate::{KernelAdd, KernelMul};
93// use std::marker::PhantomData;
94// use std::{fmt::Debug, ops::Add, ops::Mul};
95
96// #[derive(Clone)]
97// pub struct InstantKernel<T, F>
98// where
99//     T: Value,
100//     F: Fn(&[f64], &T, &T) -> Result<f64, KernelError> + Clone + Send + Sync,
101// {
102//     params_len: usize,
103//     value_function: F,
104//     phantom: PhantomData<T>,
105// }
106
107// impl<T, F> InstantKernel<T, F>
108// where
109//     T: Value,
110//     F: Fn(&[f64], &T, &T) -> Result<f64, KernelError> + Clone + Send + Sync,
111// {
112//     pub fn new(params_len: usize, value_function: F) -> Self {
113//         Self {
114//             params_len,
115//             value_function,
116//             phantom: PhantomData,
117//         }
118//     }
119// }
120
121// impl<T, F> Debug for InstantKernel<T, F>
122// where
123//     T: Value,
124//     F: Fn(&[f64], &T, &T) -> Result<f64, KernelError> + Clone + Send + Sync,
125// {
126//     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127//         write!(f, "InstantKernel {{ params_len: {} }}", self.params_len)
128//     }
129// }
130
131// impl<T, F> PositiveDefiniteKernel<T> for InstantKernel<T, F>
132// where
133//     T: Value,
134//     F: Fn(&[f64], &T, &T) -> Result<f64, KernelError> + Clone + Send + Sync,
135// {
136//     fn params_len(&self) -> usize {
137//         self.params_len
138//     }
139
140//     fn value(&self, params: &[f64], x: &T, xprime: &T) -> Result<f64, KernelError> {
141//         (self.value_function)(params, x, xprime)
142//     }
143// }
144
145// impl<T, R, F> Add<R> for InstantKernel<T, F>
146// where
147//     T: Value,
148//     R: PositiveDefiniteKernel<T>,
149//     F: Fn(&[f64], &T, &T) -> Result<f64, KernelError> + Clone + Send + Sync,
150// {
151//     type Output = KernelAdd<Self, R, T>;
152
153//     fn add(self, rhs: R) -> Self::Output {
154//         Self::Output::new(self, rhs)
155//     }
156// }
157
158// impl<T, R, F> Mul<R> for InstantKernel<T, F>
159// where
160//     T: Value,
161//     R: PositiveDefiniteKernel<T>,
162//     F: Fn(&[f64], &T, &T) -> Result<f64, KernelError> + Clone + Send + Sync,
163// {
164//     type Output = KernelMul<Self, R, T>;
165
166//     fn mul(self, rhs: R) -> Self::Output {
167//         Self::Output::new(self, rhs)
168//     }
169// }
170
171// #[cfg(test)]
172// mod tests {
173//     use crate::*;
174//     #[test]
175//     fn it_works() {
176//         let kernel = RBF + InstantKernel::new(0, |_, _, _| Ok(0.0));
177
178//         //let (func, grad) = kernel
179//         //    .value_with_grad(&[1.0, 1.0], &vec![1.0, 2.0, 3.0], &vec![3.0, 2.0, 1.0])
180//         //    .unwrap();
181
182//         //println!("{}", func);
183//         //println!("{:#?}", grad);
184
185//         let test_value = kernel
186//             .value(&[1.0, 1.0], &vec![1.0, 2.0, 3.0], &vec![3.0, 2.0, 1.0])
187//             .unwrap();
188
189//         println!("{}", test_value);
190//     }
191// }