Skip to main content

dicexp/
lib.rs

1#![deny(unused_must_use)]
2use std::error::Error;
3use core::fmt::{Debug, Formatter};
4use std::collections::HashSet;
5use std::hash::Hasher;
6use std::num::{ParseFloatError, ParseIntError};
7use rand::prelude::*;
8#[cfg(feature = "serde_support")]
9use serde::{Deserialize, Serialize};
10
11/// The DiceBag struct is use to evaluate RPG dice notation expressions (eg "2d6+3")
12///
13/// If the provided RNG implements any of `Debug`, `Clone`, `PartialEq`, `Eq`, `Hash`, or `Default`,
14/// then `DiceBag` will implement the same.
15///
16/// # Example
17/// ```
18/// use dicexp::{DiceBag, new_simple_rng};
19/// let mut dice_roller = DiceBag::new(new_simple_rng());
20/// let dice_exp = "3d6-4";
21/// let dice_roll = dice_roller.eval(dice_exp).expect("Error");
22/// println!("Rolled {}: {}", dice_exp, dice_roll);
23/// println!("The average result is {:.1}", dice_roll.average);
24/// ```
25pub struct DiceBag <R: RngExt>{
26	rng: R
27}
28impl <R>Clone for DiceBag<R> where R: RngExt+Clone{
29	fn clone(&self) -> Self {
30		DiceBag{rng: self.rng.clone()}
31	}
32}
33impl <R>Debug for DiceBag<R> where R: RngExt+Debug{
34	fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
35		write!(f, "DiceBag{{")?;
36		self.rng.fmt(f)?;
37		write!(f, "}}")
38	}
39}
40
41impl <R>PartialEq for DiceBag<R> where R: RngExt+PartialEq{
42	fn eq(&self, other: &Self) -> bool {
43		self.rng.eq(&other.rng)
44	}
45}
46
47impl <R>Eq for DiceBag<R> where R: RngExt+Eq{}
48
49impl <R> std::hash::Hash for DiceBag<R> where R: RngExt+std::hash::Hash{
50	fn hash<H: Hasher>(&self, state: &mut H) {
51		self.rng.hash(state)
52	}
53}
54
55impl <R>Default for DiceBag<R> where R: RngExt+Default{
56	fn default() -> Self {
57		DiceBag{rng: R::default()}
58	}
59}
60
61impl <R> DiceBag<R> where R: RngExt {
62	/// Constructs a new `DiceBag` instance
63	/// # Parameters
64	/// * `rng`: A random number generator to use for rolling dice
65	pub fn new(rng: R) -> Self { DiceBag{rng} }
66
67	/// Rolls a number of dice and returns the result
68	/// # Parameters
69	/// * `n`: number of dice to roll
70	/// * `d`: number of sides per die
71	/// * `m`: number to add to the total
72	pub fn roll(&mut self, n: u32, d: u32, m: i64) -> i64 {
73		let mut total = 0i64;
74		for _ in 0..n {
75			let roll: u32 = self.rng.random_range(1..=d);
76			total += roll as i64;
77		}
78		return total + m;
79	}
80
81	/// Evaluates the given RPG dice notation expression
82	/// # Parameters
83	/// * `dice_expression`: An RPG dice notation expressions (eg "2d6+3")
84	pub fn eval(&mut self, dice_expression: &str) -> Result<DiceRoll,SyntaxError>{
85		Ok(DiceRoll{
86			total: self.eval_total(dice_expression)?,
87			min: self.eval_min(dice_expression)?,
88			max: self.eval_max(dice_expression)?,
89			average: self.eval_ave(dice_expression)?,
90		})
91	}
92
93
94	/// Evaluates the given RPG dice notation expression and returns the total dice roll
95	/// # Parameters
96	/// * `dice_expression`: An RPG dice notation expressions (eg "2d6+3")
97	pub fn eval_total(&mut self, dice_expression: &str) -> Result<i64,SyntaxError>{
98		self.eval_as(dice_expression, EvalMode::Roll)?.parse::<i64>().map_err(|e| SyntaxError::from(e))
99	}
100
101	/// Evaluates the given RPG dice notation expression and returns the minimum dice roll
102	/// # Parameters
103	/// * `dice_expression`: An RPG dice notation expressions (eg "2d6+3")
104	pub fn eval_min(&mut self, dice_expression: &str) -> Result<i64,SyntaxError>{
105		self.eval_as(dice_expression, EvalMode::Minimum)?.parse::<i64>().map_err(|e| SyntaxError::from(e))
106	}
107
108	/// Evaluates the given RPG dice notation expression and returns the maximum dice roll
109	/// # Parameters
110	/// * `dice_expression`: An RPG dice notation expressions (eg "2d6+3")
111	pub fn eval_max(&mut self, dice_expression: &str) -> Result<i64,SyntaxError>{
112		self.eval_as(dice_expression, EvalMode::Maximum)?.parse::<i64>().map_err(|e| SyntaxError::from(e))
113	}
114
115	/// Evaluates the given RPG dice notation expression and returns the average dice roll
116	/// # Parameters
117	/// * `dice_expression`: An RPG dice notation expressions (eg "2d6+3")
118	pub fn eval_ave(&mut self, dice_expression: &str) -> Result<f64,SyntaxError>{
119		self.eval_as(dice_expression, EvalMode::Average)?.parse::<f64>().map_err(|e| SyntaxError::from(e))
120	}
121
122	fn eval_as(&mut self, dice_expression: &str, mode: EvalMode) -> Result<String, SyntaxError> {
123		if dice_expression.starts_with("-") || dice_expression.starts_with("+"){
124			// must start with a number or there will be problems
125			let mut new_exp = String::from("0");
126			new_exp.push_str(dice_expression);
127			return self.eval_as(new_exp.as_str(), mode);
128		}
129		let mut x = String::new();
130		// need to remove all whitespace, also using this opportunity to throw common exceptions
131		let mut line = 1;
132		let mut col = 0;
133		let mut last_c = ' ';
134		for c in dice_expression.chars() {
135			if c == '\n' {
136				line += 1;
137				col = 0;
138			}
139			col += 1;
140			if c.is_whitespace() {continue;}
141			match mode{
142				// decimals allowed in average mode, but otherwise it is ints-only
143				EvalMode::Average => {},
144				_ => {
145					if c == '.' {return Err(SyntaxError{
146						msg: Some("Found '.', but decimal numbers are not supported (integer math only)".into()),
147						line: Some(line), col: Some(col), cause: None
148					});}
149				}
150			}
151			if c == '%' {
152				// d% means d100
153				x.push_str("100")
154			} else if c == 'x' || c == 'X' {
155				// multiplication old-school notation
156				x.push('*');
157			} else if c == '-' && last_c != '+' && last_c != '/' && last_c != '*' {
158				// turn - into +- to avoid confusion over subtraction vs negative numbers
159				x.push_str("+-")
160			} else if c == '(' && (last_c.is_digit(10) || last_c == '.') {
161				// number right before ( means multiply
162				x.push_str("*(")
163			} else {
164				x.push(c);
165			}
166			last_c = c;
167		}
168		#[cfg(test)]
169		eprintln!(">> {}", x);
170		// Parentheses
171		while match x.find("(") {
172			None => false,
173			Some(i) => {
174				let cpy =  x.clone();
175				let x_str = cpy.as_str();
176				let (open, close) = find_enclosure_from(x_str, i, '(', ')')?
177					.ok_or_else(|| SyntaxError::from("Error: unmatched parentheses"))?;
178				let middle = self.eval_as(&x_str[open+1 .. close-1], mode)?;
179				let front = &x_str[0..open];
180				let back = &x_str[close..];
181				x.clear();
182				x.push_str(front);
183				x.push_str(middle.as_str());
184				x.push_str(back);
185				true
186			}
187		}{}
188		// Dice
189		while match x.find("d") {
190			None => false,
191			Some(i) => {
192				let cpy =  x.clone();
193				let x_str = cpy.as_str();
194				let (start, end) = find_operator_params(x_str, i)?;
195				let n = &x_str[start..i].parse::<u32>().map_err(|e| SyntaxError::from(e.clone()))?;
196				let d = &x_str[i+1..end].parse::<u32>().map_err(|e| SyntaxError::from(e.clone()))?;
197				let middle: String;
198				match mode {
199					EvalMode::Roll => middle = format!("{}", self.roll(*n, *d, 0)),
200					EvalMode::Average => middle = format!("{:.1}", *n as f64 * 0.5 * (1f64 + *d as f64)),
201					EvalMode::Minimum => middle = format!("{}", n),
202					EvalMode::Maximum => middle = format!("{}", n * d),
203				}
204				let front = &x_str[0..start];
205				let back = &x_str[end..];
206				x.clear();
207				x.push_str(front);
208				x.push_str(middle.as_str());
209				x.push_str(back);
210				true
211			}
212		}{}
213		// multiply and divide
214		while match find_one_of(x.as_str(), &['*', '/']) {
215			None => false,
216			Some(i) => {
217				let cpy =  x.clone();
218				let x_str = cpy.as_str();
219				let op = &x_str[i..i+1];
220				let (start, end) = find_operator_params(x_str, i)?;
221				let middle: String;
222				match mode {
223					EvalMode::Average => {
224						let left = &x_str[start..i].parse::<f64>().map_err(|e| SyntaxError::from(e.clone()))?;
225						let right = &x_str[i+1..end].parse::<f64>().map_err(|e| SyntaxError::from(e.clone()))?;
226						if op == "/" {
227							middle = format!("{:.}", *left / *right);
228						} else {
229							middle = format!("{:.}", *left * *right);
230						}
231					}
232					_ => {
233						let left = &x_str[start..i].parse::<i64>().map_err(|e| SyntaxError::from(e.clone()))?;
234						let right = &x_str[i+1..end].parse::<i64>().map_err(|e| SyntaxError::from(e.clone()))?;
235						if op == "/" {
236							middle = format!("{}", *left / *right);
237						} else {
238							middle = format!("{}", *left * *right);
239						}
240					}
241				}
242				let front = &x_str[0..start];
243				let back = &x_str[end..];
244				x.clear();
245				x.push_str(front);
246				x.push_str(middle.as_str());
247				x.push_str(back);
248				true
249			}
250		}{}
251
252		// add and subtract (subtraction already replaced with +-)
253		while match x.find('+') { // start at 1 in case of negative number on left side
254			None => false,
255			Some(i) => {
256				let cpy =  x.clone();
257				let x_str = cpy.as_str();
258				let (start, end) = find_operator_params(x_str, i)?;
259				let mut left_str = &x_str[start..i];
260				let mut right_str = &x_str[i+1..end];
261				if left_str.starts_with("--") {
262					// double negative equals a positive
263					left_str = &left_str[2..];
264				}
265				if right_str.starts_with("--") {
266					right_str = &right_str[2..];
267				}
268				let middle: String;
269				match mode {
270					EvalMode::Average => {
271						let left = left_str.parse::<f64>().map_err(|e| SyntaxError::from(e.clone()))?;
272						let right = right_str.parse::<f64>().map_err(|e| SyntaxError::from(e.clone()))?;
273						middle = format!("{:.}", left + right);
274					}
275					_ => {
276						let left = left_str.parse::<i64>().map_err(|e| SyntaxError::from(e.clone()))?;
277						let right = right_str.parse::<i64>().map_err(|e| SyntaxError::from(e.clone()))?;
278						middle = format!("{}", left + right);
279					}
280				}
281				let front = &x_str[0..start];
282				let back = &x_str[end..];
283				x.clear();
284				x.push_str(front);
285				x.push_str(middle.as_str());
286				x.push_str(back);
287				true
288			}
289		}{}
290		// DONE!
291		Ok(x)
292	}
293
294}
295
296#[derive(Clone, Copy, PartialEq, Eq)]
297enum EvalMode {
298	Roll, Average, Minimum, Maximum
299}
300
301/// The result of rolling the provided dice expression, including the average and minimum and
302/// maximum possible results.
303#[derive(Clone, Copy, PartialEq, Default, Debug)]
304#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
305pub struct DiceRoll {
306	/// The amount rolled
307	pub total: i64,
308	/// Minimum possible result
309	pub min: i64,
310	/// Maximum possible result
311	pub max: i64,
312	/// Average result
313	pub average: f64
314}
315
316impl core::fmt::Display for DiceRoll {
317	/// core::fmt::Display implementation returns the total result
318	fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
319		std::fmt::Display::fmt(&self.total, f)
320	}
321}
322
323/// Error returns when a `DiceBag` fails to interpret or evaluate a dice expression
324pub struct SyntaxError {
325	pub msg: Option<String>,
326	pub line: Option<u64>,
327	pub col: Option<u64>,
328	pub cause: Option<Box<dyn Error>>
329}
330
331impl SyntaxError{
332	/// This function is used to make Debug and Display output the same
333	fn print(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
334		write!(f, "SyntaxError: ")?;
335		match &self.msg {
336			None => write!(f, "Failed to parse string")?,
337			Some(s) => write!(f, "{}", s)?,
338		}
339		match &self.line {
340			None => {}
341			Some(s) => {
342				write!(f, "; error on line {}", s)?;
343				match &self.col {
344					None => {}
345					Some(c) => write!(f, ", column {}", c)?,
346				}
347			},
348		}
349		match &self.cause {
350			None => {}
351			Some(coz) => write!(f, "\n\tCaused by: {}", coz)?
352		}
353		Ok(())
354	}
355
356	fn from_string<T>(msg: T) -> Self where T: Into<String> {
357		SyntaxError{
358			msg: Some(msg.into()), line: None, col: None, cause: None,
359		}
360	}
361}
362
363impl From<&str> for SyntaxError{
364	fn from(msg: &str) -> Self {
365		SyntaxError{
366			msg: Some(msg.into()), line: None, col: None, cause: None,
367		}
368	}
369}
370
371
372impl From<ParseIntError> for SyntaxError {
373	fn from(value: ParseIntError) -> Self {
374		SyntaxError{msg: Some("Failed to parse string as integer".into()), line: None, col: None, cause: Some(Box::from(value)) }
375	}
376}
377
378impl From<ParseFloatError> for SyntaxError {
379	fn from(value: ParseFloatError) -> Self {
380		SyntaxError{msg: Some("Failed to parse string as decimal number".into()), line: None, col: None, cause: Some(Box::from(value)) }
381	}
382}
383
384impl Debug for SyntaxError {
385	fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
386		self.print(f)
387	}
388}
389
390impl core::fmt::Display for SyntaxError {
391	fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
392		self.print(f)
393	}
394}
395
396impl Error for SyntaxError {}
397
398
399/// Creates a new random number generator (RNG) from the provided seed using the default
400/// [rand crate](https://crates.io/crates/rand) `rand::rngs::StdRng` RNG
401/// # Parameters
402/// * `seed`: A 64-bit number to use as a seed
403pub fn simple_rng(seed: u64) -> StdRng {
404	StdRng::seed_from_u64(seed)
405}
406
407/// Creates a new random number generator (RNG) from the provided seed using the default
408/// [rand crate](https://crates.io/crates/rand) `rand::rngs::StdRng` RNG, using the current system
409/// millisecond timestamp as the RNG seed
410pub fn new_simple_rng() -> StdRng {
411	use std::time::{SystemTime, UNIX_EPOCH};
412	let time_seed = SystemTime::now().duration_since(UNIX_EPOCH)
413		.expect("Invalid system time").as_millis() as u64;
414	simple_rng(time_seed)
415}
416
417
418fn find_enclosure_from(text: &str, pos: usize, open: char, close: char) -> Result<Option<(usize, usize)>, SyntaxError> {
419	let mut depth = 0;
420	let slice = &text[pos..];
421	let mut start_index = 0;
422	for (i, c) in slice.char_indices() {
423		if c == open {
424			if depth == 0 {
425				start_index = i + pos;
426			}
427			depth += 1;
428		} else if c == close {
429			depth -= 1;
430			if depth == 0 {
431				return Ok(Some((start_index, pos+i+1)))
432			}
433		}
434	}
435	if depth > 0 {
436		return Err(SyntaxError::from("Found '(' without matching ')'"));
437	}
438	return Ok(None);
439}
440
441fn find_operator_params(text: &str, op_pos: usize) -> Result<(usize, usize), SyntaxError> {
442	#[cfg(test)]
443	eprintln!("'{}' '{}' '{}'", &text[0..op_pos], &text[op_pos..op_pos+1], &text[op_pos+1..]);
444	let front_slice = &text[0..op_pos];
445	let back_slice = &text[op_pos+1..];
446	let mut end = text.len();
447	for (i, c) in back_slice.char_indices() {
448		if !(c.is_digit(10) || c == '.' || c == '-') {end = op_pos+1+i; break;}
449	}
450	let mut start = 0;
451	for (i, c) in front_slice.char_indices().rev() {
452		if !(c.is_digit(10) || c == '.' || c == '-') {start = i+1; break;}
453	}
454	if start == op_pos || end == op_pos+1 {
455		return Err(SyntaxError::from_string(format!("Missing numbers before or after operator {}", &text[op_pos..op_pos+1])));
456	}
457	Ok((start, end))
458}
459
460fn find_one_of(text: &str, chars: &[char]) -> Option<usize> {
461	let mut set = HashSet::with_capacity(chars.len());
462	for c in chars {set.insert(c);}
463	for (i, c) in text.char_indices() {
464		if set.contains(&c){
465			return Some(i);
466		}
467	}
468	return None;
469}
470
471#[cfg(test)]
472mod unit_tests {
473
474	#[test]
475	fn arithmatic_checks() {
476		use crate::{DiceBag, simple_rng};
477		let mut dice = DiceBag::new(simple_rng(42));
478		assert_eq!(dice.eval_total("1+2").unwrap(), 3);
479		assert_eq!(dice.eval_total("-1+2").unwrap(), 1);
480		assert_eq!(dice.eval_total("(1+2)x3").unwrap(), 9);
481		assert_eq!(dice.eval_total("-3*(1+2)").unwrap(), -9);
482		assert_eq!(dice.eval_total("7-(2-5)").unwrap(), 10);
483		assert_eq!(dice.eval_total("-1+10").unwrap(), 9);
484		assert_eq!(dice.eval_total("7/2").unwrap(), 3); // integer math
485		assert_eq!(dice.eval_total("-7/2").unwrap(), -3); // integer math
486		assert_eq!(dice.eval_total("7/-2").unwrap(), -3); // integer math
487		assert_eq!(dice.eval_total("-7/-2").unwrap(), 3); // integer math
488		assert_eq!(dice.eval_total("7*2").unwrap(), 14);
489		assert_eq!(dice.eval_total("7*-2").unwrap(), -14);
490		assert_eq!(dice.eval_total("-7*2").unwrap(), -14);
491		assert_eq!(dice.eval_total("-7*-2").unwrap(), 14);
492		assert_eq!(dice.eval_total("8+5-9-9+5+8").unwrap(), 8);
493		assert_eq!(dice.eval_total("15/5*5/-3*2/2*6-10/5").unwrap(), -32);
494		assert_eq!(dice.eval_total("(15/5*5/-3*2/2*6-10/5)").unwrap(), -32);
495		assert_eq!(dice.eval_total("4(9(10/2-6-3*8+1*4/2)*8/2*5+4)*5+4(7+7-3*8)*3-10*(10)-1").unwrap(), -82941);
496	}
497
498	#[test]
499	fn dice_checks() {
500		use crate::{DiceBag, simple_rng};
501		let mut dice = DiceBag::new(simple_rng(42));
502		assert!(dice.eval_total("1d20-30").unwrap() <= -10);
503		assert!(dice.eval_total("1d20-30").unwrap() > -30);
504		assert!(dice.eval_total("1d20+30").unwrap() <= 50);
505		assert!(dice.eval_total("1d20+30").unwrap() > 30);
506		let roll = dice.eval("3d6").unwrap();
507		assert_eq!(roll.max, 18);
508		assert_eq!(roll.min, 3);
509		assert_eq!(roll.average, 3.5*3.);
510	}
511
512	#[test]
513	#[cfg(feature = "serde_support")]
514	fn serde_test(){
515		use serde_json;
516		use crate::{DiceBag, DiceRoll, simple_rng};
517		let mut dice = DiceBag::new(simple_rng(42));
518		let roll = dice.eval("2d10+6").unwrap();
519		let roll2 = roll.clone();
520		let json_str = serde_json::to_string(&roll2).unwrap();
521		let serde_roll: DiceRoll = serde_json::from_str(json_str.as_str()).unwrap();
522		assert_eq!(roll, serde_roll);
523	}
524
525	#[test]
526	fn example1(){
527		use crate::{DiceBag, new_simple_rng};
528		//use std::io;
529		let mut dice_bag = DiceBag::new(new_simple_rng());
530		println!("What would you like to roll? ");
531		// let mut input = String::new();
532		// io::stdin()
533		// 	.read_line(&mut input)
534		// 	.expect("failed to read from stdin");
535		let input = String::from("3d6\n");
536		let dice_roll = dice_bag.eval(input.as_str()).expect("invalid dice expression");
537		println!("You rolled a {}", dice_roll);
538		if dice_roll.total >= dice_roll.average as i64 {
539			println!("That's a good roll!");
540		} else {
541			println!("That's not a good roll :(");
542		}
543	}
544
545	#[test]
546	fn example2(){
547		use crate::{DiceBag, new_simple_rng};
548		let mut dice_bag = DiceBag::new(new_simple_rng());
549		let armory = vec![
550			("great axe", "1d12"),
551			("great sword", "2d6"),
552			("heavy crossbow", "1d10+2"),
553			("firebolt", "1d10"),
554			("magic missile", "3d4+3")
555		];
556		println!("Average Damage:");
557		for (name, dmg) in armory {
558			println!("{}\t{}", dice_bag.eval_ave(dmg).unwrap(), name)
559		}
560	}
561
562}