use jni::{
jni_sig, jni_str,
objects::{JByteArray, JObject, JObjectArray, JString, JValue},
signature::MethodSignature,
Env,
};
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerifier};
use rustls::crypto::{verify_tls12_signature, verify_tls13_signature, CryptoProvider};
use rustls::pki_types;
use rustls::Error::InvalidCertificate;
use rustls::{
CertificateError, DigitallySignedStruct, Error as TlsError, OtherError, SignatureScheme,
};
use std::sync::Arc;
use super::{log_server_cert, ALLOWED_EKUS};
use crate::android::{with_context, CachedClass};
static CERT_VERIFIER_CLASS: CachedClass =
CachedClass::new(jni_str!("org.rustls.platformverifier.CertificateVerifier"));
#[derive(Debug)]
enum VerifierStatus {
Ok,
Unavailable,
Expired,
UnknownCert,
Revoked,
InvalidEncoding,
InvalidExtension,
}
const AUTH_TYPE: &str = "RSA";
#[derive(Debug)]
pub struct Verifier {
#[cfg(any(test, feature = "ffi-testing"))]
test_only_root_ca_override: Option<pki_types::CertificateDer<'static>>,
crypto_provider: Arc<CryptoProvider>,
}
#[cfg(any(test, feature = "ffi-testing"))]
impl Drop for Verifier {
fn drop(&mut self) {
with_context::<_, ()>(|cx| {
let cert_verifier_class = CERT_VERIFIER_CLASS.get(cx)?;
cx.env
.call_static_method(
cert_verifier_class,
jni_str!("clearMockRoots"),
jni_sig!(() -> void),
&[],
)?
.v()?;
Ok(())
})
.expect("failed to clear test roots")
}
}
impl Verifier {
#[cfg_attr(docsrs, doc(cfg(all())))]
pub fn new(crypto_provider: Arc<CryptoProvider>) -> Result<Self, TlsError> {
Ok(Self {
#[cfg(any(test, feature = "ffi-testing"))]
test_only_root_ca_override: None,
crypto_provider,
})
}
#[cfg(any(test, feature = "ffi-testing"))]
pub(crate) fn new_with_fake_root(
root: pki_types::CertificateDer<'static>,
crypto_provider: Arc<CryptoProvider>,
) -> Self {
Self {
test_only_root_ca_override: Some(root),
crypto_provider,
}
}
fn verify_certificate(
&self,
end_entity: &pki_types::CertificateDer<'_>,
intermediates: &[pki_types::CertificateDer<'_>],
server_name: &pki_types::ServerName<'_>,
ocsp_response: Option<&[u8]>,
now: pki_types::UnixTime,
) -> Result<(), TlsError> {
let certificate_chain = std::iter::once(end_entity)
.chain(intermediates)
.map(|cert| cert.as_ref())
.enumerate();
let now: i64 = (now.as_secs() * 1000)
.try_into()
.map_err(|_| TlsError::FailedToGetCurrentTime)?;
let verification_result = with_context(|cx| {
let cert_verifier_class = CERT_VERIFIER_CLASS.get(cx)?;
let cert_list = {
let array = JObjectArray::<JByteArray>::new(
cx.env,
intermediates.len() + 1,
&JByteArray::null(),
)?;
for (idx, cert) in certificate_chain {
let cert_buffer = cx.env.byte_array_from_slice(cert)?;
array.set_element(cx.env, idx, cert_buffer)?;
}
array
};
let allowed_ekus = {
let array =
JObjectArray::<JString>::new(cx.env, ALLOWED_EKUS.len(), &JString::null())?;
for (idx, eku) in ALLOWED_EKUS.iter().enumerate() {
let eku = cx.env.new_string(eku.to_str().expect(
"ALLOWED_EKUS entries are ASCII constants -- always valid UTF-8",
))?;
array.set_element(cx.env, idx, eku)?
}
array
};
let ocsp_response = match ocsp_response {
Some(b) => cx.env.byte_array_from_slice(b)?,
None => JByteArray::null(),
};
#[cfg(any(test, feature = "ffi-testing"))]
{
if let Some(mock_root) = &self.test_only_root_ca_override {
let mock_root = cx.env.byte_array_from_slice(mock_root)?;
cx.env
.call_static_method(
cert_verifier_class,
jni_str!("addMockRoot"),
jni_sig!((byte[]) -> void),
&[JValue::from(&mock_root)],
)?
.v()
.expect("failed to add test root")
}
}
const VERIFIER_CALL: MethodSignature<'static, 'static> = jni_sig!(
(
android.content.Context,
JString,
JString,
JString[],
byte[],
jlong,
byte[][]
) -> org.rustls.platformverifier.VerificationResult
);
let server_name = cx.env.new_string(server_name.to_str())?;
let auth_type = cx.env.new_string(AUTH_TYPE)?;
let result = cx
.env
.call_static_method(
cert_verifier_class,
jni_str!("verifyCertificateChain"),
VERIFIER_CALL,
&[
JValue::from(cx.global.context.as_ref()),
JValue::from(&server_name),
JValue::from(&auth_type),
JValue::from(&JObject::from(allowed_ekus)),
JValue::from(&ocsp_response),
JValue::Long(now),
JValue::from(&JObject::from(cert_list)),
],
)?
.l()?;
Ok(extract_result_info(cx.env, result))
});
match verification_result {
Ok((status, maybe_msg)) => {
match status {
VerifierStatus::Ok => {
rustls::client::verify_server_name(
&rustls::server::ParsedCertificate::try_from(end_entity)?,
server_name,
)
}
VerifierStatus::Unavailable => Err(TlsError::General(String::from(
"No system trust stores available",
))),
VerifierStatus::Expired => Err(InvalidCertificate(CertificateError::Expired)),
VerifierStatus::UnknownCert => {
log::warn!("certificate was not trusted: {}", maybe_msg.unwrap());
Err(InvalidCertificate(CertificateError::UnknownIssuer))
}
VerifierStatus::Revoked => {
log::warn!("certificate was revoked: {}", maybe_msg.unwrap());
Err(InvalidCertificate(CertificateError::Revoked))
}
VerifierStatus::InvalidEncoding => {
Err(InvalidCertificate(CertificateError::BadEncoding))
}
VerifierStatus::InvalidExtension => Err(InvalidCertificate(
CertificateError::Other(OtherError(std::sync::Arc::new(super::EkuError))),
)),
}
}
Err(e) => Err(TlsError::General(format!(
"failed to call native verifier: {e:?}",
))),
}
}
}
fn extract_result_info(env: &mut Env<'_>, result: JObject<'_>) -> (VerifierStatus, Option<String>) {
let status_code = env
.get_field(&result, jni_str!("code"), jni_sig!(jint))
.and_then(|code| code.i())
.unwrap();
let status = match status_code {
0 => VerifierStatus::Ok,
1 => VerifierStatus::Unavailable,
2 => VerifierStatus::Expired,
3 => VerifierStatus::UnknownCert,
4 => VerifierStatus::Revoked,
5 => VerifierStatus::InvalidEncoding,
6 => VerifierStatus::InvalidExtension,
i => unreachable!("unknown status code: {i}"),
};
let msg = env
.get_field(result, jni_str!("message"), jni_sig!(java.lang.String))
.and_then(|m| m.l())
.map(|s| {
if s.is_null() {
None
} else {
env.cast_local::<JString>(s)
.and_then(|s| s.try_to_string(env))
.ok()
}
})
.unwrap();
(status, msg)
}
#[cfg_attr(docsrs, doc(cfg(all())))]
impl ServerCertVerifier for Verifier {
fn verify_server_cert(
&self,
end_entity: &pki_types::CertificateDer<'_>,
intermediates: &[pki_types::CertificateDer<'_>],
server_name: &pki_types::ServerName,
ocsp_response: &[u8],
now: pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, TlsError> {
log_server_cert(end_entity);
let ocsp_data = if !ocsp_response.is_empty() {
Some(ocsp_response)
} else {
None
};
match self.verify_certificate(end_entity, intermediates, server_name, ocsp_data, now) {
Ok(()) => Ok(rustls::client::danger::ServerCertVerified::assertion()),
Err(e) => {
log::error!("failed to verify TLS certificate: {}", e);
Err(e)
}
}
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &pki_types::CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TlsError> {
verify_tls12_signature(
message,
cert,
dss,
&self.crypto_provider.signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &pki_types::CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TlsError> {
verify_tls13_signature(
message,
cert,
dss,
&self.crypto_provider.signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.crypto_provider
.signature_verification_algorithms
.supported_schemes()
}
}