1use std::cmp::Ordering;
2use std::convert::{Infallible, TryFrom};
3use std::fmt;
4use std::hash::{Hash, Hasher};
5use std::str::FromStr;
6use std::time::Duration;
7
8use getset::{CopyGetters, Getters, Setters};
9use serde::{Deserialize, Serialize};
10
11use crate::model::sample::{Sample, SampleIter};
12
13#[derive(
14 Serialize, Deserialize, Getters, CopyGetters, Setters, Debug, Clone, PartialEq, Eq, Hash,
15)]
16pub struct Problem {
17 #[get = "pub"]
18 id: ProblemId,
19 #[get = "pub"]
20 name: String,
21 #[get = "pub"]
22 url_name: String,
23 #[serde(with = "humantime_serde")]
24 #[get_copy = "pub"]
25 time_limit: Option<Duration>,
26 #[get_copy = "pub"]
27 memory_limit: Option<Byte>,
28 #[get_copy = "pub"]
29 compare: Compare,
30 #[set = "pub"]
31 samples: Vec<Sample>,
32}
33
34impl Problem {
35 pub fn new(
36 id: impl Into<ProblemId>,
37 name: impl Into<String>,
38 url_name: impl Into<String>,
39 time_limit: Option<Duration>,
40 memory_limit: Option<Byte>,
41 compare: Compare,
42 samples: Vec<Sample>,
43 ) -> Self {
44 Self {
45 id: id.into(),
46 name: name.into(),
47 url_name: url_name.into(),
48 time_limit,
49 memory_limit,
50 compare,
51 samples,
52 }
53 }
54
55 pub fn take_samples(self, sample_name: &Option<String>) -> SampleIter {
56 if let Some(sample_name) = sample_name {
57 self.samples
58 .into_iter()
59 .filter(|sample| sample.name() == sample_name)
60 .collect::<Vec<_>>()
61 .into()
62 } else {
63 self.samples.into()
64 }
65 }
66}
67
68impl Default for Problem {
69 fn default() -> Self {
70 Self::new(
71 "C",
72 "Linear Approximation",
73 "arc100_a",
74 Some(Duration::from_secs(2)),
75 Some("1024 MB".parse().unwrap()),
76 Compare::Default,
77 vec![],
78 )
79 }
80}
81
82#[derive(Serialize, Deserialize, Debug, Clone, Eq)]
83pub struct ProblemId(String);
84
85impl ProblemId {
86 pub fn normalize(&self) -> String {
87 self.0.to_uppercase()
88 }
89}
90
91impl PartialEq<ProblemId> for ProblemId {
92 fn eq(&self, other: &ProblemId) -> bool {
93 self.normalize() == other.normalize()
94 }
95}
96
97impl PartialOrd for ProblemId {
98 fn partial_cmp(&self, other: &ProblemId) -> Option<Ordering> {
99 Some(self.normalize().cmp(&other.normalize()))
100 }
101}
102
103impl Ord for ProblemId {
104 fn cmp(&self, other: &Self) -> Ordering {
105 self.normalize().cmp(&other.normalize())
106 }
107}
108
109impl Hash for ProblemId {
110 fn hash<H: Hasher>(&self, state: &mut H) {
111 self.normalize().hash(state);
112 }
113}
114
115impl<T: Into<String>> From<T> for ProblemId {
116 fn from(id: T) -> Self {
117 Self(id.into())
118 }
119}
120
121impl FromStr for ProblemId {
122 type Err = Infallible;
123
124 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
125 Ok(Self::from(s))
126 }
127}
128
129impl AsRef<str> for ProblemId {
130 fn as_ref(&self) -> &str {
131 &self.0
132 }
133}
134
135impl fmt::Display for ProblemId {
136 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
137 f.write_str(&self.normalize())
138 }
139}
140
141#[derive(
142 Serialize,
143 Deserialize,
144 EnumString,
145 EnumVariantNames,
146 IntoStaticStr,
147 Debug,
148 Copy,
149 Clone,
150 PartialEq,
151 Eq,
152 PartialOrd,
153 Ord,
154 Hash,
155)]
156#[serde(rename_all = "kebab-case")]
157#[strum(serialize_all = "kebab-case")]
158pub enum Compare {
159 Default,
160 }
166
167impl Compare {
168 pub fn compare(self, a: &str, b: &str) -> bool {
169 match self {
170 Self::Default => Self::compare_default(a, b),
171 }
172 }
173
174 fn compare_default(a: &str, b: &str) -> bool {
175 a.trim_end() == b.trim_end() }
177}
178
179#[derive(Serialize, Deserialize, Debug, Copy, Clone, PartialEq, Eq, Hash)]
180#[serde(try_from = "String", into = "String")]
181pub struct Byte(u64);
182
183impl FromStr for Byte {
184 type Err = &'static str;
185
186 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
187 Ok(Self(bytefmt::parse(s)?))
188 }
189}
190
191impl TryFrom<String> for Byte {
192 type Error = &'static str;
193
194 fn try_from(s: String) -> std::result::Result<Self, Self::Error> {
195 Self::from_str(&s)
196 }
197}
198
199impl From<Byte> for String {
200 fn from(byte: Byte) -> Self {
201 bytefmt::format_to(byte.0, bytefmt::Unit::MB)
202 }
203}
204
205impl fmt::Display for Byte {
206 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
207 f.write_str(&String::from(*self))
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn test_problem_take_sapmles() {
217 let samples = vec![
218 Sample::new("name 1", "5", "0"),
219 Sample::new("name 2", "5", "0"),
220 ];
221 let problem = Problem {
222 id: "A".into(),
223 name: "Problem A".into(),
224 url_name: "test_contest_a".into(),
225 time_limit: Some(Duration::from_secs(2)),
226 memory_limit: Some("1024 KB".parse().unwrap()),
227 compare: Compare::Default,
228 samples: samples.clone(),
229 };
230 let tests = &[
231 (Some(String::from("name 2")), vec![&samples[1]]),
232 (None, vec![&samples[0], &samples[1]]),
233 ];
234
235 for (sample_name, expected) in tests {
236 let actual = &problem
237 .clone()
238 .take_samples(sample_name)
239 .collect::<Vec<_>>();
240 assert_eq!(actual.len(), expected.len());
241 let is_all_equal = actual
242 .iter()
243 .zip(expected)
244 .all(|(a, b)| a.as_ref().unwrap() == *b);
245 assert!(is_all_equal);
246 }
247 }
248
249 #[test]
250 fn problem_id_eq() {
251 assert_eq!(ProblemId::from("A"), ProblemId::from("A"));
252 assert_eq!(ProblemId::from("a"), ProblemId::from("A"));
253 }
254
255 #[test]
256 fn test_problem_id_display() {
257 assert_eq!(&ProblemId::from("A").to_string(), "A");
258 assert_eq!(&ProblemId::from("a").to_string(), "A");
259 }
260
261 #[test]
262 fn test_compare() {
263 let tests = &[
264 (Compare::Default, "hoge", "hoge", true),
265 (Compare::Default, "hoge", "hoge ", true),
266 (Compare::Default, "hoge", "hoge\n", true),
267 (Compare::Default, "hoge", " hoge", false),
268 (Compare::Default, "hoge", "\nhoge", false),
269 ];
270
271 for (compare, a, b, expected) in tests {
272 let actual = compare.compare(a, b);
273 assert_eq!(actual, *expected);
274 }
275 }
276
277 #[test]
278 fn test_byte_try_from() -> anyhow::Result<()> {
279 assert_eq!(
280 Byte::try_from(String::from("1024KB")).unwrap(),
281 Byte(1024 * 1000)
282 );
283 assert_eq!(
284 Byte::try_from(String::from("1.2MB")).unwrap(),
285 Byte(1200 * 1000)
286 );
287 Ok(())
288 }
289
290 #[test]
291 fn test_byte_display() {
292 assert_eq!(&Byte(1024 * 1000).to_string(), "1.02 MB");
293 assert_eq!(&Byte(2000 * 1000).to_string(), "2 MB");
294 assert_eq!(&Byte(10 * 1000 * 1000).to_string(), "10 MB");
295 }
296}