use std::cell::{Ref, RefCell};
use std::fmt;
use std::fmt::Formatter;
use serde::{Deserialize, Serialize};
use crate::datapoint::DataPoint;
use crate::error::LtrError;
#[derive(Clone, Serialize, Deserialize)]
pub struct RankList {
data_points: RefCell<Vec<DataPoint>>,
}
impl RankList {
pub fn new(data_points: Vec<DataPoint>) -> RankList {
RankList {
data_points: RefCell::new(data_points),
}
}
pub fn len(&self) -> usize {
self.data_points.borrow().len()
}
pub fn get(&self, index: usize) -> Result<Ref<DataPoint>, LtrError> {
if index < self.len() {
Ok(Ref::map(self.data_points.borrow(), |dp| &dp[index]))
} else {
Err(LtrError::RankListIndexOutOfBounds(index))
}
}
pub fn set(&self, index: usize, data_point: DataPoint) -> Result<(), LtrError> {
if index < self.len() {
self.data_points.borrow_mut()[index] = data_point;
Ok(())
} else {
Err(LtrError::RankListIndexOutOfBounds(index))
}
}
pub fn rank(&self) -> Result<(), LtrError> {
self.data_points
.borrow_mut()
.sort_by(|a, b| b.partial_cmp(&a).unwrap());
Ok(())
}
pub fn rank_by_feature(&self, feature_index: usize) -> Result<(), LtrError> {
self.data_points.borrow_mut().sort_by(|a, b| {
b.get_feature(feature_index)
.unwrap()
.partial_cmp(&a.get_feature(feature_index).unwrap())
.unwrap()
});
Ok(())
}
pub fn permute(&self, permutation: Vec<usize>) -> Result<(), LtrError> {
let mut new_data_points = Vec::with_capacity(self.data_points.borrow().len());
for i in permutation {
match self.data_points.borrow().get(i) {
Some(dp) => new_data_points.push(dp.clone()),
None => return Err(LtrError::RankListIndexOutOfBounds(i)),
}
}
self.data_points.replace(new_data_points);
Ok(())
}
}
pub struct RankListIter<'a> {
rank_list: &'a RankList,
index: usize,
}
impl<'a> Iterator for RankListIter<'a> {
type Item = Ref<'a, DataPoint>;
fn next(&mut self) -> Option<Ref<'a, DataPoint>> {
if self.index < self.rank_list.len() {
self.index += 1;
Some(Ref::map(self.rank_list.data_points.borrow(), |dp| {
&dp[self.index - 1]
}))
} else {
None
}
}
}
impl<'a> IntoIterator for &'a RankList {
type Item = Ref<'a, DataPoint>;
type IntoIter = RankListIter<'a>;
fn into_iter(self) -> Self::IntoIter {
RankListIter {
rank_list: self,
index: 0,
}
}
}
impl From<Vec<DataPoint>> for RankList {
fn from(data_points: Vec<DataPoint>) -> RankList {
RankList {
data_points: RefCell::new(data_points),
}
}
}
impl fmt::Display for RankList {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(
f,
"RankList object with {} data points",
self.data_points.borrow().len()
)
}
}
#[macro_export]
macro_rules! rl {
($(($label:expr, $query_id:expr, $features:expr)),*) => {
{
let mut data_points = Vec::new();
$(
data_points.push(crate::dp!($label, $query_id, $features));
)*
RankList::new(data_points)
}
};
($(($label:expr, $query_id:expr, $features:expr, $description:expr)),*) => {
{
let mut data_points = Vec::new();
$(
data_points.push(crate::dp!($label, $query_id, $features, $description));
)*
RankList::new(data_points)
}
};
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{dp, loader::svmlight::*};
#[test]
fn test_ranklist() {
let rank_list = rl!(
(0, 9, vec![10.0, 1.2, 4.3, 5.4], "doc1"),
(1, 9, vec![11.0, 2.2, 4.5, 5.6], "doc2"),
(0, 9, vec![12.0, 2.5, 4.7, 5.2], "doc3")
);
assert_eq!(rank_list.len(), 3);
let another_rank_list = rank_list.clone();
assert_eq!(another_rank_list.len(), 3);
assert!(rank_list.get(0).is_ok());
assert!(rank_list.get(1).is_ok());
assert!(rank_list.get(2).is_ok());
assert!(rank_list.get(3).is_err());
let first_data_point = rank_list.get(0).unwrap();
assert_eq!(first_data_point.get_label(), 0);
assert_eq!(first_data_point.get_query_id(), 9);
assert_eq!(*first_data_point.get_feature(1).unwrap(), 10.0f32);
let second_data_point = rank_list.get(1).unwrap();
assert_eq!(second_data_point.get_label(), 1);
assert_eq!(second_data_point.get_query_id(), 9);
assert_eq!(*second_data_point.get_feature(2).unwrap(), 2.2f32);
let third_data_point = rank_list.get(2).unwrap();
assert_eq!(third_data_point.get_label(), 0);
assert_eq!(third_data_point.get_query_id(), 9);
assert_eq!(*third_data_point.get_feature(3).unwrap(), 4.7f32);
let string_representation = format!("{}", rank_list);
assert_eq!(string_representation, "RankList object with 3 data points");
let partial_rank_list = rank_list.clone();
partial_rank_list.rank_by_feature(1).unwrap();
assert_eq!(partial_rank_list.len(), 3);
assert_eq!(
partial_rank_list.get(0).unwrap().get_description().unwrap(),
"doc3"
);
assert_eq!(
partial_rank_list.get(1).unwrap().get_description().unwrap(),
"doc2"
);
assert_eq!(
partial_rank_list.get(2).unwrap().get_description().unwrap(),
"doc1"
);
let full_rank_list = rank_list.clone();
full_rank_list.rank().unwrap();
assert_eq!(full_rank_list.len(), 3);
assert_eq!(
full_rank_list.get(0).unwrap().get_description().unwrap(),
"doc2"
);
assert_eq!(
full_rank_list.get(1).unwrap().get_description().unwrap(),
"doc1"
);
assert_eq!(
full_rank_list.get(2).unwrap().get_description().unwrap(),
"doc3"
);
let permuted_rank_list = rank_list.clone();
let permutation = vec![1, 2, 0];
let invalid_permutation = vec![1, 2, 3];
assert!(permuted_rank_list.permute(invalid_permutation).is_err());
permuted_rank_list.permute(permutation).unwrap();
assert_eq!(
permuted_rank_list
.get(0)
.unwrap()
.get_description()
.unwrap(),
"doc2"
);
assert_eq!(
permuted_rank_list
.get(1)
.unwrap()
.get_description()
.unwrap(),
"doc3"
);
assert_eq!(
permuted_rank_list
.get(2)
.unwrap()
.get_description()
.unwrap(),
"doc1"
);
let set_rank_list = rank_list.clone();
let new_dp = SVMLight::load_datapoint("2 qid:9 1:10 2:1.2 3:4.3 4:5.4 # doc23").unwrap();
set_rank_list.set(0, new_dp.clone()).unwrap();
assert_eq!(
set_rank_list.get(0).unwrap().get_description().unwrap(),
"doc23"
);
match set_rank_list.set(100, new_dp) {
Err(er) => assert_eq!(er, LtrError::RankListIndexOutOfBounds(100 as usize)),
_ => unreachable!(),
};
}
#[test]
fn test_ranklist_iterator() {
let rank_list: RankList = RankList::from(vec![
dp!(0, 9, vec![10.0, 1.2, 4.3, 5.4], "doc1"),
dp!(1, 9, vec![11.0, 2.2, 4.5, 5.6], "doc2"),
dp!(0, 9, vec![12.0, 2.5, 4.7, 5.2], "doc3"),
]);
assert_eq!(rank_list.len(), 3);
for (i, data_point) in rank_list.into_iter().enumerate() {
assert_eq!(
data_point.get_label(),
rank_list.get(i).unwrap().get_label()
);
assert_eq!(
data_point.get_query_id(),
rank_list.get(i).unwrap().get_query_id()
);
assert_eq!(
*data_point.get_feature(1).unwrap(),
*rank_list.get(i).unwrap().get_feature(1).unwrap()
);
}
}
}