use std::cell::RefCell;
use std::env;
use std::sync::OnceLock;
use std::thread;
use testresult::TestResult;
use cryptoki::context::{CInitializeArgs, CInitializeFlags, Pkcs11};
use cryptoki::mechanism::Mechanism;
use cryptoki::object::Attribute;
use cryptoki::session::{Session, UserType};
use cryptoki::types::AuthPin;
const USER_PIN: &str = "fedcba";
const SO_PIN: &str = "abcdef";
static PKCS11_CTX: OnceLock<Pkcs11> = OnceLock::new();
thread_local! {
static PKCS11_SESSION: RefCell<Option<Session>> = const { RefCell::new(None) };
}
fn init_pkcs11_context() -> TestResult {
let lib_path = env::var("TEST_PKCS11_MODULE")
.unwrap_or_else(|_| "/usr/local/lib/softhsm/libsofthsm2.so".to_string());
let pkcs11 = Pkcs11::new(lib_path)?;
pkcs11.initialize(CInitializeArgs::new(CInitializeFlags::OS_LOCKING_OK))?;
let slot = pkcs11.get_slots_with_token()?[0];
let so_pin = AuthPin::new(SO_PIN.into());
pkcs11.init_token(slot, &so_pin, "Test Token")?;
let user_pin = AuthPin::new(USER_PIN.into());
{
let session = pkcs11.open_rw_session(slot)?;
session.login(UserType::So, Some(&so_pin))?;
session.init_pin(&user_pin)?;
}
PKCS11_CTX
.set(pkcs11)
.expect("PKCS11 context already initialized");
println!("PKCS11 context initialized successfully");
Ok(())
}
fn with_session<F, R>(f: F) -> TestResult<R>
where
F: FnOnce(&Session) -> TestResult<R>,
{
PKCS11_SESSION.with(|session_cell| {
let mut session_opt = session_cell.borrow_mut();
let needs_reopen = session_opt
.as_ref()
.map(|s| {
let session_info = s.get_session_info();
println!(
"Thread {:?}: Session info check: {:?}",
thread::current().id(),
session_info
);
session_info.is_err()
})
.unwrap_or(true);
if needs_reopen {
*session_opt = None;
let ctx = PKCS11_CTX
.get()
.expect("PKCS11 context should be initialized");
let slot = ctx.get_slots_with_token()?[0];
let new_session = ctx.open_rw_session(slot)?;
let user_pin = AuthPin::new(USER_PIN.into());
new_session.login(UserType::User, Some(&user_pin))?;
println!("Thread {:?}: Opened new RW session", thread::current().id());
*session_opt = Some(new_session);
} else {
println!(
"Thread {:?}: Reusing existing session",
thread::current().id()
);
}
let session_ref = session_opt.as_ref().expect("Session should exist");
f(session_ref)
})
}
fn generate_and_sign(thread_id: usize) -> TestResult {
println!(
"Thread {:?} (worker {}): Starting operations",
thread::current().id(),
thread_id
);
let (_public, private) = with_session(|session| {
println!(
"Thread {:?} (worker {}): Generating RSA key pair",
thread::current().id(),
thread_id
);
let pub_key_template = vec![
Attribute::Token(false), Attribute::Private(false),
Attribute::PublicExponent(vec![0x01, 0x00, 0x01]),
Attribute::ModulusBits(1024.into()),
];
let priv_key_template = vec![
Attribute::Token(false), Attribute::Sign(true), ];
let keys = session.generate_key_pair(
&Mechanism::RsaPkcsKeyPairGen,
&pub_key_template,
&priv_key_template,
)?;
println!(
"Thread {:?} (worker {}): Keys generated (pub: {}, priv: {})",
thread::current().id(),
thread_id,
keys.0.handle(),
keys.1.handle()
);
Ok(keys)
})?;
with_session(|session| {
let data = format!("Message 1 from thread {}", thread_id);
let signature = session.sign(&Mechanism::RsaPkcs, private, data.as_bytes())?;
println!(
"Thread {:?} (worker {}): First signature: {} bytes",
thread::current().id(),
thread_id,
signature.len()
);
Ok(())
})?;
with_session(|session| {
let data = format!("Message 2 from thread {}", thread_id);
let signature = session.sign(&Mechanism::RsaPkcs, private, data.as_bytes())?;
println!(
"Thread {:?} (worker {}): Second signature: {} bytes",
thread::current().id(),
thread_id,
signature.len()
);
Ok(())
})?;
println!(
"Thread {:?} (worker {}): All operations completed",
thread::current().id(),
thread_id
);
Ok(())
}
fn main() -> TestResult {
println!("Thread-Local Session Pattern Example");
println!("====================================\n");
println!("This example demonstrates:");
println!("- Sharing Pkcs11 context across threads (via Arc)");
println!("- Per-thread Sessions (via thread_local!)");
println!("- Automatic session lifecycle management");
println!("- Session reuse within the same thread\n");
println!("Initializing PKCS11 context...");
init_pkcs11_context()?;
println!();
let max_threads = 3;
println!("Spawning {max_threads} worker threads...\n");
let mut handles = vec![];
for i in 0..max_threads {
let handle = thread::spawn(move || generate_and_sign(i));
handles.push(handle);
}
println!();
for (i, handle) in handles.into_iter().enumerate() {
handle
.join()
.unwrap_or_else(|_| panic!("Thread {} panicked", i))?;
}
println!("\nAll threads completed successfully!");
println!("Note: Each thread had its own Session instance, reused across multiple operations.");
Ok(())
}