use anyhow::{Context, Ok, Result};
use parking_lot::Mutex;
use rusqlite::{Connection, params};
use rusqlite_migration::{M, Migrations};
use std::collections::VecDeque;
use ustr::{Ustr, UstrMap};
use crate::{data::UnitReward, error::PracticeRewardsError, utils};
pub trait PracticeRewards {
fn get_rewards(
&self,
unit_id: Ustr,
num_rewards: u32,
) -> Result<Vec<UnitReward>, PracticeRewardsError>;
fn record_unit_rewards(
&mut self,
rewards: &[UnitReward],
) -> Result<Vec<Ustr>, PracticeRewardsError>;
fn trim_rewards(&mut self, num_rewards: u32) -> Result<(), PracticeRewardsError>;
fn remove_rewards_with_prefix(&mut self, prefix: &str) -> Result<(), PracticeRewardsError>;
}
const SECONDS_IN_DAY: i64 = 86_400;
const WEIGHT_EPSILON: f32 = 0.1;
const MAX_CACHE_SIZE: usize = 10;
struct RewardCache {
cache: UstrMap<VecDeque<UnitReward>>,
}
impl RewardCache {
fn has_similar_reward(&self, unit_id: Ustr, reward: &UnitReward) -> bool {
self.cache
.get(&unit_id)
.and_then(|rewards| {
rewards.iter().find(|r| {
r.value == reward.value
&& (r.timestamp - reward.timestamp).abs() < SECONDS_IN_DAY
&& (r.weight - reward.weight).abs() < WEIGHT_EPSILON
})
})
.is_some()
}
fn add_new_reward(&mut self, unit_id: Ustr, reward: UnitReward) {
let rewards = self.cache.entry(unit_id).or_default();
if rewards.len() >= MAX_CACHE_SIZE {
rewards.pop_front();
}
rewards.push_back(reward);
}
}
pub struct LocalPracticeRewards {
connection: Mutex<Connection>,
cache: RewardCache,
}
impl LocalPracticeRewards {
fn migrations() -> Migrations<'static> {
Migrations::new(vec![
M::up("CREATE TABLE uids(unit_uid INTEGER PRIMARY KEY, unit_id TEXT NOT NULL UNIQUE);")
.down("DROP TABLE uids;"),
M::up(
"CREATE TABLE practice_rewards(
id INTEGER PRIMARY KEY,
unit_uid INTEGER NOT NULL REFERENCES uids(unit_uid),
reward REAL,
weight REAL,
timestamp INTEGER);",
)
.down("DROP TABLE practice_rewards"),
M::up("CREATE INDEX unit_ids ON uids (unit_id);").down("DROP INDEX unit_ids"),
M::up("CREATE INDEX rewards ON practice_rewards (unit_uid, timestamp);")
.down("DROP INDEX rewards"),
])
}
fn init(&mut self) -> Result<()> {
let migrations = Self::migrations();
let mut connection = self.connection.lock();
migrations
.to_latest(&mut connection)
.context("failed to initialize practice rewards DB")
}
fn new(connection: Connection) -> Result<LocalPracticeRewards> {
let mut rewards = LocalPracticeRewards {
connection: Mutex::new(connection),
cache: RewardCache {
cache: UstrMap::default(),
},
};
rewards.init()?;
Ok(rewards)
}
pub fn new_from_disk(db_path: &str) -> Result<LocalPracticeRewards> {
Self::new(utils::new_connection(db_path)?)
}
fn get_rewards_helper(&self, unit_id: Ustr, num_rewards: u32) -> Result<Vec<UnitReward>> {
let connection = self.connection.lock();
let mut stmt = connection.prepare_cached(
"SELECT reward, weight, timestamp from practice_rewards WHERE unit_uid = (
SELECT unit_uid FROM uids WHERE unit_id = $1)
ORDER BY timestamp DESC LIMIT ?2;",
)?;
#[allow(clippy::let_and_return)]
let rows = stmt
.query_map(params![unit_id.as_str(), num_rewards], |row| {
let value = row.get(0)?;
let weight = row.get(1)?;
let timestamp = row.get(2)?;
rusqlite::Result::Ok(UnitReward {
unit_id,
value,
weight,
timestamp,
})
})?
.map(|r| r.context("failed to retrieve rewards from practice rewards DB"))
.collect::<Result<Vec<UnitReward>, _>>()?;
Ok(rows)
}
fn record_unit_rewards_helper(&mut self, rewards: &[UnitReward]) -> Result<Vec<Ustr>> {
let mut updated = Vec::new();
let mut connection = self.connection.lock();
let tx = connection.transaction()?;
{
for reward in rewards {
if self.cache.has_similar_reward(reward.unit_id, reward) {
continue;
}
let mut uid_stmt =
tx.prepare_cached("INSERT OR IGNORE INTO uids(unit_id) VALUES ($1);")?;
uid_stmt.execute(params![reward.unit_id.as_str()])?;
let mut stmt = tx.prepare_cached(
"INSERT INTO practice_rewards (unit_uid, reward, weight, timestamp) VALUES (
(SELECT unit_uid FROM uids WHERE unit_id = $1), $2, $3, $4);",
)?;
stmt.execute(params![
reward.unit_id.as_str(),
reward.value,
reward.weight,
reward.timestamp
])?;
let mut del_stmt = tx.prepare_cached(
"DELETE FROM practice_rewards WHERE id IN (
SELECT id FROM practice_rewards WHERE unit_uid = (
SELECT unit_uid FROM uids WHERE unit_id = $1)
ORDER BY timestamp DESC LIMIT -1 OFFSET 20
);",
)?;
let _ = del_stmt.execute(params![reward.unit_id.as_str()])?;
self.cache.add_new_reward(reward.unit_id, reward.clone());
updated.push(reward.unit_id);
}
}
tx.commit()?;
Ok(updated)
}
fn trim_rewards_helper(&mut self, num_rewards: u32) -> Result<()> {
let connection = self.connection.lock();
for row in connection
.prepare("SELECT unit_uid FROM uids")?
.query_map([], |row| row.get(0))?
{
let unit_uid: i64 = row?;
connection.execute(
"DELETE FROM practice_rewards WHERE id IN (
SELECT id FROM practice_rewards WHERE unit_uid = ?1
ORDER BY timestamp DESC LIMIT -1 OFFSET ?2
)",
params![unit_uid, num_rewards],
)?;
}
Ok(())
}
fn remove_rewards_with_prefix_helper(&mut self, prefix: &str) -> Result<()> {
let connection = self.connection.lock();
for row in connection
.prepare("SELECT unit_uid FROM uids WHERE unit_id LIKE ?1")?
.query_map(params![format!("{}%", prefix)], |row| row.get(0))?
{
let unit_uid: i64 = row?;
connection.execute(
"DELETE FROM practice_rewards WHERE unit_uid = ?1;",
params![unit_uid],
)?;
}
connection.execute_batch("VACUUM;")?;
Ok(())
}
}
impl PracticeRewards for LocalPracticeRewards {
fn get_rewards(
&self,
unit_id: Ustr,
num_rewards: u32,
) -> Result<Vec<UnitReward>, PracticeRewardsError> {
self.get_rewards_helper(unit_id, num_rewards)
.map_err(|e| PracticeRewardsError::GetRewards(unit_id, e))
}
fn record_unit_rewards(
&mut self,
rewards: &[UnitReward],
) -> Result<Vec<Ustr>, PracticeRewardsError> {
self.record_unit_rewards_helper(rewards)
.map_err(PracticeRewardsError::RecordRewards)
}
fn trim_rewards(&mut self, num_rewards: u32) -> Result<(), PracticeRewardsError> {
self.trim_rewards_helper(num_rewards)
.map_err(PracticeRewardsError::TrimReward)
}
fn remove_rewards_with_prefix(&mut self, prefix: &str) -> Result<(), PracticeRewardsError> {
self.remove_rewards_with_prefix_helper(prefix)
.map_err(|e| PracticeRewardsError::RemovePrefix(prefix.to_string(), e))
}
}
#[cfg(test)]
#[cfg_attr(coverage, coverage(off))]
mod test {
use anyhow::{Ok, Result};
use rusqlite::Connection;
use ustr::Ustr;
use crate::{
data::UnitReward,
practice_rewards::{LocalPracticeRewards, PracticeRewards},
};
fn new_tests_rewards() -> Result<Box<dyn PracticeRewards>> {
let practice_rewards = LocalPracticeRewards::new(Connection::open_in_memory()?)?;
Ok(Box::new(practice_rewards))
}
fn assert_rewards(expected_rewards: &[f32], expected_weights: &[f32], actual: &[UnitReward]) {
let only_rewards: Vec<f32> = actual.iter().map(|t| t.value).collect();
assert_eq!(expected_rewards, only_rewards);
let only_weights: Vec<f32> = actual.iter().map(|t| t.weight).collect();
assert_eq!(expected_weights, only_weights);
let timestamps_sorted = actual
.iter()
.enumerate()
.map(|(i, _)| {
if i == 0 {
return true;
}
actual[i - 1].timestamp >= actual[i].timestamp
})
.all(|b| b);
assert!(timestamps_sorted);
}
#[test]
fn basic() -> Result<()> {
let mut practice_rewards = new_tests_rewards()?;
let unit_id = Ustr::from("unit_123");
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id,
value: 3.0,
weight: 1.0,
timestamp: 1,
}])?;
let rewards = practice_rewards.get_rewards(unit_id, 1)?;
assert_rewards(&[3.0], &[1.0], &rewards);
Ok(())
}
#[test]
fn multiple_rewards() -> Result<()> {
let mut practice_rewards = new_tests_rewards()?;
let unit_id = Ustr::from("unit_123");
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id,
value: 3.0,
weight: 1.0,
timestamp: 1,
}])?;
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id,
value: 2.0,
weight: 1.0,
timestamp: 2,
}])?;
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id,
value: -1.0,
weight: 0.05,
timestamp: 3,
}])?;
let one_reward = practice_rewards.get_rewards(unit_id, 1)?;
assert_rewards(&[-1.0], &[0.05], &one_reward);
let three_rewards = practice_rewards.get_rewards(unit_id, 3)?;
assert_rewards(&[-1.0, 2.0, 3.0], &[0.05, 1.0, 1.0], &three_rewards);
let more_rewards = practice_rewards.get_rewards(unit_id, 10)?;
assert_rewards(&[-1.0, 2.0, 3.0], &[0.05, 1.0, 1.0], &more_rewards);
Ok(())
}
#[test]
fn many_rewards() -> Result<()> {
let mut practice_rewards = new_tests_rewards()?;
let unit_id = Ustr::from("unit_123");
for i in 0..20 {
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id,
value: i as f32,
weight: 1.0,
timestamp: i as i64,
}])?;
}
let rewards = practice_rewards.get_rewards(unit_id, 10)?;
let expected_rewards: Vec<f32> = (10..20).rev().map(|i| i as f32).collect();
let expected_weights: Vec<f32> = vec![1.0; 10];
assert_rewards(&expected_rewards, &expected_weights, &rewards);
Ok(())
}
#[test]
fn no_records() -> Result<()> {
let practice_rewards = new_tests_rewards()?;
let rewards = practice_rewards.get_rewards(Ustr::from("unit_123"), 10)?;
assert_rewards(&[], &[], &rewards);
Ok(())
}
#[test]
fn trim_rewards_some_rewards_removed() -> Result<()> {
let mut practice_rewards = new_tests_rewards()?;
let unit1_id = Ustr::from("unit1");
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit1_id,
value: 3.0,
weight: 1.0,
timestamp: 1,
}])?;
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit1_id,
value: 4.0,
weight: 1.0,
timestamp: 2,
}])?;
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit1_id,
value: 5.0,
weight: 1.0,
timestamp: 3,
}])?;
assert_eq!(3, practice_rewards.get_rewards(unit1_id, 10)?.len());
let unit2_id = Ustr::from("unit2");
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit2_id,
value: 1.0,
weight: 1.0,
timestamp: 1,
}])?;
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit2_id,
value: 2.0,
weight: 1.0,
timestamp: 2,
}])?;
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit2_id,
value: 3.0,
weight: 1.0,
timestamp: 3,
}])?;
assert_eq!(3, practice_rewards.get_rewards(unit2_id, 10)?.len());
practice_rewards.trim_rewards(2)?;
let rewards = practice_rewards.get_rewards(unit1_id, 10)?;
assert_rewards(&[5.0, 4.0], &[1.0, 1.0], &rewards);
let rewards = practice_rewards.get_rewards(unit2_id, 10)?;
assert_rewards(&[3.0, 2.0], &[1.0, 1.0], &rewards);
Ok(())
}
#[test]
fn trim_rewards_no_rewards_removed() -> Result<()> {
let mut practice_rewards = new_tests_rewards()?;
let unit1_id = Ustr::from("unit1");
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit1_id,
value: 3.0,
weight: 1.0,
timestamp: 1,
}])?;
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit1_id,
value: 4.0,
weight: 1.0,
timestamp: 2,
}])?;
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit1_id,
value: 5.0,
weight: 1.0,
timestamp: 3,
}])?;
let unit2_id = Ustr::from("unit2");
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit2_id,
value: 1.0,
weight: 1.0,
timestamp: 1,
}])?;
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit2_id,
value: 2.0,
weight: 1.0,
timestamp: 2,
}])?;
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit2_id,
value: 3.0,
weight: 1.0,
timestamp: 3,
}])?;
practice_rewards.trim_rewards(10)?;
let rewards = practice_rewards.get_rewards(unit1_id, 10)?;
assert_rewards(&[5.0, 4.0, 3.0], &[1.0, 1.0, 1.0], &rewards);
let rewards = practice_rewards.get_rewards(unit2_id, 10)?;
assert_rewards(&[3.0, 2.0, 1.0], &[1.0, 1.0, 1.0], &rewards);
Ok(())
}
#[test]
fn remove_rewards_with_prefix() -> Result<()> {
let mut practice_rewards = new_tests_rewards()?;
let unit1_id = Ustr::from("unit1");
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit1_id,
value: 3.0,
weight: 1.0,
timestamp: 1,
}])?;
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit1_id,
value: 4.0,
weight: 1.0,
timestamp: 2,
}])?;
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit1_id,
value: 5.0,
weight: 1.0,
timestamp: 3,
}])?;
let unit2_id = Ustr::from("unit2");
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit2_id,
value: 1.0,
weight: 1.0,
timestamp: 1,
}])?;
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit2_id,
value: 2.0,
weight: 1.0,
timestamp: 2,
}])?;
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit2_id,
value: 3.0,
weight: 1.0,
timestamp: 3,
}])?;
let unit3_id = Ustr::from("unit3");
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit3_id,
value: 1.0,
weight: 1.0,
timestamp: 1,
}])?;
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit3_id,
value: 2.0,
weight: 1.0,
timestamp: 2,
}])?;
practice_rewards.record_unit_rewards(&[UnitReward {
unit_id: unit3_id,
value: 3.0,
weight: 1.0,
timestamp: 3,
}])?;
practice_rewards.remove_rewards_with_prefix("unit1")?;
let rewards = practice_rewards.get_rewards(unit1_id, 10)?;
assert_rewards(&[], &[], &rewards);
let rewards = practice_rewards.get_rewards(unit2_id, 10)?;
assert_rewards(&[3.0, 2.0, 1.0], &[1.0, 1.0, 1.0], &rewards);
let rewards = practice_rewards.get_rewards(unit3_id, 10)?;
assert_rewards(&[3.0, 2.0, 1.0], &[1.0, 1.0, 1.0], &rewards);
practice_rewards.remove_rewards_with_prefix("unit")?;
let rewards = practice_rewards.get_rewards(unit1_id, 10)?;
assert_rewards(&[], &[], &rewards);
let rewards = practice_rewards.get_rewards(unit2_id, 10)?;
assert_rewards(&[], &[], &rewards);
let rewards = practice_rewards.get_rewards(unit3_id, 10)?;
assert_rewards(&[], &[], &rewards);
Ok(())
}
}