use crate::env_list::EnvList;
use crate::error::{Error, ErrorCode};
use crate::ffi::to_pam_conv;
use crate::session::{Session, SessionToken};
use crate::{char_ptr_to_str, ConversationHandler};
extern crate libc;
extern crate pam_sys;
use crate::{ExtResult, Flag, Result, PAM_SUCCESS};
use libc::{c_char, c_int, c_void};
use pam_sys::pam_handle_t as RawPamHandle;
use pam_sys::{
pam_acct_mgmt, pam_authenticate, pam_chauthtok, pam_close_session, pam_end, pam_get_item,
pam_getenv, pam_getenvlist, pam_open_session, pam_putenv, pam_set_item, pam_setcred, pam_start,
};
use std::cell::Cell;
use std::ffi::{CStr, CString, OsStr};
use std::mem::take;
use std::os::unix::ffi::OsStrExt;
use std::ptr::NonNull;
use std::{ptr, slice};
macro_rules! impl_pam_str_item {
($name:ident, $set_name:ident, $item_type:expr$(, $doc:literal$(, $extdoc:literal)?)?$(,)?) => {
$(#[doc = "Returns "]#[doc = $doc]$(#[doc = "\n\n"]#[doc = $extdoc])?)?
pub fn $name(&self) -> Result<String> {
let ptr = self.get_item($item_type as c_int)?;
if ptr.is_null() {
return Err(Error::new(self.handle(), ErrorCode::PERM_DENIED));
}
let string = unsafe { CStr::from_ptr(ptr.cast()) }.to_string_lossy().into_owned();
return Ok(string);
}
$(#[doc = "Sets "]#[doc = $doc])?
pub fn $set_name(&mut self, value: Option<&str>) -> Result<()> {
match value {
None => unsafe { self.set_item($item_type as c_int, ptr::null()) },
Some(string) => {
let cstring = CString::new(string).map_err(|_| Error::new(self.handle(), ErrorCode::BUF_ERR))?;
unsafe { self.set_item($item_type as c_int, cstring.as_ptr().cast()) }
}
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct PamHandle(NonNull<RawPamHandle>);
impl PamHandle {
#[inline]
pub unsafe fn new(ptr: *mut RawPamHandle) -> Option<Self> {
NonNull::new(ptr).map(Self)
}
#[inline]
pub const fn as_ptr(self) -> *mut RawPamHandle {
self.0.as_ptr()
}
}
impl From<PamHandle> for *mut RawPamHandle {
#[inline]
fn from(handle: PamHandle) -> Self {
handle.as_ptr()
}
}
impl From<PamHandle> for *const RawPamHandle {
#[inline]
fn from(handle: PamHandle) -> Self {
handle.as_ptr()
}
}
#[cfg(any(target_os = "linux", doc))]
#[repr(C)]
#[derive(Debug)]
struct XAuthData {
pub namelen: c_int,
pub name: *const c_char,
pub datalen: c_int,
pub data: *const c_char,
}
pub struct Context<ConvT>
where
ConvT: ConversationHandler,
{
handle: Option<PamHandle>,
conversation: Box<ConvT>,
last_status: Cell<c_int>,
}
impl<ConvT> Context<ConvT>
where
ConvT: ConversationHandler,
{
#[rustversion::attr(since(1.48), doc(alias = "pam_start"))]
pub fn new(service: &str, username: Option<&str>, conversation: ConvT) -> Result<Self> {
Self::from_boxed_conv(service, username, Box::new(conversation))
}
pub fn from_boxed_conv(
service: &str,
username: Option<&str>,
mut boxed_conv: Box<ConvT>,
) -> Result<Self> {
let mut handle: *mut RawPamHandle = ptr::null_mut();
let c_service = CString::new(service).map_err(|_| Error::from(ErrorCode::BUF_ERR))?;
let c_username = match username {
None => None,
Some(name) => Some(CString::new(name).map_err(|_| Error::from(ErrorCode::BUF_ERR))?),
};
let pam_conv = to_pam_conv(&mut boxed_conv);
match unsafe {
pam_start(
c_service.as_ptr(),
c_username.as_ref().map_or(ptr::null(), |s| s.as_ptr()),
&pam_conv,
&mut handle,
)
} {
PAM_SUCCESS => {
let handle = unsafe { PamHandle::new(handle) }
.ok_or_else(|| Error::from(ErrorCode::ABORT))?;
boxed_conv.init(username);
Ok(Self {
handle: Some(handle),
conversation: boxed_conv,
last_status: Cell::new(PAM_SUCCESS),
})
}
code => Err(ErrorCode::from_repr(code)
.unwrap_or(ErrorCode::ABORT)
.into()),
}
}
#[inline]
pub(crate) fn handle(&self) -> PamHandle {
self.handle.unwrap()
}
#[inline]
pub(crate) fn wrap_pam_return(&self, status: c_int) -> Result<()> {
self.last_status.set(status);
match status {
PAM_SUCCESS => Ok(()),
code => Err(Error::new(
self.handle(),
ErrorCode::from_repr(code).unwrap_or(ErrorCode::ABORT),
)),
}
}
pub fn conversation(&self) -> &ConvT {
&self.conversation
}
pub fn conversation_mut(&mut self) -> &mut ConvT {
&mut self.conversation
}
#[rustversion::attr(since(1.48), doc(alias = "pam_get_item"))]
pub fn get_item(&self, item_type: c_int) -> Result<*const c_void> {
let mut result: *const c_void = ptr::null();
self.wrap_pam_return(unsafe {
pam_get_item(self.handle().into(), item_type, &mut result)
})?;
Ok(result)
}
#[rustversion::attr(since(1.48), doc(alias = "pam_set_item"))]
pub unsafe fn set_item(&mut self, item_type: c_int, value: *const c_void) -> Result<()> {
self.wrap_pam_return(pam_set_item(self.handle().into(), item_type, &*value))
}
impl_pam_str_item!(
service,
set_service,
pam_sys::PAM_SERVICE,
"the service name"
);
impl_pam_str_item!(user, set_user, pam_sys::PAM_USER, "the username of the entity under whose identity service will be given",
"This value can be mapped by any module in the PAM stack, so don't assume it stays unchanged after calling other methods on `Self`.");
impl_pam_str_item!(
user_prompt,
set_user_prompt,
pam_sys::PAM_USER_PROMPT,
"the string used when prompting for a user's name"
);
impl_pam_str_item!(tty, set_tty, pam_sys::PAM_TTY, "the terminal name");
impl_pam_str_item!(
ruser,
set_ruser,
pam_sys::PAM_RUSER,
"the requesting user name"
);
impl_pam_str_item!(
rhost,
set_rhost,
pam_sys::PAM_RHOST,
"the requesting hostname"
);
#[cfg(any(target_os = "linux", doc))]
impl_pam_str_item!(
authtok_type,
set_authtok_type,
pam_sys::PAM_AUTHTOK_TYPE,
"the default password type in the prompt (Linux specific)",
"E.g. \"UNIX\" for \"Enter UNIX password:\""
);
#[cfg(any(target_os = "linux", doc))]
impl_pam_str_item!(
xdisplay,
set_xdisplay,
pam_sys::PAM_XDISPLAY,
"the name of the X display (Linux specific)"
);
#[cfg(any(target_os = "linux", doc))]
pub fn xauthdata(&self) -> Result<(&CStr, &[u8])> {
let handle = self.handle();
let ptr = self
.get_item(pam_sys::PAM_XAUTHDATA as c_int)?
.cast::<XAuthData>();
if ptr.is_null() {
return Err(Error::new(handle, ErrorCode::PERM_DENIED));
}
let data = unsafe { &*ptr };
if data.namelen < 0 || data.datalen < 0 || data.name.is_null() || data.data.is_null() {
return Err(Error::new(handle, ErrorCode::BUF_ERR));
}
#[allow(clippy::cast_sign_loss)]
Ok((
CStr::from_bytes_with_nul(unsafe {
slice::from_raw_parts(data.name.cast(), data.namelen as usize + 1)
})
.map_err(|_| Error::new(handle, ErrorCode::BUF_ERR))?,
unsafe { slice::from_raw_parts(data.data.cast(), data.datalen as usize) },
))
}
#[cfg(any(target_os = "linux", doc))]
pub fn set_xauthdata(&mut self, value: Option<(&CStr, &[u8])>) -> Result<()> {
match value {
None => unsafe { self.set_item(pam_sys::PAM_XAUTHDATA as c_int, ptr::null()) },
Some((name, data)) => {
let name_bytes = name.to_bytes_with_nul();
if name_bytes.len() > i32::MAX as usize || data.len() > i32::MAX as usize {
return Err(Error::new(self.handle(), ErrorCode::BUF_ERR));
}
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let xauthdata = XAuthData {
namelen: name_bytes.len() as i32 - 1,
name: name_bytes.as_ptr().cast(),
datalen: data.len() as i32,
data: data.as_ptr().cast(),
};
unsafe {
self.set_item(
pam_sys::PAM_XAUTHDATA as c_int,
&xauthdata as *const _ as *const c_void,
)
}
}
}
}
#[must_use]
#[rustversion::attr(since(1.48), doc(alias = "pam_getenv"))]
pub fn getenv(&self, name: impl AsRef<OsStr>) -> Option<&str> {
let c_name = match CString::new(name.as_ref().as_bytes()) {
Err(_) => return None,
Ok(s) => s,
};
char_ptr_to_str(unsafe { pam_getenv(self.handle().into(), c_name.as_ptr()) })
}
#[rustversion::attr(since(1.48), doc(alias = "pam_putenv"))]
pub fn putenv(&mut self, name_value: impl AsRef<OsStr>) -> Result<()> {
let c_name_value = CString::new(name_value.as_ref().as_bytes())
.map_err(|_| Error::from(ErrorCode::BUF_ERR))?;
self.wrap_pam_return(unsafe { pam_putenv(self.handle().into(), c_name_value.as_ptr()) })
}
#[must_use]
#[rustversion::attr(since(1.48), doc(alias = "pam_getenvlist"))]
pub fn envlist(&self) -> EnvList {
unsafe { EnvList::new(pam_getenvlist(self.handle().into()).cast()) }
}
#[rustversion::attr(since(1.48), doc(alias = "pam_authenticate"))]
pub fn authenticate(&mut self, flags: Flag) -> Result<()> {
self.wrap_pam_return(unsafe { pam_authenticate(self.handle().into(), flags.bits()) })
}
#[rustversion::attr(since(1.48), doc(alias = "pam_acct_mgmt"))]
pub fn acct_mgmt(&mut self, flags: Flag) -> Result<()> {
self.wrap_pam_return(unsafe { pam_acct_mgmt(self.handle().into(), flags.bits()) })
}
pub fn reinitialize_credentials(&mut self, flags: Flag) -> Result<()> {
self.wrap_pam_return(unsafe {
pam_setcred(
self.handle().into(),
(Flag::REINITIALIZE_CRED | flags).bits(),
)
})
}
#[rustversion::attr(since(1.48), doc(alias = "pam_chauthtok"))]
pub fn chauthtok(&mut self, flags: Flag) -> Result<()> {
self.wrap_pam_return(unsafe { pam_chauthtok(self.handle().into(), flags.bits()) })
}
#[rustversion::attr(since(1.48), doc(alias = "pam_open_session"))]
pub fn open_session(&mut self, flags: Flag) -> Result<Session<ConvT>> {
let handle = self.handle().as_ptr();
self.wrap_pam_return(unsafe {
pam_setcred(handle, (Flag::ESTABLISH_CRED | flags).bits())
})?;
if let Err(e) = self.wrap_pam_return(unsafe { pam_open_session(handle, flags.bits()) }) {
let _ = self.wrap_pam_return(unsafe {
pam_setcred(handle, (Flag::DELETE_CRED | flags).bits())
});
return Err(e);
}
if let Err(e) = self.wrap_pam_return(unsafe {
pam_setcred(handle, (Flag::REINITIALIZE_CRED | flags).bits())
}) {
let _ = self.wrap_pam_return(unsafe { pam_close_session(handle, flags.bits()) });
let _ = self.wrap_pam_return(unsafe {
pam_setcred(handle, (Flag::DELETE_CRED | flags).bits())
});
return Err(e);
}
Ok(Session::new(self, true))
}
pub fn open_pseudo_session(&mut self, flags: Flag) -> Result<Session<ConvT>> {
self.wrap_pam_return(unsafe {
pam_setcred(self.handle().into(), (Flag::ESTABLISH_CRED | flags).bits())
})?;
Ok(Session::new(self, false))
}
pub fn unleak_session(&mut self, token: SessionToken) -> Session<ConvT> {
Session::new(self, matches!(token, SessionToken::FullSession))
}
}
impl<ConvT> Context<ConvT>
where
ConvT: ConversationHandler + Default,
{
pub fn replace_conversation<T: ConversationHandler>(
self,
new_handler: T,
) -> ExtResult<(Context<T>, ConvT), (Self, T)> {
match self.replace_conversation_boxed(new_handler.into()) {
Ok((context, boxed_old_conv)) => Ok((context, *boxed_old_conv)),
Err(error) => Err(error.map(|(ctx, b_conv)| (ctx, *b_conv))),
}
}
pub fn replace_conversation_boxed<T: ConversationHandler>(
mut self,
mut new_handler: Box<T>,
) -> ExtResult<(Context<T>, Box<ConvT>), (Self, Box<T>)> {
let username = match self.user() {
Ok(u) => Some(u),
Err(e) => {
if e.code() != ErrorCode::PERM_DENIED {
return Err(e.into_with_payload((self, new_handler)));
}
None
}
};
let pam_conv = to_pam_conv(&mut new_handler);
if let Err(e) = unsafe {
self.set_item(
pam_sys::PAM_CONV as c_int,
&pam_conv as *const _ as *const _,
)
} {
Err(e.into_with_payload((self, new_handler)))
} else {
new_handler.init(username);
Ok((
Context::<T> {
handle: self.handle.take(),
conversation: new_handler,
last_status: Cell::new(self.last_status.replace(PAM_SUCCESS)),
},
take(&mut self.conversation),
))
}
}
}
impl<ConvT> Drop for Context<ConvT>
where
ConvT: ConversationHandler,
{
#[rustversion::attr(since(1.48), doc(alias = "pam_end"))]
fn drop(&mut self) {
if let Some(handle) = self.handle {
unsafe { pam_end(handle.into(), self.last_status.get()) };
}
}
}
unsafe impl<ConvT> Send for Context<ConvT> where ConvT: ConversationHandler + Send {}
#[cfg(test)]
mod tests {
use super::*;
use std::ffi::{OsStr, OsString};
#[test]
fn test_basic() {
let mut context =
Context::new("test", Some("user"), crate::conv_null::Conversation::new()).unwrap();
assert_eq!(context.service().unwrap(), "test");
assert_eq!(context.user().unwrap(), "user");
let h = context.handle();
assert_eq!(&h.clone().0, &h.0);
assert!(format!("{:?}", h).contains(&format!("{:?}", h.as_ptr())));
context.set_user_prompt(Some("Who art thou? ")).unwrap();
assert_eq!(context.user_prompt().unwrap(), "Who art thou? ");
context.set_tty(Some("/dev/tty")).unwrap();
assert_eq!(context.tty().unwrap(), "/dev/tty");
context.set_ruser(Some("nobody")).unwrap();
assert_eq!(context.ruser().unwrap(), "nobody");
context.set_rhost(Some("nowhere")).unwrap();
assert_eq!(context.rhost().unwrap(), "nowhere");
#[cfg(target_os = "linux")]
{
context.set_authtok_type(Some("TEST")).unwrap();
assert_eq!(context.authtok_type().unwrap(), "TEST");
context.set_xdisplay(Some(":0")).unwrap();
assert_eq!(context.xdisplay().unwrap(), ":0");
let xauthname = CString::new("TEST_DATA").unwrap();
let xauthdata = [];
let _ = context.xauthdata();
context
.set_xauthdata(Some((&xauthname, &xauthdata)))
.unwrap();
let (resultname, resultdata) = context.xauthdata().unwrap();
assert_eq!(resultname, xauthname.as_c_str());
assert_eq!(resultdata, &xauthdata);
};
assert_eq!(
context.conversation_mut() as *mut _ as *const _,
context.conversation() as *const _
);
context
.conversation_mut()
.text_info(&CString::new("").unwrap());
assert!(context.get_item(pam_sys::PAM_AUTHTOK as c_int).is_err());
context.putenv("TEST=1").unwrap();
context.putenv("TEST2=2").unwrap();
let _ = context.putenv("\0=\0").unwrap_err();
assert_eq!(context.getenv("TEST").unwrap(), "1");
assert!(context.getenv("TESTNONEXIST").is_none());
let env = context.envlist();
assert!(env.len() > 0);
let _ = env.get("TEST").unwrap();
let _ = env.get("TESTNONEXIST").is_none();
for (key, value) in env.iter_tuples() {
if key.to_string_lossy() == "TEST" {
assert_eq!(value.to_string_lossy(), "1");
}
}
assert!(format!("{:?}", &env.iter_tuples()).contains("EnvItem"));
for item in &env {
let string = item.to_string();
if string.starts_with("TEST=") {
assert_eq!(string, "TEST=1");
assert!(format!("{:?}", &item).contains("EnvItem"));
} else if string.starts_with("TEST2=") {
let (_, v): (&OsStr, &OsStr) = item.into();
assert_eq!(v.to_string_lossy(), "2");
}
let _ = item.as_ref();
}
let _ = format!("{:?}", &env);
assert_eq!(env.is_empty(), false);
assert_eq!(env.len(), env.as_ref().len());
assert_eq!(env.as_ref(), context.envlist().as_ref());
assert_eq!(
env.as_ref().partial_cmp(context.envlist().as_ref()),
Some(std::cmp::Ordering::Equal)
);
assert_eq!(
env.as_ref().cmp(context.envlist().as_ref()),
std::cmp::Ordering::Equal
);
assert_eq!(&env["TEST"], "1");
assert_eq!(env.len(), env.iter_tuples().size_hint().0);
let list: std::vec::Vec<&CStr> = (&env).into();
assert_eq!(list.len(), env.len());
let list: std::vec::Vec<(&OsStr, _)> = (&env).into();
assert_eq!(list.len(), env.len());
let map: std::collections::HashMap<&OsStr, _> = (&env).into();
assert_eq!(map.len(), map.len());
assert_eq!(
map.get(&OsString::from("TEST".to_string()).as_ref()),
Some(&OsString::from("1".to_string()).as_ref())
);
assert!(env.to_string().contains("TEST=1"));
let list: std::vec::Vec<(std::ffi::OsString, _)> = context.envlist().into();
assert_eq!(list.len(), env.len());
let list: std::vec::Vec<CString> = context.envlist().into();
assert_eq!(list.len(), env.len());
let map: std::collections::HashMap<_, _> = context.envlist().into();
assert_eq!(map.len(), env.len());
assert_eq!(
map.get(&OsString::from("TEST".to_string())),
Some(&OsString::from("1".to_string()))
);
drop(context)
}
#[test]
fn test_conv_replace() {
let mut context =
Context::new("test", Some("user"), crate::conv_null::Conversation::new()).unwrap();
context.set_user(Some("anybody")).unwrap();
let (mut context, old_conv) = context
.replace_conversation(crate::conv_mock::Conversation::default())
.unwrap();
assert_eq!(context.conversation().username, "anybody");
context.set_user(None).unwrap();
let (context, _) = context.replace_conversation(old_conv).unwrap();
assert!(context.user().is_err());
}
#[test]
#[cfg_attr(not(feature = "full_test"), ignore)]
fn test_full() {
let mut context = Context::new(
"test_rust_pam_client",
Some("nobody"),
crate::conv_null::Conversation::new(),
)
.unwrap();
let _ = context.authenticate(Flag::SILENT);
let _ = context.acct_mgmt(Flag::SILENT);
let _ = context.chauthtok(Flag::CHANGE_EXPIRED_AUTHTOK);
let _ = context.reinitialize_credentials(Flag::SILENT | Flag::NONE);
drop(context.open_session(Flag::SILENT));
drop(context.open_pseudo_session(Flag::SILENT));
}
#[test]
#[cfg_attr(not(feature = "full_test"), ignore)]
fn test_full_unauth() {
let mut context = Context::new(
"test_rust_pam_client",
Some("nobody"),
crate::conv_null::Conversation::new(),
)
.unwrap();
let _ = context.acct_mgmt(Flag::SILENT);
let _ = context.chauthtok(Flag::CHANGE_EXPIRED_AUTHTOK);
let _ = context.reinitialize_credentials(Flag::SILENT | Flag::NONE);
if let Ok(mut session) = context.open_session(Flag::SILENT) {
let _ = session.refresh_credentials(Flag::SILENT);
let _ = session.reinitialize_credentials(Flag::SILENT);
let _ = session.envlist();
let _ = session.close(Flag::SILENT);
};
if let Ok(mut session) = context.open_pseudo_session(Flag::SILENT) {
let _ = session.refresh_credentials(Flag::SILENT);
};
}
}