#![cfg_attr(not(feature = "blocking"), allow(dead_code))]
use std::io::{self, Read, Write};
use std::net::{TcpStream, ToSocketAddrs};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crate::error::{Error, Result};
use crate::xdr::{Decoder, Encode, Encoder};
pub const RPC_VERSION: u32 = 2;
pub const AUTH_NONE: u32 = 0;
pub const AUTH_SYS: u32 = 1;
pub const AUTH_SYS_MAX_GROUPS: usize = 16;
const MSG_CALL: u32 = 0;
const MSG_REPLY: u32 = 1;
const REPLY_ACCEPTED: u32 = 0;
const REPLY_DENIED: u32 = 1;
const ACCEPT_SUCCESS: u32 = 0;
const ACCEPT_PROG_MISMATCH: u32 = 2;
const REJECT_RPC_MISMATCH: u32 = 0;
const REJECT_AUTH_ERROR: u32 = 1;
pub(crate) const LAST_FRAGMENT: u32 = 0x8000_0000;
pub(crate) const FRAGMENT_LEN_MASK: u32 = 0x7fff_ffff;
pub(crate) const DEFAULT_MAX_RECORD_SIZE: usize = 64 * 1024 * 1024;
const MAX_RECORD_HEADROOM: usize = 1024 * 1024;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Auth {
None,
Sys(AuthSys),
}
impl Auth {
pub fn none() -> Self {
Self::None
}
pub fn sys(auth: AuthSys) -> Self {
Self::Sys(auth)
}
pub(crate) fn encode_opaque_auth(&self, encoder: &mut Encoder) -> Result<()> {
match self {
Self::None => {
encoder.write_u32(AUTH_NONE);
encoder.write_opaque(&[], 400)?;
}
Self::Sys(auth) => {
let body = crate::xdr::to_bytes(auth)?;
encoder.write_u32(AUTH_SYS);
encoder.write_opaque(&body, 400)?;
}
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AuthSys {
pub stamp: u32,
pub machine_name: String,
pub uid: u32,
pub gid: u32,
pub gids: Vec<u32>,
}
impl AuthSys {
pub fn new(machine_name: impl Into<String>, uid: u32, gid: u32, gids: Vec<u32>) -> Self {
Self {
stamp: default_stamp(),
machine_name: machine_name.into(),
uid,
gid,
gids,
}
}
pub fn current() -> Self {
let machine_name = std::env::var("HOSTNAME").unwrap_or_else(|_| "localhost".to_owned());
let gid = current_gid();
Self::new(
machine_name,
current_uid(),
gid,
current_auxiliary_gids(gid),
)
}
}
impl Default for AuthSys {
fn default() -> Self {
Self::current()
}
}
impl Encode for AuthSys {
fn encode(&self, encoder: &mut Encoder) -> crate::xdr::Result<()> {
encoder.write_u32(self.stamp);
encoder.write_string(&self.machine_name, 255)?;
encoder.write_u32(self.uid);
encoder.write_u32(self.gid);
encoder.write_array(&self.gids, AUTH_SYS_MAX_GROUPS)?;
Ok(())
}
}
#[derive(Debug)]
pub struct RpcClient {
stream: TcpStream,
xid: u32,
auth: Auth,
max_record_size: usize,
}
impl RpcClient {
pub fn connect_with_timeout<A: ToSocketAddrs>(
addr: A,
auth: Auth,
timeout: Option<Duration>,
) -> Result<Self> {
let stream = connect_tcp_stream(addr, timeout)?;
let client = Self::new(stream, auth)?;
client.set_timeout(timeout)?;
Ok(client)
}
pub fn new(stream: TcpStream, auth: Auth) -> Result<Self> {
stream.set_nodelay(true)?;
Ok(Self {
stream,
xid: default_stamp(),
auth,
max_record_size: DEFAULT_MAX_RECORD_SIZE,
})
}
pub fn set_timeout(&self, timeout: Option<Duration>) -> Result<()> {
self.stream.set_read_timeout(timeout)?;
self.stream.set_write_timeout(timeout)?;
Ok(())
}
pub fn set_max_record_size(&mut self, max_record_size: usize) {
self.max_record_size = max_record_size;
}
pub fn call<T: Encode + ?Sized>(
&mut self,
program: u32,
version: u32,
procedure: u32,
args: &T,
) -> Result<Vec<u8>> {
let xid = self.next_xid();
let request = encode_call(xid, program, version, procedure, &self.auth, args)?;
self.write_record(&request)?;
let reply = self.read_record()?;
decode_reply(xid, &reply)
}
fn next_xid(&mut self) -> u32 {
self.xid = self.xid.wrapping_add(1);
if self.xid == 0 {
self.xid = 1;
}
self.xid
}
fn write_record(&mut self, payload: &[u8]) -> Result<()> {
if payload.len() > FRAGMENT_LEN_MASK as usize {
return Err(Error::RpcRecordTooLarge {
len: payload.len(),
max: FRAGMENT_LEN_MASK as usize,
});
}
let len = u32::try_from(payload.len()).map_err(|_| Error::RpcRecordTooLarge {
len: payload.len(),
max: FRAGMENT_LEN_MASK as usize,
})?;
let header = LAST_FRAGMENT | len;
self.stream.write_all(&header.to_be_bytes())?;
self.stream.write_all(payload)?;
self.stream.flush()?;
Ok(())
}
fn read_record(&mut self) -> Result<Vec<u8>> {
let mut record = Vec::new();
loop {
let mut header_bytes = [0; 4];
self.stream.read_exact(&mut header_bytes)?;
let header = u32::from_be_bytes(header_bytes);
let is_last = (header & LAST_FRAGMENT) != 0;
let fragment_len = (header & FRAGMENT_LEN_MASK) as usize;
if record.len().saturating_add(fragment_len) > self.max_record_size {
return Err(Error::RpcRecordTooLarge {
len: record.len().saturating_add(fragment_len),
max: self.max_record_size,
});
}
let start = record.len();
record.resize(start + fragment_len, 0);
self.stream.read_exact(&mut record[start..])?;
if is_last {
return Ok(record);
}
}
}
}
fn connect_tcp_stream<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> Result<TcpStream> {
let Some(timeout) = timeout else {
return Ok(TcpStream::connect(addr)?);
};
let mut last_error = None;
for socket_addr in addr.to_socket_addrs()? {
match TcpStream::connect_timeout(&socket_addr, timeout) {
Ok(stream) => return Ok(stream),
Err(err) => last_error = Some(err),
}
}
Err(last_error
.unwrap_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "no socket address resolved")
})
.into())
}
pub(crate) fn encode_call<T: Encode + ?Sized>(
xid: u32,
program: u32,
version: u32,
procedure: u32,
auth: &Auth,
args: &T,
) -> Result<Vec<u8>> {
let mut encoder = Encoder::new();
encoder.write_u32(xid);
encoder.write_u32(MSG_CALL);
encoder.write_u32(RPC_VERSION);
encoder.write_u32(program);
encoder.write_u32(version);
encoder.write_u32(procedure);
auth.encode_opaque_auth(&mut encoder)?;
Auth::None.encode_opaque_auth(&mut encoder)?;
args.encode(&mut encoder)?;
Ok(encoder.into_bytes())
}
pub(crate) fn decode_reply(expected_xid: u32, reply: &[u8]) -> Result<Vec<u8>> {
let mut decoder = Decoder::new(reply);
let actual_xid = decoder.read_u32()?;
if actual_xid != expected_xid {
return Err(Error::RpcMismatch {
expected: expected_xid,
actual: actual_xid,
});
}
let message_type = decoder.read_u32()?;
if message_type != MSG_REPLY {
return Err(Error::RpcUnexpectedMessageType(message_type));
}
match decoder.read_u32()? {
REPLY_ACCEPTED => decode_accepted_reply(&mut decoder),
REPLY_DENIED => decode_denied_reply(&mut decoder),
value => Err(Error::RpcDenied {
reject_stat: value,
detail: 0,
}),
}
}
fn decode_accepted_reply(decoder: &mut Decoder<'_>) -> Result<Vec<u8>> {
read_opaque_auth(decoder)?;
match decoder.read_u32()? {
ACCEPT_SUCCESS => {
let remaining = decoder.remaining();
let payload = decoder.read_fixed_opaque_unpadded(remaining)?;
Ok(payload.to_vec())
}
ACCEPT_PROG_MISMATCH => {
let low = decoder.read_u32()?;
let high = decoder.read_u32()?;
Err(Error::RpcProgramMismatch { low, high })
}
accept_stat => Err(Error::RpcAcceptedError { accept_stat }),
}
}
fn decode_denied_reply(decoder: &mut Decoder<'_>) -> Result<Vec<u8>> {
match decoder.read_u32()? {
REJECT_RPC_MISMATCH => {
let low = decoder.read_u32()?;
let high = decoder.read_u32()?;
Err(Error::RpcProgramMismatch { low, high })
}
REJECT_AUTH_ERROR => {
let auth_stat = decoder.read_u32()?;
Err(Error::RpcDenied {
reject_stat: REJECT_AUTH_ERROR,
detail: auth_stat,
})
}
reject_stat => Err(Error::RpcDenied {
reject_stat,
detail: 0,
}),
}
}
pub(crate) fn read_opaque_auth(decoder: &mut Decoder<'_>) -> Result<(u32, Vec<u8>)> {
let flavor = decoder.read_u32()?;
let body = decoder.read_opaque(400)?.to_vec();
Ok((flavor, body))
}
pub(crate) fn max_record_size_for_payloads(payload_sizes: &[u32]) -> usize {
let configured = payload_sizes
.iter()
.map(|size| *size as usize)
.max()
.unwrap_or(0)
.saturating_add(MAX_RECORD_HEADROOM);
configured.max(DEFAULT_MAX_RECORD_SIZE)
}
pub(crate) fn default_stamp() -> u32 {
let duration = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
(duration.as_secs() as u32) ^ duration.subsec_nanos()
}
#[cfg(unix)]
fn current_uid() -> u32 {
unsafe { libc::geteuid() as u32 }
}
#[cfg(not(unix))]
fn current_uid() -> u32 {
0
}
#[cfg(unix)]
fn current_gid() -> u32 {
unsafe { libc::getegid() as u32 }
}
#[cfg(not(unix))]
fn current_gid() -> u32 {
0
}
#[cfg(unix)]
fn current_auxiliary_gids(primary_gid: u32) -> Vec<u32> {
let count = unsafe { libc::getgroups(0, std::ptr::null_mut()) };
if count <= 0 {
return Vec::new();
}
let mut groups = vec![0 as libc::gid_t; count as usize];
let count = unsafe { libc::getgroups(groups.len() as libc::c_int, groups.as_mut_ptr()) };
if count <= 0 {
return Vec::new();
}
groups.truncate(count as usize);
normalize_auxiliary_gids(primary_gid, groups)
}
#[cfg(not(unix))]
fn current_auxiliary_gids(_primary_gid: u32) -> Vec<u32> {
Vec::new()
}
fn normalize_auxiliary_gids(primary_gid: u32, groups: impl IntoIterator<Item = u32>) -> Vec<u32> {
let mut normalized = Vec::new();
for gid in groups {
if gid == primary_gid || normalized.contains(&gid) {
continue;
}
normalized.push(gid);
if normalized.len() == AUTH_SYS_MAX_GROUPS {
break;
}
}
normalized
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::TcpListener;
#[test]
fn record_limit_keeps_default_for_small_payloads() {
assert_eq!(
max_record_size_for_payloads(&[128 * 1024]),
DEFAULT_MAX_RECORD_SIZE
);
}
#[test]
fn record_limit_adds_headroom_for_large_payloads() {
assert_eq!(
max_record_size_for_payloads(&[DEFAULT_MAX_RECORD_SIZE as u32]),
DEFAULT_MAX_RECORD_SIZE + MAX_RECORD_HEADROOM
);
}
#[test]
fn connects_with_configured_timeout() {
let listener = TcpListener::bind(("127.0.0.1", 0)).unwrap();
let addr = listener.local_addr().unwrap();
let accept = std::thread::spawn(move || {
let _ = listener.accept();
});
let client =
RpcClient::connect_with_timeout(addr, Auth::none(), Some(Duration::from_secs(1)))
.unwrap();
drop(client);
accept.join().unwrap();
}
#[test]
fn normalizes_auxiliary_groups_for_auth_sys() {
let groups = normalize_auxiliary_gids(
10,
[
10, 1, 2, 1, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18,
],
);
assert_eq!(
groups,
vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17]
);
assert_eq!(groups.len(), AUTH_SYS_MAX_GROUPS);
assert!(!groups.contains(&10));
}
#[test]
fn current_auth_sys_has_bounded_auxiliary_groups() {
let auth = AuthSys::current();
assert!(auth.gids.len() <= AUTH_SYS_MAX_GROUPS);
assert!(!auth.gids.contains(&auth.gid));
}
}