opensrdk_kernel_method/instant.rs
1use std::{
2 marker::PhantomData,
3 ops::{Add, Mul},
4};
5
6use crate::KernelError;
7
8use std::fmt::Debug;
9
10use super::{KernelAdd, KernelMul, PositiveDefiniteKernel};
11use opensrdk_symbolic_computation::Expression;
12
13#[derive(Clone)]
14pub struct InstantKernel<F>
15where
16 F: Fn(&Expression, &Expression, &[Expression]) -> Result<Expression, KernelError>
17 + Clone
18 + Send
19 + Sync,
20{
21 params_len: usize,
22 function: F,
23 phantom: PhantomData<Expression>,
24}
25
26impl<F> Debug for InstantKernel<F>
27where
28 F: Fn(&Expression, &Expression, &[Expression]) -> Result<Expression, KernelError>
29 + Clone
30 + Send
31 + Sync,
32{
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 write!(f, "InstantKernel {{ params_len: {} }}", self.params_len)
35 }
36}
37
38impl<F> PositiveDefiniteKernel for InstantKernel<F>
39where
40 F: Fn(&Expression, &Expression, &[Expression]) -> Result<Expression, KernelError>
41 + Clone
42 + Send
43 + Sync,
44{
45 fn expression(
46 &self,
47 x: Expression,
48 x_prime: Expression,
49 params: &[Expression],
50 ) -> Result<Expression, KernelError> {
51 (self.function)(&x, &x_prime, params)
52 }
53
54 fn params_len(&self) -> usize {
55 self.params_len
56 }
57}
58
59impl<R, F> Add<R> for InstantKernel<F>
60where
61 R: PositiveDefiniteKernel,
62 F: Fn(&Expression, &Expression, &[Expression]) -> Result<Expression, KernelError>
63 + Clone
64 + Send
65 + Sync,
66{
67 type Output = KernelAdd<Self, R>;
68
69 fn add(self, rhs: R) -> Self::Output {
70 KernelAdd::new(self, rhs)
71 }
72}
73
74impl<R, F> Mul<R> for InstantKernel<F>
75where
76 F: Fn(&Expression, &Expression, &[Expression]) -> Result<Expression, KernelError>
77 + Clone
78 + Send
79 + Sync,
80 R: PositiveDefiniteKernel,
81{
82 type Output = KernelMul<Self, R>;
83
84 fn mul(self, rhs: R) -> Self::Output {
85 KernelMul::new(self, rhs)
86 }
87}
88
89// use super::PositiveDefiniteKernel;
90// use crate::KernelError;
91// use crate::Value;
92// use crate::{KernelAdd, KernelMul};
93// use std::marker::PhantomData;
94// use std::{fmt::Debug, ops::Add, ops::Mul};
95
96// #[derive(Clone)]
97// pub struct InstantKernel<T, F>
98// where
99// T: Value,
100// F: Fn(&[f64], &T, &T) -> Result<f64, KernelError> + Clone + Send + Sync,
101// {
102// params_len: usize,
103// value_function: F,
104// phantom: PhantomData<T>,
105// }
106
107// impl<T, F> InstantKernel<T, F>
108// where
109// T: Value,
110// F: Fn(&[f64], &T, &T) -> Result<f64, KernelError> + Clone + Send + Sync,
111// {
112// pub fn new(params_len: usize, value_function: F) -> Self {
113// Self {
114// params_len,
115// value_function,
116// phantom: PhantomData,
117// }
118// }
119// }
120
121// impl<T, F> Debug for InstantKernel<T, F>
122// where
123// T: Value,
124// F: Fn(&[f64], &T, &T) -> Result<f64, KernelError> + Clone + Send + Sync,
125// {
126// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127// write!(f, "InstantKernel {{ params_len: {} }}", self.params_len)
128// }
129// }
130
131// impl<T, F> PositiveDefiniteKernel<T> for InstantKernel<T, F>
132// where
133// T: Value,
134// F: Fn(&[f64], &T, &T) -> Result<f64, KernelError> + Clone + Send + Sync,
135// {
136// fn params_len(&self) -> usize {
137// self.params_len
138// }
139
140// fn value(&self, params: &[f64], x: &T, xprime: &T) -> Result<f64, KernelError> {
141// (self.value_function)(params, x, xprime)
142// }
143// }
144
145// impl<T, R, F> Add<R> for InstantKernel<T, F>
146// where
147// T: Value,
148// R: PositiveDefiniteKernel<T>,
149// F: Fn(&[f64], &T, &T) -> Result<f64, KernelError> + Clone + Send + Sync,
150// {
151// type Output = KernelAdd<Self, R, T>;
152
153// fn add(self, rhs: R) -> Self::Output {
154// Self::Output::new(self, rhs)
155// }
156// }
157
158// impl<T, R, F> Mul<R> for InstantKernel<T, F>
159// where
160// T: Value,
161// R: PositiveDefiniteKernel<T>,
162// F: Fn(&[f64], &T, &T) -> Result<f64, KernelError> + Clone + Send + Sync,
163// {
164// type Output = KernelMul<Self, R, T>;
165
166// fn mul(self, rhs: R) -> Self::Output {
167// Self::Output::new(self, rhs)
168// }
169// }
170
171// #[cfg(test)]
172// mod tests {
173// use crate::*;
174// #[test]
175// fn it_works() {
176// let kernel = RBF + InstantKernel::new(0, |_, _, _| Ok(0.0));
177
178// //let (func, grad) = kernel
179// // .value_with_grad(&[1.0, 1.0], &vec![1.0, 2.0, 3.0], &vec![3.0, 2.0, 1.0])
180// // .unwrap();
181
182// //println!("{}", func);
183// //println!("{:#?}", grad);
184
185// let test_value = kernel
186// .value(&[1.0, 1.0], &vec![1.0, 2.0, 3.0], &vec![3.0, 2.0, 1.0])
187// .unwrap();
188
189// println!("{}", test_value);
190// }
191// }