Documentation
/*
==--==--==--==--==--==--==--==--==--==--==--==--==--==--==--==--

Namaste

Copyright (C) 2019, 2021-2025  Anonymous

There are several releases over multiple years,
they are listed as ranges, such as: "2021-2025".

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.

::--::--::--::--::--::--::--::--::--::--::--::--::--::--::--::--
*/

//! # `UdsxUnixStream`

#![cfg(unix)]
#![doc(cfg(unix))]

use {
    core::mem,
    std::{
        io::{self, Error, ErrorKind, IoSlice, IoSliceMut},
        os::unix::{
            io::{AsRawFd, FromRawFd, RawFd},
            net::{AncillaryData, SocketAncillary, UnixStream},
        },
        process::Stdio,
    },
    crate::Result,
};

const RAW_FD_SIZE: usize = mem::size_of::<RawFd>();

/// # Extensions for `UnixStream`
///
/// ## Notes
///
/// - IDs are used for verification, they must not be empty.
/// - In tests, IDs with size from one byte up to 64 bytes work fine. So you can use something like SHA3-512 hashes as your IDs.
pub trait UdsxUnixStream {

    /// # Sends streams
    ///
    /// ## Notes
    ///
    /// Too much streams will make an error (see `cmsg(3)`).
    ///
    /// ## See also
    ///
    /// [`recv_streams()`][fn:recv_streams]
    ///
    /// [fn:recv_streams]: #tymethod.recv_streams
    fn send_streams<B, const N: usize>(&self, id: B, streams: [RawFd; N]) -> Result<()> where B: AsRef<[u8]>;

    /// # Sends all standard streams: input, output, error
    ///
    /// ## See also
    ///
    /// [`recv_ioe()`][fn:recv_ioe]
    ///
    /// [fn:recv_ioe]: #tymethod.recv_ioe
    fn send_ioe<B>(&self, id: B) -> Result<()> where B: AsRef<[u8]>;

    /// # Receives streams
    ///
    /// The number of streams you want to receive must match what client sends.
    ///
    /// ## See also
    ///
    /// [`send_streams()`][fn:send_streams]
    ///
    /// [fn:send_streams]: #tymethod.send_streams
    unsafe fn recv_streams<B, T, const N: usize>(&self, id: B) -> Result<[T; N]> where B: AsRef<[u8]>, T: FromRawFd;

    /// # Receives standard streams sent by [`send_ioe()`][fn:send_ioe]
    ///
    /// Results are: input, output and error streams.
    ///
    /// [fn:send_ioe]: #tymethod.send_ioe
    unsafe fn recv_ioe<B>(&self, id: B) -> Result<[Stdio; 3]> where B: AsRef<[u8]>;

}

impl UdsxUnixStream for UnixStream {

    fn send_streams<B, const N: usize>(&self, id: B, streams: [RawFd; N]) -> Result<()> where B: AsRef<[u8]> {
        verify_id(&id)?;

        let mut ancillary_buf = vec!(0; make_size_of_streams(streams.len())?);
        let mut ancillary = SocketAncillary::new(&mut ancillary_buf);
        if ancillary.add_fds(&streams) == false {
            return Err(Error::new(ErrorKind::Other, __!()));
        }
        self.send_vectored_with_ancillary(&[IoSlice::new(id.as_ref())], &mut ancillary)?;

        Ok(())
    }

    fn send_ioe<B>(&self, id: B) -> Result<()> where B: AsRef<[u8]> {
        self.send_streams(id, [io::stdin().as_raw_fd(), io::stdout().as_raw_fd(), io::stderr().as_raw_fd()])
    }

    unsafe fn recv_streams<B, T, const N: usize>(&self, id: B) -> Result<[T; N]> where B: AsRef<[u8]>, T: FromRawFd {
        verify_id(&id)?;

        let mut ancillary_buf = vec!(0; make_size_of_streams(N)?);
        let mut ancillary = SocketAncillary::new(&mut ancillary_buf);

        // Receive and verify
        let id = id.as_ref();
        if id != {
            let mut io_slices = vec!(0; id.len());
            {
                let mut io_slices = [IoSliceMut::new(&mut io_slices)];
                self.recv_vectored_with_ancillary(&mut io_slices, &mut ancillary)?;
            }
            io_slices
        } {
            return Err(Error::new(ErrorKind::InvalidData, __!("Invalid ID of streams")));
        }

        let mut result = Vec::with_capacity(N);
        for messages in ancillary.messages() {
            let data = messages.map_err(|e| Error::new(ErrorKind::Other, __!("{:?}", e)))?;
            match data {
                AncillaryData::ScmRights(scm_rights) => scm_rights.for_each(|fd| result.push(unsafe { T::from_raw_fd(fd) })),
                AncillaryData::ScmCredentials(_) => return Err(
                    Error::new(ErrorKind::InvalidData, __!("Expected ScmRights, got ScmCredentials"))
                ),
            };
        }

        Ok(<[T; N]>::try_from(result).map_err(|v| err!("Expected to receive {n} stream{s}, got: {len}", n=N, s=plural_s!(N), len=v.len()))?)
    }

    unsafe fn recv_ioe<B>(&self, id: B) -> Result<[Stdio; 3]> where B: AsRef<[u8]> {
        unsafe {
            self.recv_streams(id)
        }
    }

}

/// # Verifies ID
fn verify_id<B>(id: B) -> Result<()> where B: AsRef<[u8]> {
    if id.as_ref().is_empty() {
        Err(Error::new(ErrorKind::InvalidData, __!("ID must not be empty")))
    } else {
        Ok(())
    }
}

/// # Makes size of streams with `count` items
fn make_size_of_streams(count: usize) -> Result<usize> {
    // This was gotten by tests
    const FACTOR: usize = 6;

    match RAW_FD_SIZE.checked_mul(count).map(|x| x.checked_mul(FACTOR)) {
        Some(Some(result)) => Ok(result),
        _ => Err(Error::new(ErrorKind::InvalidData, __!("Stream has too much items: {count}", count=count))),
    }
}