use crate::bio::{MemBio, MemBioBuf};
use crate::error::ErrorStack;
use native_ossl_sys as sys;
use std::ffi::CString;
pub struct Pkcs12 {
ptr: *mut sys::PKCS12,
}
unsafe impl Send for Pkcs12 {}
unsafe impl Sync for Pkcs12 {}
impl Drop for Pkcs12 {
fn drop(&mut self) {
unsafe { sys::PKCS12_free(self.ptr) };
}
}
impl Pkcs12 {
pub fn from_der(der: &[u8]) -> Result<Self, ErrorStack> {
let bio = MemBioBuf::new(der)?;
let ptr = unsafe { sys::d2i_PKCS12_bio(bio.as_ptr(), std::ptr::null_mut()) };
if ptr.is_null() {
return Err(ErrorStack::drain());
}
Ok(Pkcs12 { ptr })
}
pub fn to_der(&self) -> Result<Vec<u8>, ErrorStack> {
let mut bio = MemBio::new()?;
let rc = unsafe { sys::i2d_PKCS12_bio(bio.as_ptr(), self.ptr) };
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(bio.into_vec())
}
pub fn parse(
&self,
password: &str,
) -> Result<
(
crate::pkey::Pkey<crate::pkey::Private>,
crate::x509::X509,
Vec<crate::x509::X509>,
),
ErrorStack,
> {
let pass = CString::new(password).map_err(|_| ErrorStack::drain())?;
let mut pkey_ptr: *mut sys::EVP_PKEY = std::ptr::null_mut();
let mut cert_ptr: *mut sys::X509 = std::ptr::null_mut();
let mut ca_ptr: *mut sys::stack_st_X509 = std::ptr::null_mut();
let rc = unsafe {
sys::PKCS12_parse(
self.ptr,
pass.as_ptr(),
std::ptr::addr_of_mut!(pkey_ptr),
std::ptr::addr_of_mut!(cert_ptr),
std::ptr::addr_of_mut!(ca_ptr),
)
};
if rc != 1 {
return Err(ErrorStack::drain());
}
if pkey_ptr.is_null() || cert_ptr.is_null() {
if !pkey_ptr.is_null() {
unsafe { sys::EVP_PKEY_free(pkey_ptr) };
}
if !cert_ptr.is_null() {
unsafe { sys::X509_free(cert_ptr) };
}
free_x509_stack(ca_ptr);
return Err(ErrorStack::drain());
}
let key = unsafe { crate::pkey::Pkey::from_ptr(pkey_ptr) };
let cert = unsafe { crate::x509::X509::from_ptr(cert_ptr) };
let ca = drain_x509_stack(ca_ptr);
Ok((key, cert, ca))
}
pub fn create(
password: &str,
name: &str,
key: &crate::pkey::Pkey<crate::pkey::Private>,
cert: &crate::x509::X509,
ca: &[crate::x509::X509],
) -> Result<Self, ErrorStack> {
let ca_stack = if ca.is_empty() {
std::ptr::null_mut()
} else {
build_x509_stack(ca)?
};
let pass = CString::new(password).map_err(|_| ErrorStack::drain())?;
let name = CString::new(name).map_err(|_| ErrorStack::drain())?;
let ptr = unsafe {
sys::PKCS12_create_ex(
pass.as_ptr(),
name.as_ptr(),
key.as_ptr(),
cert.as_ptr(),
ca_stack,
0, 0, 0, 0, 0, std::ptr::null_mut(), std::ptr::null(), )
};
if !ca_stack.is_null() {
unsafe {
sys::OPENSSL_sk_free(ca_stack.cast::<sys::OPENSSL_STACK>());
}
}
if ptr.is_null() {
return Err(ErrorStack::drain());
}
Ok(Pkcs12 { ptr })
}
}
fn drain_x509_stack(stack: *mut sys::stack_st_X509) -> Vec<crate::x509::X509> {
if stack.is_null() {
return Vec::new();
}
let n = unsafe { sys::OPENSSL_sk_num(stack.cast::<sys::OPENSSL_STACK>()) };
let n = usize::try_from(n).unwrap_or(0);
let mut out = Vec::with_capacity(n);
for i in 0..n {
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let raw = unsafe { sys::OPENSSL_sk_value(stack.cast::<sys::OPENSSL_STACK>(), i as i32) };
if !raw.is_null() {
out.push(unsafe { crate::x509::X509::from_ptr(raw.cast::<sys::X509>()) });
}
}
unsafe { sys::OPENSSL_sk_free(stack.cast::<sys::OPENSSL_STACK>()) };
out
}
fn free_x509_stack(stack: *mut sys::stack_st_X509) {
if stack.is_null() {
return;
}
let n = unsafe { sys::OPENSSL_sk_num(stack.cast::<sys::OPENSSL_STACK>()) };
for i in 0..n {
let raw = unsafe { sys::OPENSSL_sk_value(stack.cast::<sys::OPENSSL_STACK>(), i) };
if !raw.is_null() {
unsafe { sys::X509_free(raw.cast::<sys::X509>()) };
}
}
unsafe { sys::OPENSSL_sk_free(stack.cast::<sys::OPENSSL_STACK>()) };
}
fn build_x509_stack(certs: &[crate::x509::X509]) -> Result<*mut sys::stack_st_X509, ErrorStack> {
let raw = unsafe { sys::OPENSSL_sk_new_null() };
if raw.is_null() {
return Err(ErrorStack::drain());
}
for cert in certs {
unsafe { sys::X509_up_ref(cert.as_ptr()) };
let rc = unsafe { sys::OPENSSL_sk_push(raw, cert.as_ptr().cast::<std::ffi::c_void>()) };
if rc == 0 {
unsafe { sys::X509_free(cert.as_ptr()) };
let n = unsafe { sys::OPENSSL_sk_num(raw) };
for i in 0..n {
let p = unsafe { sys::OPENSSL_sk_value(raw, i) };
if !p.is_null() {
unsafe { sys::X509_free(p.cast::<sys::X509>()) };
}
}
unsafe { sys::OPENSSL_sk_free(raw) };
return Err(ErrorStack::drain());
}
}
Ok(raw.cast::<sys::stack_st_X509>())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pkey::{KeygenCtx, Pkey, Private, Public};
use crate::x509::{X509Builder, X509NameOwned};
fn make_self_signed() -> (crate::x509::X509, Pkey<Private>) {
let mut kgen = KeygenCtx::new(c"ED25519").unwrap();
let priv_key = kgen.generate().unwrap();
let pub_key = Pkey::<Public>::from(priv_key.clone());
let mut name = X509NameOwned::new().unwrap();
name.add_entry_by_txt(c"CN", b"PKCS12 Test").unwrap();
let cert = X509Builder::new()
.unwrap()
.set_version(2)
.unwrap()
.set_serial_number(1)
.unwrap()
.set_not_before_offset(0)
.unwrap()
.set_not_after_offset(365 * 86400)
.unwrap()
.set_subject_name(&name)
.unwrap()
.set_issuer_name(&name)
.unwrap()
.set_public_key(&pub_key)
.unwrap()
.sign(&priv_key, None)
.unwrap()
.build();
(cert, priv_key)
}
#[test]
fn pkcs12_create_and_parse_roundtrip() {
let (cert, priv_key) = make_self_signed();
let p12 = Pkcs12::create("testpass", "test", &priv_key, &cert, &[]).unwrap();
let der = p12.to_der().unwrap();
assert!(!der.is_empty());
let p12b = Pkcs12::from_der(&der).unwrap();
let (key2, cert2, ca2) = p12b.parse("testpass").unwrap();
assert!(key2.is_a(c"ED25519"));
let subj = cert2.subject_name().to_string().unwrap();
assert!(subj.contains("PKCS12 Test"));
assert!(ca2.is_empty());
}
#[test]
fn pkcs12_der_roundtrip() {
let (cert, priv_key) = make_self_signed();
let p12 = Pkcs12::create("pass", "n", &priv_key, &cert, &[]).unwrap();
let der1 = p12.to_der().unwrap();
let p12b = Pkcs12::from_der(&der1).unwrap();
let der2 = p12b.to_der().unwrap();
assert_eq!(der1, der2);
}
#[test]
fn pkcs12_wrong_password_fails() {
let (cert, priv_key) = make_self_signed();
let p12 = Pkcs12::create("rightpass", "n", &priv_key, &cert, &[]).unwrap();
let der = p12.to_der().unwrap();
let p12b = Pkcs12::from_der(&der).unwrap();
assert!(p12b.parse("wrongpass").is_err());
}
}