use crate::error::ErrorCode;
use crate::Result;
use crate::c_box::CBox;
use libc::{free, strdup};
use pam_sys::pam_response as PamResponse;
use std::ffi::CString;
use std::{mem, ptr, slice};
#[derive(Debug)]
pub(crate) struct ResponseBuffer {
items: CBox<[PamResponse]>,
}
impl ResponseBuffer {
pub fn new(len: isize) -> Result<Self> {
match len {
1..=isize::MAX => {
#[allow(clippy::cast_sign_loss)]
let buffer = CBox::<PamResponse>::try_new_zeroed_slice(len as usize)?;
Ok(Self {
items: unsafe { buffer.assume_all_init() },
})
}
_ => Err(ErrorCode::BUF_ERR.into()),
}
}
#[allow(unused)]
pub fn len(&self) -> usize {
self.items.len()
}
#[allow(unused)]
pub fn is_empty(&self) -> bool {
self.items.len() == 0
}
#[inline]
#[allow(unused)]
pub fn iter(&self) -> slice::Iter<'_, PamResponse> {
self.into_iter()
}
#[inline]
pub fn put(&mut self, index: usize, response: Option<CString>) {
assert!(index < self.items.len());
let dest = &mut self.items[index];
if !dest.resp.is_null() {
unsafe { free(dest.resp.cast()) };
}
*dest = match response {
Some(text) => PamResponse {
resp: unsafe { strdup(text.as_ptr()) },
resp_retcode: 0,
},
None => PamResponse {
resp: ptr::null_mut(),
resp_retcode: 0,
},
}
}
#[inline]
#[allow(clippy::cast_possible_truncation)]
pub fn put_binary(&mut self, index: usize, response_type: u8, response: &[u8]) {
assert!(index < self.items.len());
let len = response.len() + 5;
assert!(len <= u32::MAX as usize);
let dest = &mut self.items[index];
if !dest.resp.is_null() {
unsafe { free(dest.resp.cast()) };
}
let mut buffer = unsafe { CBox::<u8>::new_zeroed_slice(len).assume_all_init() };
buffer[0..4].copy_from_slice(&(len as u32).to_be_bytes());
buffer[4] = response_type;
buffer[5..].copy_from_slice(response);
*dest = PamResponse {
resp: CBox::into_raw_unsized(buffer).cast(),
resp_retcode: 0,
};
}
}
impl std::ops::Index<usize> for ResponseBuffer {
type Output = PamResponse;
#[inline]
fn index(&self, index: usize) -> &Self::Output {
assert!(index < self.items.len());
&self.items[index]
}
}
impl std::ops::Index<std::ops::RangeFull> for ResponseBuffer {
type Output = [PamResponse];
#[inline]
fn index(&self, _index: std::ops::RangeFull) -> &Self::Output {
&self.items[..]
}
}
impl<'a> IntoIterator for &'a ResponseBuffer {
type Item = &'a PamResponse;
type IntoIter = std::slice::Iter<'a, PamResponse>;
fn into_iter(self) -> Self::IntoIter {
self[..].iter()
}
}
impl From<ResponseBuffer> for *mut PamResponse {
fn from(mut buf: ResponseBuffer) -> Self {
let result = (buf.items.as_mut() as *mut [PamResponse]).cast();
mem::forget(buf);
result
}
}
impl Drop for ResponseBuffer {
fn drop(&mut self) {
for item in self.items.iter_mut() {
if !item.resp.is_null() {
unsafe { free(item.resp.cast()) };
item.resp = ptr::null_mut();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn prepare_test_buffer() -> ResponseBuffer {
let mut buffer = ResponseBuffer::new(4).unwrap();
buffer.put(0, Some(CString::new("some response").unwrap()));
buffer.put(1, None);
buffer.put(2, Some(CString::new("some response").unwrap()));
buffer.put(2, Some(CString::new("another response").unwrap()));
buffer.put_binary(3, 1, &[]);
buffer.put_binary(3, 1, &[0, 1, 2]);
return buffer;
}
#[test]
fn test_len() {
assert_eq!(ResponseBuffer::new(1).unwrap().len(), 1);
assert_eq!(ResponseBuffer::new(3).unwrap().len(), 3);
assert_eq!(ResponseBuffer::new(3).unwrap().is_empty(), false);
assert_eq!(ResponseBuffer::new(65535).unwrap().len(), 65535);
assert_eq!(ResponseBuffer::new(65535).unwrap()[..].len(), 65535);
assert!(ResponseBuffer::new(0).is_err());
assert!(ResponseBuffer::new(-1).is_err());
assert!(ResponseBuffer::new(isize::MAX).is_err());
}
#[test]
fn test_iter() {
let buffer = prepare_test_buffer();
for (i, item) in buffer.iter().enumerate() {
assert_eq!(item.resp_retcode, 0);
if i == 1 {
assert!(item.resp.is_null())
} else {
assert!(!item.resp.is_null())
}
}
}
#[test]
fn test_index() {
let buffer = prepare_test_buffer();
assert_eq!(buffer[0].resp_retcode, 0);
assert_eq!(buffer[1].resp.is_null(), true);
assert_eq!(buffer[2].resp.is_null(), false);
}
}