poke_engine/
io.rs

1use crate::choices::{Choice, Choices, MoveCategory, MOVES};
2use crate::engine::evaluate::evaluate;
3use crate::engine::generate_instructions::{
4    calculate_both_damage_rolls, generate_instructions_from_move_pair,
5};
6use crate::engine::state::MoveChoice;
7use crate::instruction::{Instruction, StateInstructions};
8use crate::mcts::{perform_mcts, MctsResult};
9use crate::search::{expectiminimax_search, iterative_deepen_expectiminimax, pick_safest};
10use crate::state::State;
11use clap::Parser;
12use std::io;
13use std::io::Write;
14use std::process::exit;
15use std::str::FromStr;
16use std::sync::{Arc, Mutex};
17
18struct IOData {
19    state: State,
20    instruction_list: Vec<Vec<Instruction>>,
21    last_instructions_generated: Vec<StateInstructions>,
22}
23
24#[derive(Parser)]
25struct Cli {
26    #[clap(short, long, default_value = "")]
27    state: String,
28
29    #[clap(subcommand)]
30    subcmd: Option<SubCommand>,
31}
32
33#[derive(Parser)]
34enum SubCommand {
35    Expectiminimax(Expectiminimax),
36    IterativeDeepening(IterativeDeepening),
37    MonteCarloTreeSearch(MonteCarloTreeSearch),
38    CalculateDamage(CalculateDamage),
39    GenerateInstructions(GenerateInstructions),
40}
41
42#[derive(Parser)]
43struct Expectiminimax {
44    #[clap(short, long, required = true)]
45    state: String,
46
47    #[clap(short, long, default_value_t = false)]
48    ab_prune: bool,
49
50    #[clap(short, long, default_value_t = 2)]
51    depth: i8,
52}
53
54#[derive(Parser)]
55struct IterativeDeepening {
56    #[clap(short, long, required = true)]
57    state: String,
58
59    #[clap(short, long, default_value_t = 5000)]
60    time_to_search_ms: u64,
61}
62
63#[derive(Parser)]
64struct MonteCarloTreeSearch {
65    #[clap(short, long, required = true)]
66    state: String,
67
68    #[clap(short, long, default_value_t = 5000)]
69    time_to_search_ms: u64,
70}
71
72#[derive(Parser)]
73struct CalculateDamage {
74    #[clap(short, long, required = true)]
75    state: String,
76
77    #[clap(short = 'o', long, required = true)]
78    side_one_move: String,
79
80    #[clap(short = 't', long, required = true)]
81    side_two_move: String,
82
83    #[clap(short = 'm', long, required = false, default_value_t = false)]
84    side_one_moves_first: bool,
85}
86
87#[derive(Parser)]
88struct GenerateInstructions {
89    #[clap(short, long, required = true)]
90    state: String,
91
92    #[clap(short = 'o', long, required = true)]
93    side_one_move: String,
94
95    #[clap(short = 't', long, required = true)]
96    side_two_move: String,
97}
98
99impl Default for IOData {
100    fn default() -> Self {
101        IOData {
102            state: State::default(),
103            instruction_list: Vec::new(),
104            last_instructions_generated: Vec::new(),
105        }
106    }
107}
108
109fn pprint_expectiminimax_result(
110    result: &Vec<f32>,
111    s1_options: &Vec<MoveChoice>,
112    s2_options: &Vec<MoveChoice>,
113    safest_choice: &(usize, f32),
114    state: &State,
115) {
116    let s1_len = s1_options.len();
117    let s2_len = s2_options.len();
118
119    print!("{: <12}", " ");
120
121    for s2_move in s2_options.iter() {
122        print!("{: >12}", s2_move.to_string(&state.side_two));
123    }
124    print!("\n");
125
126    for i in 0..s1_len {
127        let s1_move_str = s1_options[i];
128        print!("{:<12}", s1_move_str.to_string(&state.side_one));
129        for j in 0..s2_len {
130            let index = i * s2_len + j;
131            print!("{number:>11.2} ", number = result[index]);
132        }
133        print!("\n");
134    }
135    print!(
136        "{:<12}",
137        s1_options[safest_choice.0].to_string(&state.side_one)
138    );
139}
140
141fn print_mcts_result(state: &State, result: MctsResult) {
142    let s1_joined_options = result
143        .s1
144        .iter()
145        .map(|x| {
146            format!(
147                "{},{:.2},{}",
148                x.move_choice.to_string(&state.side_one),
149                x.total_score,
150                x.visits
151            )
152        })
153        .collect::<Vec<String>>()
154        .join("|");
155    let s2_joined_options = result
156        .s2
157        .iter()
158        .map(|x| {
159            format!(
160                "{},{:.2},{}",
161                x.move_choice.to_string(&state.side_two),
162                x.total_score,
163                x.visits
164            )
165        })
166        .collect::<Vec<String>>()
167        .join("|");
168
169    println!("Total Iterations: {}", result.iteration_count);
170    println!("side one: {}", s1_joined_options);
171    println!("side two: {}", s2_joined_options);
172}
173
174fn pprint_mcts_result(state: &State, result: MctsResult) {
175    println!("\nTotal Iterations: {}\n", result.iteration_count);
176    println!("Side One:");
177    println!(
178        "\t{:<25}{:>12}{:>12}{:>10}{:>10}",
179        "Move", "Total Score", "Avg Score", "Visits", "% Visits"
180    );
181    for x in result.s1.iter() {
182        println!(
183            "\t{:<25}{:>12.2}{:>12.2}{:>10}{:>10.2}",
184            x.move_choice.to_string(&state.side_one),
185            x.total_score,
186            x.total_score / x.visits as f32,
187            x.visits,
188            (x.visits as f32 / result.iteration_count as f32) * 100.0
189        );
190    }
191
192    println!("Side Two:");
193    println!(
194        "\t{:<25}{:>12}{:>12}{:>10}{:>10}",
195        "Move", "Total Score", "Avg Score", "Visits", "% Visits"
196    );
197    for x in result.s2.iter() {
198        println!(
199            "\t{:<25}{:>12.2}{:>12.2}{:>10}{:>10.2}",
200            x.move_choice.to_string(&state.side_two),
201            x.total_score,
202            x.total_score / x.visits as f32,
203            x.visits,
204            (x.visits as f32 / result.iteration_count as f32) * 100.0
205        );
206    }
207}
208
209fn pprint_state_instruction_vector(instructions: &Vec<StateInstructions>) {
210    for (i, instruction) in instructions.iter().enumerate() {
211        println!("Index: {}", i);
212        println!("StateInstruction: {:?}", instruction);
213    }
214}
215
216fn print_subcommand_result(
217    result: &Vec<f32>,
218    side_one_options: &Vec<MoveChoice>,
219    side_two_options: &Vec<MoveChoice>,
220    state: &State,
221) {
222    let safest = pick_safest(&result, side_one_options.len(), side_two_options.len());
223    let move_choice = side_one_options[safest.0];
224
225    let joined_side_one_options = side_one_options
226        .iter()
227        .map(|x| format!("{}", x.to_string(&state.side_one)))
228        .collect::<Vec<String>>()
229        .join(",");
230    println!("side one options: {}", joined_side_one_options);
231
232    let joined_side_two_options = side_two_options
233        .iter()
234        .map(|x| format!("{}", x.to_string(&state.side_two)))
235        .collect::<Vec<String>>()
236        .join(",");
237    println!("side two options: {}", joined_side_two_options);
238
239    let joined = result
240        .iter()
241        .map(|x| format!("{:.2}", x))
242        .collect::<Vec<String>>()
243        .join(",");
244    println!("matrix: {}", joined);
245    println!("choice: {}", move_choice.to_string(&state.side_one));
246    println!("evaluation: {}", safest.1);
247}
248
249pub fn main() {
250    let args = Cli::parse();
251    let mut io_data = IOData::default();
252
253    if args.state != "" {
254        let state = State::deserialize(args.state.as_str());
255        io_data.state = state;
256    }
257
258    let result;
259    let mut state;
260    let mut side_one_options;
261    let mut side_two_options;
262    match args.subcmd {
263        None => {
264            command_loop(io_data);
265            exit(0);
266        }
267        Some(subcmd) => match subcmd {
268            SubCommand::Expectiminimax(expectiminimax) => {
269                state = State::deserialize(expectiminimax.state.as_str());
270                (side_one_options, side_two_options) = state.root_get_all_options();
271                result = expectiminimax_search(
272                    &mut state,
273                    expectiminimax.depth,
274                    side_one_options.clone(),
275                    side_two_options.clone(),
276                    expectiminimax.ab_prune,
277                    &Arc::new(Mutex::new(true)),
278                );
279                print_subcommand_result(&result, &side_one_options, &side_two_options, &state);
280            }
281            SubCommand::IterativeDeepening(iterative_deepending) => {
282                state = State::deserialize(iterative_deepending.state.as_str());
283                (side_one_options, side_two_options) = state.root_get_all_options();
284                (side_one_options, side_two_options, result, _) = iterative_deepen_expectiminimax(
285                    &mut state,
286                    side_one_options.clone(),
287                    side_two_options.clone(),
288                    std::time::Duration::from_millis(iterative_deepending.time_to_search_ms),
289                );
290                print_subcommand_result(&result, &side_one_options, &side_two_options, &state);
291            }
292            SubCommand::MonteCarloTreeSearch(mcts) => {
293                state = State::deserialize(mcts.state.as_str());
294                (side_one_options, side_two_options) = state.root_get_all_options();
295                let result = perform_mcts(
296                    &mut state,
297                    side_one_options.clone(),
298                    side_two_options.clone(),
299                    std::time::Duration::from_millis(mcts.time_to_search_ms),
300                );
301                print_mcts_result(&state, result);
302            }
303            SubCommand::CalculateDamage(calculate_damage) => {
304                state = State::deserialize(calculate_damage.state.as_str());
305                let mut s1_choice = MOVES
306                    .get(&Choices::from_str(calculate_damage.side_one_move.as_str()).unwrap())
307                    .unwrap()
308                    .to_owned();
309                let mut s2_choice = MOVES
310                    .get(&Choices::from_str(calculate_damage.side_two_move.as_str()).unwrap())
311                    .unwrap()
312                    .to_owned();
313                let s1_moves_first = calculate_damage.side_one_moves_first;
314                if calculate_damage.side_one_move == "switch" {
315                    s1_choice.category = MoveCategory::Switch
316                }
317                if calculate_damage.side_two_move == "switch" {
318                    s2_choice.category = MoveCategory::Switch
319                }
320                calculate_damage_io(&state, s1_choice, s2_choice, s1_moves_first);
321            }
322            SubCommand::GenerateInstructions(generate_instructions) => {
323                state = State::deserialize(generate_instructions.state.as_str());
324                let (s1_movechoice, s2_movechoice);
325                match MoveChoice::from_string(
326                    generate_instructions.side_one_move.as_str(),
327                    &state.side_one,
328                ) {
329                    None => {
330                        println!(
331                            "Invalid move choice for side one: {}",
332                            generate_instructions.side_one_move
333                        );
334                        exit(1);
335                    }
336                    Some(v) => s1_movechoice = v,
337                }
338                match MoveChoice::from_string(
339                    generate_instructions.side_two_move.as_str(),
340                    &state.side_two,
341                ) {
342                    None => {
343                        println!(
344                            "Invalid move choice for side two: {}",
345                            generate_instructions.side_two_move
346                        );
347                        exit(1);
348                    }
349                    Some(v) => s2_movechoice = v,
350                }
351                let instructions = generate_instructions_from_move_pair(
352                    &mut state,
353                    &s1_movechoice,
354                    &s2_movechoice,
355                    true,
356                );
357                pprint_state_instruction_vector(&instructions);
358            }
359        },
360    }
361
362    exit(0);
363}
364
365fn calculate_damage_io(
366    state: &State,
367    s1_choice: Choice,
368    s2_choice: Choice,
369    side_one_moves_first: bool,
370) {
371    let (damages_dealt_s1, damages_dealt_s2) =
372        calculate_both_damage_rolls(state, s1_choice, s2_choice, side_one_moves_first);
373
374    for dmg in [damages_dealt_s1, damages_dealt_s2] {
375        match dmg {
376            Some(damages_vec) => {
377                let joined = damages_vec
378                    .iter()
379                    .map(|x| format!("{:?}", x))
380                    .collect::<Vec<String>>()
381                    .join(",");
382                println!("Damage Rolls: {}", joined);
383            }
384            None => {
385                println!("Damage Rolls: 0");
386            }
387        }
388    }
389}
390
391fn command_loop(mut io_data: IOData) {
392    loop {
393        print!("> ");
394        let _ = io::stdout().flush();
395
396        let mut input = String::new();
397        match io::stdin().read_line(&mut input) {
398            Ok(_) => {}
399            Err(error) => {
400                println!("Error reading input: {}", error);
401                continue;
402            }
403        }
404        let mut parts = input.trim().split_whitespace();
405        let command = parts.next().unwrap_or("");
406        let mut args = parts;
407
408        match command {
409            "state" | "s" => {
410                let state_string;
411                match args.next() {
412                    Some(s) => {
413                        state_string = s;
414                        let state = State::deserialize(state_string);
415                        io_data.state = state;
416                        println!("state initialized");
417                    }
418                    None => {
419                        println!("Expected state string");
420                    }
421                }
422                println!("{:?}", io_data.state);
423            }
424            "serialize" | "ser" => {
425                println!("{}", io_data.state.serialize());
426            }
427            "matchup" | "m" => {
428                println!("{}", io_data.state.pprint());
429            }
430            "generate-instructions" | "g" => {
431                let (s1_move, s2_move);
432                match args.next() {
433                    Some(s) => match MoveChoice::from_string(s, &io_data.state.side_one) {
434                        Some(m) => {
435                            s1_move = m;
436                        }
437                        None => {
438                            println!("Invalid move choice for side one: {}", s);
439                            continue;
440                        }
441                    },
442                    None => {
443                        println!("Usage: generate-instructions <side-1 move> <side-2 move>");
444                        continue;
445                    }
446                }
447                match args.next() {
448                    Some(s) => match MoveChoice::from_string(s, &io_data.state.side_two) {
449                        Some(m) => {
450                            s2_move = m;
451                        }
452                        None => {
453                            println!("Invalid move choice for side two: {}", s);
454                            continue;
455                        }
456                    },
457                    None => {
458                        println!("Usage: generate-instructions <side-1 choice> <side-2 choice>");
459                        continue;
460                    }
461                }
462                let instructions = generate_instructions_from_move_pair(
463                    &mut io_data.state,
464                    &s1_move,
465                    &s2_move,
466                    true,
467                );
468                pprint_state_instruction_vector(&instructions);
469                io_data.last_instructions_generated = instructions;
470            }
471            "calculate-damage" | "d" => {
472                let (mut s1_choice, mut s2_choice);
473                match args.next() {
474                    Some(s) => {
475                        s1_choice = MOVES
476                            .get(&Choices::from_str(s).unwrap())
477                            .unwrap()
478                            .to_owned();
479                        if s == "switch" {
480                            s1_choice.category = MoveCategory::Switch
481                        }
482                    }
483                    None => {
484                        println!("Usage: calculate-damage <side-1 move> <side-2 move> <side-1-moves-first>");
485                        continue;
486                    }
487                }
488                match args.next() {
489                    Some(s) => {
490                        s2_choice = MOVES
491                            .get(&Choices::from_str(s).unwrap())
492                            .unwrap()
493                            .to_owned();
494                        if s == "switch" {
495                            s2_choice.category = MoveCategory::Switch
496                        }
497                    }
498                    None => {
499                        println!("Usage: calculate-damage <side-1 move> <side-2 move> <side-1-moves-first>");
500                        continue;
501                    }
502                }
503                let s1_moves_first: bool;
504                match args.next() {
505                    Some(s) => {
506                        s1_moves_first = s.parse::<bool>().unwrap();
507                    }
508                    None => {
509                        println!("Usage: calculate-damage <side-1 move> <side-2 move> <side-1-moves-first>");
510                        continue;
511                    }
512                }
513                calculate_damage_io(&io_data.state, s1_choice, s2_choice, s1_moves_first);
514            }
515            "instructions" | "i" => {
516                println!("{:?}", io_data.last_instructions_generated);
517            }
518            "evaluate" | "ev" => {
519                println!("Evaluation: {}", evaluate(&io_data.state));
520            }
521            "iterative-deepening" | "id" => match args.next() {
522                Some(s) => {
523                    let max_time_ms = s.parse::<u64>().unwrap();
524                    let (side_one_options, side_two_options) = io_data.state.root_get_all_options();
525
526                    let start_time = std::time::Instant::now();
527                    let (s1_moves, s2_moves, result, depth_searched) =
528                        iterative_deepen_expectiminimax(
529                            &mut io_data.state,
530                            side_one_options.clone(),
531                            side_two_options.clone(),
532                            std::time::Duration::from_millis(max_time_ms),
533                        );
534                    let elapsed = start_time.elapsed();
535
536                    let safest_choice = pick_safest(&result, s1_moves.len(), s2_moves.len());
537
538                    pprint_expectiminimax_result(
539                        &result,
540                        &s1_moves,
541                        &s2_moves,
542                        &safest_choice,
543                        &io_data.state,
544                    );
545                    println!("Took: {:?}", elapsed);
546                    println!("Depth Searched: {}", depth_searched);
547                }
548                None => {
549                    println!("Usage: iterative-deepening <timeout_ms>");
550                    continue;
551                }
552            },
553            "monte-carlo-tree-search" | "mcts" => match args.next() {
554                Some(s) => {
555                    let max_time_ms = s.parse::<u64>().unwrap();
556                    let (side_one_options, side_two_options) = io_data.state.root_get_all_options();
557
558                    let start_time = std::time::Instant::now();
559                    let result = perform_mcts(
560                        &mut io_data.state,
561                        side_one_options.clone(),
562                        side_two_options.clone(),
563                        std::time::Duration::from_millis(max_time_ms),
564                    );
565                    let elapsed = start_time.elapsed();
566                    pprint_mcts_result(&io_data.state, result);
567
568                    println!("\nTook: {:?}", elapsed);
569                }
570                None => {
571                    println!("Usage: monte-carlo-tree-search <timeout_ms>");
572                    continue;
573                }
574            },
575            "apply" | "a" => match args.next() {
576                Some(s) => {
577                    let index = s.parse::<usize>().unwrap();
578                    let instructions = io_data.last_instructions_generated.remove(index);
579                    io_data
580                        .state
581                        .apply_instructions(&instructions.instruction_list);
582                    io_data.instruction_list.push(instructions.instruction_list);
583                    io_data.last_instructions_generated = Vec::new();
584                    println!("Applied instructions at index {}", index)
585                }
586                None => {
587                    println!("Usage: apply <instruction index>");
588                    continue;
589                }
590            },
591            "pop" | "p" => {
592                if io_data.instruction_list.is_empty() {
593                    println!("No instructions to pop");
594                    continue;
595                }
596                let instructions = io_data.instruction_list.pop().unwrap();
597                io_data.state.reverse_instructions(&instructions);
598                println!("Popped last applied instructions");
599            }
600            "pop-all" | "pa" => {
601                for i in io_data.instruction_list.iter().rev() {
602                    io_data.state.reverse_instructions(i);
603                }
604                io_data.instruction_list.clear();
605                println!("Popped all applied instructions");
606            }
607            "expectiminimax" | "e" => match args.next() {
608                Some(s) => {
609                    let mut ab_prune = false;
610                    match args.next() {
611                        Some(s) => ab_prune = s.parse::<bool>().unwrap(),
612                        None => {}
613                    }
614                    let depth = s.parse::<i8>().unwrap();
615                    let (side_one_options, side_two_options) = io_data.state.root_get_all_options();
616                    let start_time = std::time::Instant::now();
617                    let result = expectiminimax_search(
618                        &mut io_data.state,
619                        depth,
620                        side_one_options.clone(),
621                        side_two_options.clone(),
622                        ab_prune,
623                        &Arc::new(Mutex::new(true)),
624                    );
625                    let elapsed = start_time.elapsed();
626
627                    let safest_choice =
628                        pick_safest(&result, side_one_options.len(), side_two_options.len());
629                    pprint_expectiminimax_result(
630                        &result,
631                        &side_one_options,
632                        &side_two_options,
633                        &safest_choice,
634                        &io_data.state,
635                    );
636                    println!("\nTook: {:?}", elapsed);
637                }
638                None => {
639                    println!("Usage: expectiminimax <depth> <ab_prune=false>");
640                    continue;
641                }
642            },
643            "" => {
644                continue;
645            }
646            "exit" | "quit" | "q" => {
647                break;
648            }
649            command => {
650                println!("Unknown command: {}", command);
651            }
652        }
653    }
654}