use crate::{fapi_sys::TPM2_ALG_ID, memory::CStringHolder, HashAlgorithm};
use log::trace;
use std::{any::Any, borrow::Cow, ffi::CStr, fmt::Debug};
pub struct AuthCallback {
auth_fn: Box<dyn Fn(AuthCallbackParam) -> Option<Cow<'static, str>> + Send>,
auth_value: Option<CStringHolder>,
}
#[derive(Debug)]
pub struct AuthCallbackParam<'a> {
pub object_path: &'a str,
pub description: Option<&'a str>,
}
impl AuthCallback {
pub fn new(
auth_fn: impl Fn(AuthCallbackParam) -> Option<Cow<'static, str>> + 'static + Send,
) -> Self {
Self {
auth_fn: Box::new(auth_fn),
auth_value: None,
}
}
pub fn with_data<T: 'static + Send>(
sign_fn: impl Fn(AuthCallbackParam, &T) -> Option<Cow<'static, str>> + 'static + Send,
extra_data: T,
) -> Self {
Self::new(move |callback_param| sign_fn(callback_param, &extra_data))
}
pub(crate) fn invoke(
&mut self,
object_path: &CStr,
description: Option<&CStr>,
) -> Option<&CStringHolder> {
let param = AuthCallbackParam::new(object_path, description);
trace!("AuthCallback::invoke({:?})", ¶m);
match (self.auth_fn)(param) {
Some(value) => {
self.auth_value = CStringHolder::try_from(value).ok();
self.auth_value.as_ref()
}
_ => None,
}
}
pub(crate) fn clear_buffer(&mut self) {
self.auth_value.take();
}
}
impl<'a> AuthCallbackParam<'a> {
fn new(object_path: &'a CStr, description: Option<&'a CStr>) -> Self {
Self {
object_path: object_path.to_str().unwrap_or_default(),
description: description.map(|str| str.to_str().unwrap_or_default()),
}
}
}
impl Debug for AuthCallback {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthCallback")
.field("auth_fn", &self.auth_fn.type_id())
.field("auth_value", &self.auth_value)
.finish()
}
}
pub struct SignCallback {
sign_fn: Box<dyn Fn(SignCallbackParam) -> Option<Vec<u8>> + Send>,
sign_data: Option<Vec<u8>>,
}
#[derive(Debug)]
pub struct SignCallbackParam<'a> {
pub object_path: &'a str,
pub description: Option<&'a str>,
pub public_key: &'a str,
pub key_hint: Option<&'a str>,
pub hash_algo: HashAlgorithm,
pub challenge: &'a [u8],
}
impl SignCallback {
pub fn new(sign_fn: impl Fn(SignCallbackParam) -> Option<Vec<u8>> + 'static + Send) -> Self {
Self {
sign_fn: Box::new(sign_fn),
sign_data: None,
}
}
pub fn with_data<T: 'static + Send>(
sign_fn: impl Fn(SignCallbackParam, &T) -> Option<Vec<u8>> + 'static + Send,
extra_data: T,
) -> Self {
Self::new(move |callback_param| sign_fn(callback_param, &extra_data))
}
pub(crate) fn invoke(
&mut self,
object_path: &CStr,
description: Option<&CStr>,
public_key: &CStr,
key_hint: Option<&CStr>,
hash_algo: u32,
challenge: &[u8],
) -> Option<&[u8]> {
let param = SignCallbackParam::new(
object_path,
description,
public_key,
key_hint,
hash_algo,
challenge,
);
trace!("SignCallback::invoke({:?})", ¶m);
match (self.sign_fn)(param) {
Some(value) => {
self.sign_data = Some(value);
self.sign_data.as_ref().map(Vec::as_slice)
}
_ => None,
}
}
pub(crate) fn clear_buffer(&mut self) {
self.sign_data.take();
}
}
impl<'a> SignCallbackParam<'a> {
fn new(
object_path: &'a CStr,
description: Option<&'a CStr>,
public_key: &'a CStr,
key_hint: Option<&'a CStr>,
hash_algo: u32,
challenge: &'a [u8],
) -> Self {
Self {
object_path: object_path.to_str().unwrap_or_default(),
description: description.map(|str| str.to_str().unwrap_or_default()),
public_key: public_key.to_str().unwrap_or_default(),
key_hint: key_hint.map(|str| str.to_str().unwrap_or_default()),
hash_algo: HashAlgorithm::from_id(TPM2_ALG_ID::try_from(hash_algo).unwrap_or_default()),
challenge,
}
}
}
impl Debug for SignCallback {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SignCallback")
.field("sign_fn", &self.sign_fn.type_id())
.field("sign_data", &self.sign_data)
.finish()
}
}
pub struct BranCallback {
bran_fn: Box<dyn Fn(BranCallbackParam) -> Option<usize> + Send>,
}
#[derive(Debug)]
pub struct BranCallbackParam<'a> {
pub object_path: &'a str,
pub description: Option<&'a str>,
pub branches: Vec<&'a str>,
}
impl BranCallback {
pub fn new(bran_fn: impl Fn(BranCallbackParam) -> Option<usize> + 'static + Send) -> Self {
Self {
bran_fn: Box::new(bran_fn),
}
}
pub fn with_data<T: 'static + Send>(
bran_fn: impl Fn(BranCallbackParam, &T) -> Option<usize> + 'static + Send,
extra_data: T,
) -> Self {
Self::new(move |callback_param| bran_fn(callback_param, &extra_data))
}
pub(crate) fn invoke(
&mut self,
object_path: &CStr,
description: Option<&CStr>,
branches: &[&CStr],
) -> Option<usize> {
let param = BranCallbackParam::new(object_path, description, branches);
trace!("BranCallback::invoke({:?})", ¶m);
(self.bran_fn)(param).inspect(|index| {
if *index >= branches.len() {
panic!(
"The chosen branch index #{} is out of range! (must be in the 0..{} range)",
index,
branches.len() - 1usize
);
}
})
}
}
impl<'a> BranCallbackParam<'a> {
fn new(object_path: &'a CStr, description: Option<&'a CStr>, branches: &'a [&CStr]) -> Self {
Self {
object_path: object_path.to_str().unwrap_or_default(),
description: description.map(|str| str.to_str().unwrap_or_default()),
branches: branches
.into_iter()
.map(|str| str.to_str().unwrap_or_default())
.collect(),
}
}
}
impl Debug for BranCallback {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BranCallback")
.field("bran_fn", &self.bran_fn.type_id())
.finish()
}
}
pub struct ActnCallback {
actn_fn: Box<dyn Fn(ActnCallbackParam) -> bool + Send>,
}
#[derive(Debug)]
pub struct ActnCallbackParam<'a> {
pub object_path: &'a str,
pub action: Option<&'a str>,
}
impl ActnCallback {
pub fn new(actn_fn: impl Fn(ActnCallbackParam) -> bool + 'static + Send) -> Self {
Self {
actn_fn: Box::new(actn_fn),
}
}
pub fn with_data<T: 'static + Send>(
actn_fn: impl Fn(ActnCallbackParam, &T) -> bool + 'static + Send,
extra_data: T,
) -> Self {
Self::new(move |callback_param| actn_fn(callback_param, &extra_data))
}
pub(crate) fn invoke(&mut self, object_path: &CStr, action: Option<&CStr>) -> bool {
let param = ActnCallbackParam::new(object_path, action);
trace!("ActnCallback::invoke({:?})", ¶m);
(self.actn_fn)(param)
}
}
impl<'a> ActnCallbackParam<'a> {
fn new(object_path: &'a CStr, action: Option<&'a CStr>) -> Self {
Self {
object_path: object_path.to_str().unwrap_or_default(),
action: action.map(|str| str.to_str().unwrap_or_default()),
}
}
}
impl Debug for ActnCallback {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ActnCallback")
.field("actn_fn", &self.actn_fn.type_id())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::{ActnCallback, AuthCallback, BranCallback, SignCallback};
use std::ffi::CString;
#[test]
fn test_callbacks() {
let param = CString::new("my param").unwrap();
let mut callback = AuthCallback::new(|_param| None);
let _ = format!("{:?}", callback);
callback.invoke(¶m, None);
let mut callback = AuthCallback::with_data(|_param, _data| None, "my data");
let _ = format!("{:?}", callback);
callback.invoke(¶m, None);
let mut callback = SignCallback::new(|_param| None);
let _ = format!("{:?}", callback);
callback.invoke(¶m, None, ¶m, None, 0, &[0u8]);
let mut callback = SignCallback::with_data(|_param, _data| None, "my data");
let _ = format!("{:?}", callback);
callback.invoke(¶m, None, ¶m, None, 0, &[0u8]);
let mut callback = BranCallback::new(|_param| None);
let _ = format!("{:?}", callback);
callback.invoke(¶m, None, &[¶m]);
let mut callback = BranCallback::with_data(|_param, _data| None, "my data");
let _ = format!("{:?}", callback);
callback.invoke(¶m, None, &[¶m]);
let mut callback = ActnCallback::new(|_param| true);
let _ = format!("{:?}", callback);
callback.invoke(¶m, Some(¶m));
let mut callback = ActnCallback::with_data(|_param, _data| true, "my data");
let _ = format!("{:?}", callback);
callback.invoke(¶m, Some(¶m));
}
}