use rusqlite::Connection;
#[must_use]
pub fn weighted_average(values: &[f32], weights: &[f32]) -> f32 {
let cross_product: f32 = values.iter().zip(weights.iter()).map(|(s, w)| s * *w).sum();
let weight_sum = weights.iter().sum::<f32>();
if weight_sum == 0.0 {
0.0
} else {
cross_product / weight_sum
}
}
pub fn new_connection(db_path: &str) -> Result<Connection, rusqlite::Error> {
let connection = Connection::open(db_path)?;
connection.pragma_update(None, "journal_mode", "WAL")?;
connection.pragma_update(None, "synchronous", "OFF")?;
Ok(connection)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_weighted_average() {
let rewards = vec![1.0, 2.0, 3.0];
let weights = vec![0.2, 0.3, 0.5];
let average = weighted_average(&rewards, &weights);
assert_eq!(average, 2.3);
let rewards: Vec<f32> = vec![];
let weights: Vec<f32> = vec![];
let average = weighted_average(&rewards, &weights);
assert_eq!(average, 0.0);
let rewards = vec![1.0, 2.0, 3.0];
let weights = vec![0.0, 0.0, 0.0];
let average = weighted_average(&rewards, &weights);
assert_eq!(average, 0.0);
}
}