use async_trait::async_trait;
use tracing::{instrument, trace, warn};
use crate::{Message, conn::socket::ReadHalf, is_flatpak, names::OwnedUniqueName};
use super::{
AuthMechanism, Authenticated, BoxedSplit, Command, Common, Error, Handshake, OwnedGuid, Result,
sasl_auth_id,
};
#[derive(Debug)]
pub struct Client {
common: Common,
server_guid: Option<OwnedGuid>,
bus: bool,
user_id: Result<String>,
}
impl Client {
pub fn new(
socket: BoxedSplit,
mechanism: Option<AuthMechanism>,
server_guid: Option<OwnedGuid>,
bus: bool,
user_id: Option<u32>,
) -> Client {
let mechanism = mechanism.unwrap_or_else(|| socket.read().auth_mechanism());
Client {
common: Common::new(socket, mechanism),
server_guid,
bus,
user_id: match user_id {
Some(value) => Ok(value.to_string()),
None => sasl_auth_id(),
},
}
}
fn set_guid(&mut self, guid: OwnedGuid) -> Result<()> {
match &self.server_guid {
Some(server_guid) if *server_guid != guid => {
return Err(Error::Handshake(format!(
"Server GUID mismatch: expected {server_guid}, got {guid}",
)));
}
Some(_) => (),
None => self.server_guid = Some(guid),
}
Ok(())
}
#[instrument(skip(self), level = "trace")]
#[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
async fn send_zero_byte(&mut self) -> Result<()> {
let write = self.common.socket_mut().write_mut();
let written = match write.send_zero_byte().await.map_err(|e| {
Error::Handshake(format!("Could not send zero byte with credentials: {e}"))
})? {
None => write.sendmsg(&[0], &[]).await?,
Some(n) => n,
};
if written != 1 {
return Err(Error::Handshake(
"Could not send zero byte with credentials".to_string(),
));
}
Ok(())
}
#[instrument(skip(self), level = "trace")]
async fn authenticate(&mut self) -> Result<()> {
let mechanism = self.common.mechanism();
trace!("Trying {mechanism} mechanism");
let user_id = self.user_id.clone();
let auth_cmd = match mechanism {
AuthMechanism::Anonymous => Command::Auth(Some(mechanism), Some("zbus".into())),
AuthMechanism::External => Command::Auth(Some(mechanism), Some(user_id?.into_bytes())),
};
self.common.write_command(auth_cmd).await?;
match self.common.read_command().await? {
Command::Ok(guid) => {
trace!("Received OK from server");
self.set_guid(guid)?;
Ok(())
}
Command::Rejected(accepted) => {
let list = accepted.replace(" ", ", ");
Err(Error::Handshake(format!(
"{mechanism} rejected by the server. Accepted mechanisms: [{list}]"
)))
}
Command::Error(e) => Err(Error::Handshake(format!("Received error from server: {e}"))),
cmd => Err(Error::Handshake(format!(
"Unexpected command from server: {cmd}"
))),
}
}
#[instrument(skip(self), level = "trace")]
async fn send_secondary_commands(&mut self) -> Result<usize> {
let mut commands = Vec::with_capacity(4);
let can_pass_fd = self.common.socket_mut().read_mut().can_pass_unix_fd();
if can_pass_fd {
if is_flatpak() {
self.common.write_command(Command::NegotiateUnixFD).await?;
match self.common.read_command().await? {
Command::AgreeUnixFD => self.common.set_cap_unix_fd(true),
Command::Error(e) => warn!("UNIX file descriptor passing rejected: {e}"),
cmd => {
return Err(Error::Handshake(format!(
"Unexpected command from server: {cmd}"
)));
}
}
} else {
commands.push(Command::NegotiateUnixFD);
}
};
commands.push(Command::Begin);
let hello_method = if self.bus {
Some(create_hello_method_call())
} else {
None
};
self.common
.write_commands(&commands, hello_method.as_ref().map(|m| &**m.data()))
.await?;
Ok(commands.len() - 1)
}
#[instrument(skip(self), level = "trace")]
async fn receive_secondary_responses(&mut self, expected_n_responses: usize) -> Result<()> {
for response in self.common.read_commands(expected_n_responses).await? {
match response {
Command::Ok(guid) => {
trace!("Received OK from server");
self.set_guid(guid)?;
}
Command::AgreeUnixFD => self.common.set_cap_unix_fd(true),
Command::Error(e) => warn!("UNIX file descriptor passing rejected: {e}"),
cmd => {
return Err(Error::Handshake(format!(
"Unexpected command from server: {cmd}"
)));
}
}
}
Ok(())
}
}
#[async_trait]
impl Handshake for Client {
#[instrument(skip(self), level = "trace")]
async fn perform(mut self) -> Result<Authenticated> {
trace!("Initializing");
#[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
self.send_zero_byte().await?;
self.authenticate().await?;
let expected_n_responses = self.send_secondary_commands().await?;
if expected_n_responses > 0 {
self.receive_secondary_responses(expected_n_responses)
.await?;
}
trace!("Handshake done");
#[cfg(unix)]
let (socket, mut recv_buffer, received_fds, cap_unix_fd, _) = self.common.into_components();
#[cfg(not(unix))]
let (socket, mut recv_buffer, _, _) = self.common.into_components();
let (mut read, write) = socket.take();
let unique_name = if self.bus {
let unique_name = receive_hello_response(&mut read, &mut recv_buffer).await?;
Some(unique_name)
} else {
None
};
Ok(Authenticated {
socket_write: write,
socket_read: Some(read),
server_guid: self.server_guid.unwrap(),
#[cfg(unix)]
cap_unix_fd,
already_received_bytes: recv_buffer,
#[cfg(unix)]
already_received_fds: received_fds,
unique_name,
})
}
}
fn create_hello_method_call() -> Message {
Message::method_call("/org/freedesktop/DBus", "Hello")
.unwrap()
.destination("org.freedesktop.DBus")
.unwrap()
.interface("org.freedesktop.DBus")
.unwrap()
.build(&())
.unwrap()
}
async fn receive_hello_response(
read: &mut Box<dyn ReadHalf>,
recv_buffer: &mut Vec<u8>,
) -> Result<OwnedUniqueName> {
use crate::message::Type;
let reply = read
.receive_message(
0,
recv_buffer,
#[cfg(unix)]
&mut vec![],
)
.await?;
match reply.message_type() {
Type::MethodReturn => reply.body().deserialize(),
Type::Error => Err(Error::from(reply)),
m => Err(Error::Handshake(format!("Unexpected message `{m:?}`"))),
}
}