#[cfg(not(feature = "tokio"))]
use async_io::Async;
use event_listener::Event;
use static_assertions::assert_impl_all;
#[cfg(not(feature = "tokio"))]
use std::net::TcpStream;
#[cfg(all(unix, not(feature = "tokio")))]
use std::os::unix::net::UnixStream;
use std::{
collections::{HashMap, HashSet, VecDeque},
convert::TryInto,
sync::Arc,
};
#[cfg(feature = "tokio")]
use tokio::net::TcpStream;
#[cfg(all(unix, feature = "tokio"))]
use tokio::net::UnixStream;
#[cfg(feature = "tokio-vsock")]
use tokio_vsock::VsockStream;
#[cfg(windows)]
use uds_windows::UnixStream;
#[cfg(all(feature = "vsock", not(feature = "tokio")))]
use vsock::VsockStream;
use zvariant::ObjectPath;
use crate::{
address::{self, Address},
async_lock::RwLock,
names::{InterfaceName, UniqueName, WellKnownName},
raw::Socket,
AuthMechanism, Authenticated, Connection, Error, Guid, Interface, Result,
};
const DEFAULT_MAX_QUEUED: usize = 64;
#[derive(Debug)]
enum Target {
UnixStream(UnixStream),
TcpStream(TcpStream),
#[cfg(any(
all(feature = "vsock", not(feature = "tokio")),
feature = "tokio-vsock"
))]
VsockStream(VsockStream),
Address(Address),
Socket(Box<dyn Socket>),
}
type Interfaces<'a> =
HashMap<ObjectPath<'a>, HashMap<InterfaceName<'static>, Arc<RwLock<dyn Interface>>>>;
#[derive(derivative::Derivative)]
#[derivative(Debug)]
pub struct ConnectionBuilder<'a> {
target: Target,
max_queued: Option<usize>,
guid: Option<&'a Guid>,
p2p: bool,
internal_executor: bool,
#[derivative(Debug = "ignore")]
interfaces: Interfaces<'a>,
names: HashSet<WellKnownName<'a>>,
auth_mechanisms: Option<VecDeque<AuthMechanism>>,
unique_name: Option<UniqueName<'a>>,
}
assert_impl_all!(ConnectionBuilder<'_>: Send, Sync, Unpin);
impl<'a> ConnectionBuilder<'a> {
pub fn session() -> Result<Self> {
Ok(Self::new(Target::Address(Address::session()?)))
}
pub fn system() -> Result<Self> {
Ok(Self::new(Target::Address(Address::system()?)))
}
pub fn address<A>(address: A) -> Result<Self>
where
A: TryInto<Address>,
A::Error: Into<Error>,
{
Ok(Self::new(Target::Address(
address.try_into().map_err(Into::into)?,
)))
}
#[must_use]
pub fn unix_stream(stream: UnixStream) -> Self {
Self::new(Target::UnixStream(stream))
}
#[must_use]
pub fn tcp_stream(stream: TcpStream) -> Self {
Self::new(Target::TcpStream(stream))
}
#[cfg(any(
all(feature = "vsock", not(feature = "tokio")),
feature = "tokio-vsock"
))]
#[must_use]
pub fn vsock_stream(stream: VsockStream) -> Self {
Self::new(Target::VsockStream(stream))
}
#[must_use]
pub fn socket<S: Socket + 'static>(socket: S) -> Self {
Self::new(Target::Socket(Box::new(socket)))
}
#[must_use]
pub fn auth_mechanisms(mut self, auth_mechanisms: &[AuthMechanism]) -> Self {
self.auth_mechanisms = Some(VecDeque::from(auth_mechanisms.to_vec()));
self
}
#[must_use]
pub fn p2p(mut self) -> Self {
self.p2p = true;
self
}
#[must_use]
pub fn server(mut self, guid: &'a Guid) -> Self {
self.guid = Some(guid);
self
}
#[must_use]
pub fn max_queued(mut self, max: usize) -> Self {
self.max_queued = Some(max);
self
}
#[must_use]
pub fn internal_executor(mut self, enabled: bool) -> Self {
self.internal_executor = enabled;
self
}
pub fn serve_at<P, I>(mut self, path: P, iface: I) -> Result<Self>
where
I: Interface,
P: TryInto<ObjectPath<'a>>,
P::Error: Into<Error>,
{
let path = path.try_into().map_err(Into::into)?;
let entry = self.interfaces.entry(path).or_default();
entry.insert(I::name(), Arc::new(RwLock::new(iface)));
Ok(self)
}
pub fn name<W>(mut self, well_known_name: W) -> Result<Self>
where
W: TryInto<WellKnownName<'a>>,
W::Error: Into<Error>,
{
let well_known_name = well_known_name.try_into().map_err(Into::into)?;
self.names.insert(well_known_name);
Ok(self)
}
pub fn unique_name<U>(mut self, unique_name: U) -> Result<Self>
where
U: TryInto<UniqueName<'a>>,
U::Error: Into<Error>,
{
if !self.p2p {
panic!("unique name can only be set for peer-to-peer connections");
}
let name = unique_name.try_into().map_err(Into::into)?;
self.unique_name = Some(name);
Ok(self)
}
pub async fn build(self) -> Result<Connection> {
let stream = match self.target {
#[cfg(all(not(feature = "tokio")))]
Target::UnixStream(stream) => Box::new(Async::new(stream)?) as Box<dyn Socket>,
#[cfg(all(unix, feature = "tokio"))]
Target::UnixStream(stream) => Box::new(stream) as Box<dyn Socket>,
#[cfg(all(not(unix), feature = "tokio"))]
Target::UnixStream(_) => return Err(Error::Unsupported),
#[cfg(not(feature = "tokio"))]
Target::TcpStream(stream) => Box::new(Async::new(stream)?) as Box<dyn Socket>,
#[cfg(feature = "tokio")]
Target::TcpStream(stream) => Box::new(stream) as Box<dyn Socket>,
#[cfg(all(feature = "vsock", not(feature = "tokio")))]
Target::VsockStream(stream) => Box::new(Async::new(stream)?) as Box<dyn Socket>,
#[cfg(feature = "tokio-vsock")]
Target::VsockStream(stream) => Box::new(stream) as Box<dyn Socket>,
Target::Address(address) => match address.connect().await? {
#[cfg(any(unix, not(feature = "tokio")))]
address::Stream::Unix(stream) => Box::new(stream) as Box<dyn Socket>,
address::Stream::Tcp(stream) => Box::new(stream) as Box<dyn Socket>,
#[cfg(any(
all(feature = "vsock", not(feature = "tokio")),
feature = "tokio-vsock"
))]
address::Stream::Vsock(stream) => Box::new(stream) as Box<dyn Socket>,
},
Target::Socket(stream) => stream,
};
let auth = match self.guid {
None => {
Authenticated::client(stream, self.auth_mechanisms).await?
}
Some(guid) => {
if !self.p2p {
return Err(Error::Unsupported);
}
#[cfg(unix)]
let client_uid = stream.uid()?;
#[cfg(windows)]
let client_sid = stream.peer_sid();
Authenticated::server(
stream,
guid.clone(),
#[cfg(unix)]
client_uid,
#[cfg(windows)]
client_sid,
self.auth_mechanisms,
)
.await?
}
};
let mut conn = Connection::new(auth, !self.p2p).await?;
conn.set_max_queued(self.max_queued.unwrap_or(DEFAULT_MAX_QUEUED));
if let Some(unique_name) = self.unique_name {
conn.set_unique_name(unique_name)?;
}
if !self.interfaces.is_empty() {
let object_server = conn.sync_object_server(false, None);
for (path, interfaces) in self.interfaces {
for (name, iface) in interfaces {
let future = object_server.at_ready(path.to_owned(), name, || iface);
let added = conn.run_future_at_init(future).await?;
assert!(added);
}
}
let started_event = Event::new();
let listener = started_event.listen();
conn.start_object_server(Some(started_event));
#[cfg(not(feature = "tokio"))]
start_internal_executor(&conn, self.internal_executor)?;
listener.await;
conn.init_socket_reader();
} else {
conn.init_socket_reader();
#[cfg(not(feature = "tokio"))]
start_internal_executor(&conn, self.internal_executor)?;
}
if !self.p2p {
let future = conn.hello_bus();
conn.run_future_at_init(future).await?;
}
for name in self.names {
let future = conn.request_name(name);
conn.run_future_at_init(future).await?;
}
Ok(conn)
}
fn new(target: Target) -> Self {
Self {
target,
p2p: false,
max_queued: None,
guid: None,
internal_executor: true,
interfaces: HashMap::new(),
names: HashSet::new(),
auth_mechanisms: None,
unique_name: None,
}
}
}
#[cfg(not(feature = "tokio"))]
fn start_internal_executor(conn: &Connection, internal_executor: bool) -> Result<()> {
if internal_executor {
let executor = conn.executor().clone();
std::thread::Builder::new()
.name("zbus::Connection executor".into())
.spawn(move || {
crate::utils::block_on(async move {
while !executor.is_empty() {
executor.tick().await;
}
})
})?;
}
Ok(())
}