use crate::env_list::EnvList;
use crate::error::{Error, ErrorCode};
use crate::ffi::{from_pam_conv, into_pam_conv};
use crate::session::{Session, SessionToken};
use crate::{char_ptr_to_str, ConversationHandler};
extern crate libc;
extern crate pam_sys2 as pam_sys;
use crate::{ExtResult, Flag, Result, PAM_SUCCESS};
use libc::{c_char, c_int, c_void};
use pam_sys::pam_conv as PamConversation;
use pam_sys::pam_handle_t as RawPamHandle;
use pam_sys::{
pam_acct_mgmt, pam_authenticate, pam_chauthtok, pam_close_session, pam_end, pam_get_item,
pam_getenv, pam_getenvlist, pam_open_session, pam_putenv, pam_set_item, pam_setcred, pam_start,
};
use std::cell::Cell;
use std::ffi::{CStr, CString, OsStr};
use std::marker::PhantomData;
use std::mem::ManuallyDrop;
use std::os::unix::ffi::OsStrExt;
use std::ptr::NonNull;
use std::{ptr, slice};
macro_rules! impl_pam_str_item {
($name:ident, $set_name:ident, $item_type:expr$(, $doc:literal$(, $extdoc:literal)?)?$(,)?) => {
$(#[doc = "Returns "]#[doc = $doc]$(#[doc = "\n\n"]#[doc = $extdoc])?)?
pub fn $name(&self) -> Result<String> {
let ptr = self.get_item($item_type as c_int)?;
if ptr.is_null() {
return Err(Error::new(self.handle(), ErrorCode::PERM_DENIED));
}
let string = unsafe { CStr::from_ptr(ptr.cast()) }.to_string_lossy().into_owned();
return Ok(string);
}
$(#[doc = "Sets "]#[doc = $doc])?
pub fn $set_name(&mut self, value: Option<&str>) -> Result<()> {
match value {
None => unsafe { self.set_item($item_type as c_int, ptr::null()) },
Some(string) => {
let cstring = CString::new(string).map_err(|_| Error::new(self.handle(), ErrorCode::BUF_ERR))?;
unsafe { self.set_item($item_type as c_int, cstring.as_ptr().cast()) }
}
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct PamHandle(NonNull<RawPamHandle>);
impl PamHandle {
#[inline]
pub unsafe fn new(ptr: *mut RawPamHandle) -> Option<Self> {
NonNull::new(ptr).map(Self)
}
#[inline]
pub const fn as_ptr(self) -> *mut RawPamHandle {
self.0.as_ptr()
}
}
impl From<PamHandle> for *mut RawPamHandle {
#[inline]
fn from(handle: PamHandle) -> Self {
handle.as_ptr()
}
}
impl From<PamHandle> for *const RawPamHandle {
#[inline]
fn from(handle: PamHandle) -> Self {
handle.as_ptr()
}
}
#[cfg(any(target_os = "linux", doc))]
#[repr(C)]
#[derive(Debug)]
struct XAuthData {
pub namelen: c_int,
pub name: *const c_char,
pub datalen: c_int,
pub data: *const c_char,
}
pub struct Context<ConvT> {
handle: PamHandle,
last_status: Cell<c_int>,
_conversation: PhantomData<ConvT>,
}
impl<ConvT> Context<ConvT>
where
ConvT: ConversationHandler,
{
#[rustversion::attr(since(1.48), doc(alias = "pam_start"))]
pub fn new(service: &str, username: Option<&str>, conversation: ConvT) -> Result<Self> {
Self::from_boxed_conv(service, username, Box::new(conversation))
}
pub fn from_boxed_conv(
service: &str,
username: Option<&str>,
boxed_conv: Box<ConvT>,
) -> Result<Self> {
let mut handle: *mut RawPamHandle = ptr::null_mut();
let c_service = CString::new(service).map_err(|_| Error::from(ErrorCode::BUF_ERR))?;
let c_username = match username {
None => None,
Some(name) => Some(CString::new(name).map_err(|_| Error::from(ErrorCode::BUF_ERR))?),
};
let pam_conv = into_pam_conv(boxed_conv);
match unsafe {
pam_start(
c_service.as_ptr(),
c_username.as_ref().map_or(ptr::null(), |s| s.as_ptr()),
&pam_conv,
&mut handle,
)
} {
PAM_SUCCESS => {
let handle = unsafe { PamHandle::new(handle) }
.ok_or_else(|| Error::from(ErrorCode::ABORT))?;
let mut result = Self {
handle,
last_status: Cell::new(PAM_SUCCESS),
_conversation: PhantomData,
};
result.conversation_mut().init(username);
Ok(result)
}
code => Err(ErrorCode::from_repr(code)
.unwrap_or(ErrorCode::ABORT)
.into()),
}
}
#[rustversion::attr(since(1.48), doc(alias = "pam_authenticate"))]
pub fn authenticate(&mut self, flags: Flag) -> Result<()> {
self.wrap_pam_return(unsafe { pam_authenticate(self.handle().into(), flags.bits()) })
}
#[rustversion::attr(since(1.48), doc(alias = "pam_acct_mgmt"))]
pub fn acct_mgmt(&mut self, flags: Flag) -> Result<()> {
self.wrap_pam_return(unsafe { pam_acct_mgmt(self.handle().into(), flags.bits()) })
}
pub fn reinitialize_credentials(&mut self, flags: Flag) -> Result<()> {
self.wrap_pam_return(unsafe {
pam_setcred(
self.handle().into(),
(Flag::REINITIALIZE_CRED | flags).bits(),
)
})
}
#[rustversion::attr(since(1.48), doc(alias = "pam_chauthtok"))]
pub fn chauthtok(&mut self, flags: Flag) -> Result<()> {
self.wrap_pam_return(unsafe { pam_chauthtok(self.handle().into(), flags.bits()) })
}
#[rustversion::attr(since(1.48), doc(alias = "pam_open_session"))]
pub fn open_session(&mut self, flags: Flag) -> Result<Session<'_, ConvT>> {
let bits = flags.bits();
let handle = self.handle().as_ptr();
self.wrap_pam_return(unsafe {
pam_setcred(handle, (Flag::ESTABLISH_CRED | flags).bits())
})?;
if let Err(e) = self.wrap_pam_return(unsafe { pam_open_session(handle, bits) }) {
let _ = self.wrap_pam_return(unsafe {
pam_setcred(handle, (Flag::DELETE_CRED | flags).bits())
});
return Err(e);
}
if let Err(e) = self.wrap_pam_return(unsafe {
pam_setcred(handle, (Flag::REINITIALIZE_CRED | flags).bits())
}) {
let _ = self.wrap_pam_return(unsafe { pam_close_session(handle, bits) });
let _ = self.wrap_pam_return(unsafe {
pam_setcred(handle, (Flag::DELETE_CRED | flags).bits())
});
return Err(e);
}
Ok(Session::new(self, true))
}
pub fn open_pseudo_session(&mut self, flags: Flag) -> Result<Session<'_, ConvT>> {
self.wrap_pam_return(unsafe {
pam_setcred(self.handle().into(), (Flag::ESTABLISH_CRED | flags).bits())
})?;
Ok(Session::new(self, false))
}
pub fn unleak_session(&mut self, token: SessionToken) -> Session<'_, ConvT> {
Session::new(self, matches!(token, SessionToken::FullSession))
}
}
impl<ConvT> Context<ConvT> {
#[inline]
pub(crate) fn handle(&self) -> PamHandle {
self.handle
}
#[inline]
pub(crate) fn wrap_pam_return(&self, status: c_int) -> Result<()> {
self.last_status.set(status);
match status {
PAM_SUCCESS => Ok(()),
code => Err(Error::new(
self.handle(),
ErrorCode::from_repr(code).unwrap_or(ErrorCode::ABORT),
)),
}
}
#[rustversion::attr(since(1.48), doc(alias = "pam_get_item"))]
pub fn get_item(&self, item_type: c_int) -> Result<*const c_void> {
let mut result: *const c_void = ptr::null();
self.wrap_pam_return(unsafe {
pam_get_item(self.handle().into(), item_type, &mut result)
})?;
Ok(result)
}
#[rustversion::attr(since(1.48), doc(alias = "pam_set_item"))]
pub unsafe fn set_item(&mut self, item_type: c_int, value: *const c_void) -> Result<()> {
self.wrap_pam_return(pam_set_item(self.handle().into(), item_type, value))
}
#[inline]
fn conversation_raw(&self) -> *mut ConvT {
let ptr = self
.get_item(pam_sys::PAM_CONV as c_int)
.expect("Extracting the conversation handler should never fail")
.cast::<PamConversation>();
unsafe {
from_pam_conv(
ptr.as_ref()
.expect("Invalid state: conversation handler should never be null"),
)
}
}
pub fn conversation(&self) -> &ConvT {
let ptr: *const ConvT = self.conversation_raw();
unsafe { &*ptr }
}
pub fn conversation_mut(&mut self) -> &mut ConvT {
let ptr = self.conversation_raw();
unsafe { &mut *ptr }
}
impl_pam_str_item!(
service,
set_service,
pam_sys::PAM_SERVICE,
"the service name"
);
impl_pam_str_item!(user, set_user, pam_sys::PAM_USER, "the username of the entity under whose identity service will be given",
"This value can be mapped by any module in the PAM stack, so don't assume it stays unchanged after calling other methods on `Self`.");
impl_pam_str_item!(
user_prompt,
set_user_prompt,
pam_sys::PAM_USER_PROMPT,
"the string used when prompting for a user's name"
);
impl_pam_str_item!(tty, set_tty, pam_sys::PAM_TTY, "the terminal name");
impl_pam_str_item!(
ruser,
set_ruser,
pam_sys::PAM_RUSER,
"the requesting user name"
);
impl_pam_str_item!(
rhost,
set_rhost,
pam_sys::PAM_RHOST,
"the requesting hostname"
);
#[cfg(any(target_os = "linux", doc))]
impl_pam_str_item!(
authtok_type,
set_authtok_type,
pam_sys::PAM_AUTHTOK_TYPE,
"the default password type in the prompt (Linux specific)",
"E.g. \"UNIX\" for \"Enter UNIX password:\""
);
#[cfg(any(target_os = "linux", doc))]
impl_pam_str_item!(
xdisplay,
set_xdisplay,
pam_sys::PAM_XDISPLAY,
"the name of the X display (Linux specific)"
);
#[cfg(any(target_os = "linux", doc))]
pub fn xauthdata(&self) -> Result<(&CStr, &[u8])> {
let handle = self.handle();
let ptr = self
.get_item(pam_sys::PAM_XAUTHDATA as c_int)?
.cast::<XAuthData>();
if ptr.is_null() {
return Err(Error::new(handle, ErrorCode::PERM_DENIED));
}
let data = unsafe { &*ptr };
if data.namelen < 0 || data.datalen < 0 || data.name.is_null() || data.data.is_null() {
return Err(Error::new(handle, ErrorCode::BUF_ERR));
}
#[allow(clippy::cast_sign_loss)]
Ok((
CStr::from_bytes_with_nul(unsafe {
slice::from_raw_parts(data.name.cast(), data.namelen as usize + 1)
})
.map_err(|_| Error::new(handle, ErrorCode::BUF_ERR))?,
unsafe { slice::from_raw_parts(data.data.cast(), data.datalen as usize) },
))
}
#[cfg(any(target_os = "linux", doc))]
pub fn set_xauthdata(&mut self, value: Option<(&CStr, &[u8])>) -> Result<()> {
match value {
None => unsafe { self.set_item(pam_sys::PAM_XAUTHDATA as c_int, ptr::null()) },
Some((name, data)) => {
let name_bytes = name.to_bytes_with_nul();
if name_bytes.len() > i32::MAX as usize || data.len() > i32::MAX as usize {
return Err(Error::new(self.handle(), ErrorCode::BUF_ERR));
}
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let xauthdata = XAuthData {
namelen: name_bytes.len() as i32 - 1,
name: name_bytes.as_ptr().cast(),
datalen: data.len() as i32,
data: data.as_ptr().cast(),
};
unsafe {
self.set_item(
pam_sys::PAM_XAUTHDATA as c_int,
&xauthdata as *const _ as *const c_void,
)
}
}
}
}
#[must_use]
#[rustversion::attr(since(1.48), doc(alias = "pam_getenv"))]
pub fn getenv(&self, name: impl AsRef<OsStr>) -> Option<&str> {
let c_name = match CString::new(name.as_ref().as_bytes()) {
Err(_) => return None,
Ok(s) => s,
};
char_ptr_to_str(unsafe { pam_getenv(self.handle().into(), c_name.as_ptr()) })
}
#[rustversion::attr(since(1.48), doc(alias = "pam_putenv"))]
pub fn putenv(&mut self, name_value: impl AsRef<OsStr>) -> Result<()> {
let c_name_value = CString::new(name_value.as_ref().as_bytes())
.map_err(|_| Error::from(ErrorCode::BUF_ERR))?;
self.wrap_pam_return(unsafe { pam_putenv(self.handle().into(), c_name_value.as_ptr()) })
}
#[must_use]
#[rustversion::attr(since(1.48), doc(alias = "pam_getenvlist"))]
pub fn envlist(&self) -> EnvList {
unsafe { EnvList::new(pam_getenvlist(self.handle().into()).cast()) }
}
pub fn replace_conversation<T: ConversationHandler>(
self,
new_handler: T,
) -> ExtResult<(Context<T>, ConvT), (Self, T)> {
match self.replace_conversation_boxed(new_handler.into()) {
Ok((context, boxed_old_conv)) => Ok((context, *boxed_old_conv)),
Err(error) => Err(error.map(|(ctx, b_conv)| (ctx, *b_conv))),
}
}
pub fn replace_conversation_boxed<T: ConversationHandler>(
mut self,
new_handler: Box<T>,
) -> ExtResult<(Context<T>, Box<ConvT>), (Self, Box<T>)> {
let username = match self.user() {
Ok(u) => Some(u),
Err(e) => {
if e.code() != ErrorCode::PERM_DENIED {
return Err(e.into_with_payload((self, new_handler)));
}
None
}
};
let old_handler_ptr = self.conversation_raw();
let pam_conv = into_pam_conv(new_handler);
if let Err(e) = unsafe {
self.set_item(
pam_sys::PAM_CONV as c_int,
&pam_conv as *const _ as *const _,
)
} {
let new_handler = unsafe { Box::from_raw(from_pam_conv(&pam_conv)) };
Err(e.into_with_payload((self, new_handler)))
} else {
let old = ManuallyDrop::new(self);
let old_handler = unsafe { Box::from_raw(old_handler_ptr) };
let mut context = Context::<T> {
handle: old.handle,
last_status: Cell::new(old.last_status.replace(PAM_SUCCESS)),
_conversation: PhantomData,
};
context.conversation_mut().init(username.as_deref());
Ok((context, old_handler))
}
}
}
impl<ConvT> Drop for Context<ConvT> {
#[rustversion::attr(since(1.48), doc(alias = "pam_end"))]
fn drop(&mut self) {
let conv = self.conversation_raw();
unsafe { pam_end(self.handle.into(), self.last_status.get()) };
drop(unsafe { Box::from_raw(conv) });
}
}
unsafe impl<ConvT> Send for Context<ConvT> where ConvT: Send {}
#[cfg(test)]
mod tests {
use super::*;
use std::ffi::{OsStr, OsString};
#[test]
fn test_basic() {
let mut context =
Context::new("test", Some("user"), crate::conv_null::Conversation::new()).unwrap();
assert_eq!(context.service().unwrap(), "test");
assert_eq!(context.user().unwrap(), "user");
let h = context.handle();
assert_eq!(&h.0, &h.0);
assert!(format!("{:?}", h).contains(&format!("{:?}", h.as_ptr())));
context.set_user_prompt(Some("Who art thou? ")).unwrap();
assert_eq!(context.user_prompt().unwrap(), "Who art thou? ");
context.set_tty(Some("/dev/tty")).unwrap();
assert_eq!(context.tty().unwrap(), "/dev/tty");
context.set_ruser(Some("nobody")).unwrap();
assert_eq!(context.ruser().unwrap(), "nobody");
context.set_rhost(Some("nowhere")).unwrap();
assert_eq!(context.rhost().unwrap(), "nowhere");
#[cfg(target_os = "linux")]
{
context.set_authtok_type(Some("TEST")).unwrap();
assert_eq!(context.authtok_type().unwrap(), "TEST");
context.set_xdisplay(Some(":0")).unwrap();
assert_eq!(context.xdisplay().unwrap(), ":0");
let xauthname = CString::new("TEST_DATA").unwrap();
let xauthdata = [];
let _ = context.xauthdata();
context
.set_xauthdata(Some((&xauthname, &xauthdata)))
.unwrap();
let (resultname, resultdata) = context.xauthdata().unwrap();
assert_eq!(resultname, xauthname.as_c_str());
assert_eq!(resultdata, &xauthdata);
};
assert_eq!(
context.conversation_mut() as *mut _ as *const _,
context.conversation() as *const _
);
context
.conversation_mut()
.text_info(&CString::new("").unwrap());
assert!(context.get_item(pam_sys::PAM_AUTHTOK as c_int).is_err());
context.putenv("TEST=1").unwrap();
context.putenv("TEST2=2").unwrap();
let _ = context.putenv("\0=\0").unwrap_err();
assert_eq!(context.getenv("TEST").unwrap(), "1");
assert!(context.getenv("TESTNONEXIST").is_none());
let env = context.envlist();
assert!(!env.is_empty());
let _ = env.get("TEST").unwrap();
let _ = env.get("TESTNONEXIST").is_none();
for (key, value) in env.iter_tuples() {
if key.to_string_lossy() == "TEST" {
assert_eq!(value.to_string_lossy(), "1");
}
}
assert!(format!("{:?}", &env.iter_tuples()).contains("EnvItem"));
for item in &env {
let string = item.to_string();
if string.starts_with("TEST=") {
assert_eq!(string, "TEST=1");
assert!(format!("{:?}", &item).contains("EnvItem"));
} else if string.starts_with("TEST2=") {
let (_, v): (&OsStr, &OsStr) = item.into();
assert_eq!(v.to_string_lossy(), "2");
}
let _ = item.as_ref();
}
let _ = format!("{:?}", &env);
assert!(!env.is_empty());
assert_eq!(env.len(), env.as_ref().len());
assert_eq!(env.as_ref(), context.envlist().as_ref());
assert_eq!(
env.as_ref().partial_cmp(context.envlist().as_ref()),
Some(std::cmp::Ordering::Equal)
);
assert_eq!(
env.as_ref().cmp(context.envlist().as_ref()),
std::cmp::Ordering::Equal
);
assert_eq!(&env["TEST"], "1");
assert_eq!(env.len(), env.iter_tuples().size_hint().0);
let list: std::vec::Vec<&CStr> = (&env).into();
assert_eq!(list.len(), env.len());
let list: std::vec::Vec<(&OsStr, _)> = (&env).into();
assert_eq!(list.len(), env.len());
let map: std::collections::HashMap<&OsStr, _> = (&env).into();
assert_eq!(map.len(), map.len());
assert_eq!(
map.get(&OsString::from("TEST".to_string()).as_ref()),
Some(&OsString::from("1".to_string()).as_ref())
);
assert!(env.to_string().contains("TEST=1"));
let list: std::vec::Vec<(std::ffi::OsString, _)> = context.envlist().into();
assert_eq!(list.len(), env.len());
let list: std::vec::Vec<CString> = context.envlist().into();
assert_eq!(list.len(), env.len());
let map: std::collections::HashMap<_, _> = context.envlist().into();
assert_eq!(map.len(), env.len());
assert_eq!(
map.get(&OsString::from("TEST".to_string())),
Some(&OsString::from("1".to_string()))
);
drop(context)
}
#[test]
fn test_conv_replace() {
let mut context =
Context::new("test", Some("user"), crate::conv_null::Conversation::new()).unwrap();
context.set_user(Some("anybody")).unwrap();
let (mut context, old_conv) = context
.replace_conversation(crate::conv_mock::Conversation::default())
.unwrap();
assert_eq!(context.conversation().username, "anybody");
context.set_user(None).unwrap();
let (context, _) = context.replace_conversation(old_conv).unwrap();
assert!(context.user().is_err());
}
#[test]
fn test_dyn_ref() {
let mut handler_a = crate::conv_null::Conversation::new();
let mut handler_b = crate::conv_mock::Conversation::new();
let mut context = Context::new(
"test",
Some("user"),
&mut handler_a as &mut dyn ConversationHandler,
)
.unwrap();
context.set_user(Some("anybody")).unwrap();
let (context, _) = context
.replace_conversation(
&mut handler_b as &mut dyn ConversationHandler
)
.unwrap();
drop(handler_a);
drop(context);
assert_eq!(handler_b.username, "anybody");
}
#[test]
fn test_dyn() {
let mut context = Context::new(
"test",
Some("user"),
Box::new(crate::conv_null::Conversation::new()) as Box<dyn ConversationHandler>,
)
.unwrap();
context.set_user(Some("anybody")).unwrap();
let (context, _) = context
.replace_conversation(
Box::new(crate::conv_mock::Conversation::new()) as Box<dyn ConversationHandler>
)
.unwrap();
let mock_handler: &crate::conv_mock::Conversation =
unsafe { &*(&**context.conversation() as *const _ as *const _) };
assert_eq!(mock_handler.username, "anybody");
}
#[test]
#[cfg_attr(not(feature = "full_test"), ignore)]
fn test_full() {
let mut context = Context::new(
"test_rust_pam_client",
Some("nobody"),
crate::conv_null::Conversation::new(),
)
.unwrap();
let _ = context.authenticate(Flag::SILENT);
let _ = context.acct_mgmt(Flag::SILENT);
let _ = context.chauthtok(Flag::CHANGE_EXPIRED_AUTHTOK);
let _ = context.reinitialize_credentials(Flag::SILENT | Flag::NONE);
drop(context.open_session(Flag::SILENT));
drop(context.open_pseudo_session(Flag::SILENT));
}
#[test]
#[cfg_attr(not(feature = "full_test"), ignore)]
fn test_full_unauth() {
let mut context = Context::new(
"test_rust_pam_client",
Some("nobody"),
crate::conv_null::Conversation::new(),
)
.unwrap();
let _ = context.acct_mgmt(Flag::SILENT);
let _ = context.chauthtok(Flag::CHANGE_EXPIRED_AUTHTOK);
let _ = context.reinitialize_credentials(Flag::SILENT | Flag::NONE);
if let Ok(mut session) = context.open_session(Flag::SILENT) {
let _ = session.refresh_credentials(Flag::SILENT);
let _ = session.reinitialize_credentials(Flag::SILENT);
let _ = session.envlist();
let _ = session.close(Flag::SILENT);
};
if let Ok(mut session) = context.open_pseudo_session(Flag::SILENT) {
let _ = session.refresh_credentials(Flag::SILENT);
};
}
}