use std::{
fmt::{self, Display, Formatter},
num::ParseIntError,
str::FromStr,
};
use compose_spec_macros::SerializeDisplay;
use serde::{de, Deserialize, Deserializer};
use thiserror::Error;
use crate::serde::error_chain;
#[derive(SerializeDisplay, Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ByteValue {
Bytes(u64),
Kilobytes(u64),
Megabytes(u64),
Gigabytes(u64),
}
impl ByteValue {
#[must_use]
pub const fn into_bytes(self) -> Option<u64> {
match self {
Self::Bytes(bytes) => Some(bytes),
Self::Kilobytes(kilobytes) => kilobytes.checked_mul(1_000),
Self::Megabytes(megabytes) => megabytes.checked_mul(1_000_000),
Self::Gigabytes(gigabytes) => gigabytes.checked_mul(1_000_000_000),
}
}
#[must_use]
pub const fn unit(&self) -> &'static str {
match self {
Self::Bytes(_) => "b",
Self::Kilobytes(_) => "kb",
Self::Megabytes(_) => "mb",
Self::Gigabytes(_) => "gb",
}
}
}
impl Default for ByteValue {
fn default() -> Self {
Self::Bytes(0)
}
}
impl FromStr for ByteValue {
type Err = ParseByteValueError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.is_empty() {
Err(ParseByteValueError::Empty)
} else if let Ok(bytes) = s.parse() {
Ok(Self::Bytes(bytes))
} else if let Some(gigabytes) = s.strip_suffix("gb").or_else(|| s.strip_suffix(['g', 'G']))
{
parse_u64(gigabytes, Self::Gigabytes)
} else if let Some(megabytes) = s.strip_suffix("mb").or_else(|| s.strip_suffix(['m', 'M']))
{
parse_u64(megabytes, Self::Megabytes)
} else if let Some(kilobytes) = s.strip_suffix("kb").or_else(|| s.strip_suffix(['k', 'K']))
{
parse_u64(kilobytes, Self::Kilobytes)
} else if let Some(bytes) = s.strip_suffix(['b', 'B']) {
parse_u64(bytes, Self::Bytes)
} else {
Err(ParseByteValueError::UnknownUnit(s.to_owned()))
}
}
}
fn parse_u64<T>(value: &str, f: impl FnOnce(u64) -> T) -> Result<T, ParseByteValueError> {
value
.parse()
.map(f)
.map_err(|source| ParseByteValueError::ParseInt {
source,
value: value.to_owned(),
})
}
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum ParseByteValueError {
#[error("value was an empty string")]
Empty,
#[error(
"value `{0}` contains an unknown unit, \
unit must be \"b\", \"k\", \"kb\", \"m\", \"mb\", \"g\", or \"gb\""
)]
UnknownUnit(String),
#[error("value `{value}` could not be parsed as an unsigned integer")]
ParseInt {
source: ParseIntError,
value: String,
},
}
impl Display for ByteValue {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let (Self::Bytes(bytes)
| Self::Kilobytes(bytes)
| Self::Megabytes(bytes)
| Self::Gigabytes(bytes)) = self;
Display::fmt(bytes, f)?;
f.write_str(self.unit())
}
}
impl<'de> Deserialize<'de> for ByteValue {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
deserializer.deserialize_any(Visitor)
}
}
struct Visitor;
impl<'de> de::Visitor<'de> for Visitor {
type Value = ByteValue;
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
formatter
.write_str("an integer (representing bytes) or a string in the from \"{value}{unit}\"")
}
fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
Ok(ByteValue::Bytes(v))
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
v.parse().map_err(error_chain)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use serde::de::value::U64Deserializer;
use super::*;
#[test]
fn bytes() {
assert_eq!(
ByteValue::deserialize(U64Deserializer::<de::value::Error>::new(1000)).unwrap(),
ByteValue::Bytes(1000),
);
assert_eq!(ByteValue::Bytes(1000), "1000".parse().unwrap());
let string = "1000b";
let value: ByteValue = string.parse().unwrap();
assert_eq!(value, ByteValue::Bytes(1000));
assert_eq!(value.to_string(), string);
}
#[test]
fn kilobytes() {
let string = "256kb";
let value: ByteValue = string.parse().unwrap();
assert_eq!(value, ByteValue::Kilobytes(256));
assert_eq!(value.to_string(), string);
assert_eq!(value, "256k".parse().unwrap());
}
#[test]
fn megabytes() {
let string = "256mb";
let value: ByteValue = string.parse().unwrap();
assert_eq!(value, ByteValue::Megabytes(256));
assert_eq!(value.to_string(), string);
assert_eq!(value, "256m".parse().unwrap());
}
#[test]
fn gigabytes() {
let string = "2gb";
let value: ByteValue = string.parse().unwrap();
assert_eq!(value, ByteValue::Gigabytes(2));
assert_eq!(value.to_string(), string);
assert_eq!(value, "2g".parse().unwrap());
}
}