#[cfg(not(feature = "tokio"))]
use async_io::Async;
use enumflags2::BitFlags;
use event_listener::Event;
#[cfg(not(feature = "tokio"))]
use std::net::TcpStream;
#[cfg(all(unix, not(feature = "tokio")))]
use std::os::unix::net::UnixStream;
use std::{
collections::{HashMap, HashSet},
vec,
};
#[cfg(feature = "tokio")]
use tokio::net::TcpStream;
#[cfg(all(unix, feature = "tokio"))]
use tokio::net::UnixStream;
#[cfg(feature = "tokio-vsock")]
use tokio_vsock::VsockStream;
#[cfg(all(windows, not(feature = "tokio")))]
use uds_windows::UnixStream;
#[cfg(all(feature = "vsock", not(feature = "tokio")))]
use vsock::VsockStream;
use zvariant::ObjectPath;
use crate::{
address::{self, Address},
fdo::RequestNameFlags,
names::{InterfaceName, WellKnownName},
object_server::{ArcInterface, Interface},
Connection, Error, Executor, Guid, OwnedGuid, Result,
};
use super::{
handshake::{AuthMechanism, Authenticated},
socket::{BoxedSplit, ReadHalf, Split, WriteHalf},
};
const DEFAULT_MAX_QUEUED: usize = 64;
#[derive(Debug)]
enum Target {
#[cfg(any(unix, not(feature = "tokio")))]
UnixStream(UnixStream),
TcpStream(TcpStream),
#[cfg(any(
all(feature = "vsock", not(feature = "tokio")),
feature = "tokio-vsock"
))]
VsockStream(VsockStream),
Address(Address),
Socket(Split<Box<dyn ReadHalf>, Box<dyn WriteHalf>>),
AuthenticatedSocket(Split<Box<dyn ReadHalf>, Box<dyn WriteHalf>>),
}
type Interfaces<'a> = HashMap<ObjectPath<'a>, HashMap<InterfaceName<'static>, ArcInterface>>;
#[derive(Debug)]
#[must_use]
pub struct Builder<'a> {
target: Option<Target>,
max_queued: Option<usize>,
guid: Option<Guid<'a>>,
#[cfg(feature = "p2p")]
p2p: bool,
internal_executor: bool,
interfaces: Interfaces<'a>,
names: HashSet<WellKnownName<'a>>,
auth_mechanism: Option<AuthMechanism>,
#[cfg(feature = "bus-impl")]
unique_name: Option<crate::names::UniqueName<'a>>,
request_name_flags: BitFlags<RequestNameFlags>,
}
impl<'a> Builder<'a> {
pub fn session() -> Result<Self> {
Ok(Self::new(Target::Address(Address::session()?)))
}
pub fn system() -> Result<Self> {
Ok(Self::new(Target::Address(Address::system()?)))
}
pub fn address<A>(address: A) -> Result<Self>
where
A: TryInto<Address>,
A::Error: Into<Error>,
{
Ok(Self::new(Target::Address(
address.try_into().map_err(Into::into)?,
)))
}
#[cfg(any(unix, not(feature = "tokio")))]
pub fn unix_stream(stream: UnixStream) -> Self {
Self::new(Target::UnixStream(stream))
}
pub fn tcp_stream(stream: TcpStream) -> Self {
Self::new(Target::TcpStream(stream))
}
#[cfg(any(
all(feature = "vsock", not(feature = "tokio")),
feature = "tokio-vsock"
))]
pub fn vsock_stream(stream: VsockStream) -> Self {
Self::new(Target::VsockStream(stream))
}
pub fn socket<S: Into<BoxedSplit>>(socket: S) -> Self {
Self::new(Target::Socket(socket.into()))
}
pub fn authenticated_socket<S, G>(socket: S, guid: G) -> Result<Self>
where
S: Into<BoxedSplit>,
G: TryInto<Guid<'a>>,
G::Error: Into<Error>,
{
let mut builder = Self::new(Target::AuthenticatedSocket(socket.into()));
builder.guid = Some(guid.try_into().map_err(Into::into)?);
Ok(builder)
}
pub fn auth_mechanism(mut self, auth_mechanism: AuthMechanism) -> Self {
self.auth_mechanism = Some(auth_mechanism);
self
}
#[cfg(feature = "p2p")]
pub fn p2p(mut self) -> Self {
self.p2p = true;
self
}
#[cfg(feature = "p2p")]
pub fn server<G>(mut self, guid: G) -> Result<Self>
where
G: TryInto<Guid<'a>>,
G::Error: Into<Error>,
{
self.guid = Some(guid.try_into().map_err(Into::into)?);
Ok(self)
}
pub fn max_queued(mut self, max: usize) -> Self {
self.max_queued = Some(max);
self
}
pub fn internal_executor(mut self, enabled: bool) -> Self {
self.internal_executor = enabled;
self
}
pub fn serve_at<P, I>(mut self, path: P, iface: I) -> Result<Self>
where
I: Interface,
P: TryInto<ObjectPath<'a>>,
P::Error: Into<Error>,
{
let path = path.try_into().map_err(Into::into)?;
let entry = self.interfaces.entry(path).or_default();
entry.insert(I::name(), ArcInterface::new(iface));
Ok(self)
}
pub fn name<W>(mut self, well_known_name: W) -> Result<Self>
where
W: TryInto<WellKnownName<'a>>,
W::Error: Into<Error>,
{
let well_known_name = well_known_name.try_into().map_err(Into::into)?;
self.names.insert(well_known_name);
Ok(self)
}
pub fn allow_name_replacements(mut self, allow_replacement: bool) -> Self {
self.request_name_flags
.set(RequestNameFlags::AllowReplacement, allow_replacement);
self
}
pub fn replace_existing_names(mut self, replace_existing: bool) -> Self {
self.request_name_flags
.set(RequestNameFlags::ReplaceExisting, replace_existing);
self
}
#[cfg(feature = "bus-impl")]
pub fn unique_name<U>(mut self, unique_name: U) -> Result<Self>
where
U: TryInto<crate::names::UniqueName<'a>>,
U::Error: Into<Error>,
{
if !self.p2p {
panic!("unique name can only be set for peer-to-peer connections");
}
let name = unique_name.try_into().map_err(Into::into)?;
self.unique_name = Some(name);
Ok(self)
}
pub async fn build(self) -> Result<Connection> {
let executor = Executor::new();
#[cfg(not(feature = "tokio"))]
let internal_executor = self.internal_executor;
let conn = Box::pin(executor.run(self.build_(executor.clone()))).await?;
#[cfg(not(feature = "tokio"))]
start_internal_executor(&executor, internal_executor)?;
Ok(conn)
}
async fn build_(mut self, executor: Executor<'static>) -> Result<Connection> {
#[cfg(feature = "p2p")]
let is_bus_conn = !self.p2p;
#[cfg(not(feature = "p2p"))]
let is_bus_conn = true;
let mut auth = self.connect(is_bus_conn).await?;
let socket_read = auth.socket_read.take().unwrap();
let already_received_bytes = auth.already_received_bytes.drain(..).collect();
#[cfg(unix)]
let already_received_fds = auth.already_received_fds.drain(..).collect();
let mut conn = Connection::new(auth, is_bus_conn, executor).await?;
conn.set_max_queued(self.max_queued.unwrap_or(DEFAULT_MAX_QUEUED));
if !self.interfaces.is_empty() {
let object_server = conn.ensure_object_server(false);
for (path, interfaces) in self.interfaces {
for (name, iface) in interfaces {
let added = object_server
.add_arc_interface(path.clone(), name.clone(), iface.clone())
.await?;
if !added {
return Err(Error::InterfaceExists(name.clone(), path.to_owned()));
}
}
}
let started_event = Event::new();
let listener = started_event.listen();
conn.start_object_server(Some(started_event));
listener.await;
}
conn.init_socket_reader(
socket_read,
already_received_bytes,
#[cfg(unix)]
already_received_fds,
);
for name in self.names {
conn.request_name_with_flags(name, self.request_name_flags)
.await?;
}
Ok(conn)
}
fn new(target: Target) -> Self {
Self {
target: Some(target),
#[cfg(feature = "p2p")]
p2p: false,
max_queued: None,
guid: None,
internal_executor: true,
interfaces: HashMap::new(),
names: HashSet::new(),
auth_mechanism: None,
#[cfg(feature = "bus-impl")]
unique_name: None,
request_name_flags: BitFlags::default(),
}
}
async fn connect(&mut self, is_bus_conn: bool) -> Result<Authenticated> {
#[cfg(not(feature = "bus-impl"))]
let unique_name = None;
#[cfg(feature = "bus-impl")]
let unique_name = self.unique_name.take().map(Into::into);
#[allow(unused_mut)]
let (mut stream, server_guid, authenticated) = self.target_connect().await?;
if authenticated {
let (socket_read, socket_write) = stream.take();
Ok(Authenticated {
#[cfg(unix)]
cap_unix_fd: socket_read.can_pass_unix_fd(),
socket_read: Some(socket_read),
socket_write,
server_guid: server_guid.unwrap(),
already_received_bytes: vec![],
unique_name,
#[cfg(unix)]
already_received_fds: vec![],
})
} else {
#[cfg(feature = "p2p")]
match self.guid.take() {
None => {
Authenticated::client(stream, server_guid, self.auth_mechanism, is_bus_conn)
.await
}
Some(guid) => {
if !self.p2p {
return Err(Error::Unsupported);
}
let creds = stream.read_mut().peer_credentials().await?;
#[cfg(unix)]
let client_uid = creds.unix_user_id();
#[cfg(windows)]
let client_sid = creds.into_windows_sid();
Authenticated::server(
stream,
guid.to_owned().into(),
#[cfg(unix)]
client_uid,
#[cfg(windows)]
client_sid,
self.auth_mechanism,
unique_name,
)
.await
}
}
#[cfg(not(feature = "p2p"))]
Authenticated::client(stream, server_guid, self.auth_mechanism, is_bus_conn).await
}
}
async fn target_connect(&mut self) -> Result<(BoxedSplit, Option<OwnedGuid>, bool)> {
let mut authenticated = false;
let mut guid = None;
let split = match self.target.take().unwrap() {
#[cfg(not(feature = "tokio"))]
Target::UnixStream(stream) => Async::new(stream)?.into(),
#[cfg(all(unix, feature = "tokio"))]
Target::UnixStream(stream) => stream.into(),
#[cfg(not(feature = "tokio"))]
Target::TcpStream(stream) => Async::new(stream)?.into(),
#[cfg(feature = "tokio")]
Target::TcpStream(stream) => stream.into(),
#[cfg(all(feature = "vsock", not(feature = "tokio")))]
Target::VsockStream(stream) => Async::new(stream)?.into(),
#[cfg(feature = "tokio-vsock")]
Target::VsockStream(stream) => stream.into(),
Target::Address(address) => {
guid = address.guid().map(|g| g.to_owned().into());
match address.connect().await? {
#[cfg(any(unix, not(feature = "tokio")))]
address::transport::Stream::Unix(stream) => stream.into(),
#[cfg(unix)]
address::transport::Stream::Unixexec(stream) => stream.into(),
address::transport::Stream::Tcp(stream) => stream.into(),
#[cfg(any(
all(feature = "vsock", not(feature = "tokio")),
feature = "tokio-vsock"
))]
address::transport::Stream::Vsock(stream) => stream.into(),
}
}
Target::Socket(stream) => stream,
Target::AuthenticatedSocket(stream) => {
authenticated = true;
guid = self.guid.take().map(Into::into);
stream
}
};
Ok((split, guid, authenticated))
}
}
#[cfg(not(feature = "tokio"))]
fn start_internal_executor(executor: &Executor<'static>, internal_executor: bool) -> Result<()> {
if internal_executor {
let executor = executor.clone();
std::thread::Builder::new()
.name("zbus::Connection executor".into())
.spawn(move || {
crate::utils::block_on(async move {
while !executor.is_empty() {
executor.tick().await;
}
})
})?;
}
Ok(())
}