1use crate::choices::{Choice, Choices, MoveCategory, MOVES};
2use crate::engine::evaluate::evaluate;
3use crate::engine::generate_instructions::{
4 calculate_both_damage_rolls, generate_instructions_from_move_pair,
5};
6use crate::engine::state::MoveChoice;
7use crate::instruction::{Instruction, StateInstructions};
8use crate::mcts::{perform_mcts, MctsResult};
9use crate::search::{expectiminimax_search, iterative_deepen_expectiminimax, pick_safest};
10use crate::state::State;
11use clap::Parser;
12use std::io;
13use std::io::Write;
14use std::process::exit;
15use std::str::FromStr;
16use std::sync::{Arc, Mutex};
17
18struct IOData {
19 state: State,
20 instruction_list: Vec<Vec<Instruction>>,
21 last_instructions_generated: Vec<StateInstructions>,
22}
23
24#[derive(Parser)]
25struct Cli {
26 #[clap(short, long, default_value = "")]
27 state: String,
28
29 #[clap(subcommand)]
30 subcmd: Option<SubCommand>,
31}
32
33#[derive(Parser)]
34enum SubCommand {
35 Expectiminimax(Expectiminimax),
36 IterativeDeepening(IterativeDeepening),
37 MonteCarloTreeSearch(MonteCarloTreeSearch),
38 CalculateDamage(CalculateDamage),
39 GenerateInstructions(GenerateInstructions),
40}
41
42#[derive(Parser)]
43struct Expectiminimax {
44 #[clap(short, long, required = true)]
45 state: String,
46
47 #[clap(short, long, default_value_t = false)]
48 ab_prune: bool,
49
50 #[clap(short, long, default_value_t = 2)]
51 depth: i8,
52}
53
54#[derive(Parser)]
55struct IterativeDeepening {
56 #[clap(short, long, required = true)]
57 state: String,
58
59 #[clap(short, long, default_value_t = 5000)]
60 time_to_search_ms: u64,
61}
62
63#[derive(Parser)]
64struct MonteCarloTreeSearch {
65 #[clap(short, long, required = true)]
66 state: String,
67
68 #[clap(short, long, default_value_t = 5000)]
69 time_to_search_ms: u64,
70}
71
72#[derive(Parser)]
73struct CalculateDamage {
74 #[clap(short, long, required = true)]
75 state: String,
76
77 #[clap(short = 'o', long, required = true)]
78 side_one_move: String,
79
80 #[clap(short = 't', long, required = true)]
81 side_two_move: String,
82
83 #[clap(short = 'm', long, required = false, default_value_t = false)]
84 side_one_moves_first: bool,
85}
86
87#[derive(Parser)]
88struct GenerateInstructions {
89 #[clap(short, long, required = true)]
90 state: String,
91
92 #[clap(short = 'o', long, required = true)]
93 side_one_move: String,
94
95 #[clap(short = 't', long, required = true)]
96 side_two_move: String,
97}
98
99impl Default for IOData {
100 fn default() -> Self {
101 IOData {
102 state: State::default(),
103 instruction_list: Vec::new(),
104 last_instructions_generated: Vec::new(),
105 }
106 }
107}
108
109fn pprint_expectiminimax_result(
110 result: &Vec<f32>,
111 s1_options: &Vec<MoveChoice>,
112 s2_options: &Vec<MoveChoice>,
113 safest_choice: &(usize, f32),
114 state: &State,
115) {
116 let s1_len = s1_options.len();
117 let s2_len = s2_options.len();
118
119 print!("{: <12}", " ");
120
121 for s2_move in s2_options.iter() {
122 print!("{: >12}", s2_move.to_string(&state.side_two));
123 }
124 print!("\n");
125
126 for i in 0..s1_len {
127 let s1_move_str = s1_options[i];
128 print!("{:<12}", s1_move_str.to_string(&state.side_one));
129 for j in 0..s2_len {
130 let index = i * s2_len + j;
131 print!("{number:>11.2} ", number = result[index]);
132 }
133 print!("\n");
134 }
135 print!(
136 "{:<12}",
137 s1_options[safest_choice.0].to_string(&state.side_one)
138 );
139}
140
141fn print_mcts_result(state: &State, result: MctsResult) {
142 let s1_joined_options = result
143 .s1
144 .iter()
145 .map(|x| {
146 format!(
147 "{},{:.2},{}",
148 x.move_choice.to_string(&state.side_one),
149 x.total_score,
150 x.visits
151 )
152 })
153 .collect::<Vec<String>>()
154 .join("|");
155 let s2_joined_options = result
156 .s2
157 .iter()
158 .map(|x| {
159 format!(
160 "{},{:.2},{}",
161 x.move_choice.to_string(&state.side_two),
162 x.total_score,
163 x.visits
164 )
165 })
166 .collect::<Vec<String>>()
167 .join("|");
168
169 println!("Total Iterations: {}", result.iteration_count);
170 println!("side one: {}", s1_joined_options);
171 println!("side two: {}", s2_joined_options);
172}
173
174fn pprint_mcts_result(state: &State, result: MctsResult) {
175 println!("\nTotal Iterations: {}\n", result.iteration_count);
176 println!("Side One:");
177 println!(
178 "\t{:<25}{:>12}{:>12}{:>10}{:>10}",
179 "Move", "Total Score", "Avg Score", "Visits", "% Visits"
180 );
181 for x in result.s1.iter() {
182 println!(
183 "\t{:<25}{:>12.2}{:>12.2}{:>10}{:>10.2}",
184 x.move_choice.to_string(&state.side_one),
185 x.total_score,
186 x.total_score / x.visits as f32,
187 x.visits,
188 (x.visits as f32 / result.iteration_count as f32) * 100.0
189 );
190 }
191
192 println!("Side Two:");
193 println!(
194 "\t{:<25}{:>12}{:>12}{:>10}{:>10}",
195 "Move", "Total Score", "Avg Score", "Visits", "% Visits"
196 );
197 for x in result.s2.iter() {
198 println!(
199 "\t{:<25}{:>12.2}{:>12.2}{:>10}{:>10.2}",
200 x.move_choice.to_string(&state.side_two),
201 x.total_score,
202 x.total_score / x.visits as f32,
203 x.visits,
204 (x.visits as f32 / result.iteration_count as f32) * 100.0
205 );
206 }
207}
208
209fn pprint_state_instruction_vector(instructions: &Vec<StateInstructions>) {
210 for (i, instruction) in instructions.iter().enumerate() {
211 println!("Index: {}", i);
212 println!("StateInstruction: {:?}", instruction);
213 }
214}
215
216fn print_subcommand_result(
217 result: &Vec<f32>,
218 side_one_options: &Vec<MoveChoice>,
219 side_two_options: &Vec<MoveChoice>,
220 state: &State,
221) {
222 let safest = pick_safest(&result, side_one_options.len(), side_two_options.len());
223 let move_choice = side_one_options[safest.0];
224
225 let joined_side_one_options = side_one_options
226 .iter()
227 .map(|x| format!("{}", x.to_string(&state.side_one)))
228 .collect::<Vec<String>>()
229 .join(",");
230 println!("side one options: {}", joined_side_one_options);
231
232 let joined_side_two_options = side_two_options
233 .iter()
234 .map(|x| format!("{}", x.to_string(&state.side_two)))
235 .collect::<Vec<String>>()
236 .join(",");
237 println!("side two options: {}", joined_side_two_options);
238
239 let joined = result
240 .iter()
241 .map(|x| format!("{:.2}", x))
242 .collect::<Vec<String>>()
243 .join(",");
244 println!("matrix: {}", joined);
245 println!("choice: {}", move_choice.to_string(&state.side_one));
246 println!("evaluation: {}", safest.1);
247}
248
249pub fn main() {
250 let args = Cli::parse();
251 let mut io_data = IOData::default();
252
253 if args.state != "" {
254 let state = State::deserialize(args.state.as_str());
255 io_data.state = state;
256 }
257
258 let result;
259 let mut state;
260 let mut side_one_options;
261 let mut side_two_options;
262 match args.subcmd {
263 None => {
264 command_loop(io_data);
265 exit(0);
266 }
267 Some(subcmd) => match subcmd {
268 SubCommand::Expectiminimax(expectiminimax) => {
269 state = State::deserialize(expectiminimax.state.as_str());
270 (side_one_options, side_two_options) = state.root_get_all_options();
271 result = expectiminimax_search(
272 &mut state,
273 expectiminimax.depth,
274 side_one_options.clone(),
275 side_two_options.clone(),
276 expectiminimax.ab_prune,
277 &Arc::new(Mutex::new(true)),
278 );
279 print_subcommand_result(&result, &side_one_options, &side_two_options, &state);
280 }
281 SubCommand::IterativeDeepening(iterative_deepending) => {
282 state = State::deserialize(iterative_deepending.state.as_str());
283 (side_one_options, side_two_options) = state.root_get_all_options();
284 (side_one_options, side_two_options, result, _) = iterative_deepen_expectiminimax(
285 &mut state,
286 side_one_options.clone(),
287 side_two_options.clone(),
288 std::time::Duration::from_millis(iterative_deepending.time_to_search_ms),
289 );
290 print_subcommand_result(&result, &side_one_options, &side_two_options, &state);
291 }
292 SubCommand::MonteCarloTreeSearch(mcts) => {
293 state = State::deserialize(mcts.state.as_str());
294 (side_one_options, side_two_options) = state.root_get_all_options();
295 let result = perform_mcts(
296 &mut state,
297 side_one_options.clone(),
298 side_two_options.clone(),
299 std::time::Duration::from_millis(mcts.time_to_search_ms),
300 );
301 print_mcts_result(&state, result);
302 }
303 SubCommand::CalculateDamage(calculate_damage) => {
304 state = State::deserialize(calculate_damage.state.as_str());
305 let mut s1_choice = MOVES
306 .get(&Choices::from_str(calculate_damage.side_one_move.as_str()).unwrap())
307 .unwrap()
308 .to_owned();
309 let mut s2_choice = MOVES
310 .get(&Choices::from_str(calculate_damage.side_two_move.as_str()).unwrap())
311 .unwrap()
312 .to_owned();
313 let s1_moves_first = calculate_damage.side_one_moves_first;
314 if calculate_damage.side_one_move == "switch" {
315 s1_choice.category = MoveCategory::Switch
316 }
317 if calculate_damage.side_two_move == "switch" {
318 s2_choice.category = MoveCategory::Switch
319 }
320 calculate_damage_io(&state, s1_choice, s2_choice, s1_moves_first);
321 }
322 SubCommand::GenerateInstructions(generate_instructions) => {
323 state = State::deserialize(generate_instructions.state.as_str());
324 let (s1_movechoice, s2_movechoice);
325 match MoveChoice::from_string(
326 generate_instructions.side_one_move.as_str(),
327 &state.side_one,
328 ) {
329 None => {
330 println!(
331 "Invalid move choice for side one: {}",
332 generate_instructions.side_one_move
333 );
334 exit(1);
335 }
336 Some(v) => s1_movechoice = v,
337 }
338 match MoveChoice::from_string(
339 generate_instructions.side_two_move.as_str(),
340 &state.side_two,
341 ) {
342 None => {
343 println!(
344 "Invalid move choice for side two: {}",
345 generate_instructions.side_two_move
346 );
347 exit(1);
348 }
349 Some(v) => s2_movechoice = v,
350 }
351 let instructions = generate_instructions_from_move_pair(
352 &mut state,
353 &s1_movechoice,
354 &s2_movechoice,
355 true,
356 );
357 pprint_state_instruction_vector(&instructions);
358 }
359 },
360 }
361
362 exit(0);
363}
364
365fn calculate_damage_io(
366 state: &State,
367 s1_choice: Choice,
368 s2_choice: Choice,
369 side_one_moves_first: bool,
370) {
371 let (damages_dealt_s1, damages_dealt_s2) =
372 calculate_both_damage_rolls(state, s1_choice, s2_choice, side_one_moves_first);
373
374 for dmg in [damages_dealt_s1, damages_dealt_s2] {
375 match dmg {
376 Some(damages_vec) => {
377 let joined = damages_vec
378 .iter()
379 .map(|x| format!("{:?}", x))
380 .collect::<Vec<String>>()
381 .join(",");
382 println!("Damage Rolls: {}", joined);
383 }
384 None => {
385 println!("Damage Rolls: 0");
386 }
387 }
388 }
389}
390
391fn command_loop(mut io_data: IOData) {
392 loop {
393 print!("> ");
394 let _ = io::stdout().flush();
395
396 let mut input = String::new();
397 match io::stdin().read_line(&mut input) {
398 Ok(_) => {}
399 Err(error) => {
400 println!("Error reading input: {}", error);
401 continue;
402 }
403 }
404 let mut parts = input.trim().split_whitespace();
405 let command = parts.next().unwrap_or("");
406 let mut args = parts;
407
408 match command {
409 "state" | "s" => {
410 let state_string;
411 match args.next() {
412 Some(s) => {
413 state_string = s;
414 let state = State::deserialize(state_string);
415 io_data.state = state;
416 println!("state initialized");
417 }
418 None => {
419 println!("Expected state string");
420 }
421 }
422 println!("{:?}", io_data.state);
423 }
424 "serialize" | "ser" => {
425 println!("{}", io_data.state.serialize());
426 }
427 "matchup" | "m" => {
428 println!("{}", io_data.state.pprint());
429 }
430 "generate-instructions" | "g" => {
431 let (s1_move, s2_move);
432 match args.next() {
433 Some(s) => match MoveChoice::from_string(s, &io_data.state.side_one) {
434 Some(m) => {
435 s1_move = m;
436 }
437 None => {
438 println!("Invalid move choice for side one: {}", s);
439 continue;
440 }
441 },
442 None => {
443 println!("Usage: generate-instructions <side-1 move> <side-2 move>");
444 continue;
445 }
446 }
447 match args.next() {
448 Some(s) => match MoveChoice::from_string(s, &io_data.state.side_two) {
449 Some(m) => {
450 s2_move = m;
451 }
452 None => {
453 println!("Invalid move choice for side two: {}", s);
454 continue;
455 }
456 },
457 None => {
458 println!("Usage: generate-instructions <side-1 choice> <side-2 choice>");
459 continue;
460 }
461 }
462 let instructions = generate_instructions_from_move_pair(
463 &mut io_data.state,
464 &s1_move,
465 &s2_move,
466 true,
467 );
468 pprint_state_instruction_vector(&instructions);
469 io_data.last_instructions_generated = instructions;
470 }
471 "calculate-damage" | "d" => {
472 let (mut s1_choice, mut s2_choice);
473 match args.next() {
474 Some(s) => {
475 s1_choice = MOVES
476 .get(&Choices::from_str(s).unwrap())
477 .unwrap()
478 .to_owned();
479 if s == "switch" {
480 s1_choice.category = MoveCategory::Switch
481 }
482 }
483 None => {
484 println!("Usage: calculate-damage <side-1 move> <side-2 move> <side-1-moves-first>");
485 continue;
486 }
487 }
488 match args.next() {
489 Some(s) => {
490 s2_choice = MOVES
491 .get(&Choices::from_str(s).unwrap())
492 .unwrap()
493 .to_owned();
494 if s == "switch" {
495 s2_choice.category = MoveCategory::Switch
496 }
497 }
498 None => {
499 println!("Usage: calculate-damage <side-1 move> <side-2 move> <side-1-moves-first>");
500 continue;
501 }
502 }
503 let s1_moves_first: bool;
504 match args.next() {
505 Some(s) => {
506 s1_moves_first = s.parse::<bool>().unwrap();
507 }
508 None => {
509 println!("Usage: calculate-damage <side-1 move> <side-2 move> <side-1-moves-first>");
510 continue;
511 }
512 }
513 calculate_damage_io(&io_data.state, s1_choice, s2_choice, s1_moves_first);
514 }
515 "instructions" | "i" => {
516 println!("{:?}", io_data.last_instructions_generated);
517 }
518 "evaluate" | "ev" => {
519 println!("Evaluation: {}", evaluate(&io_data.state));
520 }
521 "iterative-deepening" | "id" => match args.next() {
522 Some(s) => {
523 let max_time_ms = s.parse::<u64>().unwrap();
524 let (side_one_options, side_two_options) = io_data.state.root_get_all_options();
525
526 let start_time = std::time::Instant::now();
527 let (s1_moves, s2_moves, result, depth_searched) =
528 iterative_deepen_expectiminimax(
529 &mut io_data.state,
530 side_one_options.clone(),
531 side_two_options.clone(),
532 std::time::Duration::from_millis(max_time_ms),
533 );
534 let elapsed = start_time.elapsed();
535
536 let safest_choice = pick_safest(&result, s1_moves.len(), s2_moves.len());
537
538 pprint_expectiminimax_result(
539 &result,
540 &s1_moves,
541 &s2_moves,
542 &safest_choice,
543 &io_data.state,
544 );
545 println!("Took: {:?}", elapsed);
546 println!("Depth Searched: {}", depth_searched);
547 }
548 None => {
549 println!("Usage: iterative-deepening <timeout_ms>");
550 continue;
551 }
552 },
553 "monte-carlo-tree-search" | "mcts" => match args.next() {
554 Some(s) => {
555 let max_time_ms = s.parse::<u64>().unwrap();
556 let (side_one_options, side_two_options) = io_data.state.root_get_all_options();
557
558 let start_time = std::time::Instant::now();
559 let result = perform_mcts(
560 &mut io_data.state,
561 side_one_options.clone(),
562 side_two_options.clone(),
563 std::time::Duration::from_millis(max_time_ms),
564 );
565 let elapsed = start_time.elapsed();
566 pprint_mcts_result(&io_data.state, result);
567
568 println!("\nTook: {:?}", elapsed);
569 }
570 None => {
571 println!("Usage: monte-carlo-tree-search <timeout_ms>");
572 continue;
573 }
574 },
575 "apply" | "a" => match args.next() {
576 Some(s) => {
577 let index = s.parse::<usize>().unwrap();
578 let instructions = io_data.last_instructions_generated.remove(index);
579 io_data
580 .state
581 .apply_instructions(&instructions.instruction_list);
582 io_data.instruction_list.push(instructions.instruction_list);
583 io_data.last_instructions_generated = Vec::new();
584 println!("Applied instructions at index {}", index)
585 }
586 None => {
587 println!("Usage: apply <instruction index>");
588 continue;
589 }
590 },
591 "pop" | "p" => {
592 if io_data.instruction_list.is_empty() {
593 println!("No instructions to pop");
594 continue;
595 }
596 let instructions = io_data.instruction_list.pop().unwrap();
597 io_data.state.reverse_instructions(&instructions);
598 println!("Popped last applied instructions");
599 }
600 "pop-all" | "pa" => {
601 for i in io_data.instruction_list.iter().rev() {
602 io_data.state.reverse_instructions(i);
603 }
604 io_data.instruction_list.clear();
605 println!("Popped all applied instructions");
606 }
607 "expectiminimax" | "e" => match args.next() {
608 Some(s) => {
609 let mut ab_prune = false;
610 match args.next() {
611 Some(s) => ab_prune = s.parse::<bool>().unwrap(),
612 None => {}
613 }
614 let depth = s.parse::<i8>().unwrap();
615 let (side_one_options, side_two_options) = io_data.state.root_get_all_options();
616 let start_time = std::time::Instant::now();
617 let result = expectiminimax_search(
618 &mut io_data.state,
619 depth,
620 side_one_options.clone(),
621 side_two_options.clone(),
622 ab_prune,
623 &Arc::new(Mutex::new(true)),
624 );
625 let elapsed = start_time.elapsed();
626
627 let safest_choice =
628 pick_safest(&result, side_one_options.len(), side_two_options.len());
629 pprint_expectiminimax_result(
630 &result,
631 &side_one_options,
632 &side_two_options,
633 &safest_choice,
634 &io_data.state,
635 );
636 println!("\nTook: {:?}", elapsed);
637 }
638 None => {
639 println!("Usage: expectiminimax <depth> <ab_prune=false>");
640 continue;
641 }
642 },
643 "" => {
644 continue;
645 }
646 "exit" | "quit" | "q" => {
647 break;
648 }
649 command => {
650 println!("Unknown command: {}", command);
651 }
652 }
653 }
654}