hangman_solver_lib/solver/
char_collection.rs

1// SPDX-License-Identifier: EUPL-1.2
2use crate::solver::infallible_char_collection::InfallibleCharCollection;
3use std::convert::Infallible;
4
5#[cfg(feature = "pyo3")]
6use pyo3::prelude::*;
7
8pub trait CharCollection {
9    type Error;
10
11    fn try_count_chars(&self) -> Result<usize, Self::Error>;
12
13    #[allow(dead_code)]
14    fn try_get_first_char(&self) -> Result<Option<char>, Self::Error>;
15
16    fn try_iter_chars(
17        &self,
18    ) -> Result<impl Iterator<Item = Result<char, Self::Error>> + '_, Self::Error>;
19}
20
21impl<CC: InfallibleCharCollection + ?Sized> CharCollection for CC {
22    type Error = Infallible;
23
24    #[inline]
25    fn try_iter_chars(
26        &self,
27    ) -> Result<impl Iterator<Item = Result<char, Self::Error>> + '_, Self::Error>
28    {
29        Ok(self.iter_chars().map(Result::Ok))
30    }
31
32    #[inline]
33    fn try_count_chars(&self) -> Result<usize, Self::Error> {
34        Ok(CC::char_count(self))
35    }
36
37    #[inline]
38    fn try_get_first_char(&self) -> Result<Option<char>, Self::Error> {
39        Ok(CC::first_char(self))
40    }
41}
42
43#[cfg(feature = "pyo3")]
44impl CharCollection for pyo3::Bound<'_, pyo3::types::PyString> {
45    type Error = PyErr;
46
47    #[inline]
48    fn try_count_chars(&self) -> PyResult<usize> {
49        self.len()
50    }
51
52    #[inline]
53    fn try_get_first_char(&self) -> PyResult<Option<char>> {
54        if self.try_count_chars()? == 0 {
55            Ok(None)
56        } else {
57            Ok(Some(self.get_item(0)?.extract()?))
58        }
59    }
60
61    #[inline]
62    fn try_iter_chars(
63        &self,
64    ) -> PyResult<impl Iterator<Item = Result<char, Self::Error>> + '_> {
65        Ok(self.try_iter()?.map(|ch| ch?.extract::<char>()))
66    }
67}
68
69#[cfg(test)]
70mod test {
71    use itertools::Itertools;
72    use unwrap_infallible::UnwrapInfallible;
73
74    use crate::solver::char_collection::CharCollection;
75
76    #[test]
77    fn test_iter_ascii() {
78        let ascii_strings = ["Hello, world!", "abcde", "test"];
79
80        for string in ascii_strings {
81            assert!(string.is_ascii());
82            assert_eq!(
83                string.try_count_chars().unwrap_infallible(),
84                string.len()
85            );
86        }
87    }
88
89    #[test]
90    fn test_iter_ascii_chars() {
91        let strings = ["µ ASCII TEXT", "äöüßÄÖÜẞ", "🤓🦈"];
92
93        for string in strings.map(String::from) {
94            assert!(!string.is_ascii());
95            assert!(string.try_get_first_char().unwrap_infallible().is_some());
96            assert_eq!(
97                string.chars().count(),
98                string.try_count_chars().unwrap_infallible()
99            );
100            assert_eq!(
101                string.chars().next(),
102                string.try_get_first_char().unwrap_infallible()
103            );
104            assert_eq!(
105                string.chars().collect_vec(),
106                string
107                    .try_iter_chars()
108                    .unwrap_infallible()
109                    .collect::<Result<Vec<_>, _>>()
110                    .unwrap_infallible()
111            );
112        }
113
114        for string in strings {
115            assert!(!string.is_ascii());
116            assert!(string.try_get_first_char().unwrap_infallible().is_some());
117            assert_eq!(
118                string.chars().count(),
119                string.try_count_chars().unwrap_infallible()
120            );
121            assert_eq!(
122                string.chars().next(),
123                string.try_get_first_char().unwrap_infallible()
124            );
125            assert_eq!(
126                string.chars().collect_vec(),
127                string
128                    .try_iter_chars()
129                    .unwrap_infallible()
130                    .collect::<Result<Vec<_>, _>>()
131                    .unwrap_infallible()
132            );
133        }
134    }
135}