namaste 0.18.0

Simple locks between processes
Documentation
// License: see LICENSE file at root directory of main branch

//! # Namaste

use std::{
    ffi::CStr,
    io::{Error, ErrorKind},
};

#[cfg(any(target_os = "linux", target_os = "l4re"))]
use core::{
    convert::TryFrom,
    mem,
};

#[cfg(not(any(target_os = "linux", target_os = "l4re")))]
use std::{
    env,
    fs::OpenOptions,
};

use {
    crate::Result,
};

#[cfg(any(target_os = "linux", target_os = "l4re"))]
type SunFamily = u16;

// This cast has tests covered
#[cfg(any(target_os = "linux", target_os = "l4re"))]
const AF_UNIX: SunFamily = libc::AF_UNIX as SunFamily;

#[cfg(any(target_os = "linux", target_os = "l4re"))]
const ID_LEN: usize = mem::size_of::<libc::sockaddr_un>() - mem::size_of::<SunFamily>();

#[cfg(any(target_os = "linux", target_os = "l4re"))]
const MAX_USER_ID_LEN: usize = ID_LEN - 1;

/// # Namaste
///
/// This struct uses abstract Linux sockets on Linux, and file locks on other Unix systems.
///
/// ## Usage
///
/// You can call [`make()`][fn:make] or [`make_wait()`][fn:make_wait] and pass your ID to that function. When done, simply drop it via `drop()`.
///
/// [fn:make]: fn.make.html
/// [fn:make_wait]: fn.make_wait.html
#[derive(Debug, Eq, PartialEq, Hash)]
pub struct Namaste {

    /// # File descriptor
    fd: i32,

}

impl Namaste {

    /// # Makes new instance using abstract Linux socket
    ///
    /// An error is returned if calling to functions in `libc` fails. Currently, error kind is simply [`Other`][r:ErrorKind#Other]. Error
    /// message will contain the original `libc` function name and _possible_ `errno`.
    ///
    /// Your ID will be filled into [`sockaddr_un.sun_path`][libc:sockaddr_un#sun_path] from `#1` index. If your ID length is small, all other
    /// bytes are zeros. The whole array will be used as address.
    ///
    /// [r:ErrorKind#Other]: https://doc.rust-lang.org/std/io/enum.ErrorKind.html#variant.Other
    /// [libc:sockaddr_un#sun_path]: https://docs.rs/libc/0.2.*/libc/struct.sockaddr_un.html#structfield.sun_path
    #[cfg(any(target_os = "linux", target_os = "l4re"))]
    pub (crate) unsafe fn make<B>(id: B) -> Result<Self> where B: AsRef<[u8]> {
        let id = id.as_ref();
        if id.is_empty() {
            return Err(Error::new(ErrorKind::InvalidData, "ID must not be empty"));
        }
        if id.len() > MAX_USER_ID_LEN {
            return Err(Error::new(ErrorKind::InvalidData, format!("ID length must be smaller than {}", ID_LEN)));
        }

        let fd = {
            const DOMAIN: i32 = libc::AF_UNIX;
            const TYPE: i32 = libc::SOCK_STREAM | libc::SOCK_NONBLOCK | libc::SOCK_CLOEXEC;
            const PROTOCOL: i32 = 0;
            match libc::socket(DOMAIN, TYPE, PROTOCOL) {
                -1 => return Err(Error::new(ErrorKind::Other, format_errno(format!("socket({}, {}, {})", DOMAIN, TYPE, PROTOCOL), None))),
                other => other,
            }
        };

        let mut addr = libc::sockaddr_un {
            sun_family: AF_UNIX,
            sun_path: [0; ID_LEN],
        };
        // Do NOT use std::ptr::copy_nonoverlapping()! First, it's unsafe. Second, tests showed that it had same performance as below loop.
        for (i, b) in id.iter().enumerate() {
            // Array bounds and cast have tests covered
            addr.sun_path[i + 1] = *b as i8;
        }

        match libc::bind(
            fd,
            mem::transmute(&addr),
            u32::try_from(mem::size_of::<libc::sockaddr_un>()).map_err(|e| Error::new(ErrorKind::Other, e))?,
        ) {
            0 => Ok(Self { fd }),
            _ => {
                close_fd(fd)?;
                Err(Error::new(ErrorKind::Other, format_errno(format!("bind({}, ...)", fd), None)))
            },
        }
    }

    /// # Makes new instance using file lock
    ///
    /// An error is returned if calling to functions in `libc` fails. Currently, error kind is simply [`Other`][r:ErrorKind#Other]. Error
    /// message will contain the original `libc` function name and _possible_ `errno`.
    ///
    /// Your ID will be used to generate a temporary file path. Then the function will try to open and lock that file.
    ///
    /// [r:ErrorKind#Other]: https://doc.rust-lang.org/std/io/enum.ErrorKind.html#variant.Other
    #[cfg(not(any(target_os = "linux", target_os = "l4re")))]
    pub (crate) unsafe fn make<B>(id: B) -> Result<Self> where B: AsRef<[u8]> {
        let file = env::temp_dir().join(format!("{hash}.{code_name}", hash=bytes_to_hex(id), code_name=crate::CODE_NAME));
        if file.exists() {
            if file.is_dir() {
                return Err(Error::new(ErrorKind::InvalidData, format!("Output file is a directory: {:?}", file)));
            }
            if file.metadata()?.len() > 0 {
                return Err(Error::new(ErrorKind::InvalidData, format!("Output file has unexpected content: {:?}", file)));
            }
        }

        let fd = {
            use {
                std::os::unix::{
                    fs::OpenOptionsExt,
                    io::IntoRawFd,
                },
            };
            OpenOptions::new().create(true).write(true).truncate(true).custom_flags(libc::O_CLOEXEC).open(file)?.into_raw_fd()
        };
        let lock_flags = libc::LOCK_NB | libc::LOCK_EX;
        match libc::flock(fd, lock_flags) {
            0 => Ok(Self { fd }),
            _ => {
                close_fd(fd)?;
                Err(Error::new(ErrorKind::Other, format_errno(format!("flock({fd}, {lock_flags})", fd=fd, lock_flags=lock_flags), None)))
            },
        }
    }

}

impl Drop for Namaste {

    fn drop(&mut self) {
        if let Err(err) = unsafe {
            close_fd(self.fd)
        } {
            eprintln!("{}", __!("{}", err));
        }
    }

}

#[cfg(not(any(target_os = "linux", target_os = "l4re")))]
const HEX_STRS: &[&str] = &[
    "00", "01", "02", "03", "04", "05", "06", "07", "08", "09", "0a", "0b", "0c", "0d", "0e", "0f", "10", "11", "12", "13", "14", "15", "16",
    "17", "18", "19", "1a", "1b", "1c", "1d", "1e", "1f", "20", "21", "22", "23", "24", "25", "26", "27", "28", "29", "2a", "2b", "2c", "2d",
    "2e", "2f", "30", "31", "32", "33", "34", "35", "36", "37", "38", "39", "3a", "3b", "3c", "3d", "3e", "3f", "40", "41", "42", "43", "44",
    "45", "46", "47", "48", "49", "4a", "4b", "4c", "4d", "4e", "4f", "50", "51", "52", "53", "54", "55", "56", "57", "58", "59", "5a", "5b",
    "5c", "5d", "5e", "5f", "60", "61", "62", "63", "64", "65", "66", "67", "68", "69", "6a", "6b", "6c", "6d", "6e", "6f", "70", "71", "72",
    "73", "74", "75", "76", "77", "78", "79", "7a", "7b", "7c", "7d", "7e", "7f", "80", "81", "82", "83", "84", "85", "86", "87", "88", "89",
    "8a", "8b", "8c", "8d", "8e", "8f", "90", "91", "92", "93", "94", "95", "96", "97", "98", "99", "9a", "9b", "9c", "9d", "9e", "9f", "a0",
    "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", "aa", "ab", "ac", "ad", "ae", "af", "b0", "b1", "b2", "b3", "b4", "b5", "b6", "b7",
    "b8", "b9", "ba", "bb", "bc", "bd", "be", "bf", "c0", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "ca", "cb", "cc", "cd", "ce",
    "cf", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "da", "db", "dc", "dd", "de", "df", "e0", "e1", "e2", "e3", "e4", "e5",
    "e6", "e7", "e8", "e9", "ea", "eb", "ec", "ed", "ee", "ef", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9", "fa", "fb", "fc",
    "fd", "fe", "ff",
];

#[cfg(not(any(target_os = "linux", target_os = "l4re")))]
#[test]
fn test_hex_strs() {
    assert!(HEX_STRS.len() > usize::from(u8::max_value()));
    for i in 0..HEX_STRS.len() {
        assert_eq!(format!("{:02x}", i), HEX_STRS[i]);
    }
}

/// # Formats a byte slice as a hexadecimal string, in lower-case
///
/// This function uses [`HEX_STRS`][::HEX_STRS], which (at the time) was faster than `format!("{:02x}", ...)`.
///
/// [::HEX_STRS]: constant.HEX_STRS.html
#[cfg(not(any(target_os = "linux", target_os = "l4re")))]
fn bytes_to_hex<B>(bytes: B) -> String where B: AsRef<[u8]> {
    let bytes = bytes.as_ref();
    let mut result = String::with_capacity(bytes.len().saturating_mul(2));
    for b in bytes {
        result.push_str(HEX_STRS[usize::from(*b)]);
    }
    result
}

/// # Formats errno
unsafe fn format_errno<S>(libc_fn_name: S, errno: Option<i32>) -> String where S: AsRef<str> {
    let libc_fn_name = libc_fn_name.as_ref();
    let errno = errno.unwrap_or_else(|| libc_errno());
    format!(
        "Failed calling libc::{}{} -- *possible* errno: {:?} -> {:?}",
        libc_fn_name, match libc_fn_name.ends_with(')') { true => concat!(), false => "()" },
        errno, CStr::from_ptr(libc::strerror(errno)),
    )
}

#[cfg(any(target_os = "linux", target_os = "l4re"))]
unsafe fn libc_errno() -> i32 {
    *libc::__errno_location()
}

#[cfg(target_os = "netbsd")]
unsafe fn libc_errno() -> i32 {
    *libc::__errno()
}

#[cfg(not(any(target_os = "linux", target_os = "l4re", target_os = "netbsd")))]
unsafe fn libc_errno() -> i32 {
    *libc::__error()
}

/// # Closes a file descriptor
unsafe fn close_fd(fd: i32) -> Result<()> {
    match libc::close(fd) {
        0 => Ok(()),
        _ => Err(Error::new(ErrorKind::Other, format_errno(format!("close({})", fd), None))),
    }
}

#[test]
fn test_namaste() {
    #[cfg(any(target_os = "linux", target_os = "l4re"))]
    use core::mem;
    use zeros::Hash;

    #[cfg(any(target_os = "linux", target_os = "l4re"))]
    {
        // Test for self::AF_UNIX
        let af_unix: i32 = libc::AF_UNIX;
        assert!(af_unix >= 0 && af_unix as u32 <= u32::from(SunFamily::max_value()));

        // Test for stuff in Namaste::make()
        assert_eq!(mem::size_of::<SunFamily>().checked_add(ID_LEN).unwrap(), mem::size_of::<libc::sockaddr_un>());
        assert_eq!(MAX_USER_ID_LEN.checked_add(1).unwrap(), ID_LEN);
        assert!(mem::size_of::<libc::sockaddr_un>() < usize::from(u8::max_value()));
    }

    let mut id = Hash::Sha3_512.hash(crate::ID);
    for _ in 0..1000 {
        id = Hash::Sha3_512.hash(id);

        // Test for cast between [u8] and [i8] in Namaste::make()
        let mut i8_vec: Vec<i8> = Vec::with_capacity(id.len());
        for b in &id {
            i8_vec.push(*b as i8);
        }

        let mut u8_array: Vec<u8> = Vec::with_capacity(id.len());
        for n in &i8_vec {
            u8_array.push(*n as u8);
        }
        assert_eq!(id, u8_array);

        // Test Namaste
        unsafe {
            let namaste = Namaste::make(&id).unwrap();

            // Let's try it several times
            for _ in 0..9 {
                assert!(Namaste::make(&id).is_err());
            }

            drop(namaste);

            // Let's try it several times
            for _ in 0..9 {
                assert!(Namaste::make(&id).is_ok());
            }
        }
    }
}