use core::fmt::Display;
use const_str::{compare, ends_with, parse, strip_suffix};
use crate::{
prefixes::PREFIX_MAP,
units::{DIMENSIONLESS, UNITS_MAP},
SI,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ParseError<'a> {
pub message: &'static str,
pub span: &'a str,
}
impl<'a> Display for ParseError<'a> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(self.message)?;
f.write_str(self.span)
}
}
#[cfg(feature = "std")]
impl<'a> std::error::Error for ParseError<'a> {}
const fn search<T: Copy, const L: usize>(abbreviation: &str, map: [(&str, T); L]) -> Option<T> {
let mut min = 0;
let mut max = L;
while min != max {
let mid = (min + max) >> 1;
match compare!(map[mid].0, abbreviation) {
core::cmp::Ordering::Less => min = mid + 1,
core::cmp::Ordering::Equal => return Some(map[mid].1),
core::cmp::Ordering::Greater => max = mid,
}
}
None
}
const fn from_abbreviation_checked(abbreviation: &str) -> Result<SI, ParseError<'_>> {
if let Some(unit) = search(abbreviation, UNITS_MAP) {
return Ok(unit);
}
let mut unit = None;
let mut i = 0;
while i < UNITS_MAP.len() {
if ends_with!(abbreviation, UNITS_MAP[i].0) {
unit = Some(UNITS_MAP[i]);
}
i += 1;
}
let (unit_abbr, si) = match unit {
Some(v) => v,
None => {
return Err(ParseError {
message: "Unknown abbreviation: ",
span: abbreviation,
})
}
};
let prefix = match strip_suffix!(abbreviation, unit_abbr) {
Some(v) => v,
None => panic!("Expected the suffix to be checked"),
};
if let Some(prefix) = search(prefix, PREFIX_MAP) {
return Ok(si.scale_by(prefix));
}
return Err(ParseError {
message: "Unknown prefix: ",
span: prefix,
});
}
pub const fn si(units: &str) -> SI {
match si_checked(units) {
Ok(v) => v,
Err(ParseError { message, span }) => {
let message = message.as_bytes();
let span = span.as_bytes();
let mut buf = [0; 512];
let mut i = 0;
while i < message.len() && i < 512 {
buf[i] = message[i];
i += 1;
}
let mut j = 0;
while j < span.len() && i + j < 512 {
buf[i + j] = span[j];
j += 1;
}
panic!(
"{}",
match core::str::from_utf8(&buf.split_at(i + j).0) {
Ok(v) => v,
Err(_) => panic!("Expected converting the buffer to utf8 to work"),
}
);
} }
}
pub const fn si_checked(units: &str) -> Result<SI, ParseError<'_>> {
if units.is_empty() {
return Ok(DIMENSIONLESS);
}
let (scale, bytes) = match front_const(units.as_bytes()) {
Ok(v) => v,
Err(e) => return Err(e),
};
si_recursive(bytes, DIMENSIONLESS.scale_by(scale))
}
const fn front_const(bytes: &[u8]) -> Result<((i128, u128), &[u8]), ParseError<'_>> {
let bytes = take_spaces(bytes);
let (bytes, expect_paren) = if bytes[0] == 40 {
(bytes.split_at(1).1, true)
} else {
(bytes, false)
};
let bytes = take_spaces(bytes);
let (num, bytes) = match parse_num(bytes) {
Ok(v) => v,
Err(_) => (1, bytes),
};
let bytes = take_spaces(bytes);
if bytes.is_empty() || bytes[0] != 47 {
let bytes = match maybe_expect_paren(expect_paren, bytes) {
Ok(v) => v,
Err(e) => return Err(e),
};
return Ok(((num, 1), bytes));
}
let (_, bytes) = bytes.split_at(1); let bytes = take_spaces(bytes);
let (den, bytes) = match parse_num(bytes) {
Ok(v) => v,
Err(e) => return Err(e),
};
let bytes = take_spaces(bytes);
let bytes = match maybe_expect_paren(expect_paren, bytes) {
Ok(v) => v,
Err(e) => return Err(e),
};
let bytes = take_spaces(bytes);
Ok(((num * den.signum(), den.unsigned_abs()), bytes))
}
const fn maybe_expect_paren(should_expect: bool, bytes: &[u8]) -> Result<&[u8], ParseError<'_>> {
if !should_expect {
return Ok(bytes);
}
if bytes[0] != 41 {
return Err(ParseError {
message: "Unmatched parenthesis",
span: "",
});
}
Ok(bytes.split_at(1).1)
}
const fn si_recursive(bytes: &[u8], si: SI) -> Result<SI, ParseError<'_>> {
let bytes = take_spaces(bytes);
if bytes.is_empty() {
return Ok(si);
}
let bytes = if bytes[0] == 42 {
bytes.split_at(1).1
} else if bytes.len() > 2 && bytes[0] == 226 && bytes[1] == 139 && bytes[2] == 133 {
bytes.split_at(3).1
} else {
bytes
};
let bytes = take_spaces(bytes);
let (si2, bytes) = match parse_abbr(bytes) {
Ok(v) => v,
Err(e) => return Err(e),
};
let bytes = take_spaces(bytes);
if bytes.is_empty() || bytes[0] != 94 {
return si_recursive(bytes, si.mul(si2));
}
let (_, bytes) = bytes.split_at(1); let bytes = take_spaces(bytes);
let (num, bytes) = match parse_num(bytes) {
Ok(v) => v,
Err(e) => return Err(e),
};
let bytes = take_spaces(bytes);
if bytes.is_empty() || bytes[0] != 47 {
return si_recursive(bytes, si.mul(si2.powi(num as i32)));
}
let (_, bytes) = bytes.split_at(1); let bytes = take_spaces(bytes);
let (den, bytes) = match parse_num(bytes) {
Ok(v) => v,
Err(e) => return Err(e),
};
let bytes = take_spaces(bytes);
si_recursive(
bytes,
si.mul(si2.powf(((num * den.signum()) as i32, den.unsigned_abs() as u32))),
)
}
const fn parse_abbr(bytes: &[u8]) -> Result<(SI, &[u8]), ParseError<'_>> {
let mut i = 0;
while i < bytes.len() {
if (bytes[i] >= 65 && bytes[i] <= 91) || (bytes[i] >= 97 && bytes[i] <= 122) {
i += 1;
continue;
}
if i + 1 < bytes.len() && bytes[i] == 206 && (bytes[i + 1] == 169 || bytes[i + 1] == 188) {
i += 2;
continue;
}
break;
}
let (chars, tail) = bytes.split_at(i);
let parsed = match from_abbreviation_checked(match core::str::from_utf8(chars) {
Ok(v) => v,
Err(_) => panic!("The previous code should have only allowed utf8"),
}) {
Ok(v) => v,
Err(e) => return Err(e),
};
Ok((parsed, tail))
}
const fn take_spaces(bytes: &[u8]) -> &[u8] {
let mut i = 0;
while i < bytes.len() && bytes[i] == 32 {
i += 1;
}
let (_, tail) = bytes.split_at(i);
tail
}
const fn parse_num(bytes: &[u8]) -> Result<(i128, &[u8]), ParseError<'_>> {
let mut i = 0;
while i < bytes.len() && ((bytes[i] >= 48 && bytes[i] <= 57) || bytes[i] == 45) {
i += 1;
}
if i == 0 {
return Err(ParseError {
message: "Expected a number",
span: "",
});
}
let (digits, tail) = bytes.split_at(i);
Ok((
parse!(
match core::str::from_utf8(digits) {
Ok(v) => v,
Err(_) => panic!("Parsing utf-8 shouldn't have failed"),
},
i128
),
tail,
))
}
#[cfg(test)]
mod tests {
use static_assertions::const_assert;
use crate::{
prefixes::micro,
si, si_checked,
units::{hertz, joule, meter, minute, mole, ohm, DIMENSIONLESS},
};
#[test]
fn test_display() {
assert_eq!(format!("{}", mole), "mol");
assert_eq!(format!("{}", joule), "m^2⋅kg⋅s^-2");
assert_eq!(format!("{}", joule.powf((1, 2))), "m⋅kg^1/2⋅s^-1");
assert_eq!(format!("{}", DIMENSIONLESS), "(1)");
assert_eq!(format!("{}", DIMENSIONLESS.scale_by((1, 100))), "(1/100)");
assert_eq!(format!("{}", meter.div(minute)), "(1/60)⋅m⋅s^-1");
}
#[test]
const fn test_si_macro() {
const_assert!(joule.const_eq(si("N m")));
const_assert!(joule.scale_by((1, 1000)).const_eq(si("mN⋅m")));
const_assert!(hertz.const_eq(si(" -1/-1 ⋅ s ^ -4 / 2 * s ^ 2 / 2 ")));
const_assert!(hertz.const_eq(si(" ( 1 ) ⋅ s ^ -4 / 2 s ^ 2 / 2 ")));
const_assert!(micro(ohm).const_eq(si("μΩ")));
const_assert!(si_checked("⋅").is_err());
}
}