use std::cell::RefCell;
use std::hash::{Hash, Hasher};
use rand::{Rng, thread_rng};
use crate::{Result};
use crate::error::{Error, ErrorKind};
use crate::prelude::*;
use crate::Symbol;
#[derive(Debug, Copy, Clone)]
pub enum ChanceKind {
Set,
Derived
}
#[derive(Debug, Copy, Clone)]
pub struct Chance {
kind: ChanceKind,
chance: Option<f32>
}
impl Chance {
pub fn new(chance: f32) -> Self {
assert!(chance > 0_f32, "chance should be positive");
assert!(chance <= 1.0_f32, "chance should be less than or equal to 1.0");
Chance {
kind: ChanceKind::Set,
chance: Some(chance)
}
}
#[inline]
pub fn empty() -> Self {
Chance {
kind: ChanceKind::Derived,
chance: None
}
}
#[inline]
pub fn is_derived(&self) -> bool {
matches!(self.kind, ChanceKind::Derived)
}
#[inline]
pub fn is_user_set(&self) -> bool {
matches!(self.kind, ChanceKind::Set)
}
#[inline]
pub fn expect(&self, message: &str) -> f32 {
self.chance.expect(message)
}
#[inline]
pub fn unwrap(&self) -> f32 {
self.chance.unwrap()
}
#[inline]
pub fn unwrap_or(&self, default: f32) -> f32 {
self.chance.unwrap_or(default)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ProductionHead {
pre: Option<ProductionString>,
target: Symbol,
post: Option<ProductionString>
}
impl ProductionHead {
pub fn build(pre: Option<ProductionString>, target: Symbol, post: Option<ProductionString>) -> Result<Self> {
Ok(ProductionHead {
pre,
target,
post
})
}
#[inline]
pub fn target(&self) -> &Symbol {
&self.target
}
#[inline]
pub fn pre_context(&self) -> Option<&ProductionString> {
self.pre.as_ref()
}
#[inline]
pub fn post_context(&self) -> Option<&ProductionString> {
self.post.as_ref()
}
pub fn matches(&self, string: &ProductionString, index: usize) -> bool {
self.pre_matches(string, index) &&
self.post_matches(string, index) &&
string.symbols()
.get(index)
.map(|symbol| self.target == *symbol)
.unwrap_or(false)
}
pub fn pre_matches(&self, string: &ProductionString, index: usize) -> bool {
if self.pre.is_none() {
return true;
}
let left = self.pre.as_ref().unwrap();
if index == 0 {
return left.is_empty();
}
let symbols: Vec<_> = string.symbols()[0..index].iter().rev().collect();
if symbols.len() < left.len() {
return false;
}
return left.iter().rev().enumerate().all(|(i, t)| t == symbols[i]);
}
pub fn post_matches(&self, string: &ProductionString, index: usize) -> bool {
if self.post.is_none() {
return true;
}
let right = self.post.as_ref().unwrap();
if index == string.len() - 1 {
return right.is_empty();
}
let symbols = string.symbols()[index + 1 ..].to_vec();
if symbols.len() < right.len() {
return false;
}
return right.iter().enumerate().all(|(i, t)| *t == symbols[i]);
}
}
#[derive(Debug, Clone)]
pub struct ProductionBody {
string: ProductionString,
chance: Chance
}
impl ProductionBody {
pub fn new(string: ProductionString) -> Self {
ProductionBody {
string,
chance: Chance::empty()
}
}
pub fn try_with_chance(chance: f32, string: ProductionString) -> Result<Self> {
if !(0.0..=1.0).contains(&chance) {
return Err(Error::new(ErrorKind::Parse, "chance should be between 0.0 and 1.0 inclusive"));
}
Ok(ProductionBody {
string,
chance: Chance::new(chance),
})
}
pub fn empty() -> Self {
ProductionBody {
string: ProductionString::empty(),
chance: Chance::empty()
}
}
#[inline]
pub fn is_empty(&self) -> bool {
self.string.is_empty()
}
#[inline]
pub fn len(&self) -> usize {
self.string.len()
}
#[inline]
pub fn string(&self) -> &ProductionString {
&self.string
}
#[inline]
pub fn chance(&self) -> &Chance {
&self.chance
}
}
#[derive(Debug, Clone)]
pub struct Production {
head: ProductionHead,
body: Vec<ProductionBody>
}
impl Production {
pub fn new(head: ProductionHead, body: ProductionBody) -> Self {
Production {
head,
body: vec![body]
}
}
#[inline]
pub fn head(&self) -> &ProductionHead {
&self.head
}
pub fn body(&self) -> Result<&ProductionBody> {
if self.body.is_empty() {
return Err(Error::execution("Production has no bodies set"))
}
if self.body.len() == 1 {
return Ok(self.body.last().unwrap());
}
let total_chance : f32 = self.body.iter()
.map(|b| b.chance.unwrap_or(0.0))
.sum();
if total_chance < 0.0 {
return Err(Error::execution("chance should never be negative"));
}
if total_chance > 1.0 {
return Err(Error::execution("total chance of production bodies should not be greater than 1.0"));
}
let remaining = self.body.iter().filter(|b| b.chance.is_derived()).count();
let default_chance = if remaining == 0 {
0_f32
} else {
(1.0_f32 - total_chance) / (remaining as f32)
};
let mut current = 0_f32;
let random : f32 = thread_rng().gen_range(0.0..=1.0);
for body in &self.body {
current += body.chance.unwrap_or(default_chance);
if random < current {
return Ok(body);
}
}
return Ok(self.body.last().unwrap());
}
#[inline]
pub fn matches(&self, string: &ProductionString, index: usize) -> bool {
self.head().matches(string, index)
}
pub fn add_body(&mut self, body: ProductionBody) {
self.body.push(body);
}
pub fn merge(&mut self, other: Self) {
other.body.into_iter().for_each(|b| self.add_body(b));
}
pub fn all_bodies(&self) -> &Vec<ProductionBody> {
&self.body
}
}
impl PartialEq for Production {
fn eq(&self, other: &Self) -> bool {
self.head().eq(other.head())
}
}
impl Eq for Production { }
impl Hash for Production {
fn hash<H: Hasher>(&self, state: &mut H) {
self.head.hash(state);
}
}
pub trait ProductionStore {
fn add_production(&self, production: Production) -> Result<Production>;
}
impl ProductionStore for RefCell<Vec<Production>> {
fn add_production(&self, production: Production) -> Result<Production> {
let mut vec = self.borrow_mut();
vec.push(production);
vec.last().cloned().ok_or_else(|| Error::general("Unable to add production"))
}
}
#[cfg(test)]
mod tests {
use crate::parser::parse_prod_string;
use super::*;
#[test]
fn production_matches() {
let system = System::default();
let production = system.parse_production("X -> F F").unwrap();
let string = parse_prod_string("X").unwrap();
assert!(production.matches(&string, 0));
let production = system.parse_production("X < X -> F F").unwrap();
assert!(!production.matches(&string, 0));
let string = parse_prod_string("X X").unwrap();
assert!(!production.matches(&string, 0));
assert!( production.matches(&string, 1));
let production = system.parse_production("a b < X -> F F").unwrap();
let string = parse_prod_string("a b X").unwrap();
assert!(!production.matches(&string, 0));
assert!(!production.matches(&string, 1));
assert!( production.matches(&string, 2));
let production = system.parse_production("X > X -> F F").unwrap();
assert!(!production.matches(&string, 0));
let string = parse_prod_string("X X").unwrap();
assert!( production.matches(&string, 0));
assert!(!production.matches(&string, 1));
let production = system.parse_production("X > a b -> F F").unwrap();
let string = parse_prod_string("a X a b").unwrap();
assert!(!production.matches(&string, 0));
assert!( production.matches(&string, 1));
assert!(!production.matches(&string, 2));
assert!(!production.matches(&string, 3));
let system = System::default();
let string = parse_prod_string("G S S S X").unwrap();
let production = system.parse_production("G < S -> S G").unwrap();
assert!(!production.matches(&string, 0));
assert!( production.matches(&string, 1));
assert!(!production.matches(&string, 2));
}
}