#[cfg(feature = "std")]
mod credentials;
mod read_connection;
#[cfg(feature = "std")]
pub use credentials::Credentials;
pub use read_connection::ReadConnection;
#[cfg(feature = "std")]
pub use rustix::{process::Gid, process::Pid, process::Uid};
pub mod chain;
pub mod socket;
#[cfg(test)]
mod tests;
mod write_connection;
use crate::{
Call, Result,
reply::{self, Reply},
};
#[cfg(feature = "std")]
use alloc::vec;
pub use chain::Chain;
use core::{fmt::Debug, sync::atomic::AtomicUsize};
#[cfg(feature = "std")]
use socket::FetchPeerCredentials;
pub use write_connection::WriteConnection;
use serde::{Deserialize, Serialize};
pub use socket::Socket;
#[cfg(feature = "std")]
type RecvResult<T> = (T, Vec<std::os::fd::OwnedFd>);
#[cfg(not(feature = "std"))]
type RecvResult<T> = T;
#[derive(Debug)]
pub struct Connection<S: Socket> {
read: ReadConnection<S::ReadHalf>,
write: WriteConnection<S::WriteHalf>,
#[cfg(feature = "std")]
credentials: Option<std::sync::Arc<Credentials>>,
}
impl<S> Connection<S>
where
S: Socket,
{
pub fn new(socket: S) -> Self {
let (read, write) = socket.split();
let id = NEXT_ID.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
Self {
read: ReadConnection::new(read, id),
write: WriteConnection::new(write, id),
#[cfg(feature = "std")]
credentials: None,
}
}
pub fn read(&self) -> &ReadConnection<S::ReadHalf> {
&self.read
}
pub fn read_mut(&mut self) -> &mut ReadConnection<S::ReadHalf> {
&mut self.read
}
pub fn write(&self) -> &WriteConnection<S::WriteHalf> {
&self.write
}
pub fn write_mut(&mut self) -> &mut WriteConnection<S::WriteHalf> {
&mut self.write
}
pub fn split(self) -> (ReadConnection<S::ReadHalf>, WriteConnection<S::WriteHalf>) {
(self.read, self.write)
}
pub fn join(read: ReadConnection<S::ReadHalf>, write: WriteConnection<S::WriteHalf>) -> Self {
Self {
read,
write,
#[cfg(feature = "std")]
credentials: None,
}
}
pub fn id(&self) -> usize {
assert_eq!(self.read.id(), self.write.id());
self.read.id()
}
pub async fn send_call<Method>(
&mut self,
call: &Call<Method>,
#[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
) -> Result<()>
where
Method: Serialize + Debug,
{
#[cfg(feature = "std")]
{
self.write.send_call(call, fds).await
}
#[cfg(not(feature = "std"))]
{
self.write.send_call(call).await
}
}
pub async fn receive_reply<'r, ReplyParams, ReplyError>(
&'r mut self,
) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
where
ReplyParams: Deserialize<'r> + Debug,
ReplyError: Deserialize<'r> + Debug,
{
self.read.receive_reply().await
}
pub async fn call_method<'r, Method, ReplyParams, ReplyError>(
&'r mut self,
call: &Call<Method>,
#[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
) -> Result<RecvResult<reply::Result<ReplyParams, ReplyError>>>
where
Method: Serialize + Debug,
ReplyParams: Deserialize<'r> + Debug,
ReplyError: Deserialize<'r> + Debug,
{
#[cfg(feature = "std")]
self.send_call(call, fds).await?;
#[cfg(not(feature = "std"))]
self.send_call(call).await?;
self.receive_reply().await
}
pub async fn receive_call<'m, Method>(&'m mut self) -> Result<RecvResult<Call<Method>>>
where
Method: Deserialize<'m> + Debug,
{
self.read.receive_call().await
}
pub async fn send_reply<ReplyParams>(
&mut self,
reply: &Reply<ReplyParams>,
#[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
) -> Result<()>
where
ReplyParams: Serialize + Debug,
{
#[cfg(feature = "std")]
{
self.write.send_reply(reply, fds).await
}
#[cfg(not(feature = "std"))]
{
self.write.send_reply(reply).await
}
}
pub async fn send_error<ReplyError>(
&mut self,
error: &ReplyError,
#[cfg(feature = "std")] fds: Vec<std::os::fd::OwnedFd>,
) -> Result<()>
where
ReplyError: Serialize + Debug,
{
#[cfg(feature = "std")]
{
self.write.send_error(error, fds).await
}
#[cfg(not(feature = "std"))]
{
self.write.send_error(error).await
}
}
pub fn enqueue_call<Method>(&mut self, method: &Call<Method>) -> Result<()>
where
Method: Serialize + Debug,
{
#[cfg(feature = "std")]
{
self.write.enqueue_call(method, vec![])
}
#[cfg(not(feature = "std"))]
{
self.write.enqueue_call(method)
}
}
pub async fn flush(&mut self) -> Result<()> {
self.write.flush().await
}
pub fn chain_call<'c, Method>(
&'c mut self,
call: &Call<Method>,
#[cfg(feature = "std")] fds: alloc::vec::Vec<std::os::fd::OwnedFd>,
) -> Result<Chain<'c, S>>
where
Method: Serialize + Debug,
{
Chain::new(
self,
call,
#[cfg(feature = "std")]
fds,
)
}
pub fn chain_from_iter<'c, Method, MethodCall, MethodCalls>(
&'c mut self,
calls: MethodCalls,
) -> Result<Chain<'c, S>>
where
Method: Serialize + Debug,
MethodCall: Into<Call<Method>>,
MethodCalls: IntoIterator<Item = MethodCall>,
{
let mut iter = calls.into_iter();
let first: Call<Method> = iter.next().ok_or(crate::Error::EmptyChain)?.into();
#[cfg(feature = "std")]
let mut chain = Chain::new(self, &first, alloc::vec::Vec::new())?;
#[cfg(not(feature = "std"))]
let mut chain = Chain::new(self, &first)?;
for call in iter {
let call: Call<Method> = call.into();
#[cfg(feature = "std")]
{
chain = chain.append(&call, alloc::vec::Vec::new())?;
}
#[cfg(not(feature = "std"))]
{
chain = chain.append(&call)?;
}
}
Ok(chain)
}
#[cfg(feature = "std")]
pub fn chain_from_iter_with_fds<'c, Method, MethodCall, MethodCalls>(
&'c mut self,
calls: MethodCalls,
) -> Result<Chain<'c, S>>
where
Method: Serialize + Debug,
MethodCall: Into<Call<Method>>,
MethodCalls: IntoIterator<Item = (MethodCall, alloc::vec::Vec<std::os::fd::OwnedFd>)>,
{
let mut iter = calls.into_iter();
let (first, first_fds) = iter.next().ok_or(crate::Error::EmptyChain)?;
let first: Call<Method> = first.into();
let mut chain = Chain::new(self, &first, first_fds)?;
for (call, fds) in iter {
let call: Call<Method> = call.into();
chain = chain.append(&call, fds)?;
}
Ok(chain)
}
#[cfg(feature = "std")]
pub async fn peer_credentials(&mut self) -> std::io::Result<&std::sync::Arc<Credentials>>
where
S::ReadHalf: socket::FetchPeerCredentials,
{
if self.credentials.is_none() {
let creds = self.read.read_half().fetch_peer_credentials().await?;
self.credentials = Some(std::sync::Arc::new(creds));
}
Ok(self.credentials.as_ref().unwrap())
}
}
impl<S> From<S> for Connection<S>
where
S: Socket,
{
fn from(socket: S) -> Self {
Self::new(socket)
}
}
pub(crate) const BUFFER_SIZE: usize = 256;
const MAX_BUFFER_SIZE: usize = 100 * 1024 * 1024;
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);