pgn_reader/
reader.rs

1use std::{
2    cmp::min,
3    io::{self, Chain, Cursor, Read},
4};
5
6use shakmaty::{
7    san::{San, SanPlus, Suffix},
8    CastlingSide, Color, Outcome,
9};
10
11// use slice_deque::SliceDeque;
12use crate::{
13    types::{Nag, RawComment, RawHeader, Skip},
14    visitor::{SkipVisitor, Visitor},
15};
16
17const MIN_BUFFER_SIZE: usize = 8192;
18
19trait ReadPgn {
20    type Err;
21
22    /// Fill the buffer. The buffer must then contain at least MIN_BUFFER_SIZE
23    /// bytes or all remaining bytes until the end of the source.
24    fn fill_buffer_and_peek(&mut self) -> Result<Option<u8>, Self::Err>;
25
26    /// Returns the current buffer.
27    fn buffer(&self) -> &[u8];
28
29    /// Consume n bytes from the buffer.
30    fn consume(&mut self, n: usize);
31
32    /// Constructs a parser error.
33    fn invalid_data() -> Self::Err;
34
35    fn peek(&self) -> Option<u8> {
36        self.buffer().first().cloned()
37    }
38
39    fn bump(&mut self) -> Option<u8> {
40        let head = self.peek();
41        if head.is_some() {
42            self.consume(1);
43        }
44        head
45    }
46
47    fn remaining(&self) -> usize {
48        self.buffer().len()
49    }
50
51    fn consume_all(&mut self) {
52        let remaining = self.remaining();
53        self.consume(remaining);
54    }
55
56    fn skip_bom(&mut self) -> Result<(), Self::Err> {
57        self.fill_buffer_and_peek()?;
58        if self.buffer().starts_with(b"\xef\xbb\xbf") {
59            self.consume(3);
60        }
61        Ok(())
62    }
63
64    fn skip_until(&mut self, needle: u8) -> Result<(), Self::Err> {
65        while self.fill_buffer_and_peek()?.is_some() {
66            if let Some(pos) = memchr::memchr(needle, self.buffer()) {
67                self.consume(pos);
68                return Ok(());
69            } else {
70                self.consume_all();
71            }
72        }
73
74        Ok(())
75    }
76
77    fn skip_line(&mut self) -> Result<(), Self::Err> {
78        self.skip_until(b'\n')?;
79        self.bump();
80        Ok(())
81    }
82
83    fn skip_whitespace(&mut self) -> Result<(), Self::Err> {
84        while let Some(ch) = self.fill_buffer_and_peek()? {
85            match ch {
86                b' ' | b'\t' | b'\r' | b'\n' => {
87                    self.bump();
88                }
89                b'%' => {
90                    self.bump();
91                    self.skip_line()?;
92                }
93                _ => return Ok(()),
94            }
95        }
96
97        Ok(())
98    }
99
100    fn skip_ket(&mut self) -> Result<(), Self::Err> {
101        while let Some(ch) = self.fill_buffer_and_peek()? {
102            match ch {
103                b' ' | b'\t' | b'\r' | b']' => {
104                    self.bump();
105                }
106                b'%' => {
107                    self.bump();
108                    self.skip_line()?;
109                    return Ok(());
110                }
111                b'\n' => {
112                    self.bump();
113                    return Ok(());
114                }
115                _ => {
116                    return Ok(());
117                }
118            }
119        }
120
121        Ok(())
122    }
123
124    fn read_headers<V: Visitor>(&mut self, visitor: &mut V) -> Result<(), Self::Err> {
125        while let Some(ch) = self.fill_buffer_and_peek()? {
126            match ch {
127                b'[' => {
128                    self.bump();
129
130                    let left_quote = match memchr::memchr3(b'"', b'\n', b']', self.buffer()) {
131                        Some(left_quote) if self.buffer()[left_quote] == b'"' => left_quote,
132                        Some(eol) => {
133                            self.consume(eol + 1);
134                            self.skip_ket()?;
135                            continue;
136                        }
137                        None => {
138                            self.consume_all();
139                            self.skip_line()?;
140                            return Err(Self::invalid_data());
141                        }
142                    };
143
144                    let space = if left_quote > 0 && self.buffer()[left_quote - 1] == b' ' {
145                        left_quote - 1
146                    } else {
147                        left_quote
148                    };
149
150                    let value_start = left_quote + 1;
151                    let mut right_quote = value_start;
152                    let consumed = loop {
153                        match memchr::memchr3(b'\\', b'"', b'\n', &self.buffer()[right_quote..]) {
154                            Some(delta) if self.buffer()[right_quote + delta] == b'"' => {
155                                right_quote += delta;
156                                break right_quote + 1;
157                            }
158                            Some(delta) if self.buffer()[right_quote + delta] == b'\n' => {
159                                right_quote += delta;
160                                break right_quote;
161                            }
162                            Some(delta) => {
163                                // Skip escaped character.
164                                right_quote = min(right_quote + delta + 2, self.remaining());
165                            }
166                            None => {
167                                self.consume_all();
168                                self.skip_line()?;
169                                return Err(Self::invalid_data());
170                            }
171                        }
172                    };
173
174                    visitor.header(
175                        &self.buffer()[..space],
176                        RawHeader(&self.buffer()[value_start..right_quote]),
177                    );
178                    self.consume(consumed);
179                    self.skip_ket()?;
180                }
181                b'%' => self.skip_line()?,
182                _ => return Ok(()),
183            }
184        }
185
186        Ok(())
187    }
188
189    fn skip_movetext(&mut self) -> Result<(), Self::Err> {
190        while let Some(ch) = self.fill_buffer_and_peek()? {
191            self.bump();
192
193            match ch {
194                b'{' => {
195                    self.skip_until(b'}')?;
196                    self.bump();
197                }
198                b';' => {
199                    self.skip_until(b'\n')?;
200                }
201                b'\n' => match self.peek() {
202                    Some(b'%') => self.skip_until(b'\n')?,
203                    Some(b'\n') | Some(b'[') => break,
204                    Some(b'\r') => {
205                        self.bump();
206                        if let Some(b'\n') = self.peek() {
207                            break;
208                        }
209                    }
210                    _ => continue,
211                },
212                _ => {
213                    if let Some(consumed) = memchr::memchr3(b'\n', b'{', b';', self.buffer()) {
214                        self.consume(consumed);
215                    } else {
216                        self.consume_all();
217                    }
218                }
219            }
220        }
221
222        Ok(())
223    }
224
225    fn find_token_end(&mut self, start: usize) -> usize {
226        let mut end = start;
227        for &ch in &self.buffer()[start..] {
228            match ch {
229                b' ' | b'\t' | b'\n' | b'\r' | b'{' | b'}' | b'(' | b')' | b'!' | b'?' | b'$'
230                | b';' | b'.' => break,
231                _ => end += 1,
232            }
233        }
234        end
235    }
236
237    fn read_movetext<V: Visitor>(&mut self, visitor: &mut V) -> Result<(), Self::Err> {
238        while let Some(ch) = self.fill_buffer_and_peek()? {
239            match ch {
240                b'{' => {
241                    self.bump();
242
243                    let right_brace = if let Some(right_brace) = memchr::memchr(b'}', self.buffer())
244                    {
245                        right_brace
246                    } else {
247                        self.consume_all();
248                        self.skip_until(b'}')?;
249                        self.bump();
250                        return Err(Self::invalid_data());
251                    };
252
253                    visitor.comment(RawComment(&self.buffer()[..right_brace]));
254                    self.consume(right_brace + 1);
255                }
256                b'\n' => {
257                    self.bump();
258
259                    match self.peek() {
260                        Some(b'%') => {
261                            self.bump();
262                            self.skip_line()?;
263                        }
264                        Some(b'[') | Some(b'\n') => {
265                            break;
266                        }
267                        Some(b'\r') => {
268                            self.bump();
269                            if self.peek() == Some(b'\n') {
270                                break;
271                            }
272                        }
273                        _ => continue,
274                    }
275                }
276                b';' => {
277                    self.bump();
278                    self.skip_until(b'\n')?;
279                }
280                b'1' => {
281                    self.bump();
282                    if self.buffer().starts_with(b"-0") {
283                        self.consume(2);
284                        visitor.outcome(Some(Outcome::Decisive {
285                            winner: Color::White,
286                        }));
287                    } else if self.buffer().starts_with(b"/2-1/2") {
288                        self.consume(6);
289                        visitor.outcome(Some(Outcome::Draw));
290                    } else {
291                        let token_end = self.find_token_end(0);
292                        self.consume(token_end);
293                    }
294                }
295                b'0' => {
296                    self.bump();
297                    if self.buffer().starts_with(b"-1") {
298                        self.consume(2);
299                        visitor.outcome(Some(Outcome::Decisive {
300                            winner: Color::Black,
301                        }));
302                    } else if self.buffer().starts_with(b"-0") {
303                        // Castling notation with zeros.
304                        self.consume(2);
305                        let side = if self.buffer().starts_with(b"-0") {
306                            self.consume(2);
307                            CastlingSide::QueenSide
308                        } else {
309                            CastlingSide::KingSide
310                        };
311                        let suffix = match self.peek() {
312                            Some(b'+') => Some(Suffix::Check),
313                            Some(b'#') => Some(Suffix::Checkmate),
314                            _ => None,
315                        };
316                        visitor.san(SanPlus {
317                            san: San::Castle(side),
318                            suffix,
319                        });
320                    } else {
321                        let token_end = self.find_token_end(0);
322                        self.consume(token_end);
323                    }
324                }
325                b'(' => {
326                    self.bump();
327                    if let Skip(true) = visitor.begin_variation() {
328                        self.skip_variation()?;
329                    }
330                }
331                b')' => {
332                    self.bump();
333                    visitor.end_variation();
334                }
335                b'$' => {
336                    self.bump();
337                    let token_end = self.find_token_end(0);
338                    if let Ok(nag) = btoi::btou(&self.buffer()[..token_end]) {
339                        visitor.nag(Nag(nag));
340                    }
341                    self.consume(token_end);
342                }
343                b'!' => {
344                    self.bump();
345                    match self.peek() {
346                        Some(b'!') => {
347                            self.bump();
348                            visitor.nag(Nag::BRILLIANT_MOVE);
349                        }
350                        Some(b'?') => {
351                            self.bump();
352                            visitor.nag(Nag::SPECULATIVE_MOVE);
353                        }
354                        _ => {
355                            visitor.nag(Nag::GOOD_MOVE);
356                        }
357                    }
358                }
359                b'?' => {
360                    self.bump();
361                    match self.peek() {
362                        Some(b'!') => {
363                            self.bump();
364                            visitor.nag(Nag::DUBIOUS_MOVE);
365                        }
366                        Some(b'?') => {
367                            self.bump();
368                            visitor.nag(Nag::BLUNDER);
369                        }
370                        _ => {
371                            visitor.nag(Nag::MISTAKE);
372                        }
373                    }
374                }
375                b'*' => {
376                    visitor.outcome(None);
377                    self.bump();
378                }
379                b' ' | b'\t' | b'\r' | b'P' | b'.' => {
380                    self.bump();
381                }
382                _ => {
383                    let token_end = self.find_token_end(1);
384                    if ch > b'9' || ch == b'-' {
385                        if let Ok(san) = SanPlus::from_ascii(&self.buffer()[..token_end]) {
386                            visitor.san(san);
387                        }
388                    }
389                    self.consume(token_end);
390                }
391            }
392        }
393
394        Ok(())
395    }
396
397    fn skip_variation(&mut self) -> Result<(), Self::Err> {
398        let mut depth = 0usize;
399
400        while let Some(ch) = self.fill_buffer_and_peek()? {
401            match ch {
402                b'(' => {
403                    depth += 1;
404                    self.bump();
405                }
406                b')' => {
407                    if let Some(d) = depth.checked_sub(1) {
408                        self.bump();
409                        depth = d;
410                    } else {
411                        break;
412                    }
413                }
414                b'{' => {
415                    self.bump();
416                    self.skip_until(b'}')?;
417                    self.bump();
418                }
419                b';' => {
420                    self.bump();
421                    self.skip_until(b'\n')?;
422                }
423                b'\n' => {
424                    match self.buffer().get(1).cloned() {
425                        Some(b'%') => {
426                            self.consume(2);
427                            self.skip_until(b'\n')?;
428                        }
429                        Some(b'[') | Some(b'\n') => {
430                            // Do not consume the first or second line break.
431                            break;
432                        }
433                        Some(b'\r') => {
434                            // Do not consume the first or second line break.
435                            if self.buffer().get(2).cloned() == Some(b'\n') {
436                                break;
437                            }
438                        }
439                        _ => {
440                            self.bump();
441                        }
442                    }
443                }
444                _ => {
445                    self.bump();
446                }
447            }
448        }
449
450        Ok(())
451    }
452
453    fn read_game<V: Visitor>(&mut self, visitor: &mut V) -> Result<Option<V::Result>, Self::Err> {
454        self.skip_bom()?;
455        self.skip_whitespace()?;
456
457        if self.fill_buffer_and_peek()?.is_none() {
458            return Ok(None);
459        }
460
461        visitor.begin_game();
462        visitor.begin_headers();
463        self.read_headers(visitor)?;
464        if let Skip(false) = visitor.end_headers() {
465            self.read_movetext(visitor)?;
466        } else {
467            self.skip_movetext()?;
468        }
469
470        self.skip_whitespace()?;
471        Ok(Some(visitor.end_game()))
472    }
473
474    fn skip_game(&mut self) -> Result<bool, Self::Err> {
475        self.read_game(&mut SkipVisitor).map(|r| r.is_some())
476    }
477}
478
479/// Internal read ahead buffer.
480#[derive(Debug, Clone)]
481pub struct Buffer {
482    inner: circular::Buffer,
483}
484
485impl Buffer {
486    fn new() -> Buffer {
487        Buffer {
488            inner: circular::Buffer::with_capacity(MIN_BUFFER_SIZE * 2),
489        }
490    }
491}
492
493impl AsRef<[u8]> for Buffer {
494    fn as_ref(&self) -> &[u8] {
495        self.inner.data()
496    }
497}
498
499/// A buffered PGN reader.
500#[derive(Debug)]
501pub struct BufferedReader<R> {
502    inner: R,
503    buffer: Buffer,
504}
505
506impl<T: AsRef<[u8]>> BufferedReader<Cursor<T>> {
507    /// Create a new reader by wrapping a byte slice in a [`Cursor`].
508    ///
509    /// ```
510    /// use pgn_reader::BufferedReader;
511    ///
512    /// let pgn = b"1. e4 e5 *";
513    /// let reader = BufferedReader::new_cursor(&pgn[..]);
514    /// ```
515    ///
516    /// [`Cursor`]: https://doc.rust-lang.org/std/io/struct.Cursor.html
517    pub fn new_cursor(inner: T) -> BufferedReader<Cursor<T>> {
518        BufferedReader::new(Cursor::new(inner))
519    }
520}
521
522impl<R: Read> BufferedReader<R> {
523    /// Create a new buffered PGN reader.
524    ///
525    /// ```
526    /// # use std::io;
527    /// # fn try_main() -> io::Result<()> {
528    /// use std::fs::File;
529    /// use pgn_reader::BufferedReader;
530    ///
531    /// let file = File::open("example.pgn")?;
532    /// let reader = BufferedReader::new(file);
533    /// # Ok(())
534    /// # }
535    /// ```
536    pub fn new(inner: R) -> BufferedReader<R> {
537        BufferedReader {
538            inner,
539            buffer: Buffer::new(),
540        }
541    }
542
543    /// Read a single game, if any, and returns the result produced by the
544    /// visitor. Returns Ok(None) if the underlying reader is empty.
545    ///
546    /// # Errors
547    ///
548    /// * I/O error from the underlying reader.
549    /// * Irrecoverable parser errors.
550    pub fn read_game<V: Visitor>(&mut self, visitor: &mut V) -> io::Result<Option<V::Result>> {
551        ReadPgn::read_game(self, visitor)
552    }
553
554    /// Skip a single game, if any.
555    ///
556    /// # Errors
557    ///
558    /// * I/O error from the underlying reader.
559    /// * Irrecoverable parser errors.
560    pub fn skip_game<V: Visitor>(&mut self) -> io::Result<bool> {
561        ReadPgn::skip_game(self)
562    }
563
564    /// Read all games.
565    ///
566    /// # Errors
567    ///
568    /// * I/O error from the underlying reader.
569    /// * Irrecoverable parser errors.
570    pub fn read_all<V: Visitor>(&mut self, visitor: &mut V) -> io::Result<()> {
571        while self.read_game(visitor)?.is_some() {}
572        Ok(())
573    }
574
575    /// Create an iterator over all games.
576    ///
577    /// # Errors
578    ///
579    /// * I/O error from the underlying reader.
580    /// * Irrecoverable parser errors.
581    pub fn into_iter<V: Visitor>(self, visitor: &mut V) -> IntoIter<'_, V, R> {
582        IntoIter {
583            reader: self,
584            visitor,
585        }
586    }
587
588    /// Gets the remaining bytes in the buffer and the underlying reader.
589    pub fn into_inner(self) -> Chain<Cursor<Buffer>, R> {
590        Cursor::new(self.buffer).chain(self.inner)
591    }
592
593    /// Returns whether the reader has another game to parse, but does not
594    /// actually parse it.
595    ///
596    /// # Errors
597    ///
598    /// * I/O error from the underlying reader.
599    pub fn has_more(&mut self) -> io::Result<bool> {
600        self.skip_bom()?;
601        self.skip_whitespace()?;
602        Ok(self.fill_buffer_and_peek()?.is_some())
603    }
604}
605
606impl<R: Read> ReadPgn for BufferedReader<R> {
607    type Err = io::Error;
608
609    fn fill_buffer_and_peek(&mut self) -> io::Result<Option<u8>> {
610        while self.buffer.inner.available_data() < MIN_BUFFER_SIZE {
611            let remainder = self.buffer.inner.space();
612            let size = self.inner.read(remainder)?;
613
614            if size == 0 {
615                break;
616            }
617
618            self.buffer.inner.fill(size);
619        }
620
621        Ok(self.buffer.inner.data().first().cloned())
622    }
623
624    fn invalid_data() -> io::Error {
625        io::Error::from(io::ErrorKind::InvalidData)
626    }
627
628    fn buffer(&self) -> &[u8] {
629        self.buffer.inner.data()
630    }
631
632    fn consume(&mut self, bytes: usize) {
633        self.buffer.inner.consume(bytes);
634    }
635
636    fn peek(&self) -> Option<u8> {
637        self.buffer.inner.data().first().cloned()
638    }
639}
640
641/// Iterator returned by
642/// [`BufferedReader::into_iter()`](struct.BufferedReader.html#method.into_iter).
643#[derive(Debug)]
644#[must_use]
645pub struct IntoIter<'a, V: 'a, R> {
646    visitor: &'a mut V,
647    reader: BufferedReader<R>,
648}
649
650impl<'a, V: Visitor, R: Read> Iterator for IntoIter<'a, V, R> {
651    type Item = Result<V::Result, io::Error>;
652
653    fn next(&mut self) -> Option<Self::Item> {
654        match self.reader.read_game(self.visitor) {
655            Ok(Some(result)) => Some(Ok(result)),
656            Ok(None) => None,
657            Err(err) => Some(Err(err)),
658        }
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665
666    struct _AssertObjectSafe<R>(Box<BufferedReader<R>>);
667
668    struct GameCounter {
669        count: usize,
670    }
671
672    impl Default for GameCounter {
673        fn default() -> GameCounter {
674            GameCounter { count: 0 }
675        }
676    }
677
678    impl Visitor for GameCounter {
679        type Result = ();
680
681        fn end_game(&mut self) {
682            self.count += 1;
683        }
684    }
685
686    #[test]
687    fn test_empty_game() -> Result<(), io::Error> {
688        let mut counter = GameCounter::default();
689        let mut reader = BufferedReader::new(io::Cursor::new(b"  "));
690        reader.read_game(&mut counter)?;
691        assert_eq!(counter.count, 0);
692        Ok(())
693    }
694
695    #[test]
696    fn test_trailing_space() -> Result<(), io::Error> {
697        let mut counter = GameCounter::default();
698        let mut reader = BufferedReader::new(io::Cursor::new(b"1. e4 1-0\n\n\n\n\n  \n"));
699        reader.read_game(&mut counter)?;
700        assert_eq!(counter.count, 1);
701        reader.read_game(&mut counter)?;
702        assert_eq!(counter.count, 1);
703        Ok(())
704    }
705
706    #[test]
707    fn test_nag() -> Result<(), io::Error> {
708        struct NagCollector {
709            nags: Vec<Nag>,
710        }
711
712        impl Visitor for NagCollector {
713            type Result = ();
714
715            fn nag(&mut self, nag: Nag) {
716                self.nags.push(nag);
717            }
718
719            fn end_game(&mut self) {}
720        }
721
722        let mut collector = NagCollector { nags: Vec::new() };
723        let mut reader = BufferedReader::new(io::Cursor::new(b"1.f3! e5$71 2.g4?? Qh4#!?"));
724        reader.read_game(&mut collector)?;
725        assert_eq!(
726            collector.nags,
727            vec![Nag::GOOD_MOVE, Nag(71), Nag::BLUNDER, Nag::SPECULATIVE_MOVE]
728        );
729        Ok(())
730    }
731
732    #[test]
733    fn test_null_moves() -> Result<(), io::Error> {
734        struct SanCollector {
735            sans: Vec<San>,
736        }
737
738        impl Visitor for SanCollector {
739            type Result = ();
740
741            fn san(&mut self, san: SanPlus) {
742                self.sans.push(san.san);
743            }
744
745            fn end_game(&mut self) {}
746        }
747
748        let mut collector = SanCollector { sans: Vec::new() };
749        let mut reader = BufferedReader::new(io::Cursor::new(b"1. e4 -- 2. Nf3 -- 3. -- e5"));
750        reader.read_game(&mut collector)?;
751        assert_eq!(collector.sans.len(), 6);
752        assert_ne!(collector.sans[0], San::Null);
753        assert_eq!(collector.sans[1], San::Null);
754        assert_ne!(collector.sans[2], San::Null);
755        assert_eq!(collector.sans[3], San::Null);
756        assert_eq!(collector.sans[4], San::Null);
757        assert_ne!(collector.sans[5], San::Null);
758        Ok(())
759    }
760}