use std::collections::BTreeMap;
use std::{error, fmt};
use std::fs::{File, OpenOptions};
use std::io::{BufRead, BufReader, Write};
use rand::distributions::{Distribution,Normal};
use rand::seq::SliceRandom;
use serde::{Serialize, Deserialize};
pub struct Engine<I, O, R> {
input: I,
output: O,
rng: R,
rejected_choices: Vec<String>
}
impl<I, O, R> Engine<I, O, R>
where
I: BufRead,
O: Write,
R: rand::RngCore,
{
pub fn new(input: I, output: O, rng:R) -> Engine<I, O, R> {
Engine{input: input, output: output, rng: rng, rejected_choices: Vec::new()}
}
pub fn pick(&mut self, config: &mut BTreeMap<String, ConfigCategory>, category: String)
-> Result<String, Box<dyn error::Error>> {
let config_category = config.get_mut(&category[..]);
match config_category {
Some(category) => {
match category {
ConfigCategory::Even { choices } => {
Ok(self.pick_even(choices))
}
ConfigCategory::Gaussian { choices, stddev_scaling_factor } => {
Ok(self.pick_gaussian(choices, *stddev_scaling_factor))
}
ConfigCategory::Lottery { choices } => {
Ok(self.pick_lottery(choices))
}
ConfigCategory::Weighted { choices } => {
Ok(self.pick_weighted(choices))
}
}
}
None => {
Err(Box::new(ValueError::new(
format!("Category {} not found in config.", category))))
}
}
}
fn get_consent(&mut self, choice: &str, num_choices: usize) -> bool {
if self.rejected_choices.contains(&choice.to_string()) {
return false;
}
write!(self.output, "Choice is {}. Accept? (Y/n) ", choice).expect(
"Could not write to output");
self.output.flush().unwrap();
let line1 = self.input.by_ref().lines().next().unwrap().unwrap();
if ["", "y", "Y"].contains(&line1.as_str()) {
return true;
}
if self.rejected_choices.len() + 1 >= num_choices {
self.rejected_choices = Vec::new();
write!(self.output, "🤨\n").expect("Could not write to output");
}
self.rejected_choices.push(choice.to_string());
return false;
}
fn pick_even(&mut self, choices: &Vec<String>) -> String {
let choices = choices.iter().map(|x| (x, 1)).collect::<Vec<_>>();
loop {
let choice = choices.choose_weighted(&mut self.rng, |item| item.1).unwrap().0;
if self.get_consent(choice, choices.len()) {
return choice.clone();
}
}
}
fn pick_gaussian(&mut self, choices: &mut Vec<String>, stddev_scaling_factor: f64) -> String {
let stddev = (choices.len() as f64) / stddev_scaling_factor;
let normal = Normal::new(0.0, stddev);
let mut index;
loop {
index = normal.sample(&mut self.rng).abs() as usize;
match choices.get(index) {
Some(value) => {
if self.get_consent(&value[..], choices.len()) {
break;
}
},
None => ()
}
}
let value = choices.remove(index);
choices.push(value.clone());
value
}
fn pick_lottery(&mut self, choices: &mut Vec<LotteryChoice>) -> String {
let weighted_choices = choices.iter().enumerate().map(
|x| ((x.0, &x.1.name), x.1.tickets)).collect::<Vec<_>>();
let index = loop {
let (index, choice) = weighted_choices.choose_weighted(
&mut self.rng, |item| item.1).unwrap().0;
if self.get_consent(&choice[..], choices.len()) {
break index;
}
};
for choice in choices.iter_mut() {
choice.tickets += choice.weight;
}
choices[index].tickets = 0;
choices[index].name.clone()
}
fn pick_weighted(&mut self, choices: &Vec<WeightedChoice>) -> String {
let choices = choices.iter().map(|x| (&x.name, x.weight)).collect::<Vec<_>>();
loop {
let choice = choices.choose_weighted(&mut self.rng, |item| item.1).unwrap().0;
if self.get_consent(&choice[..], choices.len()) {
return choice.clone();
}
}
}
}
#[derive(Debug)]
struct ValueError {
message: String
}
impl ValueError {
fn new(message: String) -> ValueError {
ValueError{message: message}
}
}
impl fmt::Display for ValueError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.message)
}
}
impl error::Error for ValueError {}
pub fn read_config(config_file_path: &String)
-> Result<BTreeMap<String, ConfigCategory>, Box<error::Error>> {
let f = File::open(&config_file_path)?;
let reader = BufReader::new(f);
let config: BTreeMap<String, ConfigCategory> = serde_yaml::from_reader(reader)?;
return Ok(config);
}
pub fn write_config(config_file_path: &String, config: BTreeMap<String, ConfigCategory>)
-> Result<(), Box<error::Error>> {
let mut f = OpenOptions::new().write(true).create(true).truncate(true).open(
&config_file_path)?;
let yaml = serde_yaml::to_string(&config).unwrap();
f.write_all(&yaml.into_bytes())?;
Ok(())
}
#[derive(PartialEq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "model")]
pub enum ConfigCategory {
Even {
choices: Vec<String>
},
Gaussian {
#[serde(default = "default_stddev_scaling_factor")]
stddev_scaling_factor: f64,
choices: Vec<String>
},
Lottery {
choices: Vec<LotteryChoice>
},
Weighted {
choices: Vec<WeightedChoice>
}
}
#[derive(Debug)]
#[derive(PartialEq, Serialize, Deserialize)]
pub struct LotteryChoice {
name: String,
#[serde(default = "default_weight")]
tickets: u64,
#[serde(default = "default_weight")]
weight: u64,
}
#[derive(PartialEq, Serialize, Deserialize)]
pub struct WeightedChoice {
name: String,
#[serde(default = "default_weight")]
weight: u64,
}
fn default_stddev_scaling_factor() -> f64 {
return 3.0;
}
fn default_weight() -> u64 {
return 1;
}
#[cfg(test)]
mod tests {
use rand::SeedableRng;
use super::*;
#[test]
fn test_defaults() {
assert_eq!(default_stddev_scaling_factor(), 3.0);
assert_eq!(default_weight(), 1);
}
#[test]
fn test_get_consent() {
let tests = [
(String::from("y"), true), (String::from("Y"), true), (String::from("\n"), true),
(String::from("f"), false), (String::from("F"), false),
(String::from("anything else"), false)];
for (input, expected_output) in tests.iter() {
let output = Vec::new();
let mut engine = Engine::new(input.as_bytes(), output,
rand::rngs::SmallRng::seed_from_u64(42));
assert_eq!(engine.get_consent("do you want this", 2), *expected_output);
let output = String::from_utf8(engine.output).expect("Not UTF-8");
assert_eq!(output, "Choice is do you want this. Accept? (Y/n) ");
let mut expected_rejected_choices: Vec<String> = Vec::new();
if !expected_output {
expected_rejected_choices = vec![String::from("do you want this")];
}
assert_eq!(engine.rejected_choices, expected_rejected_choices);
}
}
#[test]
fn test_pick() {
let input = String::from("N\ny");
let output = Vec::new();
let mut engine = Engine::new(input.as_bytes(), output,
rand::rngs::SmallRng::seed_from_u64(42));
let choices = vec![String::from("this"), String::from("that"), String::from("the other")];
let category = ConfigCategory::Even{choices: choices};
let mut config = BTreeMap::new();
config.insert("things".to_string(), category);
let choice = engine.pick(&mut config, "things".to_string()).expect("unexpected");
assert_eq!(choice, "the other");
let output = String::from_utf8(engine.output).expect("Not UTF-8");
assert_eq!(output, "Choice is this. Accept? (Y/n) Choice is the other. Accept? (Y/n) ");
}
#[test]
fn test_pick_nonexistant_category() {
let input = String::from("N\ny");
let output = Vec::new();
let mut engine = Engine::new(input.as_bytes(), output,
rand::rngs::SmallRng::seed_from_u64(42));
let choices = vec![String::from("this"), String::from("that"), String::from("the other")];
let category = ConfigCategory::Even{choices: choices};
let mut config = BTreeMap::new();
config.insert("things".to_string(), category);
match engine.pick(&mut config, "does not exist".to_string()) {
Ok(_) => {
panic!("The non-existant category should have returned an error.");
},
Err(error) => {
assert_eq!(format!("{}", error), "Category does not exist not found in config.");
}
}
}
#[test]
fn test_pick_even() {
let input = String::from("y");
let output = Vec::new();
let mut engine = Engine::new(input.as_bytes(), output,
rand::rngs::SmallRng::seed_from_u64(1));
let choices = vec![String::from("this"), String::from("that"), String::from("the other")];
let result = engine.pick_even(&choices);
let output = String::from_utf8(engine.output).expect("Not UTF-8");
assert_eq!(output, "Choice is the other. Accept? (Y/n) ");
assert_eq!(result, "the other");
}
#[test]
fn test_pick_gaussian() {
let input = String::from("y");
let output = Vec::new();
let mut engine = Engine::new(input.as_bytes(), output,
rand::rngs::SmallRng::seed_from_u64(1));
let mut choices = vec![
String::from("this"), String::from("that"), String::from("the other")];
let result = engine.pick_gaussian(&mut choices, 3.0);
let output = String::from_utf8(engine.output).expect("Not UTF-8");
assert_eq!(output, "Choice is that. Accept? (Y/n) ");
assert_eq!(result, "that");
assert_eq!(choices,
vec![String::from("this"), String::from("the other"), String::from("that")]);
}
#[test]
fn test_pick_lottery() {
let input = String::from("y");
let output = Vec::new();
let mut engine = Engine::new(input.as_bytes(), output,
rand::rngs::SmallRng::seed_from_u64(2));
let mut choices = vec![
LotteryChoice{name: "this".to_string(), tickets: 1, weight: 1},
LotteryChoice{name: "that".to_string(), tickets: 2, weight: 4},
LotteryChoice{name: "the other".to_string(), tickets:3, weight: 9}];
let result = engine.pick_lottery(&mut choices);
let output = String::from_utf8(engine.output).expect("Not UTF-8");
assert_eq!(output, "Choice is the other. Accept? (Y/n) ");
assert_eq!(result, "the other");
assert_eq!(
choices,
vec![
LotteryChoice{name: "this".to_string(), tickets: 2, weight: 1},
LotteryChoice{name: "that".to_string(), tickets: 6, weight: 4},
LotteryChoice{name: "the other".to_string(), tickets: 0, weight: 9}]);
}
#[test]
fn test_pick_weighted() {
let input = String::from("y");
let output = Vec::new();
let mut engine = Engine::new(input.as_bytes(), output,
rand::rngs::SmallRng::seed_from_u64(3));
let mut choices = vec![
WeightedChoice{name: "this".to_string(), weight: 1},
WeightedChoice{name: "that".to_string(), weight: 4},
WeightedChoice{name: "the other".to_string(), weight: 9}];
let result = engine.pick_weighted(&mut choices);
let output = String::from_utf8(engine.output).expect("Not UTF-8");
assert_eq!(output, "Choice is that. Accept? (Y/n) ");
assert_eq!(result, "that");
}
}