#![allow(dead_code)]
use crate::constants::{ErrorCode, Result};
use std::cell::Cell;
use std::ffi::{OsStr, OsString};
use std::fmt;
use std::fmt::Debug;
use std::result::Result as StdResult;
#[derive(Debug)]
#[non_exhaustive]
pub enum Exchange<'a> {
Prompt(&'a QAndA<'a>),
MaskedPrompt(&'a MaskedQAndA<'a>),
Error(&'a ErrorMsg<'a>),
Info(&'a InfoMsg<'a>),
RadioPrompt(&'a RadioQAndA<'a>),
BinaryPrompt(&'a BinaryQAndA<'a>),
}
impl Exchange<'_> {
pub fn set_error(&self, err: ErrorCode) {
match *self {
Exchange::Prompt(m) => m.set_answer(Err(err)),
Exchange::MaskedPrompt(m) => m.set_answer(Err(err)),
Exchange::Error(m) => m.set_answer(Err(err)),
Exchange::Info(m) => m.set_answer(Err(err)),
Exchange::RadioPrompt(m) => m.set_answer(Err(err)),
Exchange::BinaryPrompt(m) => m.set_answer(Err(err)),
}
}
}
macro_rules! q_and_a {
($(#[$m:meta])* $name:ident<'a, Q=$qt:ty, A=$at:ty>, $val:path) => {
$(#[$m])*
pub struct $name<'a> {
q: $qt,
a: Cell<Result<$at>>,
}
$(#[$m])*
impl<'a> $name<'a> {
#[doc = concat!("Creates a `", stringify!($t), "` to be sent to the user.")]
pub fn new(question: $qt) -> Self {
Self {
q: question,
a: Cell::new(Err(ErrorCode::ConversationError)),
}
}
pub fn exchange(&self) -> Exchange<'_> {
$val(self)
}
pub fn question(&self) -> $qt {
self.q
}
pub fn set_answer(&self, answer: Result<$at>) {
self.a.set(answer)
}
pub fn answer(self) -> Result<$at> {
self.a.into_inner()
}
}
$(#[$m])*
impl fmt::Debug for $name<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> StdResult<(), fmt::Error> {
f.debug_struct(stringify!($name)).field("q", &self.q).finish_non_exhaustive()
}
}
};
}
q_and_a!(
MaskedQAndA<'a, Q=&'a OsStr, A=OsString>,
Exchange::MaskedPrompt
);
q_and_a!(
QAndA<'a, Q=&'a OsStr, A=OsString>,
Exchange::Prompt
);
q_and_a!(
RadioQAndA<'a, Q=&'a OsStr, A=OsString>,
Exchange::RadioPrompt
);
q_and_a!(
BinaryQAndA<'a, Q=(&'a [u8], u8), A=BinaryData>,
Exchange::BinaryPrompt
);
#[derive(Debug, Default, PartialEq)]
pub struct BinaryData {
pub data: Vec<u8>,
pub data_type: u8,
}
impl BinaryData {
pub fn new(data: impl Into<Vec<u8>>, data_type: u8) -> Self {
Self {
data: data.into(),
data_type,
}
}
}
impl<IV: Into<Vec<u8>>> From<(IV, u8)> for BinaryData {
fn from((data, data_type): (IV, u8)) -> Self {
Self {
data: data.into(),
data_type,
}
}
}
impl From<BinaryData> for (Vec<u8>, u8) {
fn from(value: BinaryData) -> Self {
(value.data, value.data_type)
}
}
impl<'a> From<&'a BinaryData> for (&'a [u8], u8) {
fn from(value: &'a BinaryData) -> Self {
(&value.data, value.data_type)
}
}
q_and_a!(
InfoMsg<'a, Q = &'a OsStr, A = ()>,
Exchange::Info
);
q_and_a!(
ErrorMsg<'a, Q = &'a OsStr, A = ()>,
Exchange::Error
);
pub trait Conversation {
fn communicate(&self, messages: &[Exchange]);
}
pub fn conversation_func(func: impl Fn(&[Exchange])) -> impl Conversation {
FunctionConvo(func)
}
struct FunctionConvo<C: Fn(&[Exchange])>(C);
impl<C: Fn(&[Exchange])> Conversation for FunctionConvo<C> {
fn communicate(&self, messages: &[Exchange]) {
self.0(messages)
}
}
struct UsernamePasswordConvo {
username: String,
password: String,
}
pub trait ConversationAdapter {
fn into_conversation(self) -> Demux<Self>
where
Self: Sized,
{
Demux(self)
}
fn prompt(&self, request: impl AsRef<OsStr>) -> Result<OsString>;
fn masked_prompt(&self, request: impl AsRef<OsStr>) -> Result<OsString>;
fn error_msg(&self, message: impl AsRef<OsStr>);
fn info_msg(&self, message: impl AsRef<OsStr>);
fn radio_prompt(&self, request: impl AsRef<OsStr>) -> Result<OsString> {
let _ = request;
Err(ErrorCode::ConversationError)
}
fn binary_prompt(&self, data_and_type: (&[u8], u8)) -> Result<BinaryData> {
let _ = data_and_type;
Err(ErrorCode::ConversationError)
}
}
impl<CA: ConversationAdapter> From<CA> for Demux<CA> {
fn from(value: CA) -> Self {
Demux(value)
}
}
macro_rules! conv_fn {
($(#[$m:meta])* $fn_name:ident($param:tt: $pt:ty) -> $resp_type:ty { $msg:ty }) => {
$(#[$m])*
fn $fn_name(&self, $param: impl AsRef<$pt>) -> Result<$resp_type> {
let prompt = <$msg>::new($param.as_ref());
self.communicate(&[prompt.exchange()]);
prompt.answer()
}
};
($(#[$m:meta])*$fn_name:ident($param:tt: $pt:ty) { $msg:ty }) => {
$(#[$m])*
fn $fn_name(&self, $param: impl AsRef<$pt>) {
self.communicate(&[<$msg>::new($param.as_ref()).exchange()]);
}
};
}
impl<C: Conversation + ?Sized> ConversationAdapter for C {
conv_fn!(prompt(message: OsStr) -> OsString { QAndA });
conv_fn!(masked_prompt(message: OsStr) -> OsString { MaskedQAndA } );
conv_fn!(error_msg(message: OsStr) { ErrorMsg });
conv_fn!(info_msg(message: OsStr) { InfoMsg });
conv_fn!(radio_prompt(message: OsStr) -> OsString { RadioQAndA });
fn binary_prompt(&self, (data, typ): (&[u8], u8)) -> Result<BinaryData> {
let prompt = BinaryQAndA::new((data, typ));
self.communicate(&[prompt.exchange()]);
prompt.answer()
}
}
pub struct Demux<CA: ConversationAdapter>(CA);
impl<CA: ConversationAdapter> Demux<CA> {
fn into_inner(self) -> CA {
self.0
}
}
impl<CA: ConversationAdapter> Conversation for Demux<CA> {
fn communicate(&self, messages: &[Exchange]) {
for msg in messages {
match msg {
Exchange::Prompt(prompt) => prompt.set_answer(self.0.prompt(prompt.question())),
Exchange::MaskedPrompt(prompt) => {
prompt.set_answer(self.0.masked_prompt(prompt.question()))
}
Exchange::RadioPrompt(prompt) => {
prompt.set_answer(self.0.radio_prompt(prompt.question()))
}
Exchange::Info(prompt) => {
self.0.info_msg(prompt.question());
prompt.set_answer(Ok(()))
}
Exchange::Error(prompt) => {
self.0.error_msg(prompt.question());
prompt.set_answer(Ok(()))
}
Exchange::BinaryPrompt(prompt) => {
let q = prompt.question();
prompt.set_answer(self.0.binary_prompt(q))
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_demux() {
#[derive(Default)]
struct DemuxTester {
error_ran: Cell<bool>,
info_ran: Cell<bool>,
}
impl ConversationAdapter for DemuxTester {
fn prompt(&self, request: impl AsRef<OsStr>) -> Result<OsString> {
match request.as_ref().to_str().unwrap() {
"what" => Ok("whatwhat".into()),
"give_err" => Err(ErrorCode::PermissionDenied),
_ => panic!("unexpected prompt!"),
}
}
fn masked_prompt(&self, request: impl AsRef<OsStr>) -> Result<OsString> {
assert_eq!("reveal", request.as_ref());
Ok("my secrets".into())
}
fn error_msg(&self, message: impl AsRef<OsStr>) {
self.error_ran.set(true);
assert_eq!("whoopsie", message.as_ref());
}
fn info_msg(&self, message: impl AsRef<OsStr>) {
self.info_ran.set(true);
assert_eq!("did you know", message.as_ref());
}
fn radio_prompt(&self, request: impl AsRef<OsStr>) -> Result<OsString> {
assert_eq!("channel?", request.as_ref());
Ok("zero".into())
}
fn binary_prompt(&self, data_and_type: (&[u8], u8)) -> Result<BinaryData> {
assert_eq!((&[10, 9, 8][..], 66), data_and_type);
Ok(BinaryData::new(vec![5, 5, 5], 5))
}
}
let tester = DemuxTester::default();
let what = QAndA::new("what".as_ref());
let pass = MaskedQAndA::new("reveal".as_ref());
let err = ErrorMsg::new("whoopsie".as_ref());
let info = InfoMsg::new("did you know".as_ref());
let has_err = QAndA::new("give_err".as_ref());
let conv = tester.into_conversation();
conv.communicate(&[
what.exchange(),
pass.exchange(),
err.exchange(),
info.exchange(),
has_err.exchange(),
]);
assert_eq!("whatwhat", what.answer().unwrap());
assert_eq!("my secrets", pass.answer().unwrap());
assert_eq!(Ok(()), err.answer());
assert_eq!(Ok(()), info.answer());
assert_eq!(ErrorCode::PermissionDenied, has_err.answer().unwrap_err());
let tester = conv.into_inner();
assert!(tester.error_ran.get());
assert!(tester.info_ran.get());
{
let conv = tester.into_conversation();
let radio = RadioQAndA::new("channel?".as_ref());
let bin = BinaryQAndA::new((&[10, 9, 8], 66));
conv.communicate(&[radio.exchange(), bin.exchange()]);
assert_eq!("zero", radio.answer().unwrap());
assert_eq!(BinaryData::from(([5, 5, 5], 5)), bin.answer().unwrap());
}
}
fn test_mux() {
struct MuxTester;
impl Conversation for MuxTester {
fn communicate(&self, messages: &[Exchange]) {
if let [msg] = messages {
match *msg {
Exchange::Info(info) => {
assert_eq!("let me tell you", info.question());
info.set_answer(Ok(()))
}
Exchange::Error(error) => {
assert_eq!("oh no", error.question());
error.set_answer(Ok(()))
}
Exchange::Prompt(prompt) => {
prompt.set_answer(match prompt.question().to_str().unwrap() {
"should_err" => Err(ErrorCode::PermissionDenied),
"question" => Ok("answer".into()),
other => panic!("unexpected question {other:?}"),
})
}
Exchange::MaskedPrompt(ask) => {
assert_eq!("password!", ask.question());
ask.set_answer(Ok("open sesame".into()))
}
Exchange::BinaryPrompt(prompt) => {
assert_eq!((&[1, 2, 3][..], 69), prompt.question());
prompt.set_answer(Ok(BinaryData::from((&[3, 2, 1], 42))))
}
Exchange::RadioPrompt(ask) => {
assert_eq!("radio?", ask.question());
ask.set_answer(Ok("yes".into()))
}
}
} else {
panic!(
"there should only be one message, not {len}",
len = messages.len()
)
}
}
}
let tester = MuxTester;
assert_eq!("answer", tester.prompt("question").unwrap());
assert_eq!("open sesame", tester.masked_prompt("password!").unwrap());
tester.error_msg("oh no");
tester.info_msg("let me tell you");
{
assert_eq!("yes", tester.radio_prompt("radio?").unwrap());
assert_eq!(
BinaryData::new(vec![3, 2, 1], 42),
tester.binary_prompt((&[1, 2, 3], 69)).unwrap(),
)
}
assert_eq!(
ErrorCode::BufferError,
tester.prompt("should_error").unwrap_err(),
);
assert_eq!(
ErrorCode::ConversationError,
tester.masked_prompt("return_wrong_type").unwrap_err()
)
}
}