1#![warn(missing_docs)]
2#![doc = include_str!("../README.md")]
3
4pub use cryptoki;
5pub use r2d2;
6
7use cryptoki::{
8 context::{Function, Pkcs11},
9 error::RvError,
10 session::{Session, SessionState, UserType},
11 slot::{Limit, Slot},
12 types::AuthPin,
13};
14use r2d2::ManageConnection;
15
16pub type Pool = r2d2::Pool<SessionManager>;
18pub type PooledSession = r2d2::PooledConnection<SessionManager>;
20
21#[derive(Debug, Clone)]
23pub struct SessionManager {
24 pkcs11: Pkcs11,
25 slot: Slot,
26 session_type: SessionType,
27}
28
29#[derive(Debug, Clone)]
31pub enum SessionType {
32 RoPublic,
34 RoUser(AuthPin),
36 RwPublic,
38 RwUser(AuthPin),
40 RwSecurityOfficer(AuthPin),
42}
43
44impl SessionType {
45 fn as_state(&self) -> SessionState {
46 match self {
47 Self::RoPublic => SessionState::RoPublic,
48 Self::RoUser(_) => SessionState::RoUser,
49 Self::RwPublic => SessionState::RwPublic,
50 Self::RwUser(_) => SessionState::RwUser,
51 Self::RwSecurityOfficer(_) => SessionState::RwSecurityOfficer,
52 }
53 }
54}
55
56impl SessionManager {
57 pub fn new(pkcs11: Pkcs11, slot: Slot, session_type: SessionType) -> Self {
67 Self {
68 pkcs11,
69 slot,
70 session_type,
71 }
72 }
73
74 pub fn max_size(&self, maximum: u32) -> Result<Option<u32>, cryptoki::error::Error> {
96 let token_info = self.pkcs11.get_token_info(self.slot)?;
97 let limit = match self.session_type {
98 SessionType::RoPublic | SessionType::RoUser(_) => token_info.max_session_count(),
99 SessionType::RwPublic | SessionType::RwUser(_) | SessionType::RwSecurityOfficer(_) => {
100 token_info.max_session_count()
101 }
102 };
103 let res = match limit {
104 Limit::Max(m) => Some(m.try_into().unwrap_or(u32::MAX)),
105 Limit::Unavailable => None,
106 Limit::Infinite => Some(u32::MAX),
107 };
108 Ok(if let Some(true) = res.map(|r| r > maximum) {
109 Some(maximum)
110 } else {
111 res
112 })
113 }
114}
115
116impl ManageConnection for SessionManager {
117 type Connection = Session;
118
119 type Error = cryptoki::error::Error;
120
121 fn connect(&self) -> Result<Self::Connection, Self::Error> {
124 let session = match self.session_type {
125 SessionType::RoPublic | SessionType::RoUser(_) => {
126 self.pkcs11.open_ro_session(self.slot)?
127 }
128 SessionType::RwPublic | SessionType::RwUser(_) | SessionType::RwSecurityOfficer(_) => {
129 self.pkcs11.open_rw_session(self.slot)?
130 }
131 };
132 let maybe_user_info = match &self.session_type {
133 SessionType::RoPublic | SessionType::RwPublic => None,
134 SessionType::RoUser(pin) | SessionType::RwUser(pin) => Some((UserType::User, pin)),
135 SessionType::RwSecurityOfficer(pin) => Some((UserType::So, pin)),
136 };
137 if let Some(user_type) = maybe_user_info {
138 match session.login(user_type.0, Some(user_type.1)) {
139 Err(Self::Error::Pkcs11(RvError::UserAlreadyLoggedIn, Function::Login)) => {}
140 res => res?,
141 };
142 }
143 Ok(session)
144 }
145
146 fn is_valid(&self, session: &mut Self::Connection) -> Result<(), Self::Error> {
147 let actual_state = session.get_session_info()?.session_state();
148 let expected_state = &self.session_type;
149 if actual_state != expected_state.as_state() {
150 Err(Self::Error::Pkcs11(
151 RvError::UserNotLoggedIn,
152 Function::GetSessionInfo,
153 ))
154 } else {
155 Ok(())
156 }
157 }
158
159 fn has_broken(&self, _session: &mut Self::Connection) -> bool {
160 false
162 }
163}
164
165#[cfg(test)]
166mod test {
167 use std::{env, fs, path::Path, time::Duration};
168
169 use cached::proc_macro::{cached, once};
170 use cryptoki::{
171 context::CInitializeArgs,
172 mechanism::Mechanism,
173 object::{Attribute, KeyType, ObjectClass},
174 };
175 use r2d2::PooledConnection;
176
177 use super::*;
178
179 #[derive(Clone, Hash, PartialEq, Eq)]
180 struct Config {
181 max_sessions: Option<u32>,
182 label: Vec<u8>,
183 }
184
185 #[once(sync_writes = true)]
187 fn default_pkcs11() -> Pkcs11 {
188 env::set_var("SOFTHSM2_CONF", "./test/softhsm2.conf");
189 let tokens_path = Path::new("./test/softhsm/tokens");
190 if tokens_path.exists() {
191 fs::remove_dir_all(tokens_path.to_str().unwrap()).unwrap();
192 }
193 fs::create_dir_all(tokens_path.to_str().unwrap()).unwrap();
194
195 let pkcs11 = Pkcs11::new("libsofthsm2.so").expect("Could not use pkcs11 library");
196 pkcs11
197 .initialize(CInitializeArgs::OsThreads)
198 .expect("Could not initialize pkcs11");
199 pkcs11
200 }
201
202 #[cached(sync_writes = true)]
203 fn default_token(pin: String) -> (Pkcs11, Slot) {
204 let pkcs11 = default_pkcs11();
205 let slot = {
206 let slots = pkcs11
207 .get_slots_with_token()
208 .expect("Could not get slots with token");
209 *slots.first().expect("Could not find a slot")
210 };
211 pkcs11
212 .init_token(slot, &pin.clone().into(), "token")
213 .expect("Could not initialize token");
214 let session = pkcs11.open_rw_session(slot).unwrap();
215 session
216 .login(cryptoki::session::UserType::So, Some(&pin.clone().into()))
217 .unwrap();
218 session.init_pin(&pin.into()).unwrap();
219
220 (pkcs11, slot)
221 }
222
223 fn default_setup(config: Config) -> Pool {
224 let pin_string = "abcde".to_string();
225 let pin = AuthPin::new(pin_string.clone());
226 let (pkcs11, slot) = default_token(pin_string);
227
228 let manager = SessionManager::new(pkcs11, slot, SessionType::RwUser(pin));
229 let pool_builder = r2d2::Pool::builder();
230 let pool_builder = if let Some(m) = config.max_sessions {
231 pool_builder.max_size(m)
232 } else {
233 pool_builder
234 };
235 let pool = pool_builder.build(manager).unwrap();
236
237 let mechanism = Mechanism::EccKeyPairGen;
238 let pub_key_template = vec![
239 Attribute::Token(true),
240 Attribute::Private(false),
241 Attribute::Derive(true),
242 Attribute::KeyType(KeyType::EC),
243 Attribute::Verify(true),
244 Attribute::EcParams(vec![
245 0x06, 0x08, 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x07,
246 ]),
247 Attribute::Label(config.label.clone()),
248 ];
249 let priv_key_template = vec![
250 Attribute::Token(true),
251 Attribute::Private(false),
252 Attribute::Sensitive(true),
253 Attribute::Extractable(false),
254 Attribute::Derive(true),
255 Attribute::Sign(true),
256 Attribute::Label(config.label),
257 ];
258
259 backoff::retry(
261 backoff::backoff::Constant::new(Duration::from_millis(25)),
262 || {
263 Ok(pool.get().unwrap().generate_key_pair(
264 &mechanism,
265 &pub_key_template,
266 &priv_key_template,
267 )?)
268 },
269 )
270 .unwrap();
271 pool
272 }
273
274 fn sign(config: &Config, session: &PooledConnection<SessionManager>) -> Vec<u8> {
275 let template = vec![
276 Attribute::Class(ObjectClass::PRIVATE_KEY),
277 Attribute::Label(config.label.clone()),
278 ];
279 let objects = session.find_objects(&template).unwrap();
280 let private = objects.first().unwrap();
281 session
282 .sign(&Mechanism::Ecdsa, *private, "test_data".as_bytes())
283 .unwrap()
284 }
285 fn verify(config: &Config, session: PooledConnection<SessionManager>, signature: &[u8]) {
286 let template = vec![
287 Attribute::Class(ObjectClass::PUBLIC_KEY),
288 Attribute::Label(config.label.clone()),
289 ];
290 let objects = session.find_objects(&template).unwrap();
291 let public = objects.first().unwrap();
292 session
293 .verify(
294 &Mechanism::Ecdsa,
295 *public,
296 "test_data".as_bytes(),
297 signature,
298 )
299 .unwrap();
300 }
301
302 #[test]
303 fn basic() {
304 let config = Config {
305 max_sessions: None,
306 label: "basic".into(),
307 };
308 let pool = default_setup(config.clone());
309 let sig = sign(&config, &pool.get().unwrap());
310 verify(&config, pool.get().unwrap(), &sig);
311 }
312
313 fn basic_test(config: &Config, pool1: Pool) {
314 let pool2 = pool1.clone();
315 let config1 = config.clone();
316 let config2 = config.clone();
317 loom::thread::spawn(move || {
318 let sig = sign(&config1, &pool1.get().unwrap());
319 verify(&config1, pool1.get().unwrap(), &sig);
320 });
321 let sig = sign(&config2, &pool2.get().unwrap());
322 verify(&config2, pool2.get().unwrap(), &sig);
323 }
324
325 #[test]
326 fn basic_concurrency() {
327 loom::model(|| {
328 let config = Config {
329 max_sessions: None,
330 label: "basic_concurrency".into(),
331 };
332 let pool1 = default_setup(config.clone());
333 basic_test(&config, pool1);
334 });
335 }
336
337 #[test]
338 fn max_one_session() {
339 loom::model(|| {
340 let config = Config {
341 max_sessions: Some(1),
342 label: "max_one_session".into(),
343 };
344 let pool1 = default_setup(config.clone());
345 basic_test(&config, pool1);
346 });
347 }
348
349 #[test]
350 fn multiple_operations_per_session() {
351 loom::model(|| {
352 let config = Config {
353 max_sessions: Some(1),
354 label: "multiple_operations_per_session".into(),
355 };
356 let config2 = config.clone();
357 let pool1 = default_setup(config.clone());
358 let pool2 = pool1.clone();
359 loom::thread::spawn(move || {
360 let session = pool1.get().unwrap();
361 let sig = sign(&config, &session);
362 verify(&config, session, &sig);
363 });
364 let session = pool2.get().unwrap();
365 let sig = sign(&config2, &session);
366 verify(&config2, session, &sig);
367 });
368 }
369}