use std::{
collections::VecDeque,
fmt,
fs::File,
io::{BufRead, BufReader},
path::PathBuf,
str::FromStr,
};
use nix::{poll::PollFlags, unistd::Uid};
use crate::{
guid::Guid,
raw::{Connection, Socket},
utils::wait_on,
Error, Result,
};
#[derive(Clone, Copy, Debug, PartialEq)]
enum ClientHandshakeStep {
Init,
MechanismInit,
WaitingForData,
WaitingForOK,
WaitingForAgreeUnixFD,
Done,
}
pub enum IoOperation {
None,
Read,
Write,
}
#[derive(Clone, Copy, Debug)]
enum Mechanism {
External,
Cookie,
Anonymous,
}
#[derive(Debug)]
enum Command {
Auth(Option<Mechanism>, Option<String>),
Cancel,
Begin,
Data(Vec<u8>),
Error(String),
NegotiateUnixFD,
Rejected(Vec<Mechanism>),
Ok(Guid),
AgreeUnixFD,
}
#[derive(Debug)]
pub struct ClientHandshake<S> {
socket: S,
recv_buffer: Vec<u8>,
send_buffer: Vec<u8>,
step: ClientHandshakeStep,
server_guid: Option<Guid>,
cap_unix_fd: bool,
mechanisms: VecDeque<Mechanism>,
}
#[derive(Debug)]
pub struct Authenticated<S> {
pub(crate) conn: Connection<S>,
pub(crate) server_guid: Guid,
pub(crate) cap_unix_fd: bool,
}
pub trait Handshake<S> {
fn blocking_finish(self) -> Result<Authenticated<S>>;
fn next_io_operation(&self) -> IoOperation;
fn advance_handshake(&mut self) -> Result<()>;
fn try_finish(self) -> std::result::Result<Authenticated<S>, Self>
where
Self: Sized;
fn socket(&self) -> &S;
}
impl<S: Socket> ClientHandshake<S> {
pub fn new(socket: S) -> ClientHandshake<S> {
let mut mechanisms = VecDeque::new();
mechanisms.push_back(Mechanism::External);
mechanisms.push_back(Mechanism::Cookie);
ClientHandshake {
socket,
recv_buffer: Vec::new(),
send_buffer: Vec::new(),
step: ClientHandshakeStep::Init,
server_guid: None,
cap_unix_fd: false,
mechanisms,
}
}
fn flush_buffer(&mut self) -> Result<()> {
while !self.send_buffer.is_empty() {
let written = self.socket.sendmsg(&self.send_buffer, &[])?;
self.send_buffer.drain(..written);
}
Ok(())
}
fn read_command(&mut self) -> Result<Command> {
self.recv_buffer.clear(); while !self.recv_buffer.ends_with(b"\r\n") {
let mut buf = [0; 40];
let (read, fds) = self.socket.recvmsg(&mut buf)?;
if !fds.is_empty() {
return Err(Error::Handshake("Unexecpted FDs during handshake".into()));
}
self.recv_buffer.extend(&buf[..read]);
}
let line = String::from_utf8_lossy(&self.recv_buffer);
line.parse()
}
fn mechanism(&self) -> Result<&Mechanism> {
self.mechanisms
.front()
.ok_or_else(|| Error::Handshake("Exhausted available AUTH mechanisms".into()))
}
fn mechanism_init(&mut self) -> Result<(ClientHandshakeStep, Command)> {
use ClientHandshakeStep::*;
let mech = self.mechanism()?;
match mech {
Mechanism::External => Ok((
WaitingForOK,
Command::Auth(Some(*mech), Some(sasl_auth_id())),
)),
Mechanism::Cookie => Ok((
WaitingForData,
Command::Auth(Some(*mech), Some(sasl_auth_id())),
)),
_ => Err(Error::Handshake("Unimplemented AUTH mechanisms".into())),
}
}
fn mechanism_data(&mut self, data: Vec<u8>) -> Result<(ClientHandshakeStep, Command)> {
use ClientHandshakeStep::*;
let mech = self.mechanism()?;
match mech {
Mechanism::Cookie => {
let context = String::from_utf8_lossy(&data);
let mut split = context.split_ascii_whitespace();
let name = split
.next()
.ok_or_else(|| Error::Handshake("Missing cookie context name".into()))?;
let id = split
.next()
.ok_or_else(|| Error::Handshake("Missing cookie ID".into()))?;
let server_chall = split
.next()
.ok_or_else(|| Error::Handshake("Missing cookie challenge".into()))?;
let cookie = Cookie::lookup(name, id)?;
let client_chall = random_ascii(16);
let sec = format!("{}:{}:{}", server_chall, client_chall, cookie);
let sha1 = sha1::Sha1::from(sec).hexdigest();
let data = format!("{} {}", client_chall, sha1);
Ok((WaitingForOK, Command::Data(data.into())))
}
_ => Err(Error::Handshake("Unexpected mechanism DATA".into())),
}
}
}
fn random_ascii(len: usize) -> String {
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use std::iter;
let mut rng = thread_rng();
iter::repeat(())
.map(|()| rng.sample(Alphanumeric))
.map(char::from)
.take(len)
.collect()
}
fn sasl_auth_id() -> String {
Uid::current()
.to_string()
.chars()
.map(|c| format!("{:x}", c as u32))
.collect::<String>()
}
#[derive(Debug)]
struct Cookie {
id: String,
creation_time: String,
cookie: String,
}
impl Cookie {
fn keyring_path() -> Result<PathBuf> {
let home = std::env::var("HOME")
.map_err(|e| Error::Handshake(format!("Failed to read $HOME: {}", e)))?;
let mut path = PathBuf::new();
path.push(home);
path.push(".dbus-keyrings");
Ok(path)
}
fn read_keyring(name: &str) -> Result<Vec<Cookie>> {
use std::os::unix::fs::PermissionsExt;
let mut path = Cookie::keyring_path()?;
let perms = std::fs::metadata(&path)?.permissions().mode();
if perms & 0o066 != 0 {
return Err(Error::Handshake(
"DBus keyring has invalid permissions".into(),
));
}
path.push(name);
let file = File::open(&path)?;
let mut cookies = vec![];
for (n, line) in BufReader::new(file).lines().enumerate() {
let line = line?;
let mut split = line.split_whitespace();
let id = split
.next()
.ok_or_else(|| {
Error::Handshake(format!(
"DBus cookie `{}` missing ID at line {}",
path.to_str().unwrap(),
n
))
})?
.to_string();
let creation_time = split
.next()
.ok_or_else(|| {
Error::Handshake(format!(
"DBus cookie `{}` missing creation time at line {}",
path.to_str().unwrap(),
n
))
})?
.to_string();
let cookie = split
.next()
.ok_or_else(|| {
Error::Handshake(format!(
"DBus cookie `{}` missing cookie data at line {}",
path.to_str().unwrap(),
n
))
})?
.to_string();
cookies.push(Cookie {
id,
creation_time,
cookie,
})
}
Ok(cookies)
}
fn lookup(name: &str, id: &str) -> Result<String> {
let keyring = Self::read_keyring(name)?;
let c = keyring
.iter()
.find(|c| c.id == id)
.ok_or_else(|| Error::Handshake(format!("DBus cookie ID {} not found", id)))?;
Ok(c.cookie.to_string())
}
}
impl<S: Socket> Handshake<S> for ClientHandshake<S> {
fn blocking_finish(mut self) -> Result<Authenticated<S>> {
loop {
match self.advance_handshake() {
Ok(()) => return Ok(self.try_finish().unwrap_or_else(|_| unreachable!())),
Err(Error::Io(e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
let flags = match self.next_io_operation() {
IoOperation::Write => PollFlags::POLLOUT,
IoOperation::Read => PollFlags::POLLIN,
_ => unreachable!(),
};
wait_on(self.socket.as_raw_fd(), flags)?;
}
Err(e) => return Err(e),
}
}
}
fn next_io_operation(&self) -> IoOperation {
use ClientHandshakeStep::*;
if self.send_buffer.is_empty() {
match self.step {
WaitingForOK | WaitingForData | WaitingForAgreeUnixFD => IoOperation::Read,
Init | MechanismInit | Done => IoOperation::None,
}
} else {
IoOperation::Write
}
}
fn advance_handshake(&mut self) -> Result<()> {
use ClientHandshakeStep::*;
loop {
self.flush_buffer()?;
let (next_step, cmd) = match self.step {
Init | MechanismInit => self.mechanism_init()?,
WaitingForData | WaitingForOK => {
let reply = self.read_command()?;
match (self.step, reply) {
(_, Command::Data(data)) => self.mechanism_data(data)?,
(_, Command::Rejected(_)) => {
self.mechanisms.pop_front();
self.step = MechanismInit;
continue;
}
(WaitingForOK, Command::Ok(guid)) => {
self.server_guid = Some(guid);
(WaitingForAgreeUnixFD, Command::NegotiateUnixFD)
}
(_, reply) => {
return Err(Error::Handshake(format!(
"Unexpected server AUTH OK reply: {}",
reply
)))
}
}
}
WaitingForAgreeUnixFD => {
let reply = self.read_command()?;
match reply {
Command::AgreeUnixFD => self.cap_unix_fd = true,
Command::Error(_) => self.cap_unix_fd = false,
_ => {
return Err(Error::Handshake(format!(
"Unexpected server UNIX_FD reply: {}",
reply
)));
}
}
(Done, Command::Begin)
}
Done => return Ok(()),
};
self.send_buffer = if self.step == Init {
format!("\0{}", cmd).into()
} else {
cmd.into()
};
self.step = next_step;
}
}
fn try_finish(self) -> std::result::Result<Authenticated<S>, Self> {
if let ClientHandshakeStep::Done = self.step {
Ok(Authenticated {
conn: Connection::wrap(self.socket),
server_guid: self.server_guid.unwrap(),
cap_unix_fd: self.cap_unix_fd,
})
} else {
Err(self)
}
}
fn socket(&self) -> &S {
&self.socket
}
}
#[derive(Debug)]
enum ServerHandshakeStep {
WaitingForNull,
WaitingForAuth,
SendingAuthOK,
SendingAuthError,
WaitingForBegin,
SendingBeginMessage,
Done,
}
#[derive(Debug)]
pub struct ServerHandshake<S> {
socket: S,
buffer: Vec<u8>,
step: ServerHandshakeStep,
server_guid: Guid,
cap_unix_fd: bool,
client_uid: u32,
}
impl<S: Socket> ServerHandshake<S> {
pub fn new(socket: S, guid: Guid, client_uid: u32) -> ServerHandshake<S> {
ServerHandshake {
socket,
buffer: Vec::new(),
step: ServerHandshakeStep::WaitingForNull,
server_guid: guid,
cap_unix_fd: false,
client_uid,
}
}
fn flush_buffer(&mut self) -> Result<()> {
while !self.buffer.is_empty() {
let written = self.socket.sendmsg(&self.buffer, &[])?;
self.buffer.drain(..written);
}
Ok(())
}
fn read_command(&mut self) -> Result<()> {
while !self.buffer.ends_with(b"\r\n") {
let mut buf = [0; 40];
let (read, _) = self.socket.recvmsg(&mut buf)?;
self.buffer.extend(&buf[..read]);
}
Ok(())
}
}
impl<S: Socket> Handshake<S> for ServerHandshake<S> {
fn blocking_finish(mut self) -> Result<Authenticated<S>> {
loop {
match self.advance_handshake() {
Ok(()) => return Ok(self.try_finish().unwrap_or_else(|_| unreachable!())),
Err(Error::Io(e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
let flags = match self.step {
ServerHandshakeStep::SendingAuthError
| ServerHandshakeStep::SendingAuthOK
| ServerHandshakeStep::SendingBeginMessage => PollFlags::POLLOUT,
ServerHandshakeStep::WaitingForNull
| ServerHandshakeStep::WaitingForBegin
| ServerHandshakeStep::WaitingForAuth => PollFlags::POLLIN,
ServerHandshakeStep::Done => unreachable!(),
};
wait_on(self.socket.as_raw_fd(), flags)?;
}
Err(e) => return Err(e),
}
}
}
fn next_io_operation(&self) -> IoOperation {
match self.step {
ServerHandshakeStep::Done => IoOperation::None,
ServerHandshakeStep::WaitingForNull
| ServerHandshakeStep::WaitingForAuth
| ServerHandshakeStep::WaitingForBegin => IoOperation::Read,
ServerHandshakeStep::SendingAuthOK
| ServerHandshakeStep::SendingAuthError
| ServerHandshakeStep::SendingBeginMessage => IoOperation::Write,
}
}
fn advance_handshake(&mut self) -> Result<()> {
loop {
match self.step {
ServerHandshakeStep::WaitingForNull => {
let mut buffer = [0; 1];
let (read, _) = self.socket.recvmsg(&mut buffer)?;
debug_assert!(read == 1);
if buffer[0] != 0 {
return Err(Error::Handshake(
"First client byte is not NUL!".to_string(),
));
}
self.step = ServerHandshakeStep::WaitingForAuth;
}
ServerHandshakeStep::WaitingForAuth => {
self.read_command()?;
let mut reply = String::new();
(&self.buffer[..]).read_line(&mut reply)?;
let mut words = reply.split_whitespace();
match (words.next(), words.next(), words.next(), words.next()) {
(Some("AUTH"), Some("EXTERNAL"), Some(uid), None) => {
let uid = id_from_str(uid)
.map_err(|e| Error::Handshake(format!("Invalid UID: {}", e)))?;
if uid == self.client_uid {
self.buffer = format!("OK {}\r\n", self.server_guid).into();
self.step = ServerHandshakeStep::SendingAuthOK;
} else {
self.buffer = Vec::from(&b"REJECTED EXTERNAL\r\n"[..]);
self.step = ServerHandshakeStep::SendingAuthError;
}
}
(Some("AUTH"), _, _, _) | (Some("ERROR"), _, _, _) => {
self.buffer = Vec::from(&b"REJECTED EXTERNAL\r\n"[..]);
self.step = ServerHandshakeStep::SendingAuthError;
}
(Some("BEGIN"), None, None, None) => {
return Err(Error::Handshake(
"Received BEGIN while not authenticated".to_string(),
));
}
_ => {
self.buffer = Vec::from(&b"ERROR Unsupported command\r\n"[..]);
self.step = ServerHandshakeStep::SendingAuthError;
}
}
}
ServerHandshakeStep::SendingAuthError => {
self.flush_buffer()?;
self.step = ServerHandshakeStep::WaitingForAuth;
}
ServerHandshakeStep::SendingAuthOK => {
self.flush_buffer()?;
self.step = ServerHandshakeStep::WaitingForBegin;
}
ServerHandshakeStep::WaitingForBegin => {
self.read_command()?;
let mut reply = String::new();
(&self.buffer[..]).read_line(&mut reply)?;
let mut words = reply.split_whitespace();
match (words.next(), words.next()) {
(Some("BEGIN"), None) => {
self.step = ServerHandshakeStep::Done;
}
(Some("CANCEL"), None) => {
self.buffer = Vec::from(&b"REJECTED EXTERNAL\r\n"[..]);
self.step = ServerHandshakeStep::SendingAuthError;
}
(Some("ERROR"), _) => {
self.buffer = Vec::from(&b"REJECTED EXTERNAL\r\n"[..]);
self.step = ServerHandshakeStep::SendingAuthError;
}
(Some("NEGOTIATE_UNIX_FD"), None) => {
self.cap_unix_fd = true;
self.buffer = Vec::from(&b"AGREE_UNIX_FD\r\n"[..]);
self.step = ServerHandshakeStep::SendingBeginMessage;
}
_ => {
self.buffer = Vec::from(&b"ERROR Unsupported command\r\n"[..]);
self.step = ServerHandshakeStep::SendingBeginMessage;
}
}
}
ServerHandshakeStep::SendingBeginMessage => {
self.flush_buffer()?;
self.step = ServerHandshakeStep::WaitingForBegin;
}
ServerHandshakeStep::Done => return Ok(()),
}
}
}
fn try_finish(self) -> std::result::Result<Authenticated<S>, Self> {
if let ServerHandshakeStep::Done = self.step {
Ok(Authenticated {
conn: Connection::wrap(self.socket),
server_guid: self.server_guid,
cap_unix_fd: self.cap_unix_fd,
})
} else {
Err(self)
}
}
fn socket(&self) -> &S {
&self.socket
}
}
fn id_from_str(s: &str) -> std::result::Result<u32, Box<dyn std::error::Error>> {
let mut id = String::new();
for s in s.as_bytes().chunks(2) {
let c = char::from(u8::from_str_radix(std::str::from_utf8(s)?, 16)?);
id.push(c);
}
Ok(id.parse::<u32>()?)
}
impl fmt::Display for Mechanism {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mech = match self {
Mechanism::External => "EXTERNAL",
Mechanism::Cookie => "DBUS_COOKIE_SHA1",
Mechanism::Anonymous => "ANONYMOUS",
};
write!(f, "{}", mech)
}
}
impl FromStr for Mechanism {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
match s {
"EXTERNAL" => Ok(Mechanism::External),
"DBUS_COOKIE_SHA1" => Ok(Mechanism::Cookie),
"ANONYMOUS" => Ok(Mechanism::Anonymous),
_ => Err(Error::Handshake(format!("Unknown mechanism: {}", s))),
}
}
}
impl From<Command> for Vec<u8> {
fn from(c: Command) -> Self {
c.to_string().into()
}
}
impl fmt::Display for Command {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let cmd = match self {
Command::Auth(mech, resp) => match (mech, resp) {
(Some(mech), Some(resp)) => format!("AUTH {} {}", mech, resp),
(Some(mech), None) => format!("AUTH {}", mech),
_ => "AUTH".into(),
},
Command::Cancel => "CANCEL".into(),
Command::Begin => "BEGIN".into(),
Command::Data(data) => {
format!("DATA {}", hex::encode(data))
}
Command::Error(expl) => {
format!("ERROR {}", expl)
}
Command::NegotiateUnixFD => "NEGOTIATE_UNIX_FD".into(),
Command::Rejected(mechs) => {
format!(
"REJECTED {}",
mechs
.iter()
.map(|m| m.to_string())
.collect::<Vec<_>>()
.join(" ")
)
}
Command::Ok(guid) => {
format!("OK {}", guid)
}
Command::AgreeUnixFD => "AGREE_UNIX_FD".into(),
};
write!(f, "{}\r\n", cmd)
}
}
impl From<hex::FromHexError> for Error {
fn from(e: hex::FromHexError) -> Self {
Error::Handshake(format!("Invalid hexcode: {}", e))
}
}
impl FromStr for Command {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
let mut words = s.split_ascii_whitespace();
let cmd = match words.next() {
Some("AUTH") => {
let mech = if let Some(m) = words.next() {
Some(m.parse()?)
} else {
None
};
let resp = words.next().map(|s| s.into());
Command::Auth(mech, resp)
}
Some("CANCEL") => Command::Cancel,
Some("BEGIN") => Command::Begin,
Some("DATA") => {
let data = words
.next()
.ok_or_else(|| Error::Handshake("Missing DATA data".into()))?;
Command::Data(hex::decode(data)?)
}
Some("ERROR") => Command::Error(s.into()),
Some("NEGOTIATE_UNIX_FD") => Command::NegotiateUnixFD,
Some("REJECTED") => {
let mechs = words.map(|m| m.parse()).collect::<Result<_>>()?;
Command::Rejected(mechs)
}
Some("OK") => {
let guid = words
.next()
.ok_or_else(|| Error::Handshake("Missing OK server GUID!".into()))?;
Command::Ok(guid.parse()?)
}
Some("AGREE_UNIX_FD") => Command::AgreeUnixFD,
_ => return Err(Error::Handshake(format!("Unknown command: {}", s))),
};
Ok(cmd)
}
}
#[cfg(test)]
mod tests {
use std::os::unix::net::UnixStream;
use super::*;
use crate::Guid;
#[test]
fn handshake() {
let (p0, p1) = UnixStream::pair().unwrap();
p0.set_nonblocking(true).unwrap();
p1.set_nonblocking(true).unwrap();
let mut client = ClientHandshake::new(p0);
let mut server = ServerHandshake::new(p1, Guid::generate(), Uid::current().into());
let mut client_done = false;
let mut server_done = false;
while !(client_done && server_done) {
match client.advance_handshake() {
Ok(()) => client_done = true,
Err(Error::Io(e)) => assert!(e.kind() == std::io::ErrorKind::WouldBlock),
Err(e) => panic!("Unexpected error: {:?}", e),
}
match server.advance_handshake() {
Ok(()) => server_done = true,
Err(Error::Io(e)) => assert!(e.kind() == std::io::ErrorKind::WouldBlock),
Err(e) => panic!("Unexpected error: {:?}", e),
}
}
let client = client.try_finish().unwrap();
let server = server.try_finish().unwrap();
assert_eq!(client.server_guid, server.server_guid);
assert_eq!(client.cap_unix_fd, server.cap_unix_fd);
}
}