use std::{
borrow::Cow,
error::Error,
fmt::{self, Debug},
time::Duration,
};
use bytes::{BufMut, Bytes, BytesMut};
const COMMAND_LIST_BEGIN: &[u8] = b"command_list_ok_begin\n";
const COMMAND_LIST_END: &[u8] = b"command_list_end\n";
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Command(pub(crate) BytesMut);
impl Command {
#[track_caller]
pub fn new(command: &str) -> Command {
match Command::build(command) {
Ok(c) => c,
Err(e) => panic!("invalid command: {}", e),
}
}
pub fn build(command: &str) -> Result<Command, CommandError> {
match validate_command_part(command) {
Ok(()) => Ok(Command(BytesMut::from(command))),
Err(kind) => Err(CommandError {
data: Bytes::copy_from_slice(command.as_bytes()),
kind,
}),
}
}
#[track_caller]
pub fn argument<A: Argument>(mut self, argument: A) -> Command {
if let Err(e) = self.add_argument(argument) {
panic!("invalid argument: {}", e);
}
self
}
pub fn add_argument<A: Argument>(&mut self, argument: A) -> Result<(), CommandError> {
let len_without_arg = self.0.len();
self.0.put_u8(b' ');
argument.render(&mut self.0);
if let Err(kind) = validate_argument(&self.0[len_without_arg + 1..]) {
let data = self.0.split_off(len_without_arg + 1).freeze();
self.0.truncate(len_without_arg);
Err(CommandError { data, kind })
} else {
Ok(())
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct CommandList(pub(crate) Vec<Command>);
#[allow(clippy::len_without_is_empty)]
impl CommandList {
pub fn new(first: Command) -> Self {
CommandList(vec![first])
}
pub fn command(mut self, command: Command) -> Self {
self.add(command);
self
}
pub fn add(&mut self, command: Command) {
self.0.push(command);
}
pub fn len(&self) -> usize {
self.0.len()
}
pub(crate) fn render(mut self) -> BytesMut {
if self.len() == 1 {
let mut buf = self.0.pop().unwrap().0;
buf.put_u8(b'\n');
return buf;
}
let required_length = COMMAND_LIST_BEGIN.len()
+ self.0.iter().map(|c| c.0.len() + 1).sum::<usize>()
+ COMMAND_LIST_END.len();
let mut buf = BytesMut::with_capacity(required_length);
buf.put_slice(COMMAND_LIST_BEGIN);
for command in self.0 {
buf.put_slice(&command.0);
buf.put_u8(b'\n');
}
buf.put_slice(COMMAND_LIST_END);
buf
}
}
impl Extend<Command> for CommandList {
fn extend<T: IntoIterator<Item = Command>>(&mut self, iter: T) {
self.0.extend(iter);
}
}
pub fn escape_argument(argument: &str) -> Cow<'_, str> {
let needs_quotes = argument.contains(&[' ', '\t'][..]);
let escape_count = argument.chars().filter(|c| should_escape(*c)).count();
if escape_count == 0 && !needs_quotes {
Cow::Borrowed(argument)
} else {
let len = argument.len() + escape_count + if needs_quotes { 2 } else { 0 };
let mut out = String::with_capacity(len);
if needs_quotes {
out.push('"');
}
for c in argument.chars() {
if should_escape(c) {
out.push('\\');
}
out.push(c);
}
if needs_quotes {
out.push('"');
}
Cow::Owned(out)
}
}
fn should_escape(c: char) -> bool {
c == '\\' || c == '"' || c == '\''
}
fn validate_command_part(command: &str) -> Result<(), CommandErrorKind> {
if command.is_empty() {
return Err(CommandErrorKind::Empty);
}
if let Some((i, c)) = command
.char_indices()
.find(|(_, c)| !is_valid_command_char(*c))
{
Err(CommandErrorKind::InvalidCharacter(i, c))
} else if is_command_list_command(command) {
Err(CommandErrorKind::CommandList)
} else {
Ok(())
}
}
fn validate_argument(argument: &[u8]) -> Result<(), CommandErrorKind> {
match argument.iter().position(|&c| c == b'\n') {
None => Ok(()),
Some(i) => Err(CommandErrorKind::InvalidCharacter(i, '\n')),
}
}
fn is_valid_command_char(c: char) -> bool {
c.is_ascii_alphabetic() || c == '_'
}
fn is_command_list_command(command: &str) -> bool {
command.starts_with("command_list")
}
#[derive(Debug)]
pub struct CommandError {
data: Bytes,
kind: CommandErrorKind,
}
#[derive(Debug)]
enum CommandErrorKind {
Empty,
InvalidCharacter(usize, char),
CommandList,
}
impl Error for CommandError {}
impl fmt::Display for CommandError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.kind {
CommandErrorKind::Empty => write!(f, "empty command"),
CommandErrorKind::InvalidCharacter(i, c) => {
write!(
f,
"invalid character {:?} at position {} in {:?}",
c, i, self.data
)
}
CommandErrorKind::CommandList => write!(
f,
"attempted to open or close a command list: {:?}",
self.data
),
}
}
}
pub trait Argument {
fn render(&self, buf: &mut BytesMut);
}
impl<A> Argument for &A
where
A: Argument + ?Sized,
{
fn render(&self, buf: &mut BytesMut) {
(*self).render(buf);
}
}
impl Argument for String {
fn render(&self, buf: &mut BytesMut) {
let arg = escape_argument(self);
buf.put_slice(arg.as_bytes());
}
}
impl Argument for str {
fn render(&self, buf: &mut BytesMut) {
let arg = escape_argument(self);
buf.put_slice(arg.as_bytes());
}
}
impl Argument for Cow<'_, str> {
fn render(&self, buf: &mut BytesMut) {
let arg = escape_argument(self);
buf.put_slice(arg.as_bytes());
}
}
impl Argument for bool {
fn render(&self, buf: &mut BytesMut) {
buf.put_u8(if *self { b'1' } else { b'0' });
}
}
impl Argument for Duration {
fn render(&self, buf: &mut BytesMut) {
use std::fmt::Write;
write!(buf, "{:.3}", self.as_secs_f64()).unwrap();
}
}
macro_rules! implement_integer_arg {
($($type:ty),+) => {
$(
impl $crate::command::Argument for $type {
fn render(&self, buf: &mut ::bytes::BytesMut) {
use ::std::fmt::Write;
::std::write!(buf, "{}", self).unwrap();
}
}
)+
}
}
implement_integer_arg!(u8, u16, u32, u64, usize);
#[cfg(test)]
mod test {
use super::*;
#[test]
fn arguments() {
let mut command = Command::new("foo");
assert_eq!(command.0, "foo");
command.add_argument("bar").unwrap();
assert_eq!(command.0, "foo bar");
let _e = command.add_argument("foo\nbar").unwrap_err();
assert_eq!(command.0, "foo bar");
}
#[test]
fn argument_escaping() {
assert_eq!(escape_argument("status"), "status");
assert_eq!(escape_argument("Joe's"), "Joe\\'s");
assert_eq!(escape_argument("hello\\world"), "hello\\\\world");
assert_eq!(escape_argument("foo bar"), r#""foo bar""#);
}
#[test]
fn argument_rendering() {
let mut buf = BytesMut::new();
"foo\"bar".render(&mut buf);
assert_eq!(buf, "foo\\\"bar");
buf.clear();
true.render(&mut buf);
assert_eq!(buf, "1");
buf.clear();
false.render(&mut buf);
assert_eq!(buf, "0");
buf.clear();
Duration::from_secs(2).render(&mut buf);
assert_eq!(buf, "2.000");
buf.clear();
Duration::from_secs_f64(2.34567).render(&mut buf);
assert_eq!(buf, "2.346");
buf.clear();
}
}