use jni::errors::Error as JNIError;
use jni::objects::{Global, JClass, JClassLoader, JObject};
use jni::strings::JNIStr;
use jni::{jni_sig, jni_str, Env, JavaVM};
use once_cell::sync::OnceCell;
static GLOBAL: OnceCell<GlobalStorage> = OnceCell::new();
pub trait Runtime: Send + Sync {
fn java_vm(&self) -> &JavaVM;
fn context(&self) -> &Global<JObject<'static>>;
fn class_loader(&self) -> &Global<JClassLoader<'static>>;
}
enum GlobalStorage {
Internal {
java_vm: JavaVM,
context: Global<JObject<'static>>,
loader: Global<JClassLoader<'static>>,
},
External(&'static dyn Runtime),
}
impl GlobalStorage {
fn vm(&self) -> &JavaVM {
match self {
GlobalStorage::Internal { java_vm, .. } => java_vm,
GlobalStorage::External(runtime) => runtime.java_vm(),
}
}
fn context(&self, env: &mut Env) -> Result<GlobalContext, Error> {
let context = match self {
Self::Internal { context, .. } => context,
Self::External(global) => global.context(),
};
let loader = match self {
Self::Internal { loader, .. } => loader,
Self::External(global) => global.class_loader(),
};
Ok(GlobalContext {
context: env.new_global_ref(context)?,
loader: env.new_global_ref(loader)?,
})
}
}
pub(super) struct GlobalContext {
pub(super) context: Global<JObject<'static>>,
loader: Global<JClassLoader<'static>>,
}
fn global() -> &'static GlobalStorage {
GLOBAL
.get()
.expect("Expect rustls-platform-verifier to be initialized")
}
pub fn init_with_env(env: &mut Env, context: JObject) -> Result<(), JNIError> {
GLOBAL.get_or_try_init(|| -> Result<_, JNIError> {
let loader = env
.call_method(
&context,
jni_str!("getClassLoader"),
jni_sig!(() -> JClassLoader),
&[],
)?
.l()?;
let loader = env.cast_local::<JClassLoader>(loader)?;
Ok(GlobalStorage::Internal {
java_vm: env.get_java_vm()?,
context: env.new_global_ref(context)?,
loader: env.new_global_ref(loader)?,
})
})?;
Ok(())
}
pub fn init_with_runtime(runtime: &'static dyn Runtime) {
GLOBAL.get_or_init(|| GlobalStorage::External(runtime));
}
pub fn init_with_refs(
java_vm: JavaVM,
context: Global<JObject<'static>>,
loader: Global<JClassLoader<'static>>,
) {
GLOBAL.get_or_init(|| GlobalStorage::Internal {
java_vm,
context,
loader,
});
}
#[derive(Debug)]
pub(super) struct Error;
impl From<JNIError> for Error {
#[track_caller]
fn from(cause: JNIError) -> Self {
if let JNIError::JavaException = cause {
let _ = global()
.vm()
.with_top_local_frame(|env| -> Result<(), JNIError> {
env.exception_describe();
env.exception_clear();
Ok(())
});
}
Self
}
}
pub(super) struct LocalContext<'a, 'env> {
pub(super) env: &'a mut Env<'env>,
pub(super) global: GlobalContext,
}
impl<'env> LocalContext<'_, 'env> {
fn load_class(&mut self, name: &'static JNIStr) -> Result<JClass<'env>, Error> {
let name = self.env.new_string(name.to_str())?;
self.global
.loader
.load_class(self.env, name)
.map_err(Error::from)
}
}
pub(super) fn with_context<F, T: 'static>(f: F) -> Result<T, Error>
where
F: FnOnce(&mut LocalContext) -> Result<T, Error>,
{
let global = global();
global.vm().attach_current_thread_for_scope(|env| {
let global_context = global.context(env)?;
let mut context = LocalContext {
env,
global: global_context,
};
f(&mut context)
})
}
pub(super) struct CachedClass {
name: &'static JNIStr,
class: OnceCell<Global<JClass<'static>>>,
}
impl CachedClass {
pub(super) const fn new(name: &'static JNIStr) -> Self {
Self {
name,
class: OnceCell::new(),
}
}
pub(super) fn get(&self, cx: &mut LocalContext) -> Result<&JClass<'static>, Error> {
let class = self.class.get_or_try_init(|| -> Result<_, Error> {
let class = cx.load_class(self.name)?;
Ok(cx.env.new_global_ref(class)?)
})?;
Ok(class)
}
}