password_hash/
params.rs

1//! Algorithm parameters.
2
3use crate::errors::InvalidValue;
4use crate::{
5    value::{Decimal, Value},
6    Encoding, Error, Ident, Result,
7};
8use core::{
9    fmt::{self, Debug, Write},
10    iter::FromIterator,
11    str::{self, FromStr},
12};
13
14/// Individual parameter name/value pair.
15pub type Pair<'a> = (Ident<'a>, Value<'a>);
16
17/// Delimiter character between name/value pairs.
18pub(crate) const PAIR_DELIMITER: char = '=';
19
20/// Delimiter character between parameters.
21pub(crate) const PARAMS_DELIMITER: char = ',';
22
23/// Maximum number of supported parameters.
24const MAX_LENGTH: usize = 127;
25
26/// Error message used with `expect` for when internal invariants are violated
27/// (i.e. the contents of a [`ParamsString`] should always be valid)
28const INVARIANT_VIOLATED_MSG: &str = "PHC params invariant violated";
29
30/// Algorithm parameter string.
31///
32/// The [PHC string format specification][1] defines a set of optional
33/// algorithm-specific name/value pairs which can be encoded into a
34/// PHC-formatted parameter string as follows:
35///
36/// ```text
37/// $<param>=<value>(,<param>=<value>)*
38/// ```
39///
40/// This type represents that set of parameters.
41///
42/// [1]: https://github.com/P-H-C/phc-string-format/blob/master/phc-sf-spec.md#specification
43#[derive(Clone, Default, Eq, PartialEq)]
44pub struct ParamsString(Buffer);
45
46impl ParamsString {
47    /// Create new empty [`ParamsString`].
48    pub fn new() -> Self {
49        Self::default()
50    }
51
52    /// Add the given byte value to the [`ParamsString`], encoding it as "B64".
53    pub fn add_b64_bytes<'a>(&mut self, name: impl TryInto<Ident<'a>>, bytes: &[u8]) -> Result<()> {
54        if !self.is_empty() {
55            self.0
56                .write_char(PARAMS_DELIMITER)
57                .map_err(|_| Error::ParamsMaxExceeded)?
58        }
59
60        let name = name.try_into().map_err(|_| Error::ParamNameInvalid)?;
61
62        // Add param name
63        let offset = self.0.length;
64        if write!(self.0, "{}=", name).is_err() {
65            self.0.length = offset;
66            return Err(Error::ParamsMaxExceeded);
67        }
68
69        // Encode B64 value
70        let offset = self.0.length as usize;
71        let written = Encoding::B64
72            .encode(bytes, &mut self.0.bytes[offset..])?
73            .len();
74
75        self.0.length += written as u8;
76        Ok(())
77    }
78
79    /// Add a key/value pair with a decimal value to the [`ParamsString`].
80    pub fn add_decimal<'a>(&mut self, name: impl TryInto<Ident<'a>>, value: Decimal) -> Result<()> {
81        let name = name.try_into().map_err(|_| Error::ParamNameInvalid)?;
82        self.add(name, value)
83    }
84
85    /// Add a key/value pair with a string value to the [`ParamsString`].
86    pub fn add_str<'a>(
87        &mut self,
88        name: impl TryInto<Ident<'a>>,
89        value: impl TryInto<Value<'a>>,
90    ) -> Result<()> {
91        let name = name.try_into().map_err(|_| Error::ParamNameInvalid)?;
92
93        let value = value
94            .try_into()
95            .map_err(|_| Error::ParamValueInvalid(InvalidValue::InvalidFormat))?;
96
97        self.add(name, value)
98    }
99
100    /// Borrow the contents of this [`ParamsString`] as a byte slice.
101    pub fn as_bytes(&self) -> &[u8] {
102        self.as_str().as_bytes()
103    }
104
105    /// Borrow the contents of this [`ParamsString`] as a `str`.
106    pub fn as_str(&self) -> &str {
107        self.0.as_ref()
108    }
109
110    /// Get the count of the number ASCII characters in this [`ParamsString`].
111    pub fn len(&self) -> usize {
112        self.as_str().len()
113    }
114
115    /// Is this set of parameters empty?
116    pub fn is_empty(&self) -> bool {
117        self.len() == 0
118    }
119
120    /// Iterate over the parameters.
121    pub fn iter(&self) -> Iter<'_> {
122        Iter::new(self.as_str())
123    }
124
125    /// Get a parameter [`Value`] by name.
126    pub fn get<'a>(&self, name: impl TryInto<Ident<'a>>) -> Option<Value<'_>> {
127        let name = name.try_into().ok()?;
128
129        for (n, v) in self.iter() {
130            if name == n {
131                return Some(v);
132            }
133        }
134
135        None
136    }
137
138    /// Get a parameter as a `str`.
139    pub fn get_str<'a>(&self, name: impl TryInto<Ident<'a>>) -> Option<&str> {
140        self.get(name).map(|value| value.as_str())
141    }
142
143    /// Get a parameter as a [`Decimal`].
144    ///
145    /// See [`Value::decimal`] for format information.
146    pub fn get_decimal<'a>(&self, name: impl TryInto<Ident<'a>>) -> Option<Decimal> {
147        self.get(name).and_then(|value| value.decimal().ok())
148    }
149
150    /// Add a value to this [`ParamsString`] using the provided callback.
151    fn add(&mut self, name: Ident<'_>, value: impl fmt::Display) -> Result<()> {
152        if self.get(name).is_some() {
153            return Err(Error::ParamNameDuplicated);
154        }
155
156        let orig_len = self.0.length;
157
158        if !self.is_empty() {
159            self.0
160                .write_char(PARAMS_DELIMITER)
161                .map_err(|_| Error::ParamsMaxExceeded)?
162        }
163
164        if write!(self.0, "{}={}", name, value).is_err() {
165            self.0.length = orig_len;
166            return Err(Error::ParamsMaxExceeded);
167        }
168
169        Ok(())
170    }
171}
172
173impl FromStr for ParamsString {
174    type Err = Error;
175
176    fn from_str(s: &str) -> Result<Self> {
177        if s.as_bytes().len() > MAX_LENGTH {
178            return Err(Error::ParamsMaxExceeded);
179        }
180
181        if s.is_empty() {
182            return Ok(ParamsString::new());
183        }
184
185        // Validate the string is well-formed
186        for mut param in s.split(PARAMS_DELIMITER).map(|p| p.split(PAIR_DELIMITER)) {
187            // Validate name
188            param
189                .next()
190                .ok_or(Error::ParamNameInvalid)
191                .and_then(Ident::try_from)?;
192
193            // Validate value
194            param
195                .next()
196                .ok_or(Error::ParamValueInvalid(InvalidValue::Malformed))
197                .and_then(Value::try_from)?;
198
199            if param.next().is_some() {
200                return Err(Error::ParamValueInvalid(InvalidValue::Malformed));
201            }
202        }
203
204        let mut bytes = [0u8; MAX_LENGTH];
205        bytes[..s.as_bytes().len()].copy_from_slice(s.as_bytes());
206
207        Ok(Self(Buffer {
208            bytes,
209            length: s.as_bytes().len() as u8,
210        }))
211    }
212}
213
214impl<'a> FromIterator<Pair<'a>> for ParamsString {
215    fn from_iter<I>(iter: I) -> Self
216    where
217        I: IntoIterator<Item = Pair<'a>>,
218    {
219        let mut params = ParamsString::new();
220
221        for pair in iter {
222            params.add_str(pair.0, pair.1).expect("PHC params error");
223        }
224
225        params
226    }
227}
228
229impl fmt::Display for ParamsString {
230    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
231        f.write_str(self.as_str())
232    }
233}
234
235impl fmt::Debug for ParamsString {
236    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237        f.debug_map().entries(self.iter()).finish()
238    }
239}
240
241/// Iterator over algorithm parameters stored in a [`ParamsString`] struct.
242pub struct Iter<'a> {
243    inner: Option<str::Split<'a, char>>,
244}
245
246impl<'a> Iter<'a> {
247    /// Create a new [`Iter`].
248    fn new(s: &'a str) -> Self {
249        if s.is_empty() {
250            Self { inner: None }
251        } else {
252            Self {
253                inner: Some(s.split(PARAMS_DELIMITER)),
254            }
255        }
256    }
257}
258
259impl<'a> Iterator for Iter<'a> {
260    type Item = Pair<'a>;
261
262    fn next(&mut self) -> Option<Pair<'a>> {
263        let mut param = self.inner.as_mut()?.next()?.split(PAIR_DELIMITER);
264
265        let name = param
266            .next()
267            .and_then(|id| Ident::try_from(id).ok())
268            .expect(INVARIANT_VIOLATED_MSG);
269
270        let value = param
271            .next()
272            .and_then(|value| Value::try_from(value).ok())
273            .expect(INVARIANT_VIOLATED_MSG);
274
275        debug_assert_eq!(param.next(), None);
276        Some((name, value))
277    }
278}
279
280/// Parameter buffer.
281#[derive(Clone, Debug, Eq)]
282struct Buffer {
283    /// Byte array containing an ASCII-encoded string.
284    bytes: [u8; MAX_LENGTH],
285
286    /// Length of the string in ASCII characters (i.e. bytes).
287    length: u8,
288}
289
290impl AsRef<str> for Buffer {
291    fn as_ref(&self) -> &str {
292        str::from_utf8(&self.bytes[..(self.length as usize)]).expect(INVARIANT_VIOLATED_MSG)
293    }
294}
295
296impl Default for Buffer {
297    fn default() -> Buffer {
298        Buffer {
299            bytes: [0u8; MAX_LENGTH],
300            length: 0,
301        }
302    }
303}
304
305impl PartialEq for Buffer {
306    fn eq(&self, other: &Self) -> bool {
307        // Ensure comparisons always honor the initialized portion of the buffer
308        self.as_ref().eq(other.as_ref())
309    }
310}
311
312impl Write for Buffer {
313    fn write_str(&mut self, input: &str) -> fmt::Result {
314        let bytes = input.as_bytes();
315        let length = self.length as usize;
316
317        if length + bytes.len() > MAX_LENGTH {
318            return Err(fmt::Error);
319        }
320
321        self.bytes[length..(length + bytes.len())].copy_from_slice(bytes);
322        self.length += bytes.len() as u8;
323
324        Ok(())
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::{Error, FromIterator, Ident, ParamsString, Value};
331
332    #[cfg(feature = "alloc")]
333    use alloc::string::ToString;
334    use core::str::FromStr;
335
336    #[test]
337    fn add() {
338        let mut params = ParamsString::new();
339        params.add_str("a", "1").unwrap();
340        params.add_decimal("b", 2).unwrap();
341        params.add_str("c", "3").unwrap();
342
343        assert_eq!(params.iter().count(), 3);
344        assert_eq!(params.get_decimal("a").unwrap(), 1);
345        assert_eq!(params.get_decimal("b").unwrap(), 2);
346        assert_eq!(params.get_decimal("c").unwrap(), 3);
347    }
348
349    #[test]
350    #[cfg(feature = "alloc")]
351    fn add_b64_bytes() {
352        let mut params = ParamsString::new();
353        params.add_b64_bytes("a", &[1]).unwrap();
354        params.add_b64_bytes("b", &[2, 3]).unwrap();
355        params.add_b64_bytes("c", &[4, 5, 6]).unwrap();
356        assert_eq!(params.to_string(), "a=AQ,b=AgM,c=BAUG");
357    }
358
359    #[test]
360    fn duplicate_names() {
361        let name = Ident::new("a").unwrap();
362        let mut params = ParamsString::new();
363        params.add_decimal(name, 1).unwrap();
364
365        let err = params.add_decimal(name, 2u32.into()).err().unwrap();
366        assert_eq!(err, Error::ParamNameDuplicated);
367    }
368
369    #[test]
370    fn from_iter() {
371        let params = ParamsString::from_iter(
372            [
373                (Ident::new("a").unwrap(), Value::try_from("1").unwrap()),
374                (Ident::new("b").unwrap(), Value::try_from("2").unwrap()),
375                (Ident::new("c").unwrap(), Value::try_from("3").unwrap()),
376            ]
377            .iter()
378            .cloned(),
379        );
380
381        assert_eq!(params.iter().count(), 3);
382        assert_eq!(params.get_decimal("a").unwrap(), 1);
383        assert_eq!(params.get_decimal("b").unwrap(), 2);
384        assert_eq!(params.get_decimal("c").unwrap(), 3);
385    }
386
387    #[test]
388    fn iter() {
389        let mut params = ParamsString::new();
390        params.add_str("a", "1").unwrap();
391        params.add_str("b", "2").unwrap();
392        params.add_str("c", "3").unwrap();
393
394        let mut i = params.iter();
395
396        for (name, value) in &[("a", "1"), ("b", "2"), ("c", "3")] {
397            let name = Ident::new(name).unwrap();
398            let value = Value::try_from(*value).unwrap();
399            assert_eq!(i.next(), Some((name, value)));
400        }
401
402        assert_eq!(i.next(), None);
403    }
404
405    //
406    // `FromStr` tests
407    //
408
409    #[test]
410    fn parse_empty() {
411        let params = ParamsString::from_str("").unwrap();
412        assert!(params.is_empty());
413    }
414
415    #[test]
416    fn parse_one() {
417        let params = ParamsString::from_str("a=1").unwrap();
418        assert_eq!(params.iter().count(), 1);
419        assert_eq!(params.get("a").unwrap().decimal().unwrap(), 1);
420    }
421
422    #[test]
423    fn parse_many() {
424        let params = ParamsString::from_str("a=1,b=2,c=3").unwrap();
425        assert_eq!(params.iter().count(), 3);
426        assert_eq!(params.get_decimal("a").unwrap(), 1);
427        assert_eq!(params.get_decimal("b").unwrap(), 2);
428        assert_eq!(params.get_decimal("c").unwrap(), 3);
429    }
430
431    //
432    // `Display` tests
433    //
434
435    #[test]
436    #[cfg(feature = "alloc")]
437    fn display_empty() {
438        let params = ParamsString::new();
439        assert_eq!(params.to_string(), "");
440    }
441
442    #[test]
443    #[cfg(feature = "alloc")]
444    fn display_one() {
445        let params = ParamsString::from_str("a=1").unwrap();
446        assert_eq!(params.to_string(), "a=1");
447    }
448
449    #[test]
450    #[cfg(feature = "alloc")]
451    fn display_many() {
452        let params = ParamsString::from_str("a=1,b=2,c=3").unwrap();
453        assert_eq!(params.to_string(), "a=1,b=2,c=3");
454    }
455}