use crate::datapoint::DataPoint;
use crate::error::LtrError;
use crate::ranklist::RankList;
use crate::DataSet;
use super::LtrFormat;
pub struct SVMLight;
impl SVMLight {
pub fn load_datapoint(buffer: &str) -> Result<DataPoint, LtrError> {
let mut data_point = DataPoint::empty();
let mut buffer_iter = buffer.split('#');
let buffer_str = buffer_iter.next().ok_or(LtrError::ParseError(
"Error in SVMLight::load_datapoint: Description processing failure.",
))?;
if let Some(info) = buffer_iter.next() {
data_point.set_description(info.trim());
}
let mut iter = buffer_str.trim().split(' ');
let label = iter
.next()
.ok_or(LtrError::InvalidDataPoint("Missing the label parameter."))?;
data_point.set_label(
label
.parse::<u8>()
.map_err(|_| LtrError::InvalidDataPoint("Invalid label parameter."))?,
);
let qid = iter
.next()
.ok_or(LtrError::InvalidDataPoint("Missing the qid parameter."))?;
let mut qid_iter = qid.split(':');
qid_iter.next().ok_or(LtrError::ParseError(
"Error in SVMLight::load_datapoint: Query ID processing failure.",
))?; let qid_str = qid_iter.next().ok_or(LtrError::ParseError(
"Error in SVMLight::load_datapoint: Query ID processing failure",
))?;
let qid = qid_str
.parse::<u32>()
.map_err(|_| LtrError::InvalidDataPoint("Invalid qid parameter."))?;
data_point.set_query_id(qid);
let mut feature_values = Vec::new();
for feature in iter {
let mut feature_iter = feature.split(':');
let index = feature_iter
.next()
.ok_or(LtrError::InvalidDataPoint("Missing feature index."))?
.parse::<usize>()
.map_err(|_| LtrError::InvalidDataPoint("Invalid feature index."))?;
let value = feature_iter
.next()
.ok_or(LtrError::InvalidDataPoint("Missing feature value."))?
.parse::<f32>()
.map_err(|_| LtrError::InvalidDataPoint("Invalid feature value."))?;
if index > feature_values.len() {
feature_values.resize(index as usize, 0.0);
}
feature_values[index - 1] = value;
}
data_point.set_features(feature_values)?;
Ok(data_point)
}
pub fn load_ranklist(buffer: &str) -> Result<RankList, LtrError> {
let mut data_points = Vec::new();
let mut buffer_iter = buffer.split('\n');
while let Some(line) = buffer_iter.next() {
if line.is_empty() {
continue;
}
let data_point = SVMLight::load_datapoint(line)?;
data_points.push(data_point);
}
Ok(RankList::new(data_points))
}
pub fn load_dataset(buffer: &str) -> Result<DataSet, LtrError> {
let mut buffer_iter = buffer.split('\n');
let mut dataset: DataSet = DataSet::new();
let mut current_query_id = 0;
let mut current_rank_list = Vec::new();
while let Some(line) = buffer_iter.next() {
if line.is_empty() {
continue;
}
let dp = SVMLight::load_datapoint(line)?;
if dp.get_query_id() == current_query_id || current_query_id == 0 {
current_query_id = dp.get_query_id();
current_rank_list.push(dp);
} else {
let ranklist = RankList::new(current_rank_list.clone());
dataset.push(ranklist);
current_rank_list.clear();
current_query_id = dp.get_query_id();
current_rank_list.push(dp);
}
}
let ranklist = RankList::new(current_rank_list);
dataset.push(ranklist);
Ok(dataset)
}
}
impl LtrFormat for SVMLight {
fn load(path: &str) -> Result<DataSet, LtrError> {
let buffer = match std::fs::read_to_string(path) {
Ok(buffer) => buffer,
Err(e) => return Err(LtrError::IOError(e.to_string())),
};
SVMLight::load_dataset(&buffer)
}
fn save(_path: &str, _dataset: &DataSet) -> Result<(), LtrError> {
unimplemented!()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_svm_light_parser() {
let buffer = "1 qid:10 1:21.00 2:2.30 3:4.50 # desc";
let data_point = SVMLight::load_datapoint(buffer).unwrap();
assert_eq!(data_point.get_label(), 1);
assert_eq!(data_point.get_query_id(), 10);
assert_eq!(data_point.get_description(), Some(&"desc".to_string()));
assert_eq!(data_point.get_features().len(), 3);
assert_eq!(*data_point.get_feature(1).unwrap(), 21.0f32);
assert_eq!(*data_point.get_feature(2).unwrap(), 2.3f32);
assert_eq!(*data_point.get_feature(3).unwrap(), 4.5f32);
let buffer_without_description: &str = "20 qid:9 1:1.00 2:222.30 3:444.50";
let data_point = SVMLight::load_datapoint(buffer_without_description).unwrap();
assert_eq!(data_point.get_description(), None);
assert_eq!(data_point.get_label(), 20);
assert_eq!(data_point.get_query_id(), 9);
assert_eq!(data_point.get_features().len(), 3);
assert_eq!(*data_point.get_feature(1).unwrap(), 1.0f32);
assert_eq!(*data_point.get_feature(2).unwrap(), 222.3f32);
assert_eq!(*data_point.get_feature(3).unwrap(), 444.5f32)
}
}