r2d2_cryptoki/
lib.rs

1#![warn(missing_docs)]
2#![doc = include_str!("../README.md")]
3
4pub use cryptoki;
5pub use r2d2;
6
7use cryptoki::{
8    context::{Function, Pkcs11},
9    error::RvError,
10    session::{Session, SessionState, UserType},
11    slot::{Limit, Slot},
12    types::AuthPin,
13};
14use r2d2::ManageConnection;
15
16/// Alias for this crate's instance of r2d2's Pool
17pub type Pool = r2d2::Pool<SessionManager>;
18/// Alias for this crate's instance of r2d2's PooledSession
19pub type PooledSession = r2d2::PooledConnection<SessionManager>;
20
21/// Manager holding all information necessary for opening new connections
22#[derive(Debug, Clone)]
23pub struct SessionManager {
24    pkcs11: Pkcs11,
25    slot: Slot,
26    session_type: SessionType,
27}
28
29/// Session types, holding the pin for the authenticated sessions
30#[derive(Debug, Clone)]
31pub enum SessionType {
32    /// [SessionState::RoPublic]
33    RoPublic,
34    /// [SessionState::RoUser]
35    RoUser(AuthPin),
36    /// [SessionState::RwPublic]
37    RwPublic,
38    /// [SessionState::RwUser]
39    RwUser(AuthPin),
40    /// [SessionState::RwSecurityOfficer]
41    RwSecurityOfficer(AuthPin),
42}
43
44impl SessionType {
45    fn as_state(&self) -> SessionState {
46        match self {
47            Self::RoPublic => SessionState::RoPublic,
48            Self::RoUser(_) => SessionState::RoUser,
49            Self::RwPublic => SessionState::RwPublic,
50            Self::RwUser(_) => SessionState::RwUser,
51            Self::RwSecurityOfficer(_) => SessionState::RwSecurityOfficer,
52        }
53    }
54}
55
56impl SessionManager {
57    /// # Example
58    /// ```no_run
59    ///  # use r2d2_cryptoki::{*, cryptoki::{context::*, types::AuthPin}};
60    ///  let pkcs11 = Pkcs11::new("libsofthsm2.so").unwrap();
61    ///  pkcs11 .initialize(CInitializeArgs::OsThreads).unwrap();
62    ///  let slots = pkcs11.get_slots_with_token().unwrap();
63    ///  let slot = slots.first().unwrap();
64    ///  let manager = SessionManager::new(pkcs11, *slot, SessionType::RwUser(AuthPin::new("abcd".to_string())));
65    /// ```
66    pub fn new(pkcs11: Pkcs11, slot: Slot, session_type: SessionType) -> Self {
67        Self {
68            pkcs11,
69            slot,
70            session_type,
71        }
72    }
73
74    /// Returns the maximum number of sessions supported by the HSM.
75    ///
76    /// Arguments:
77    /// * `maximum`: A maximum number of sessions as `max_size` can return u32::max_value() which is probably more than what your application should use.
78    ///
79    /// # Example
80    /// ```no_run
81    ///  # use r2d2_cryptoki::{*, cryptoki::{context::*, types::AuthPin}};
82    ///  # let pkcs11 = Pkcs11::new("libsofthsm2.so").unwrap();
83    ///  # pkcs11.initialize(CInitializeArgs::OsThreads);
84    ///  # let slots = pkcs11.get_slots_with_token().unwrap();
85    ///  # let slot = slots.first().unwrap();
86    ///  # let manager = SessionManager::new(pkcs11, *slot, SessionType::RwUser(AuthPin::new("fedcba".to_string())));
87    ///  let pool_builder = r2d2::Pool::builder();
88    ///  let pool_builder = if let Some(max_size) = manager.max_size(100).unwrap() {
89    ///     pool_builder.max_size(max_size)
90    ///  } else {
91    ///     pool_builder
92    ///  };
93    ///  let pool = pool_builder.build(manager).unwrap();
94    /// ```
95    pub fn max_size(&self, maximum: u32) -> Result<Option<u32>, cryptoki::error::Error> {
96        let token_info = self.pkcs11.get_token_info(self.slot)?;
97        let limit = match self.session_type {
98            SessionType::RoPublic | SessionType::RoUser(_) => token_info.max_session_count(),
99            SessionType::RwPublic | SessionType::RwUser(_) | SessionType::RwSecurityOfficer(_) => {
100                token_info.max_session_count()
101            }
102        };
103        let res = match limit {
104            Limit::Max(m) => Some(m.try_into().unwrap_or(u32::MAX)),
105            Limit::Unavailable => None,
106            Limit::Infinite => Some(u32::MAX),
107        };
108        Ok(if let Some(true) = res.map(|r| r > maximum) {
109            Some(maximum)
110        } else {
111            res
112        })
113    }
114}
115
116impl ManageConnection for SessionManager {
117    type Connection = Session;
118
119    type Error = cryptoki::error::Error;
120
121    // Login is global, once a session logs in, all sessions are logged in https://stackoverflow.com/a/40225885.
122    // TODO cryptoki automatically logs out on Drop, so this is ineficient and we will need to find a better way to check the login state when we start having a pool of sessions
123    fn connect(&self) -> Result<Self::Connection, Self::Error> {
124        let session = match self.session_type {
125            SessionType::RoPublic | SessionType::RoUser(_) => {
126                self.pkcs11.open_ro_session(self.slot)?
127            }
128            SessionType::RwPublic | SessionType::RwUser(_) | SessionType::RwSecurityOfficer(_) => {
129                self.pkcs11.open_rw_session(self.slot)?
130            }
131        };
132        let maybe_user_info = match &self.session_type {
133            SessionType::RoPublic | SessionType::RwPublic => None,
134            SessionType::RoUser(pin) | SessionType::RwUser(pin) => Some((UserType::User, pin)),
135            SessionType::RwSecurityOfficer(pin) => Some((UserType::So, pin)),
136        };
137        if let Some(user_type) = maybe_user_info {
138            match session.login(user_type.0, Some(user_type.1)) {
139                Err(Self::Error::Pkcs11(RvError::UserAlreadyLoggedIn, Function::Login)) => {}
140                res => res?,
141            };
142        }
143        Ok(session)
144    }
145
146    fn is_valid(&self, session: &mut Self::Connection) -> Result<(), Self::Error> {
147        let actual_state = session.get_session_info()?.session_state();
148        let expected_state = &self.session_type;
149        if actual_state != expected_state.as_state() {
150            Err(Self::Error::Pkcs11(
151                RvError::UserNotLoggedIn,
152                Function::GetSessionInfo,
153            ))
154        } else {
155            Ok(())
156        }
157    }
158
159    fn has_broken(&self, _session: &mut Self::Connection) -> bool {
160        // TODO find a way to check session state without reaching out to the HSM
161        false
162    }
163}
164
165#[cfg(test)]
166mod test {
167    use std::{env, fs, path::Path, time::Duration};
168
169    use cached::proc_macro::{cached, once};
170    use cryptoki::{
171        context::CInitializeArgs,
172        mechanism::Mechanism,
173        object::{Attribute, KeyType, ObjectClass},
174    };
175    use r2d2::PooledConnection;
176
177    use super::*;
178
179    #[derive(Clone, Hash, PartialEq, Eq)]
180    struct Config {
181        max_sessions: Option<u32>,
182        label: Vec<u8>,
183    }
184
185    // Using cached to create only one pkcs11 ojbect, otherwise it segfaults.
186    #[once(sync_writes = true)]
187    fn default_pkcs11() -> Pkcs11 {
188        env::set_var("SOFTHSM2_CONF", "./test/softhsm2.conf");
189        let tokens_path = Path::new("./test/softhsm/tokens");
190        if tokens_path.exists() {
191            fs::remove_dir_all(tokens_path.to_str().unwrap()).unwrap();
192        }
193        fs::create_dir_all(tokens_path.to_str().unwrap()).unwrap();
194
195        let pkcs11 = Pkcs11::new("libsofthsm2.so").expect("Could not use pkcs11 library");
196        pkcs11
197            .initialize(CInitializeArgs::OsThreads)
198            .expect("Could not initialize pkcs11");
199        pkcs11
200    }
201
202    #[cached(sync_writes = true)]
203    fn default_token(pin: String) -> (Pkcs11, Slot) {
204        let pkcs11 = default_pkcs11();
205        let slot = {
206            let slots = pkcs11
207                .get_slots_with_token()
208                .expect("Could not get slots with token");
209            *slots.first().expect("Could not find a slot")
210        };
211        pkcs11
212            .init_token(slot, &pin.clone().into(), "token")
213            .expect("Could not initialize token");
214        let session = pkcs11.open_rw_session(slot).unwrap();
215        session
216            .login(cryptoki::session::UserType::So, Some(&pin.clone().into()))
217            .unwrap();
218        session.init_pin(&pin.into()).unwrap();
219
220        (pkcs11, slot)
221    }
222
223    fn default_setup(config: Config) -> Pool {
224        let pin_string = "abcde".to_string();
225        let pin = AuthPin::new(pin_string.clone());
226        let (pkcs11, slot) = default_token(pin_string);
227
228        let manager = SessionManager::new(pkcs11, slot, SessionType::RwUser(pin));
229        let pool_builder = r2d2::Pool::builder();
230        let pool_builder = if let Some(m) = config.max_sessions {
231            pool_builder.max_size(m)
232        } else {
233            pool_builder
234        };
235        let pool = pool_builder.build(manager).unwrap();
236
237        let mechanism = Mechanism::EccKeyPairGen;
238        let pub_key_template = vec![
239            Attribute::Token(true),
240            Attribute::Private(false),
241            Attribute::Derive(true),
242            Attribute::KeyType(KeyType::EC),
243            Attribute::Verify(true),
244            Attribute::EcParams(vec![
245                0x06, 0x08, 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x07,
246            ]),
247            Attribute::Label(config.label.clone()),
248        ];
249        let priv_key_template = vec![
250            Attribute::Token(true),
251            Attribute::Private(false),
252            Attribute::Sensitive(true),
253            Attribute::Extractable(false),
254            Attribute::Derive(true),
255            Attribute::Sign(true),
256            Attribute::Label(config.label),
257        ];
258
259        // sometimes raises an GeneralError
260        backoff::retry(
261            backoff::backoff::Constant::new(Duration::from_millis(25)),
262            || {
263                Ok(pool.get().unwrap().generate_key_pair(
264                    &mechanism,
265                    &pub_key_template,
266                    &priv_key_template,
267                )?)
268            },
269        )
270        .unwrap();
271        pool
272    }
273
274    fn sign(config: &Config, session: &PooledConnection<SessionManager>) -> Vec<u8> {
275        let template = vec![
276            Attribute::Class(ObjectClass::PRIVATE_KEY),
277            Attribute::Label(config.label.clone()),
278        ];
279        let objects = session.find_objects(&template).unwrap();
280        let private = objects.first().unwrap();
281        session
282            .sign(&Mechanism::Ecdsa, *private, "test_data".as_bytes())
283            .unwrap()
284    }
285    fn verify(config: &Config, session: PooledConnection<SessionManager>, signature: &[u8]) {
286        let template = vec![
287            Attribute::Class(ObjectClass::PUBLIC_KEY),
288            Attribute::Label(config.label.clone()),
289        ];
290        let objects = session.find_objects(&template).unwrap();
291        let public = objects.first().unwrap();
292        session
293            .verify(
294                &Mechanism::Ecdsa,
295                *public,
296                "test_data".as_bytes(),
297                signature,
298            )
299            .unwrap();
300    }
301
302    #[test]
303    fn basic() {
304        let config = Config {
305            max_sessions: None,
306            label: "basic".into(),
307        };
308        let pool = default_setup(config.clone());
309        let sig = sign(&config, &pool.get().unwrap());
310        verify(&config, pool.get().unwrap(), &sig);
311    }
312
313    fn basic_test(config: &Config, pool1: Pool) {
314        let pool2 = pool1.clone();
315        let config1 = config.clone();
316        let config2 = config.clone();
317        loom::thread::spawn(move || {
318            let sig = sign(&config1, &pool1.get().unwrap());
319            verify(&config1, pool1.get().unwrap(), &sig);
320        });
321        let sig = sign(&config2, &pool2.get().unwrap());
322        verify(&config2, pool2.get().unwrap(), &sig);
323    }
324
325    #[test]
326    fn basic_concurrency() {
327        loom::model(|| {
328            let config = Config {
329                max_sessions: None,
330                label: "basic_concurrency".into(),
331            };
332            let pool1 = default_setup(config.clone());
333            basic_test(&config, pool1);
334        });
335    }
336
337    #[test]
338    fn max_one_session() {
339        loom::model(|| {
340            let config = Config {
341                max_sessions: Some(1),
342                label: "max_one_session".into(),
343            };
344            let pool1 = default_setup(config.clone());
345            basic_test(&config, pool1);
346        });
347    }
348
349    #[test]
350    fn multiple_operations_per_session() {
351        loom::model(|| {
352            let config = Config {
353                max_sessions: Some(1),
354                label: "multiple_operations_per_session".into(),
355            };
356            let config2 = config.clone();
357            let pool1 = default_setup(config.clone());
358            let pool2 = pool1.clone();
359            loom::thread::spawn(move || {
360                let session = pool1.get().unwrap();
361                let sig = sign(&config, &session);
362                verify(&config, session, &sig);
363            });
364            let session = pool2.get().unwrap();
365            let sig = sign(&config2, &session);
366            verify(&config2, session, &sig);
367        });
368    }
369}