1use std::io::{self, BufRead, BufReader, Lines, Read};
5
6use crate::encodings::{encode, FeatureEncoding};
7use crate::errors::NrpsError;
8use crate::predictors::predictions::PredictionCategory;
9use crate::svm::kernels::{Kernel, LinearKernel, RBFKernel};
10use crate::svm::vectors::{FeatureVector, SupportVector};
11
12#[derive(Debug)]
13pub enum KernelType {
14 Linear,
15 Polynomial,
16 RBF,
17 Sigmoid,
18 Custom,
19}
20
21#[derive(Debug)]
22pub struct SVMlightModel {
23 pub name: String,
24 pub category: PredictionCategory,
25 pub vectors: Vec<SupportVector>,
26 pub bias: f64,
27 pub encoding: FeatureEncoding,
28 pub kernel_type: KernelType,
29 pub kernel: Box<dyn Kernel>,
30}
31
32impl SVMlightModel {
33 pub fn new(
34 name: String,
35 category: PredictionCategory,
36 vectors: Vec<SupportVector>,
37 bias: f64,
38 encoding: FeatureEncoding,
39 kernel_type: KernelType,
40 gamma: f64,
41 ) -> Self {
42 let kernel: Box<dyn Kernel>;
43 match kernel_type {
44 KernelType::Linear => kernel = Box::new(LinearKernel {}),
45 KernelType::RBF => kernel = Box::new(RBFKernel::new(gamma)),
46 _ => unimplemented!(),
47 }
48 SVMlightModel {
49 name,
50 category,
51 vectors,
52 bias,
53 encoding,
54 kernel_type,
55 kernel,
56 }
57 }
58
59 pub fn predict(&self, vec: &FeatureVector) -> Result<f64, NrpsError> {
60 let res: Result<f64, NrpsError> = self.vectors.iter().try_fold(0.0, |sum, svec| {
61 Ok(sum + svec.yalpha * self.kernel.compute(svec, vec)?)
62 });
63 Ok(res? - self.bias)
64 }
65
66 pub fn encode(&self, sequence: &String) -> Vec<f64> {
67 encode(sequence, &self.encoding, &self.category)
68 }
69
70 pub fn predict_seq(&self, sequence: &String) -> Result<f64, NrpsError> {
71 let fvec = FeatureVector::new(self.encode(sequence));
72 self.predict(&fvec)
73 }
74
75 pub fn from_handle<R>(
76 handle: R,
77 name: String,
78 category: PredictionCategory,
79 ) -> Result<Self, NrpsError>
80 where
81 R: Read,
82 {
83 let mut line_iter = io::BufReader::new(handle).lines();
84 line_iter.next(); let kernel_type = match parse_int(&mut line_iter)? {
87 0 => KernelType::Linear,
88 2 => KernelType::RBF,
89 _ => {
90 return Err(NrpsError::InvalidFeatureLine(
91 "Failed to match kernel type".to_string(),
92 ))
93 }
94 };
95
96 line_iter.next(); let gamma: f64 = parse_float(&mut line_iter)?;
99
100 line_iter.next(); line_iter.next(); line_iter.next(); let dimensions = parse_int(&mut line_iter)?;
105
106 let encoding = match dimensions {
107 102 => FeatureEncoding::Wold,
108 408 => FeatureEncoding::Rausch,
109 510 => FeatureEncoding::Blin,
110 _ => {
111 return Err(NrpsError::InvalidFeatureLine(format!(
112 "Can't determine encoding type from {} features",
113 dimensions
114 )));
115 }
116 };
117
118 line_iter.next(); let num_vecs = parse_int(&mut line_iter)?;
120
121 let bias = parse_float(&mut line_iter)?;
122
123 let mut vectors = Vec::with_capacity(num_vecs);
124
125 while let Some(line_res) = line_iter.next() {
126 let svec = SupportVector::from_line(line_res?, dimensions)?;
127 vectors.push(svec);
128 }
129
130 Ok(SVMlightModel::new(
131 name,
132 category,
133 vectors,
134 bias,
135 encoding,
136 kernel_type,
137 gamma,
138 ))
139 }
140}
141
142fn parse_float(line_iter: &mut Lines<BufReader<impl Read>>) -> Result<f64, NrpsError> {
143 if let Some(line_result) = line_iter.next() {
144 if let Some(raw_value) = line_result?.trim_end().splitn(2, "#").next() {
145 return Ok(raw_value.trim().parse::<f64>()?);
146 }
147 }
148 Err(NrpsError::InvalidFeatureLine(
149 "Failed to read line".to_string(),
150 ))
151}
152
153fn parse_int(line_iter: &mut Lines<BufReader<impl Read>>) -> Result<usize, NrpsError> {
154 if let Some(line_result) = line_iter.next() {
155 if let Some(raw_value) = line_result?.trim_end().splitn(2, "#").next() {
156 return Ok(raw_value.trim().parse::<usize>()?);
157 }
158 }
159 Err(NrpsError::InvalidFeatureLine(
160 "Failed to read line".to_string(),
161 ))
162}