opensrdk_kernel_method/
lib.rs

1pub extern crate opensrdk_linear_algebra;
2extern crate rayon;
3extern crate thiserror;
4
5pub use add::*;
6pub use ard::*;
7pub use constant::*;
8pub use convolutional::*;
9pub use exponential::*;
10pub use instant::*;
11pub use linear::*;
12pub use mul::*;
13pub use neural_network::{deep_neural_network::*, relu::*};
14use opensrdk_symbolic_computation::Expression;
15pub use periodic::*;
16pub use rbf::*;
17pub use spectral_mixture::*;
18
19use std::fmt::Debug;
20
21pub mod add;
22pub mod ard;
23pub mod constant;
24pub mod convolutional;
25pub mod exponential;
26pub mod instant;
27pub mod linear;
28pub mod mul;
29pub mod neural_network;
30pub mod periodic;
31pub mod rbf;
32pub mod spectral_mixture;
33
34pub trait Value: Clone + Debug + Send + Sync {}
35impl<T> Value for T where T: Clone + Debug + Send + Sync {}
36
37pub trait PositiveDefiniteKernel: Clone + Debug + Send + Sync {
38    fn params_len(&self) -> usize;
39
40    fn expression(
41        &self,
42        x: Expression,
43        x_prime: Expression,
44        params: &[Expression],
45    ) -> Result<Expression, KernelError>;
46}
47
48#[derive(thiserror::Error, Debug)]
49pub enum KernelError {
50    #[error("parameters length mismatch")]
51    ParametersLengthMismatch,
52    #[error("invalid parameter")]
53    InvalidParameter,
54    #[error("invalid argument")]
55    InvalidArgument,
56}
57
58// #[cfg(test)]
59// mod tests {
60//     use crate::*;
61//     #[test]
62//     fn it_works() {
63//         let kernel = RBF + Constant * Linear + Constant * Periodic + Constant * ARD(3);
64//         let test_value = kernel
65//             .value(
66//                 &vec![1.0; kernel.params_len()],
67//                 &vec![1.0, 2.0, 3.0],
68//                 &vec![30.0, 20.0, 10.0],
69//             )
70//             .unwrap();
71
72//         println!("{}", test_value);
73//     }
74// }