#![forbid(unsafe_code)]
use super::ConversationHandler;
use crate::error::ErrorCode;
use std::ffi::{CStr, CString};
use std::io::{self, BufRead, Write};
fn trim_newline(s: &mut String) {
if s.ends_with('\n') {
s.pop();
if s.ends_with('\r') {
s.pop();
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Conversation {
info_prefix: String,
error_prefix: String,
}
impl Conversation {
#[must_use]
pub fn new() -> Self {
Self {
info_prefix: "[PAM INFO] ".to_string(),
error_prefix: "[PAM ERROR] ".to_string(),
}
}
#[inline]
#[must_use]
pub fn info_prefix(&self) -> &str {
&self.info_prefix
}
pub fn set_info_prefix(&mut self, prefix: impl Into<String>) {
self.info_prefix = prefix.into();
}
#[inline]
#[must_use]
pub fn error_prefix(&self) -> &str {
&self.error_prefix
}
pub fn set_error_prefix(&mut self, prefix: impl Into<String>) {
self.error_prefix = prefix.into();
}
}
impl Default for Conversation {
fn default() -> Self {
Self::new()
}
}
impl ConversationHandler for Conversation {
fn prompt_echo_on(&mut self, msg: &CStr) -> Result<CString, ErrorCode> {
let mut line = String::new();
if io::stderr().lock().write_all(msg.to_bytes()).is_err() {
return Err(ErrorCode::CONV_ERR);
}
let result = io::stdin().lock().read_line(&mut line);
match result {
Err(_) | Ok(0) => Err(ErrorCode::CONV_ERR),
Ok(_) => {
trim_newline(&mut line);
CString::new(line).map_err(|_| ErrorCode::CONV_ERR)
}
}
}
fn prompt_echo_off(&mut self, msg: &CStr) -> Result<CString, ErrorCode> {
let prompt = msg.to_string_lossy();
match rpassword::prompt_password(&prompt) {
Err(_) => Err(ErrorCode::CONV_ERR),
Ok(password) => CString::new(password).map_err(|_| ErrorCode::CONV_ERR),
}
}
fn text_info(&mut self, msg: &CStr) {
eprintln!("{}{}", &self.info_prefix, msg.to_string_lossy());
}
fn error_msg(&mut self, msg: &CStr) {
eprintln!("{}{}", &self.error_prefix, msg.to_string_lossy());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trim() {
let mut value = "Test\r\n".to_string();
trim_newline(&mut value);
assert_eq!(value, "Test");
let mut value = "Test\n".to_string();
trim_newline(&mut value);
assert_eq!(value, "Test");
}
#[test]
fn test_output() {
let mut c = Conversation::default();
c.set_info_prefix("INFO: ");
c.set_error_prefix("ERROR: ");
assert_eq!(c.info_prefix(), "INFO: ");
assert_eq!(c.error_prefix(), "ERROR: ");
c.text_info(&CString::new("test").unwrap());
c.error_msg(&CString::new("test2").unwrap());
assert!(format!("{:?}", &c).contains("ERROR: "));
}
}