nrps_rs/svm/
models.rs

1// License: GNU Affero General Public License v3 or later
2// A copy of GNU AGPL v3 should have been included in this software package in LICENSE.txt.
3
4use std::io::{self, BufRead, BufReader, Lines, Read};
5
6use crate::encodings::{encode, FeatureEncoding};
7use crate::errors::NrpsError;
8use crate::predictors::predictions::PredictionCategory;
9use crate::svm::kernels::{Kernel, LinearKernel, RBFKernel};
10use crate::svm::vectors::{FeatureVector, SupportVector};
11
12#[derive(Debug)]
13pub enum KernelType {
14    Linear,
15    Polynomial,
16    RBF,
17    Sigmoid,
18    Custom,
19}
20
21#[derive(Debug)]
22pub struct SVMlightModel {
23    pub name: String,
24    pub category: PredictionCategory,
25    pub vectors: Vec<SupportVector>,
26    pub bias: f64,
27    pub encoding: FeatureEncoding,
28    pub kernel_type: KernelType,
29    pub kernel: Box<dyn Kernel>,
30}
31
32impl SVMlightModel {
33    pub fn new(
34        name: String,
35        category: PredictionCategory,
36        vectors: Vec<SupportVector>,
37        bias: f64,
38        encoding: FeatureEncoding,
39        kernel_type: KernelType,
40        gamma: f64,
41    ) -> Self {
42        let kernel: Box<dyn Kernel>;
43        match kernel_type {
44            KernelType::Linear => kernel = Box::new(LinearKernel {}),
45            KernelType::RBF => kernel = Box::new(RBFKernel::new(gamma)),
46            _ => unimplemented!(),
47        }
48        SVMlightModel {
49            name,
50            category,
51            vectors,
52            bias,
53            encoding,
54            kernel_type,
55            kernel,
56        }
57    }
58
59    pub fn predict(&self, vec: &FeatureVector) -> Result<f64, NrpsError> {
60        let res: Result<f64, NrpsError> = self.vectors.iter().try_fold(0.0, |sum, svec| {
61            Ok(sum + svec.yalpha * self.kernel.compute(svec, vec)?)
62        });
63        Ok(res? - self.bias)
64    }
65
66    pub fn encode(&self, sequence: &String) -> Vec<f64> {
67        encode(sequence, &self.encoding, &self.category)
68    }
69
70    pub fn predict_seq(&self, sequence: &String) -> Result<f64, NrpsError> {
71        let fvec = FeatureVector::new(self.encode(sequence));
72        self.predict(&fvec)
73    }
74
75    pub fn from_handle<R>(
76        handle: R,
77        name: String,
78        category: PredictionCategory,
79    ) -> Result<Self, NrpsError>
80    where
81        R: Read,
82    {
83        let mut line_iter = io::BufReader::new(handle).lines();
84        line_iter.next(); // skip
85
86        let kernel_type = match parse_int(&mut line_iter)? {
87            0 => KernelType::Linear,
88            2 => KernelType::RBF,
89            _ => {
90                return Err(NrpsError::InvalidFeatureLine(
91                    "Failed to match kernel type".to_string(),
92                ))
93            }
94        };
95
96        line_iter.next(); // skip
97
98        let gamma: f64 = parse_float(&mut line_iter)?;
99
100        line_iter.next(); // skip
101        line_iter.next(); // skip
102        line_iter.next(); // skip
103
104        let dimensions = parse_int(&mut line_iter)?;
105
106        let encoding = match dimensions {
107            102 => FeatureEncoding::Wold,
108            408 => FeatureEncoding::Rausch,
109            510 => FeatureEncoding::Blin,
110            _ => {
111                return Err(NrpsError::InvalidFeatureLine(format!(
112                    "Can't determine encoding type from {} features",
113                    dimensions
114                )));
115            }
116        };
117
118        line_iter.next(); // skip
119        let num_vecs = parse_int(&mut line_iter)?;
120
121        let bias = parse_float(&mut line_iter)?;
122
123        let mut vectors = Vec::with_capacity(num_vecs);
124
125        while let Some(line_res) = line_iter.next() {
126            let svec = SupportVector::from_line(line_res?, dimensions)?;
127            vectors.push(svec);
128        }
129
130        Ok(SVMlightModel::new(
131            name,
132            category,
133            vectors,
134            bias,
135            encoding,
136            kernel_type,
137            gamma,
138        ))
139    }
140}
141
142fn parse_float(line_iter: &mut Lines<BufReader<impl Read>>) -> Result<f64, NrpsError> {
143    if let Some(line_result) = line_iter.next() {
144        if let Some(raw_value) = line_result?.trim_end().splitn(2, "#").next() {
145            return Ok(raw_value.trim().parse::<f64>()?);
146        }
147    }
148    Err(NrpsError::InvalidFeatureLine(
149        "Failed to read line".to_string(),
150    ))
151}
152
153fn parse_int(line_iter: &mut Lines<BufReader<impl Read>>) -> Result<usize, NrpsError> {
154    if let Some(line_result) = line_iter.next() {
155        if let Some(raw_value) = line_result?.trim_end().splitn(2, "#").next() {
156            return Ok(raw_value.trim().parse::<usize>()?);
157        }
158    }
159    Err(NrpsError::InvalidFeatureLine(
160        "Failed to read line".to_string(),
161    ))
162}