use std::ffi::{CString, NulError};
use std::io::{self, Write};
use zeroize::Zeroize;
pub(crate) fn write_to_c_buf(
src: &str,
dst: &mut [u8],
) -> Result<usize, WriteBufError> {
let src = CString::new(src.as_bytes())?;
let src = src.as_bytes_with_nul();
check_len(src, dst)?;
let bytes = dst.zeroized_write(src)?;
Ok(bytes)
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum WriteBufError {
#[error("destination buffer too short (needs {src_len} bytes, has {dst_len} bytes)")]
DstTooShort { src_len: usize, dst_len: usize },
#[error(transparent)]
FailedWrite(#[from] io::Error),
#[error(transparent)]
InteriorNul(#[from] NulError),
}
fn check_len(src: &[u8], dst: &[u8]) -> Result<(), WriteBufError> {
let src_len = src.len();
let dst_len = dst.len();
if dst_len < src_len {
Err(WriteBufError::DstTooShort { src_len, dst_len })
} else {
Ok(())
}
}
trait ZeroizedWrite: Write {
fn zeroized_write(self, buf: &[u8]) -> io::Result<usize>;
}
impl<'a> ZeroizedWrite for &'a mut [u8] {
fn zeroized_write(mut self, buf: &[u8]) -> io::Result<usize> {
let bytes = self.write(buf)?;
if let Some(remainder) = self.get_mut(buf.len()..) {
remainder.iter_mut().zeroize();
}
Ok(bytes)
}
}