use std::collections::HashMap;
use super::ID;
use csv::Reader;
use std::fs::File;
pub trait DataHandler {
fn get_user_ids(&self) -> Vec<ID>;
fn get_item_ids(&self) -> Vec<ID>;
fn get_user_ratings(&self, user_id: ID) -> HashMap<ID, f64>;
fn get_item_ratings(&self, item_id: ID) -> HashMap<ID, f64>;
fn get_rating(&self, user_id: ID, item_id: ID) -> f64;
fn get_num_users(&self) -> usize;
fn get_num_items(&self) -> usize;
fn add_user(&mut self, user_id: ID) -> bool;
fn add_item(&mut self, item_id: ID) -> bool;
fn add_rating(&mut self, user_id: ID, item_id: ID, rating: f64) -> bool;
}
pub struct BasicDataHandler {
user_ratings: HashMap<ID, HashMap<ID, f64>>,
item_ratings: HashMap<ID, HashMap<ID, f64>>
}
impl BasicDataHandler {
pub fn new() -> BasicDataHandler {
let user_ratings: HashMap<ID, HashMap<ID, f64>> = HashMap::new();
let item_ratings: HashMap<ID, HashMap<ID, f64>> = HashMap::new();
BasicDataHandler {
user_ratings: user_ratings,
item_ratings: item_ratings
}
}
pub fn from_reader(mut reader: Reader<File>) -> BasicDataHandler {
let mut user_ratings: HashMap<ID, HashMap<ID, f64>> = HashMap::new();
let mut item_ratings: HashMap<ID, HashMap<ID, f64>> = HashMap::new();
for row in reader.decode() {
let (user_id, item_id, rating): (ID, ID, f64) = row.unwrap();
if user_ratings.contains_key(&user_id) {
let user_rating = user_ratings.get_mut(&user_id).unwrap();
user_rating.insert(item_id, rating);
}
else {
let mut user_rating: HashMap<ID, f64> = HashMap::new();
user_rating.insert(item_id, rating);
user_ratings.insert(user_id, user_rating);
}
if item_ratings.contains_key(&item_id) {
let item_rating = item_ratings.get_mut(&item_id).unwrap();
item_rating.insert(user_id, rating);
}
else {
let mut item_rating: HashMap<ID, f64> = HashMap::new();
item_rating.insert(user_id, rating);
item_ratings.insert(item_id, item_rating);
}
}
BasicDataHandler {
user_ratings: user_ratings,
item_ratings: item_ratings
}
}
}
impl DataHandler for BasicDataHandler {
fn get_user_ids(&self) -> Vec<ID> {
self.user_ratings.keys().cloned().collect()
}
fn get_item_ids(&self) -> Vec<ID> {
self.item_ratings.keys().cloned().collect()
}
fn get_user_ratings(&self, user_id: ID) -> HashMap<ID, f64> {
self.user_ratings.get(&user_id).unwrap().clone()
}
fn get_item_ratings(&self, item_id: ID) -> HashMap<ID, f64> {
self.item_ratings.get(&item_id).unwrap().clone()
}
fn get_rating(&self, user_id: ID, item_id: ID) -> f64 {
*self.user_ratings.get(&user_id).unwrap().get(&item_id).unwrap_or(&-1.0)
}
fn get_num_users(&self) -> usize {
self.user_ratings.len()
}
fn get_num_items(&self) -> usize {
self.item_ratings.len()
}
fn add_user(&mut self, user_id: ID) -> bool {
if !self.user_ratings.contains_key(&user_id) {
self.user_ratings.insert(user_id, HashMap::new());
return true;
}
false
}
fn add_item(&mut self, item_id: ID) -> bool {
if !self.item_ratings.contains_key(&item_id) {
self.item_ratings.insert(item_id, HashMap::new());
return true;
}
false
}
fn add_rating(&mut self, user_id: ID, item_id: ID, rating: f64) -> bool {
if self.user_ratings.contains_key(&user_id) && self.item_ratings.contains_key(&item_id) {
let user_rating = self.user_ratings.get_mut(&user_id).unwrap();
let item_rating = self.item_ratings.get_mut(&item_id).unwrap();
user_rating.insert(item_id, rating);
item_rating.insert(user_id, rating);
return true;
}
false
}
}