use std::borrow::Cow;
use std::io;
use bytes::{BufMut, BytesMut};
pub const WIRE_UNSAFE: &[char] = &['\n', '\r', '\0'];
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum EncoderMode {
#[default]
Sanitize,
Strict,
}
#[derive(Debug, thiserror::Error)]
pub enum EncodeError {
#[error("{0} contains characters unsafe for the management protocol (\\n, \\r, or \\0)")]
UnsafeCharacters(&'static str),
#[error("block body line equals \"END\", which would terminate the block early")]
EndInBlockBody,
}
pub fn wire_safe<'a>(
s: &'a str,
field: &'static str,
mode: EncoderMode,
) -> Result<Cow<'a, str>, io::Error> {
if !s.contains(WIRE_UNSAFE) {
return Ok(Cow::Borrowed(s));
}
match mode {
EncoderMode::Sanitize => Ok(Cow::Owned(
s.chars().filter(|chr| !WIRE_UNSAFE.contains(chr)).collect(),
)),
EncoderMode::Strict => Err(io::Error::other(EncodeError::UnsafeCharacters(field))),
}
}
pub fn escape(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for c in s.chars() {
match c {
'\\' => out.push_str("\\\\"),
'"' => out.push_str("\\\""),
_ => out.push(c),
}
}
out
}
pub fn quote(s: &str) -> String {
format!("\"{s}\"")
}
pub fn write_line(dst: &mut BytesMut, s: &str) {
dst.reserve(s.len() + 1);
dst.put_slice(s.as_bytes());
dst.put_u8(b'\n');
}
pub fn write_block(
dst: &mut BytesMut,
header: &str,
lines: &[String],
mode: EncoderMode,
) -> Result<(), io::Error> {
let total: usize =
header.len() + 1 + lines.iter().map(|line| line.len() + 2).sum::<usize>() + 4;
dst.reserve(total);
dst.put_slice(header.as_bytes());
dst.put_u8(b'\n');
for line in lines {
let clean = wire_safe(line, "block body line", mode)?;
if *clean == *"END" {
match mode {
EncoderMode::Sanitize => {
dst.put_slice(b" END");
dst.put_u8(b'\n');
continue;
}
EncoderMode::Strict => {
return Err(io::Error::other(EncodeError::EndInBlockBody));
}
}
}
dst.put_slice(clean.as_bytes());
dst.put_u8(b'\n');
}
dst.put_slice(b"END\n");
Ok(())
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AccumulationLimit {
Unlimited,
Max(usize),
}