use ahash::AHashMap;
use hashbrown::Equivalent;
use std::borrow::Cow;
use std::hash::{Hash, Hasher};
use std::sync::RwLock;
use crate::FieldType;
use crate::java_string::{JavaStr, JavaString};
#[derive(Debug, Clone)]
pub enum CachedResult {
Success,
Failed(String),
}
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct MethodKey<'a> {
pub class_name: Cow<'a, JavaStr>,
pub method_name: Cow<'a, JavaStr>,
pub descriptor: Cow<'a, JavaStr>,
}
struct MethodKeyLookup<'a, 'b>(&'a MethodKey<'b>);
impl Hash for MethodKeyLookup<'_, '_> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
impl Equivalent<MethodKey<'static>> for MethodKeyLookup<'_, '_> {
fn equivalent(&self, key: &MethodKey<'static>) -> bool {
self.0.class_name == key.class_name
&& self.0.method_name == key.method_name
&& self.0.descriptor == key.descriptor
}
}
impl MethodKey<'static> {
pub fn new(
class_name: impl Into<JavaString>,
method_name: impl Into<JavaString>,
descriptor: impl Into<JavaString>,
) -> Self {
Self {
class_name: Cow::Owned(class_name.into()),
method_name: Cow::Owned(method_name.into()),
descriptor: Cow::Owned(descriptor.into()),
}
}
}
impl<'a> MethodKey<'a> {
pub fn borrowed(
class_name: &'a JavaStr,
method_name: &'a JavaStr,
descriptor: &'a JavaStr,
) -> Self {
Self {
class_name: Cow::Borrowed(class_name),
method_name: Cow::Borrowed(method_name),
descriptor: Cow::Borrowed(descriptor),
}
}
pub fn into_owned(self) -> MethodKey<'static> {
MethodKey {
class_name: Cow::Owned(self.class_name.into_owned()),
method_name: Cow::Owned(self.method_name.into_owned()),
descriptor: Cow::Owned(self.descriptor.into_owned()),
}
}
}
#[derive(Debug, Clone)]
pub struct ParsedDescriptor {
pub parameters: Vec<FieldType>,
pub return_type: Option<FieldType>,
}
#[derive(Debug, Default)]
pub struct VerificationCache {
enabled: bool,
results: RwLock<hashbrown::HashMap<MethodKey<'static>, CachedResult, ahash::RandomState>>,
descriptors: RwLock<AHashMap<String, ParsedDescriptor>>,
stats: RwLock<CacheStats>,
}
#[derive(Debug, Default, Clone)]
pub struct CacheStats {
pub result_hits: u64,
pub result_misses: u64,
pub descriptor_hits: u64,
pub descriptor_misses: u64,
}
impl VerificationCache {
#[must_use]
pub fn new(enabled: bool) -> Self {
Self {
enabled,
results: RwLock::new(hashbrown::HashMap::with_hasher(
ahash::RandomState::default(),
)),
descriptors: RwLock::new(AHashMap::default()),
stats: RwLock::new(CacheStats::default()),
}
}
#[must_use]
pub fn disabled() -> Self {
Self::new(false)
}
#[must_use]
pub fn is_enabled(&self) -> bool {
self.enabled
}
#[must_use]
pub fn get_result(&self, key: &MethodKey<'_>) -> Option<CachedResult> {
if !self.enabled {
return None;
}
let guard = self.results.read().ok()?;
let result = guard.get(&MethodKeyLookup(key)).cloned();
drop(guard);
if let Ok(mut stats) = self.stats.write() {
if result.is_some() {
stats.result_hits += 1;
} else {
stats.result_misses += 1;
}
}
result
}
pub fn put_result(&self, key: &MethodKey<'_>, result: CachedResult) {
if !self.enabled {
return;
}
let owned_key = key.clone().into_owned();
if let Ok(mut guard) = self.results.write() {
guard.insert(owned_key, result);
}
}
#[must_use]
pub fn get_descriptor(&self, descriptor: &str) -> Option<ParsedDescriptor> {
if !self.enabled {
return None;
}
let guard = self.descriptors.read().ok()?;
let result = guard.get(descriptor).cloned();
drop(guard);
if let Ok(mut stats) = self.stats.write() {
if result.is_some() {
stats.descriptor_hits += 1;
} else {
stats.descriptor_misses += 1;
}
}
result
}
pub fn parse_descriptor(&self, descriptor: &str) -> Option<ParsedDescriptor> {
if let Some(cached) = self.get_descriptor(descriptor) {
return Some(cached);
}
let java_descriptor = JavaStr::cow_from_str(descriptor);
let (parameters, return_type) =
FieldType::parse_method_descriptor(&java_descriptor).ok()?;
let parsed = ParsedDescriptor {
parameters,
return_type,
};
if self.enabled
&& let Ok(mut guard) = self.descriptors.write()
{
guard.insert(descriptor.to_string(), parsed.clone());
}
Some(parsed)
}
#[must_use]
pub fn stats(&self) -> CacheStats {
self.stats.read().map(|s| s.clone()).unwrap_or_default()
}
pub fn clear(&self) {
if let Ok(mut guard) = self.results.write() {
guard.clear();
}
if let Ok(mut guard) = self.descriptors.write() {
guard.clear();
}
if let Ok(mut stats) = self.stats.write() {
*stats = CacheStats::default();
}
}
#[must_use]
pub fn result_count(&self) -> usize {
self.results.read().map_or(0, |g| g.len())
}
#[must_use]
pub fn descriptor_count(&self) -> usize {
self.descriptors.read().map_or(0, |g| g.len())
}
}
#[derive(Debug)]
pub struct FramePool {
locals_pool: Vec<Vec<crate::verifiers::bytecode::type_system::VerificationType>>,
stack_pool: Vec<Vec<crate::verifiers::bytecode::type_system::VerificationType>>,
max_size: usize,
}
impl FramePool {
#[must_use]
pub fn new(max_size: usize) -> Self {
Self {
locals_pool: Vec::with_capacity(max_size),
stack_pool: Vec::with_capacity(max_size),
max_size,
}
}
pub fn acquire_locals(
&mut self,
capacity: usize,
) -> Vec<crate::verifiers::bytecode::type_system::VerificationType> {
if let Some(mut buffer) = self.locals_pool.pop() {
buffer.clear();
if buffer.capacity() < capacity {
buffer.reserve(capacity);
}
buffer
} else {
Vec::with_capacity(capacity)
}
}
pub fn return_locals(
&mut self,
buffer: Vec<crate::verifiers::bytecode::type_system::VerificationType>,
) {
if self.locals_pool.len() < self.max_size {
self.locals_pool.push(buffer);
}
}
pub fn acquire_stack(
&mut self,
capacity: usize,
) -> Vec<crate::verifiers::bytecode::type_system::VerificationType> {
if let Some(mut buffer) = self.stack_pool.pop() {
buffer.clear();
if buffer.capacity() < capacity {
buffer.reserve(capacity);
}
buffer
} else {
Vec::with_capacity(capacity)
}
}
pub fn return_stack(
&mut self,
buffer: Vec<crate::verifiers::bytecode::type_system::VerificationType>,
) {
if self.stack_pool.len() < self.max_size {
self.stack_pool.push(buffer);
}
}
pub fn clear(&mut self) {
self.locals_pool.clear();
self.stack_pool.clear();
}
}
impl Default for FramePool {
fn default() -> Self {
Self::new(32)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_disabled() {
let cache = VerificationCache::disabled();
assert!(!cache.is_enabled());
let key = MethodKey::new("Test", "foo", "()V");
cache.put_result(&key, CachedResult::Success);
assert!(cache.get_result(&key).is_none());
}
#[test]
fn test_cache_enabled() {
let cache = VerificationCache::new(true);
assert!(cache.is_enabled());
let key = MethodKey::new("Test", "foo", "()V");
cache.put_result(&key, CachedResult::Success);
let result = cache.get_result(&key);
assert!(matches!(result, Some(CachedResult::Success)));
}
#[test]
fn test_descriptor_cache() {
let cache = VerificationCache::new(true);
let parsed = cache.parse_descriptor("(II)V");
assert!(parsed.is_some());
let parsed = parsed.unwrap();
assert_eq!(parsed.parameters.len(), 2);
assert!(parsed.return_type.is_none());
assert_eq!(cache.descriptor_count(), 1);
let _ = cache.parse_descriptor("(II)V");
let stats = cache.stats();
assert_eq!(stats.descriptor_hits, 1);
}
#[test]
fn test_cache_clear() {
let cache = VerificationCache::new(true);
let key = MethodKey::new("Test", "foo", "()V");
cache.put_result(&key, CachedResult::Success);
cache.parse_descriptor("(II)V");
assert_eq!(cache.result_count(), 1);
assert_eq!(cache.descriptor_count(), 1);
cache.clear();
assert_eq!(cache.result_count(), 0);
assert_eq!(cache.descriptor_count(), 0);
}
#[test]
fn test_frame_pool() {
let mut pool = FramePool::new(4);
let locals = pool.acquire_locals(10);
let stack = pool.acquire_stack(5);
assert!(locals.capacity() >= 10);
assert!(stack.capacity() >= 5);
pool.return_locals(locals);
pool.return_stack(stack);
let locals2 = pool.acquire_locals(5);
assert!(locals2.capacity() >= 10); }
#[test]
fn test_method_key_traits() {
let key1 = MethodKey::new("Class", "method", "()V");
let key2 = MethodKey::new("Class", "method", "()V");
let key3 = MethodKey::new("Class", "other", "()V");
assert_eq!(key1, key2);
assert_ne!(key1, key3);
let debug_str = format!("{key1:?}");
assert!(debug_str.contains("MethodKey"));
assert!(debug_str.contains("Class"));
let key_clone = key1.clone();
assert_eq!(key1, key_clone);
}
#[test]
fn test_cached_result_traits() {
let success = CachedResult::Success;
let failed = CachedResult::Failed("error".to_string());
assert!(format!("{success:?}").contains("Success"));
assert!(format!("{failed:?}").contains("Failed"));
let success_clone = success.clone();
assert!(matches!(success_clone, CachedResult::Success));
}
#[test]
fn test_parsed_descriptor_traits() {
let desc = ParsedDescriptor {
parameters: vec![],
return_type: None,
};
assert!(format!("{desc:?}").contains("ParsedDescriptor"));
let desc_clone = desc.clone();
assert!(desc_clone.parameters.is_empty());
assert!(desc_clone.return_type.is_none());
}
#[test]
fn test_cache_default() {
let cache = VerificationCache::default();
assert!(!cache.is_enabled()); }
#[test]
fn test_cache_stats_misses() {
let cache = VerificationCache::new(true);
let key = MethodKey::new("Test", "foo", "()V");
assert!(cache.get_result(&key).is_none());
assert!(cache.get_descriptor("()V").is_none());
let stats = cache.stats();
assert_eq!(stats.result_misses, 1);
assert_eq!(stats.descriptor_misses, 1);
assert_eq!(stats.result_hits, 0);
assert_eq!(stats.descriptor_hits, 0);
}
#[test]
fn test_parse_descriptor_invalid() {
let cache = VerificationCache::new(true);
let result = cache.parse_descriptor("invalid");
assert!(result.is_none());
}
#[test]
fn test_parse_descriptor_disabled() {
let cache = VerificationCache::disabled();
let result = cache.parse_descriptor("(II)V");
assert!(result.is_some());
assert_eq!(cache.descriptor_count(), 0);
}
#[test]
fn test_frame_pool_default() {
let pool = FramePool::default();
assert_eq!(pool.max_size, 32);
}
#[test]
fn test_frame_pool_resize() {
let mut pool = FramePool::new(1);
let locals = pool.acquire_locals(5);
pool.return_locals(locals);
let locals = pool.acquire_locals(10);
assert!(locals.capacity() >= 10);
pool.return_locals(locals);
let stack = pool.acquire_stack(5);
pool.return_stack(stack);
let stack = pool.acquire_stack(10);
assert!(stack.capacity() >= 10);
}
#[test]
fn test_frame_pool_limit() {
let mut pool = FramePool::new(1);
let l1 = pool.acquire_locals(1);
let l2 = pool.acquire_locals(1);
pool.return_locals(l1);
pool.return_locals(l2);
assert_eq!(pool.locals_pool.len(), 1);
let s1 = pool.acquire_stack(1);
let s2 = pool.acquire_stack(1);
pool.return_stack(s1);
pool.return_stack(s2);
assert_eq!(pool.stack_pool.len(), 1);
}
#[test]
fn test_frame_pool_clear() {
let mut pool = FramePool::new(5);
let l = pool.acquire_locals(1);
pool.return_locals(l);
let s = pool.acquire_stack(1);
pool.return_stack(s);
assert_eq!(pool.locals_pool.len(), 1);
assert_eq!(pool.stack_pool.len(), 1);
pool.clear();
assert_eq!(pool.locals_pool.len(), 0);
assert_eq!(pool.stack_pool.len(), 0);
}
#[test]
fn test_cache_stats_traits() {
let stats = CacheStats::default();
let debug_str = format!("{stats:?}");
assert!(debug_str.contains("CacheStats"));
let stats_clone = stats.clone();
assert_eq!(stats.result_hits, stats_clone.result_hits);
}
#[test]
fn test_verification_cache_debug() {
let cache = VerificationCache::new(true);
let debug_str = format!("{cache:?}");
assert!(debug_str.contains("VerificationCache"));
}
}