apt_edsp/scenario/
version.rs

1use std::cmp::Ordering;
2use std::fmt::{Display, Formatter};
3use std::num::ParseIntError;
4use std::ops::Range;
5
6use serde::{Deserialize, Serialize};
7
8use crate::util::TryFromStringVisitor;
9
10/// The version number of a package.
11///
12/// Implements [`Ord`] based on the comparison rules defined in the [Debian Policy Manual][man].
13///
14/// See [the manual][man] for more information.
15///
16/// [man]: https://www.debian.org/doc/debian-policy/ch-controlfields.html#version
17#[derive(Clone, Debug, Default)]
18pub struct Version {
19    epoch: usize,
20    version: Range<usize>,
21    revision: Range<usize>,
22    original: String,
23}
24
25impl Version {
26    /// The epoch of the version number.
27    pub fn epoch(&self) -> usize {
28        self.epoch
29    }
30
31    /// The main part of the version number. Equivalent to the `upstream_version`.
32    pub fn version(&self) -> &str {
33        &self.original[self.version.clone()]
34    }
35
36    /// The version of the Debian package based on the upstream version. Equivalent to
37    /// the `debian_revision`.
38    pub fn revision(&self) -> &str {
39        &self.original[self.revision.clone()]
40    }
41
42    /// Returns the string representation of this version number.
43    pub fn as_str(&self) -> &str {
44        &self.original
45    }
46}
47
48impl AsRef<str> for Version {
49    fn as_ref(&self) -> &str {
50        self.as_str()
51    }
52}
53
54impl Display for Version {
55    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56        Display::fmt(&self.original, f)
57    }
58}
59
60impl Eq for Version {}
61
62impl PartialEq<Self> for Version {
63    fn eq(&self, other: &Self) -> bool {
64        self.epoch == other.epoch
65            && self.version() == other.version()
66            && self.revision() == other.revision()
67    }
68}
69
70impl std::hash::Hash for Version {
71    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
72        self.epoch.hash(state);
73        self.version().hash(state);
74        self.revision().hash(state);
75    }
76}
77
78fn cmp_non_digit(a: &mut &[u8], b: &mut &[u8]) -> Ordering {
79    while !a.is_empty() || !b.is_empty() {
80        match (
81            a.first().filter(|c| !c.is_ascii_digit()),
82            b.first().filter(|c| !c.is_ascii_digit()),
83        ) {
84            (None, None) => return Ordering::Equal,
85            (Some(c_a), Some(c_b)) if c_a == c_b => {}
86            (Some(b'~'), _) => return Ordering::Less,
87            (_, Some(b'~')) => return Ordering::Greater,
88            (Some(_), None) => return Ordering::Greater,
89            (None, Some(_)) => return Ordering::Less,
90            (Some(c_a), Some(c_b)) => {
91                if c_a != c_b {
92                    return match (c_a.is_ascii_alphabetic(), c_b.is_ascii_alphabetic()) {
93                        (true, true) | (false, false) => c_a.cmp(c_b),
94                        (true, false) => Ordering::Less,
95                        (false, true) => Ordering::Greater,
96                    };
97                }
98            }
99        }
100        *a = &a[1..];
101        *b = &b[1..];
102    }
103
104    Ordering::Equal
105}
106
107fn get_next_num(s: &mut &[u8]) -> u128 {
108    std::iter::from_fn(|| match s.first() {
109        Some(&c) if c.is_ascii_digit() => {
110            *s = &s[1..];
111            Some(c - b'0')
112        }
113        _ => None,
114    })
115    .fold(0, |num, digit| 10 * num + (digit as u128))
116}
117
118fn cmp_num(a: &mut &[u8], b: &mut &[u8]) -> Ordering {
119    get_next_num(a).cmp(&get_next_num(b))
120}
121
122fn cmp_string(a: &str, b: &str) -> Ordering {
123    let (mut a, mut b) = (a.as_bytes(), b.as_bytes());
124    let mut compare_non_digit = true;
125
126    while !a.is_empty() || !b.is_empty() {
127        let res = if compare_non_digit {
128            cmp_non_digit(&mut a, &mut b)
129        } else {
130            cmp_num(&mut a, &mut b)
131        };
132
133        if res != Ordering::Equal {
134            return res;
135        }
136        compare_non_digit = !compare_non_digit;
137    }
138
139    Ordering::Equal
140}
141
142impl Ord for Version {
143    fn cmp(&self, other: &Self) -> Ordering {
144        if self.epoch > other.epoch {
145            return Ordering::Greater;
146        }
147
148        if self.epoch < other.epoch {
149            return Ordering::Less;
150        }
151
152        let version_cmp = cmp_string(self.version(), other.version());
153
154        if version_cmp != Ordering::Equal {
155            return version_cmp;
156        }
157
158        cmp_string(self.revision(), other.revision())
159    }
160}
161
162impl PartialOrd for Version {
163    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
164        Some(self.cmp(other))
165    }
166}
167
168impl TryFrom<String> for Version {
169    type Error = ParseIntError;
170
171    fn try_from(value: String) -> Result<Self, Self::Error> {
172        let (epoch, epoch_len, remainder) = match value.split_once(':') {
173            None => (0, 0, &*value),
174            Some((epoch_str, remainder)) => (epoch_str.parse()?, epoch_str.len() + 1, remainder),
175        };
176
177        let (revision, remainder) = match remainder.rsplit_once('-') {
178            None => (0..0, remainder),
179            Some((remainder, revision_str)) => {
180                ((value.len() - revision_str.len())..value.len(), remainder)
181            }
182        };
183
184        Ok(Version {
185            epoch,
186            version: epoch_len..(epoch_len + remainder.len()),
187            revision,
188            original: value,
189        })
190    }
191}
192
193impl TryFrom<&str> for Version {
194    type Error = ParseIntError;
195
196    fn try_from(value: &str) -> Result<Self, Self::Error> {
197        value.to_string().try_into()
198    }
199}
200
201impl Serialize for Version {
202    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
203        serializer.collect_str(self)
204    }
205}
206
207impl<'de> Deserialize<'de> for Version {
208    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
209        deserializer.deserialize_str(TryFromStringVisitor::new())
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    mod version {
218        use std::cmp::Ordering::*;
219        use std::num::IntErrorKind;
220
221        use super::*;
222
223        #[test]
224        fn parse() {
225            let all_components = Version::try_from("1:foo:bar-baz-qux").unwrap();
226            assert_eq!(1, all_components.epoch());
227            assert_eq!("foo:bar-baz", all_components.version());
228            assert_eq!("qux", all_components.revision());
229
230            let no_epoch = Version::try_from("foo.123+bar-baz-qux").unwrap();
231            assert_eq!(0, no_epoch.epoch());
232            assert_eq!("foo.123+bar-baz", no_epoch.version());
233            assert_eq!("qux", no_epoch.revision());
234
235            let no_revision = Version::try_from("90:foo.123+bar").unwrap();
236            assert_eq!(90, no_revision.epoch());
237            assert_eq!("foo.123+bar", no_revision.version());
238            assert_eq!("", no_revision.revision());
239
240            let no_epoch_and_revision = Version::try_from("foo.123+bar~baz").unwrap();
241            assert_eq!(0, no_epoch_and_revision.epoch());
242            assert_eq!("foo.123+bar~baz", no_epoch_and_revision.version());
243            assert_eq!("", no_epoch_and_revision.revision());
244
245            assert_eq!(
246                &IntErrorKind::InvalidDigit,
247                Version::try_from("foo:bar").unwrap_err().kind()
248            )
249        }
250
251        #[test]
252        fn cmp_string() {
253            assert_eq!(
254                Less,
255                cmp_non_digit(&mut "~".as_bytes(), &mut "+".as_bytes())
256            );
257            assert_eq!(
258                Greater,
259                cmp_non_digit(&mut "~r".as_bytes(), &mut "~d".as_bytes())
260            );
261        }
262
263        #[test]
264        fn ord() {
265            let source = vec![
266                ("1.1.1", Less, "1.1.2"),
267                ("1b", Greater, "1a"),
268                ("1~~", Less, "1~~a"),
269                ("1~~a", Less, "1~"),
270                ("1", Less, "1.1"),
271                ("1.0", Less, "1.1"),
272                ("1.2", Less, "1.11"),
273                ("1.0-1", Less, "1.1"),
274                ("1.0-1", Less, "1.0-12"),
275                // make them different for sorting
276                ("1:1.0-0", Equal, "1:1.0"),
277                ("1.0", Equal, "1.0"),
278                ("1.0-1", Equal, "1.0-1"),
279                ("1:1.0-1", Equal, "1:1.0-1"),
280                ("1:1.0", Equal, "1:1.0"),
281                ("1.0-1", Less, "1.0-2"),
282                //("1.0final-5sarge1", Greater, "1.0final-5"),
283                ("1.0final-5", Greater, "1.0a7-2"),
284                ("0.9.2-5", Less, "0.9.2+cvs.1.0.dev.2004.07.28-1"),
285                ("1:500", Less, "1:5000"),
286                ("100:500", Greater, "11:5000"),
287                ("1.0.4-2", Greater, "1.0pre7-2"),
288                ("1.5~rc1", Less, "1.5"),
289                ("1.5~rc1", Less, "1.5+1"),
290                ("1.5~rc1", Less, "1.5~rc2"),
291                ("1.5~rc1", Greater, "1.5~dev0"),
292            ];
293
294            for e in source {
295                assert_eq!(
296                    Version::try_from(e.0)
297                        .unwrap()
298                        .cmp(&Version::try_from(e.2).unwrap()),
299                    e.1,
300                    "{:#?} vs {:#?}",
301                    Version::try_from(e.0).unwrap(),
302                    Version::try_from(e.2).unwrap()
303                );
304            }
305        }
306
307        #[test]
308        fn eq() {
309            let source = vec![("1.1+git2021", "0:1.1+git2021")];
310            for e in &source {
311                assert_eq!(
312                    Version::try_from(e.0).unwrap(),
313                    Version::try_from(e.1).unwrap()
314                );
315            }
316        }
317    }
318}