reversi_client/
lib.rs

1// reversi_client/src/lib.rs
2
3use rand::Rng;
4use std::io::{Read, Write};
5use std::net::{SocketAddr, TcpStream};
6use std::str::FromStr;
7use thiserror::Error;
8
9#[derive(Debug, Error)]
10pub enum ReversiError {
11    #[error("Connection error: {0}")]
12    ConnectionError(String),
13    #[error("IO error: {0}")]
14    IoError(#[from] std::io::Error),
15    #[error("Protocol error: {0}")]
16    ProtocolError(String),
17}
18
19#[derive(Debug, Copy, Clone, PartialEq, Eq)]
20pub struct Position {
21    pub row: i8,
22    pub col: i8,
23}
24
25// Directions: (dx, dy) for all 8 possible directions
26const DIRECTIONS: [(i8, i8); 8] = [
27            (-1, -1),
28            (-1, 0),
29            (-1, 1),
30            (0, -1),
31            (0, 1),
32            (1, -1),
33            (1, 0),
34            (1, 1),
35        ];
36
37#[derive(Debug, Clone)]
38pub struct Board {
39    pub board: [[i8; 8]; 8],
40}
41
42#[derive(Debug, Clone)]
43pub struct GameState {
44    pub turn: i8,
45    pub round: i32,
46    pub t1: f32,
47    pub t2: f32,
48    pub board: [[i8; 8]; 8], // 0: empty, 1: player1, 2: player2
49}
50
51impl GameState {
52    pub fn get_initial_state() -> Self {
53        GameState {
54            turn: 0,
55            round: 0,
56            t1: 0.0,
57            t2: 0.0,
58            board: [[0; 8]; 8],
59        }
60    }
61
62    pub fn print_board(&self) {
63        for row in &self.board {
64            for &cell in row {
65                print!("{} ", cell);
66            }
67            println!();
68        }
69    }
70
71    fn get_valid_moves_first_4(&self) -> Vec<Position> {
72        let mut moves = Vec::with_capacity(4);
73
74        if self.board[3][3] == 0 {
75            moves.push(Position { row: 3, col: 3 });
76        }
77
78        if self.board[3][4] == 0 {
79            moves.push(Position { row: 3, col: 4 });
80        }
81
82        if self.board[4][3] == 0 {
83            moves.push(Position { row: 4, col: 3 });
84        }
85
86        if self.board[4][4] == 0 {
87            moves.push(Position { row: 4, col: 4 });
88        }
89
90        moves
91    }
92
93    pub fn get_valid_moves(&self, player_num: i8) -> Vec<Position> {
94        if self.round < 4 { // TODO check if this should be a <= (aka, the server starts the game at round 1 vs 0)
95            return self.get_valid_moves_first_4();
96        }
97
98        let mut moves: Vec<Position> = Vec::with_capacity(24);
99        let opponent = 3 - player_num; // Player numbers are 1 or 2
100
101
102        for (i, row) in self.board.iter().enumerate() {
103            for (j, &cell) in row.iter().enumerate() {
104                if cell != 0 {
105                    continue;
106                }
107
108                'directions: for &(dx, dy) in &DIRECTIONS {
109                    let mut x = i as i8 + dx;
110                    let mut y = j as i8 + dy;
111                    let mut found_opponent = false;
112
113                    while x >= 0 && x < 8 && y >= 0 && y < 8 {
114                        let current = self.board[x as usize][y as usize];
115
116                        match current {
117                            // Empty space, can't flank
118                            0 => break,
119                            // Found player's piece after opponent's
120                            p if p == player_num => {
121                                if found_opponent {
122                                    moves.push(Position { row: i as i8, col: j as i8 });
123                                    break 'directions;
124                                }
125                                break;
126                            }
127                            // Found opponent's piece
128                            p if p == opponent => {
129                                found_opponent = true;
130                            }
131                            _ => break,
132                        }
133
134                        x += dx;
135                        y += dy;
136                    }
137                }
138            }
139        }
140
141        moves
142    }
143
144    pub fn get_state_after_move(&self, position: Position, player_num: i8) -> GameState {
145        #[cfg(debug_assertions)]
146        {
147            let valid_moves = self.get_valid_moves(player_num);
148            debug_assert!(valid_moves.contains(&position), "Invalid move attempted: ({}, {}) is not valid", position.row, position.col);
149        }
150        let mut new_board = self.board.clone();
151        new_board[position.row as usize][position.col as usize] = player_num;
152
153        for &(dx, dy) in &DIRECTIONS {
154            let mut x = position.row + dx;
155            let mut y = position.col + dy;
156            let mut found_opponent = false;
157
158            'line: while x >= 0 && x < 8 && y >= 0 && y < 8 {
159                let current = new_board[x as usize][y as usize];
160
161                match current {
162                    // Empty space, can't flank
163                    0 => break 'line,
164                    // Found player's piece after opponent's
165                    p if p == player_num => {
166                        if found_opponent {
167                            // Flip the pieces
168                            let mut flip_x = position.row + dx;
169                            let mut flip_y = position.col + dy;
170
171                            while flip_x < x || flip_y < y {
172                                new_board[flip_x as usize][flip_y as usize] = player_num;
173
174                                flip_x += dx;
175                                flip_y += dy;
176                            }
177                        }
178                        found_opponent = false;
179                    }
180                    // Found opponent's piece
181                    p if p == 3 - player_num => {
182                        found_opponent = true;
183                    }
184                    _ => break,
185                }
186
187                x += dx;
188                y += dy;
189            }
190        }
191
192        GameState {
193            turn: 3 - self.turn, // Turn should be 1 or 2, so this switches them
194            round: self.round + 1, // increment round
195            t1: self.t1, // We'll ignore this for now
196            t2: self.t2,
197            board: new_board, // attach the new board
198        }
199    }
200}
201
202pub trait ReversiStrategy {
203    fn choose_move(&self, game_state: &GameState, player_num: i8) -> Position;
204}
205
206pub struct RandomStrategy;
207
208impl ReversiStrategy for RandomStrategy {
209    fn choose_move(&self, game_state: &GameState, player_num: i8) -> Position {
210        let mut rng = rand::thread_rng();
211        let valid_moves = game_state.get_valid_moves(player_num);
212
213        valid_moves[rng.gen_range(0..valid_moves.len())]
214    }
215}
216
217#[derive(Debug)]
218pub struct ReversiClient<S: ReversiStrategy> {
219    stream: TcpStream,
220    player_number: i8,
221    strategy: S,
222    game_state: GameState,
223}
224
225impl<S: ReversiStrategy> ReversiClient<S> {
226    pub fn connect(
227        server_addr: &str,
228        player_number: i8,
229        strategy: S,
230    ) -> Result<Self, ReversiError> {
231        let port = 3333 + player_number as i32;
232        let addr = SocketAddr::from_str(&format!("{}:{}", server_addr, port))
233            .map_err(|e| ReversiError::ConnectionError(e.to_string()))?;
234
235        let mut stream = TcpStream::connect(addr)?;
236
237        let mut buffer = [0u8; 1024];
238        let mut bytes_read = stream.read(&mut buffer)?;
239
240        while bytes_read == 0 {
241            bytes_read = stream.read(&mut buffer)?;
242        }
243
244        let message = String::from_utf8_lossy(&buffer[..bytes_read]);
245
246        let parts: Vec<&str> = message.split(' ').collect();
247        println!("{:?}", parts);
248        let server_player_number = parts[0].parse().unwrap_or(-1);
249        let game_minutes = parts[1].trim().parse::<f32>().unwrap_or(0.0);
250        println!("Playing a {} minute game", game_minutes);
251
252        if player_number != server_player_number {
253            return Err(ReversiError::ProtocolError(format!(
254                "Player number mismatch: expected {}, got {}",
255                player_number, server_player_number
256            )));
257        }
258
259        Ok(ReversiClient {
260            stream,
261            player_number,
262            strategy,
263            game_state: GameState::get_initial_state(),
264        })
265    }
266
267    pub fn run(&mut self) -> Result<(), ReversiError> {
268        let mut buffer = [0u8; 1024];
269
270        loop {
271            let bytes_read = self.stream.read(&mut buffer)?;
272            if bytes_read == 0 {
273                return Err(ReversiError::ConnectionError("Connection closed".into()));
274            }
275
276            let message = String::from_utf8_lossy(&buffer[..bytes_read]);
277            match self.parse_message(&message) {
278                Ok(()) => {
279                    println!("Parsed state!");
280                    if self.game_state.turn == self.player_number {
281                        let chosen_move = self.strategy.choose_move(&self.game_state, self.player_number);
282                        self.send_move(&chosen_move)?;
283                    }
284                }
285                Err(e) => {
286                    println!("Failed to parse message: {}", e);
287                }
288            }
289        }
290    }
291
292    fn parse_message(&mut self, message: &str) -> Result<(), ReversiError> {
293        let parts: Vec<&str> = message.split('\n').collect();
294        println!("{:?}", parts);
295        let turn=  parts[0].parse::<i32>().unwrap_or(-1);
296        if turn == -999 {
297            println!("Game over");
298            return Err(ReversiError::ProtocolError("Game over".into()));
299        }
300
301        if parts.len() < 69 {
302            return Err(ReversiError::ProtocolError("Message was too short to contain the board".into()));
303        }
304
305        let mut board: [[i8; 8]; 8] = [[0; 8]; 8];
306        let mut index = 4;
307        for i in 0..8 {
308            for j in 0..8 {
309                board[i][j] = parts[index].parse().unwrap_or(0);
310                index += 1;
311            }
312        }
313
314        self.game_state = GameState {
315            turn: turn as i8,
316            round: parts[1].parse().unwrap_or(0),
317            t1: parts[2].parse::<f32>().unwrap_or(0.0),
318            t2: parts[3].parse::<f32>().unwrap_or(0.0),
319            board: board
320        };
321
322        Ok(())
323    }
324
325    // fn get_valid_moves_first_4(&self, board: &[[i8; 8]; 8]) -> Vec<(i8, i8)> {
326    //     let mut moves = Vec::with_capacity(4);
327
328    //     if board[3][3] == 0 {
329    //         moves.push((3, 3));
330    //     }
331
332    //     if board[3][4] == 0 {
333    //         moves.push((3, 4));
334    //     }
335
336    //     if board[4][3] == 0 {
337    //         moves.push((4, 3));
338    //     }
339
340    //     if board[4][4] == 0 {
341    //         moves.push((4, 4));
342    //     }
343
344    //     moves
345    // }
346
347    // pub fn get_valid_moves(&self, board: &[[i8; 8]; 8], player: i8) -> Vec<(i8, i8)> {
348    //     let mut moves: Vec<(i8, i8)> = Vec::with_capacity(24);
349    //     let opponent = 3 - player; // Player numbers are 1 or 2
350
351    //     // Directions: (dx, dy) for all 8 possible directions
352    //     const DIRECTIONS: [(i8, i8); 8] = [
353    //         (-1, -1),
354    //         (-1, 0),
355    //         (-1, 1),
356    //         (0, -1),
357    //         (0, 1),
358    //         (1, -1),
359    //         (1, 0),
360    //         (1, 1),
361    //     ];
362
363    //     for (i, row) in board.iter().enumerate() {
364    //         for (j, &cell) in row.iter().enumerate() {
365    //             if cell != 0 {
366    //                 continue;
367    //             }
368
369    //             'directions: for &(dx, dy) in &DIRECTIONS {
370    //                 let mut x = i as i8 + dx;
371    //                 let mut y = j as i8 + dy;
372    //                 let mut found_opponent = false;
373
374    //                 while x >= 0 && x < 8 && y >= 0 && y < 8 {
375    //                     let current = board[x as usize][y as usize];
376
377    //                     match current {
378    //                         // Empty space, can't flank
379    //                         0 => break,
380    //                         // Found player's piece after opponent's
381    //                         p if p == player => {
382    //                             if found_opponent {
383    //                                 moves.push((i as i8, j as i8));
384    //                                 break 'directions;
385    //                             }
386    //                             break;
387    //                         }
388    //                         // Found opponent's piece
389    //                         p if p == opponent => {
390    //                             found_opponent = true;
391    //                         }
392    //                         _ => break,
393    //                     }
394
395    //                     x += dx;
396    //                     y += dy;
397    //                 }
398    //             }
399    //         }
400    //     }
401
402    //     moves
403    // }
404
405    fn send_move(&mut self, pos: &Position) -> Result<(), ReversiError> {
406        let move_str = format!("{}\n{}\n", pos.row, pos.col);
407        self.stream.write_all(move_str.as_bytes())?;
408        Ok(())
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use std::{net::TcpListener, thread};
415
416    use super::*;
417
418    fn create_mock_stream() -> TcpStream {
419        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
420        let addr = listener.local_addr().unwrap();
421
422        thread::spawn(move || {
423            let (mut socket, _) = listener.accept().unwrap();
424            let _ = socket.write_all(b"");
425        });
426
427        TcpStream::connect(addr).unwrap()
428    }
429
430
431    #[test]
432    fn test_parse_message_valid() {
433        let mut client = ReversiClient {
434            stream: create_mock_stream(),
435            player_number: 1,
436            strategy: RandomStrategy,
437            game_state: GameState::get_initial_state(),
438        };
439
440        let message = "1\n0\n0.0\n0.0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n0\n";
441        client.parse_message(message).unwrap();
442
443        assert_eq!(client.game_state.turn, 1);
444        assert_eq!(client.game_state.round, 0);
445        assert_eq!(client.game_state.t1, 0.0);
446        assert_eq!(client.game_state.t2, 0.0);
447        assert_eq!(client.game_state.board, [[0; 8]; 8]);
448    }
449
450    #[test]
451    fn test_parse_message_invalid() {
452        let mut client = ReversiClient {
453            stream: create_mock_stream(),
454            player_number: 1,
455            strategy: RandomStrategy,
456            game_state: GameState::get_initial_state(),
457        };
458
459        let message = "invalid\nmessage\n";
460        let result = client.parse_message(message);
461
462        assert!(result.is_err());
463    }
464
465    #[test]
466    fn test_parse_message_game_over() {
467        let mut client = ReversiClient {
468            stream: create_mock_stream(),
469            player_number: 1,
470            strategy: RandomStrategy,
471            game_state: GameState::get_initial_state(),
472        };
473
474        let message = "-999\n";
475        let result = client.parse_message(message);
476        if let Err(ReversiError::ProtocolError(ref e)) = result {
477            assert_eq!(e, "Game over");
478        } else {
479            panic!("Expected ProtocolError with 'Game over'");
480        }
481    }
482
483    #[test]
484    fn test_get_valid_moves_first_4() {
485        let client = ReversiClient {
486            stream: create_mock_stream(),
487            player_number: 1,
488            strategy: RandomStrategy,
489            game_state: GameState::get_initial_state(),
490        };
491    
492        let valid_moves = client.game_state.get_valid_moves(1);
493    
494        assert_eq!(valid_moves, vec![
495            Position { row: 3, col: 3 },
496            Position { row: 3, col: 4 },
497            Position { row: 4, col: 3 },
498            Position { row: 4, col: 4 },
499        ]);
500    }
501
502    #[test]
503    fn test_get_valid_moves() {
504        let mut client = ReversiClient {
505            stream: create_mock_stream(),
506            player_number: 1,
507            strategy: RandomStrategy,
508            game_state: GameState::get_initial_state(),
509        };
510
511        let mut board = [[0; 8]; 8];
512        board[3][3] = 2;
513        board[3][4] = 1;
514        board[4][3] = 1;
515        board[4][4] = 2;
516
517        client.game_state.board = board;
518        client.game_state.round = 5;
519
520        let valid_moves = client.game_state.get_valid_moves(1);
521
522        assert_eq!(valid_moves, vec![
523            Position { row: 2, col: 3 },
524            Position { row: 3, col: 2 },
525            Position { row: 4, col: 5 },
526            Position { row: 5, col: 4 },
527        ]);
528    }
529
530    #[test]
531    fn test_send_move() {
532        let mut client = ReversiClient {
533            stream: create_mock_stream(),
534            player_number: 1,
535            strategy: RandomStrategy,
536            game_state: GameState::get_initial_state(),
537        };
538
539        let result = client.send_move(&Position{row: 3, col: 4});
540
541        assert!(result.is_ok());
542    }
543}