use std::str::{self, FromStr, Utf8Error};
use binrw::{BinRead, BinResult, binrw};
use thiserror::Error;
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[cfg(test)]
mod tests;
pub type FixedString12 = FixedString<{ 11 + 1 }>;
#[binrw]
#[derive(Debug)]
#[repr(transparent)]
pub struct FixedString<const N: usize> {
#[br(parse_with = read_bytes_until_nul)]
inner: [u8; N],
}
impl<const N: usize> FixedString<N> {
pub const CAPACITY: usize = N;
pub const CAPACITY_WITHOUT_NUL: usize = N - 1;
pub const fn new() -> Self {
Self { inner: [0; N] }
}
pub const fn len(&self) -> usize {
let mut len = 0;
while self.inner[len] != 0 {
len += 1;
}
len
}
pub const fn is_empty(&self) -> bool {
self.inner[0] == 0
}
pub fn as_bytes(&self) -> &[u8] {
&self.inner[..self.len()]
}
pub fn to_str(&self) -> Result<&str, Utf8Error> {
str::from_utf8(self.as_bytes())
}
pub fn to_string(&self) -> Result<String, Utf8Error> {
self.to_str().map(|s| s.to_string())
}
}
impl<const N: usize> Default for FixedString<N> {
fn default() -> Self {
Self::new()
}
}
impl<const N: usize> FromStr for FixedString<N> {
type Err = ParseFixedStringError<N>;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.len() > Self::CAPACITY_WITHOUT_NUL {
return Err(Self::Err::BufferOverflow);
}
let mut buf = [0; N];
for (index, byte) in s.as_bytes().iter().copied().enumerate() {
buf[index] = byte;
}
Ok(Self { inner: buf })
}
}
impl<const N: usize> TryFrom<&String> for FixedString<N> {
type Error = ParseFixedStringError<N>;
fn try_from(value: &String) -> Result<Self, Self::Error> {
Self::from_str(value)
}
}
impl<const N: usize> TryFrom<&str> for FixedString<N> {
type Error = ParseFixedStringError<N>;
fn try_from(value: &str) -> Result<Self, Self::Error> {
Self::from_str(value)
}
}
impl<const N: usize> TryFrom<String> for FixedString<N> {
type Error = ParseFixedStringError<N>;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::from_str(&value)
}
}
impl<const N: usize> PartialEq for FixedString<N> {
fn eq(&self, other: &Self) -> bool {
self.as_bytes() == other.as_bytes()
}
}
impl<const N: usize> PartialEq<&String> for FixedString<N> {
fn eq(&self, other: &&String) -> bool {
self.as_bytes() == other.as_bytes()
}
}
impl<const N: usize> PartialEq<&str> for FixedString<N> {
fn eq(&self, other: &&str) -> bool {
self.as_bytes() == other.as_bytes()
}
}
impl<const N: usize> PartialEq<String> for FixedString<N> {
fn eq(&self, other: &String) -> bool {
self.as_bytes() == other.as_bytes()
}
}
#[cfg(feature = "serde")]
impl<const N: usize> Serialize for FixedString<N> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let string = self.to_str().map_err(serde::ser::Error::custom)?;
serializer.serialize_str(string)
}
}
#[cfg(feature = "serde")]
impl<'de, const N: usize> Deserialize<'de> for FixedString<N> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let string = String::deserialize(deserializer)?;
Self::from_str(&string).map_err(serde::de::Error::custom)
}
}
#[derive(Debug, PartialEq, Error)]
pub enum ParseFixedStringError<const N: usize> {
#[error("nul-terminated string exceeds buffer capacity of {} bytes", N)]
BufferOverflow,
}
#[binrw::parser(reader)]
fn read_bytes_until_nul<const N: usize>() -> BinResult<[u8; N]> {
use std::io::SeekFrom;
let pos = reader.stream_position()?;
let mut buf = [0; N];
let mut index = 0;
while index != N {
let b = u8::read(reader)?;
if b == 0 {
reader.seek(SeekFrom::Start(pos + N as u64))?;
return Ok(buf);
}
buf[index] = b;
index += 1;
}
Err(binrw::Error::AssertFail {
pos: reader.stream_position()?,
message: "unable to read beyond the end of the buffer".to_string(),
})
}