use crate::{HashAlgorithm, fapi_sys::TPM2_ALG_ID, memory::CStringHolder};
use log::trace;
use std::{any::Any, borrow::Cow, ffi::CStr, fmt::Debug};
pub struct AuthCallback {
auth_fn: Box<AuthCallbackFunction>,
auth_value: Option<CStringHolder>,
}
type AuthCallbackFunction = dyn Fn(AuthCallbackParam) -> Option<Cow<'static, str>> + Send;
#[derive(Debug)]
pub struct AuthCallbackParam<'a> {
pub object_path: &'a str,
pub description: Option<&'a str>,
}
impl AuthCallback {
pub fn new(auth_fn: impl Fn(AuthCallbackParam) -> Option<Cow<'static, str>> + 'static + Send) -> Self {
Self { auth_fn: Box::new(auth_fn), auth_value: None }
}
pub fn with_data<T: 'static + Send>(sign_fn: impl Fn(AuthCallbackParam, &T) -> Option<Cow<'static, str>> + 'static + Send, extra_data: T) -> Self {
Self::new(move |callback_param| sign_fn(callback_param, &extra_data))
}
pub(crate) fn invoke(&mut self, object_path: &CStr, description: Option<&CStr>) -> Option<&CStringHolder> {
let param = AuthCallbackParam::new(object_path, description);
trace!("AuthCallback::invoke({:?})", ¶m);
match (self.auth_fn)(param) {
Some(value) => {
self.auth_value = CStringHolder::try_from(value).ok();
self.auth_value.as_ref()
}
_ => None,
}
}
pub(crate) fn clear_buffer(&mut self) {
self.auth_value.take();
}
}
impl<'a> AuthCallbackParam<'a> {
fn new(object_path: &'a CStr, description: Option<&'a CStr>) -> Self {
Self { object_path: object_path.to_str().unwrap_or_default(), description: description.map(|str| str.to_str().unwrap_or_default()) }
}
}
impl Debug for AuthCallback {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthCallback").field("auth_fn", &(*self.auth_fn).type_id()).field("auth_value", &self.auth_value).finish()
}
}
pub struct SignCallback {
sign_fn: Box<SignCallbackFunction>,
sign_data: Option<Vec<u8>>,
}
type SignCallbackFunction = dyn Fn(SignCallbackParam) -> Option<Vec<u8>> + Send;
#[derive(Debug)]
pub struct SignCallbackParam<'a> {
pub object_path: &'a str,
pub description: Option<&'a str>,
pub public_key: &'a str,
pub key_hint: Option<&'a str>,
pub hash_algo: HashAlgorithm,
pub challenge: &'a [u8],
}
impl SignCallback {
pub fn new(sign_fn: impl Fn(SignCallbackParam) -> Option<Vec<u8>> + 'static + Send) -> Self {
Self { sign_fn: Box::new(sign_fn), sign_data: None }
}
pub fn with_data<T: 'static + Send>(sign_fn: impl Fn(SignCallbackParam, &T) -> Option<Vec<u8>> + 'static + Send, extra_data: T) -> Self {
Self::new(move |callback_param| sign_fn(callback_param, &extra_data))
}
pub(crate) fn invoke(
&mut self,
object_path: &CStr,
description: Option<&CStr>,
public_key: &CStr,
key_hint: Option<&CStr>,
hash_algo: u32,
challenge: &[u8],
) -> Option<&[u8]> {
let param = SignCallbackParam::new(object_path, description, public_key, key_hint, hash_algo, challenge);
trace!("SignCallback::invoke({:?})", ¶m);
match (self.sign_fn)(param) {
Some(value) => {
self.sign_data = Some(value);
self.sign_data.as_deref()
}
_ => None,
}
}
pub(crate) fn clear_buffer(&mut self) {
self.sign_data.take();
}
}
impl<'a> SignCallbackParam<'a> {
fn new(
object_path: &'a CStr,
description: Option<&'a CStr>,
public_key: &'a CStr,
key_hint: Option<&'a CStr>,
hash_algo: u32,
challenge: &'a [u8],
) -> Self {
Self {
object_path: object_path.to_str().unwrap_or_default(),
description: description.map(|str| str.to_str().unwrap_or_default()),
public_key: public_key.to_str().unwrap_or_default(),
key_hint: key_hint.map(|str| str.to_str().unwrap_or_default()),
hash_algo: HashAlgorithm::from_id(TPM2_ALG_ID::try_from(hash_algo).unwrap_or_default()),
challenge,
}
}
}
impl Debug for SignCallback {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SignCallback").field("sign_fn", &(*self.sign_fn).type_id()).field("sign_data", &self.sign_data).finish()
}
}
pub struct BranCallback {
bran_fn: Box<BranCallbackFunction>,
}
type BranCallbackFunction = dyn Fn(BranCallbackParam) -> Option<usize> + Send;
#[derive(Debug)]
pub struct BranCallbackParam<'a> {
pub object_path: &'a str,
pub description: Option<&'a str>,
pub branches: Vec<&'a str>,
}
impl BranCallback {
pub fn new(bran_fn: impl Fn(BranCallbackParam) -> Option<usize> + 'static + Send) -> Self {
Self { bran_fn: Box::new(bran_fn) }
}
pub fn with_data<T: 'static + Send>(bran_fn: impl Fn(BranCallbackParam, &T) -> Option<usize> + 'static + Send, extra_data: T) -> Self {
Self::new(move |callback_param| bran_fn(callback_param, &extra_data))
}
pub(crate) fn invoke(&mut self, object_path: &CStr, description: Option<&CStr>, branches: &[&CStr]) -> Option<usize> {
let param = BranCallbackParam::new(object_path, description, branches);
trace!("BranCallback::invoke({:?})", ¶m);
(self.bran_fn)(param).inspect(|index| {
if *index >= branches.len() {
panic!("The chosen branch index #{} is out of range! (must be in the 0..{} range)", index, branches.len() - 1usize);
}
})
}
}
impl<'a> BranCallbackParam<'a> {
fn new(object_path: &'a CStr, description: Option<&'a CStr>, branches: &'a [&CStr]) -> Self {
Self {
object_path: object_path.to_str().unwrap_or_default(),
description: description.map(|str| str.to_str().unwrap_or_default()),
branches: branches.iter().map(|str| str.to_str().unwrap_or_default()).collect(),
}
}
}
impl Debug for BranCallback {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BranCallback").field("bran_fn", &(*self.bran_fn).type_id()).finish()
}
}
pub struct ActnCallback {
actn_fn: Box<ActnCallbackFunction>,
}
type ActnCallbackFunction = dyn Fn(ActnCallbackParam) -> bool + Send;
#[derive(Debug)]
pub struct ActnCallbackParam<'a> {
pub object_path: &'a str,
pub action: Option<&'a str>,
}
impl ActnCallback {
pub fn new(actn_fn: impl Fn(ActnCallbackParam) -> bool + 'static + Send) -> Self {
Self { actn_fn: Box::new(actn_fn) }
}
pub fn with_data<T: 'static + Send>(actn_fn: impl Fn(ActnCallbackParam, &T) -> bool + 'static + Send, extra_data: T) -> Self {
Self::new(move |callback_param| actn_fn(callback_param, &extra_data))
}
pub(crate) fn invoke(&mut self, object_path: &CStr, action: Option<&CStr>) -> bool {
let param = ActnCallbackParam::new(object_path, action);
trace!("ActnCallback::invoke({:?})", ¶m);
(self.actn_fn)(param)
}
}
impl<'a> ActnCallbackParam<'a> {
fn new(object_path: &'a CStr, action: Option<&'a CStr>) -> Self {
Self { object_path: object_path.to_str().unwrap_or_default(), action: action.map(|str| str.to_str().unwrap_or_default()) }
}
}
impl Debug for ActnCallback {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ActnCallback").field("actn_fn", &(*self.actn_fn).type_id()).finish()
}
}
#[cfg(test)]
mod tests {
use super::{ActnCallback, AuthCallback, BranCallback, SignCallback};
use std::ffi::CString;
#[test]
fn test_callbacks() {
let param = CString::new("my param").unwrap();
let mut callback = AuthCallback::new(|_param| None);
let _ = format!("{:?}", callback);
callback.invoke(¶m, None);
let mut callback = AuthCallback::with_data(|_param, _data| None, "my data");
let _ = format!("{:?}", callback);
callback.invoke(¶m, None);
let mut callback = SignCallback::new(|_param| None);
let _ = format!("{:?}", callback);
callback.invoke(¶m, None, ¶m, None, 0, &[0u8]);
let mut callback = SignCallback::with_data(|_param, _data| None, "my data");
let _ = format!("{:?}", callback);
callback.invoke(¶m, None, ¶m, None, 0, &[0u8]);
let mut callback = BranCallback::new(|_param| None);
let _ = format!("{:?}", callback);
callback.invoke(¶m, None, &[¶m]);
let mut callback = BranCallback::with_data(|_param, _data| None, "my data");
let _ = format!("{:?}", callback);
callback.invoke(¶m, None, &[¶m]);
let mut callback = ActnCallback::new(|_param| true);
let _ = format!("{:?}", callback);
callback.invoke(¶m, Some(¶m));
let mut callback = ActnCallback::with_data(|_param, _data| true, "my data");
let _ = format!("{:?}", callback);
callback.invoke(¶m, Some(¶m));
}
}