use rand::{RngExt, distr::Alphanumeric};
use sha2::Digest;
use std::{
io::{Error, ErrorKind, Read, Result},
os::{
fd::{AsFd, AsRawFd, OwnedFd},
unix::ffi::OsStrExt,
},
path::Path,
};
use rustix::{
fs::{AtFlags, readlinkat, renameat, symlinkat, unlinkat},
io::{Errno, Result as ErrnoResult},
};
use tokio::io::{AsyncRead, AsyncReadExt};
#[derive(Debug)]
pub struct DigestWrite<D: Digest>(pub D);
impl<D: Digest> std::io::Write for DigestWrite<D> {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
self.0.update(buf);
Ok(buf.len())
}
fn flush(&mut self) -> Result<()> {
Ok(())
}
}
impl<D: Digest> DigestWrite<D> {
pub fn finalize(self) -> sha2::digest::Output<D> {
self.0.finalize()
}
}
pub(crate) fn proc_self_fd(fd: impl AsFd) -> String {
format!("/proc/self/fd/{}", fd.as_fd().as_raw_fd())
}
pub(crate) fn reopen_tmpfile_ro(file: std::fs::File) -> std::io::Result<rustix::fd::OwnedFd> {
let path = proc_self_fd(&file);
let ro = rustix::fs::open(
&*path,
rustix::fs::OFlags::RDONLY | rustix::fs::OFlags::CLOEXEC,
rustix::fs::Mode::empty(),
)?;
drop(file);
Ok(ro)
}
pub fn read_exactish(reader: &mut impl Read, buf: &mut [u8]) -> Result<bool> {
let buflen = buf.len();
let mut todo: &mut [u8] = buf;
while !todo.is_empty() {
match reader.read(todo) {
Ok(0) => {
return match todo.len() {
s if s == buflen => Ok(false), _ => Err(Error::from(ErrorKind::UnexpectedEof)),
};
}
Ok(n) => todo = &mut todo[n..],
Err(e) if e.kind() == ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
Ok(true)
}
pub async fn read_exactish_async(
reader: &mut (impl AsyncRead + Unpin),
buf: &mut [u8],
) -> Result<bool> {
let buflen = buf.len();
let mut todo: &mut [u8] = buf;
while !todo.is_empty() {
match reader.read(todo).await {
Ok(0) => {
return match todo.len() {
s if s == buflen => Ok(false), _ => Err(ErrorKind::UnexpectedEof.into()),
};
}
Ok(n) => todo = &mut todo[n..],
Err(e) if e.kind() == ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
Ok(true)
}
pub type Sha256Digest = [u8; 32];
pub fn parse_sha256(string: impl AsRef<str>) -> Result<Sha256Digest> {
let mut value = [0u8; 32];
hex::decode_to_slice(string.as_ref(), &mut value)
.map_err(|source| Error::new(ErrorKind::InvalidInput, source))?;
Ok(value)
}
pub(crate) trait ErrnoFilter<T> {
fn filter_errno(self, ignored: Errno) -> ErrnoResult<Option<T>>;
}
impl<T> ErrnoFilter<T> for ErrnoResult<T> {
fn filter_errno(self, ignored: Errno) -> ErrnoResult<Option<T>> {
match self {
Ok(result) => Ok(Some(result)),
Err(err) if err == ignored => Ok(None),
Err(err) => Err(err),
}
}
}
fn generate_tmpname(prefix: &str) -> String {
let rand_string: String = rand::rng()
.sample_iter(&Alphanumeric)
.take(12)
.map(char::from)
.collect();
format!("{prefix}{rand_string}")
}
pub(crate) fn replace_symlinkat(
target: impl AsRef<Path>,
dirfd: &OwnedFd,
name: impl AsRef<Path>,
) -> ErrnoResult<()> {
let name = name.as_ref();
let target = target.as_ref();
if symlinkat(target, dirfd, name)
.filter_errno(Errno::EXIST)?
.is_some()
{
return Ok(());
};
if let Some(current_target) = readlinkat(dirfd, name, []).filter_errno(Errno::NOENT)?
&& current_target.into_bytes() == target.as_os_str().as_bytes()
{
return Ok(());
}
for _ in 0..16 {
let tmp_name = generate_tmpname(".symlink-");
if symlinkat(target, dirfd, &tmp_name)
.filter_errno(Errno::EXIST)?
.is_none()
{
continue;
}
match renameat(dirfd, &tmp_name, dirfd, name) {
Ok(_) => return Ok(()),
Err(e) => {
let _ = unlinkat(dirfd, tmp_name, AtFlags::empty());
return Err(e);
}
}
}
Err(Errno::EXIST)
}
#[cfg(test)]
mod test {
use similar_asserts::assert_eq;
use super::*;
fn read_exactish_common(read9: fn(&mut &[u8]) -> Result<bool>) {
let mut r = b"" as &[u8];
assert_eq!(read9(&mut r).unwrap(), false);
assert_eq!(read9(&mut r).unwrap(), false);
r = b"ninebytes";
assert_eq!(read9(&mut r).unwrap(), true);
assert_eq!(read9(&mut r).unwrap(), false);
r = b"twelve bytes";
assert_eq!(read9(&mut r).unwrap(), true);
assert_eq!(read9(&mut r).unwrap_err().kind(), ErrorKind::UnexpectedEof);
r = b"eighteen(18) bytes";
assert_eq!(read9(&mut r).unwrap(), true);
assert_eq!(read9(&mut r).unwrap(), true);
assert_eq!(read9(&mut r).unwrap(), false);
}
#[test]
fn test_read_exactish() {
read_exactish_common(|r| read_exactish(r, &mut [0; 9]));
}
#[test]
fn test_read_exactish_broken_reader() {
struct BrokenReader;
impl Read for BrokenReader {
fn read(&mut self, _buffer: &mut [u8]) -> Result<usize> {
Err(ErrorKind::NetworkDown.into())
}
}
assert_eq!(
read_exactish(&mut BrokenReader, &mut [0; 9])
.unwrap_err()
.kind(),
ErrorKind::NetworkDown
);
}
#[test]
fn test_read_exactish_async() {
read_exactish_common(|r| {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(read_exactish_async(r, &mut [0; 9]))
});
}
#[tokio::test]
async fn test_read_exactish_broken_reader_async() {
let mut reader = tokio_test::io::Builder::new()
.read_error(Error::from(ErrorKind::NetworkDown))
.build();
assert_eq!(
read_exactish_async(&mut reader, &mut [0; 9])
.await
.unwrap_err()
.kind(),
ErrorKind::NetworkDown
);
}
#[test]
fn test_parse_sha256() {
let valid = "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff";
let valid_caps = "00112233445566778899AABBCCDDEEFF00112233445566778899AABBCCDDEEFf";
let valid_weird = "00112233445566778899aABbcCDdeEFf00112233445566778899AaBbCcDdEeFf";
assert_eq!(hex::encode(parse_sha256(valid).unwrap()), valid);
assert_eq!(hex::encode(parse_sha256(valid_caps).unwrap()), valid);
assert_eq!(hex::encode(parse_sha256(valid_weird).unwrap()), valid);
fn assert_invalid(x: &str) {
assert_eq!(parse_sha256(x).unwrap_err().kind(), ErrorKind::InvalidInput);
}
assert_invalid("");
assert_invalid("/etc/shadow");
assert_invalid("00112233445566778899aabbccddeeff00112233445566778899aabbccddeef");
assert_invalid("00112233445566778899aabbccddeeff00112233445566778899aabbccddeefff");
assert_invalid("00112233445566778899aabbccddeeff00112233445566778899aabbccddeefg");
}
}