acick_util/model/
problem.rs

1use std::cmp::Ordering;
2use std::convert::{Infallible, TryFrom};
3use std::fmt;
4use std::hash::{Hash, Hasher};
5use std::str::FromStr;
6use std::time::Duration;
7
8use getset::{CopyGetters, Getters, Setters};
9use serde::{Deserialize, Serialize};
10
11use crate::model::sample::{Sample, SampleIter};
12
13#[derive(
14    Serialize, Deserialize, Getters, CopyGetters, Setters, Debug, Clone, PartialEq, Eq, Hash,
15)]
16pub struct Problem {
17    #[get = "pub"]
18    id: ProblemId,
19    #[get = "pub"]
20    name: String,
21    #[get = "pub"]
22    url_name: String,
23    #[serde(with = "humantime_serde")]
24    #[get_copy = "pub"]
25    time_limit: Option<Duration>,
26    #[get_copy = "pub"]
27    memory_limit: Option<Byte>,
28    #[get_copy = "pub"]
29    compare: Compare,
30    #[set = "pub"]
31    samples: Vec<Sample>,
32}
33
34impl Problem {
35    pub fn new(
36        id: impl Into<ProblemId>,
37        name: impl Into<String>,
38        url_name: impl Into<String>,
39        time_limit: Option<Duration>,
40        memory_limit: Option<Byte>,
41        compare: Compare,
42        samples: Vec<Sample>,
43    ) -> Self {
44        Self {
45            id: id.into(),
46            name: name.into(),
47            url_name: url_name.into(),
48            time_limit,
49            memory_limit,
50            compare,
51            samples,
52        }
53    }
54
55    pub fn take_samples(self, sample_name: &Option<String>) -> SampleIter {
56        if let Some(sample_name) = sample_name {
57            self.samples
58                .into_iter()
59                .filter(|sample| sample.name() == sample_name)
60                .collect::<Vec<_>>()
61                .into()
62        } else {
63            self.samples.into()
64        }
65    }
66}
67
68impl Default for Problem {
69    fn default() -> Self {
70        Self::new(
71            "C",
72            "Linear Approximation",
73            "arc100_a",
74            Some(Duration::from_secs(2)),
75            Some("1024 MB".parse().unwrap()),
76            Compare::Default,
77            vec![],
78        )
79    }
80}
81
82#[derive(Serialize, Deserialize, Debug, Clone, Eq)]
83pub struct ProblemId(String);
84
85impl ProblemId {
86    pub fn normalize(&self) -> String {
87        self.0.to_uppercase()
88    }
89}
90
91impl PartialEq<ProblemId> for ProblemId {
92    fn eq(&self, other: &ProblemId) -> bool {
93        self.normalize() == other.normalize()
94    }
95}
96
97impl PartialOrd for ProblemId {
98    fn partial_cmp(&self, other: &ProblemId) -> Option<Ordering> {
99        Some(self.normalize().cmp(&other.normalize()))
100    }
101}
102
103impl Ord for ProblemId {
104    fn cmp(&self, other: &Self) -> Ordering {
105        self.normalize().cmp(&other.normalize())
106    }
107}
108
109impl Hash for ProblemId {
110    fn hash<H: Hasher>(&self, state: &mut H) {
111        self.normalize().hash(state);
112    }
113}
114
115impl<T: Into<String>> From<T> for ProblemId {
116    fn from(id: T) -> Self {
117        Self(id.into())
118    }
119}
120
121impl FromStr for ProblemId {
122    type Err = Infallible;
123
124    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
125        Ok(Self::from(s))
126    }
127}
128
129impl AsRef<str> for ProblemId {
130    fn as_ref(&self) -> &str {
131        &self.0
132    }
133}
134
135impl fmt::Display for ProblemId {
136    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
137        f.write_str(&self.normalize())
138    }
139}
140
141#[derive(
142    Serialize,
143    Deserialize,
144    EnumString,
145    EnumVariantNames,
146    IntoStaticStr,
147    Debug,
148    Copy,
149    Clone,
150    PartialEq,
151    Eq,
152    PartialOrd,
153    Ord,
154    Hash,
155)]
156#[serde(rename_all = "kebab-case")]
157#[strum(serialize_all = "kebab-case")]
158pub enum Compare {
159    Default,
160    // TODO: support float
161    // Float {
162    //     relative_error: Option<f64>,
163    //     absolute_error: Option<f64>,
164    // },
165}
166
167impl Compare {
168    pub fn compare(self, a: &str, b: &str) -> bool {
169        match self {
170            Self::Default => Self::compare_default(a, b),
171        }
172    }
173
174    fn compare_default(a: &str, b: &str) -> bool {
175        a.trim_end() == b.trim_end() // ignore spaces at the end of lines
176    }
177}
178
179#[derive(Serialize, Deserialize, Debug, Copy, Clone, PartialEq, Eq, Hash)]
180#[serde(try_from = "String", into = "String")]
181pub struct Byte(u64);
182
183impl FromStr for Byte {
184    type Err = &'static str;
185
186    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
187        Ok(Self(bytefmt::parse(s)?))
188    }
189}
190
191impl TryFrom<String> for Byte {
192    type Error = &'static str;
193
194    fn try_from(s: String) -> std::result::Result<Self, Self::Error> {
195        Self::from_str(&s)
196    }
197}
198
199impl From<Byte> for String {
200    fn from(byte: Byte) -> Self {
201        bytefmt::format_to(byte.0, bytefmt::Unit::MB)
202    }
203}
204
205impl fmt::Display for Byte {
206    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
207        f.write_str(&String::from(*self))
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_problem_take_sapmles() {
217        let samples = vec![
218            Sample::new("name 1", "5", "0"),
219            Sample::new("name 2", "5", "0"),
220        ];
221        let problem = Problem {
222            id: "A".into(),
223            name: "Problem A".into(),
224            url_name: "test_contest_a".into(),
225            time_limit: Some(Duration::from_secs(2)),
226            memory_limit: Some("1024 KB".parse().unwrap()),
227            compare: Compare::Default,
228            samples: samples.clone(),
229        };
230        let tests = &[
231            (Some(String::from("name 2")), vec![&samples[1]]),
232            (None, vec![&samples[0], &samples[1]]),
233        ];
234
235        for (sample_name, expected) in tests {
236            let actual = &problem
237                .clone()
238                .take_samples(sample_name)
239                .collect::<Vec<_>>();
240            assert_eq!(actual.len(), expected.len());
241            let is_all_equal = actual
242                .iter()
243                .zip(expected)
244                .all(|(a, b)| a.as_ref().unwrap() == *b);
245            assert!(is_all_equal);
246        }
247    }
248
249    #[test]
250    fn problem_id_eq() {
251        assert_eq!(ProblemId::from("A"), ProblemId::from("A"));
252        assert_eq!(ProblemId::from("a"), ProblemId::from("A"));
253    }
254
255    #[test]
256    fn test_problem_id_display() {
257        assert_eq!(&ProblemId::from("A").to_string(), "A");
258        assert_eq!(&ProblemId::from("a").to_string(), "A");
259    }
260
261    #[test]
262    fn test_compare() {
263        let tests = &[
264            (Compare::Default, "hoge", "hoge", true),
265            (Compare::Default, "hoge", "hoge  ", true),
266            (Compare::Default, "hoge", "hoge\n", true),
267            (Compare::Default, "hoge", "  hoge", false),
268            (Compare::Default, "hoge", "\nhoge", false),
269        ];
270
271        for (compare, a, b, expected) in tests {
272            let actual = compare.compare(a, b);
273            assert_eq!(actual, *expected);
274        }
275    }
276
277    #[test]
278    fn test_byte_try_from() -> anyhow::Result<()> {
279        assert_eq!(
280            Byte::try_from(String::from("1024KB")).unwrap(),
281            Byte(1024 * 1000)
282        );
283        assert_eq!(
284            Byte::try_from(String::from("1.2MB")).unwrap(),
285            Byte(1200 * 1000)
286        );
287        Ok(())
288    }
289
290    #[test]
291    fn test_byte_display() {
292        assert_eq!(&Byte(1024 * 1000).to_string(), "1.02 MB");
293        assert_eq!(&Byte(2000 * 1000).to_string(), "2 MB");
294        assert_eq!(&Byte(10 * 1000 * 1000).to_string(), "10 MB");
295    }
296}