use std::collections::BTreeMap;
use crate::audit::AuditOption;
use crate::connector::{ConnectionError, ConnectionErrorKind};
use crate::object::{ObjectId, ObjectType};
use crate::session::{
securechannel::{Challenge, SecureChannel},
SessionId,
};
use super::{audit::CommandAuditOptions, object::Objects, session::HsmSession};
#[derive(Debug)]
pub(crate) struct State {
pub(super) command_audit_options: CommandAuditOptions,
pub(super) force_audit: AuditOption,
sessions: BTreeMap<SessionId, HsmSession>,
pub(super) objects: Objects,
}
impl State {
pub fn new() -> Self {
Self {
command_audit_options: CommandAuditOptions::default(),
force_audit: AuditOption::Off,
sessions: BTreeMap::new(),
objects: Objects::default(),
}
}
pub fn create_session(
&mut self,
authentication_key_id: ObjectId,
host_challenge: Challenge,
) -> &HsmSession {
let card_challenge = Challenge::random();
let session_id = self
.sessions
.keys()
.max()
.map(|id| id.succ().expect("session count exceeded"))
.unwrap_or_else(|| SessionId::from_u8(0).unwrap());
let channel = {
let authentication_key_obj = self
.objects
.get(authentication_key_id, ObjectType::AuthenticationKey)
.unwrap_or_else(|| {
panic!(
"MockHsm has no AuthenticationKey in slot {:?}",
authentication_key_id
)
});
SecureChannel::new(
session_id,
authentication_key_obj
.payload
.authentication_key()
.expect("auth key payload"),
host_challenge,
card_challenge,
)
};
let session = HsmSession::new(session_id, card_challenge, channel);
assert!(self.sessions.insert(session_id, session).is_none());
self.get_session(session_id).unwrap()
}
pub fn get_session(&mut self, id: SessionId) -> Result<&mut HsmSession, ConnectionError> {
self.sessions.get_mut(&id).ok_or_else(|| {
ConnectionError::new(
ConnectionErrorKind::RequestError,
Some(format!("invalid session ID: {:?}", id)),
)
})
}
pub fn close_session(&mut self, id: SessionId) {
assert!(self.sessions.remove(&id).is_some());
}
pub fn reset(&mut self) {
self.command_audit_options = CommandAuditOptions::default();
self.sessions = BTreeMap::new();
self.objects = Objects::default();
}
}