use std::cmp::Ordering;
use std::fmt;
use std::ops::Index;
use serde::{Deserialize, Serialize};
use crate::{error::LtrError, Feature};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataPoint {
label: u8,
query_id: u32,
features: Vec<Feature>,
description: Option<String>,
}
impl DataPoint {
pub fn empty() -> DataPoint {
DataPoint {
label: 0,
query_id: 0,
features: Vec::new(),
description: None,
}
}
pub fn new(
label: u8,
query_id: u32,
features: Vec<Feature>,
description: Option<&str>,
) -> DataPoint {
DataPoint {
label,
query_id,
features,
description: description.map(|s| s.to_string()), }
}
pub fn get_label(&self) -> u8 {
self.label
}
pub fn get_query_id(&self) -> u32 {
self.query_id
}
pub fn get_features(&self) -> &Vec<Feature> {
&self.features
}
pub fn get_feature(&self, index: usize) -> Result<&Feature, LtrError> {
if index == 0 || index > self.features.len() {
return Err(LtrError::FeatureIndexOutOfBounds(index));
}
Ok(&self.features[index - 1])
}
pub fn get_description(&self) -> Option<&String> {
self.description.as_ref()
}
pub fn set_label(&mut self, label: u8) {
self.label = label;
}
pub fn set_query_id(&mut self, query_id: u32) {
self.query_id = query_id;
}
pub fn add_feature(&mut self, feature: Feature) -> Result<(), LtrError> {
self.features.push(feature);
Ok(())
}
pub fn set_feature(&mut self, index: usize, feature: Feature) -> Result<(), LtrError> {
if index > self.features.len() {
return Err(LtrError::FeatureIndexOutOfBounds(index));
}
self.features[index - 1] = feature;
Ok(())
}
pub fn set_features(&mut self, features: Vec<Feature>) -> Result<(), LtrError> {
self.features = features;
Ok(())
}
pub fn set_description(&mut self, description: &str) {
self.description = Some(description.to_string());
}
}
impl fmt::Display for DataPoint {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"DataPoint: label={}, query_id={}, features={:?}, description={:?}",
self.label, self.query_id, self.features, self.description
)
}
}
impl PartialEq for DataPoint {
fn eq(&self, other: &Self) -> bool {
self.label == other.label && self.query_id == other.query_id
}
fn ne(&self, other: &Self) -> bool {
!self.eq(other)
}
}
impl PartialOrd for DataPoint {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.label.cmp(&other.label))
}
}
impl Index<usize> for DataPoint {
type Output = Feature;
fn index(&self, index: usize) -> &Self::Output {
self.get_feature(index).unwrap()
}
}
#[macro_export]
macro_rules! dp {
($label:expr, $query_id:expr, $features:expr) => {
DataPoint::new($label, $query_id, $features, None)
};
($label:expr, $query_id:expr, $features:expr, $description:expr) => {
DataPoint::new($label, $query_id, $features, Some($description))
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_data_point_new() {
let features: Vec<Feature> = vec![1.2, 3.4, 5.6];
let mut data_point = dp!(1, 2, features.clone(), "This is a test");
assert_eq!(data_point.get_label(), 1);
assert_eq!(data_point.get_query_id(), 2);
assert_eq!(data_point.get_features(), &features);
assert_eq!(
data_point.get_description(),
Some(&"This is a test".to_string())
);
let formatted_data_point = format!("{}", data_point);
assert_eq!(formatted_data_point, "DataPoint: label=1, query_id=2, features=[1.2, 3.4, 5.6], description=Some(\"This is a test\")");
let cloned_data_point = data_point.clone();
assert_eq!(cloned_data_point, data_point);
assert_eq!(data_point, data_point);
assert_eq!(
cloned_data_point,
DataPoint::new(1, 2, vec![0.0], Some("This is a test"))
);
assert_ne!(
cloned_data_point,
DataPoint::new(2, 4, vec![1.2, 3.4, 5.6], Some("This is a test"))
);
data_point.set_label(2);
data_point.set_query_id(4);
data_point.set_description("This is another test");
assert_eq!(data_point.get_label(), 2);
assert_eq!(data_point.get_query_id(), 4);
assert_eq!(
data_point.get_description(),
Some(&"This is another test".to_string())
);
}
#[test]
fn test_update_features() {
let mut mydp = dp!(1, 2, vec![1.2, 3.4, 5.6], "This is a test");
assert_eq!(mydp.get_features(), &vec![1.2, 3.4, 5.6]);
match mydp.get_feature(0) {
Ok(_) => assert!(false),
Err(er) => assert_eq!(er, LtrError::FeatureIndexOutOfBounds(0 as usize)),
}
mydp.add_feature(20.0).unwrap();
assert_eq!(mydp.get_feature(4), Ok(&20.0));
let snapshot = mydp.clone();
mydp.set_feature(4, 21.0).unwrap();
assert_eq!(mydp.get_feature(4), Ok(&21.0));
assert_ne!(mydp.get_features(), snapshot.get_features());
assert_eq!(mydp, snapshot);
mydp.set_label(2);
assert!(mydp > snapshot);
}
}