use crate::core::Expression;
use std::cell::RefCell;
use std::collections::HashMap;
use std::thread_local;
thread_local! {
static FUNCTION_NAME_CACHE: RefCell<HashMap<String, String>> = RefCell::new(HashMap::new());
static EXPR_LIST_BUFFER: RefCell<Vec<Expression>> = RefCell::new(Vec::with_capacity(16));
static COMMON_EXPRESSIONS: RefCell<HashMap<&'static str, Expression>> = RefCell::new({
let mut map = HashMap::new();
map.insert("0", Expression::integer(0));
map.insert("1", Expression::integer(1));
map.insert("-1", Expression::integer(-1));
map.insert("2", Expression::integer(2));
map.insert("pi", Expression::pi());
map.insert("e", Expression::e());
map.insert("i", Expression::i());
map.insert("infinity", Expression::infinity());
map
});
}
pub fn get_cached_function_name(base: &str, suffix: &str) -> String {
FUNCTION_NAME_CACHE.with(|cache| {
let key = format!("{}_{}", base, suffix);
let mut cache = cache.borrow_mut();
cache.entry(key.clone()).or_insert_with(|| key).clone()
})
}
pub fn build_expr_list(exprs: impl IntoIterator<Item = Expression>) -> Vec<Expression> {
EXPR_LIST_BUFFER.with(|buffer| {
let mut buffer = buffer.borrow_mut();
buffer.clear();
buffer.extend(exprs);
buffer.clone() })
}
pub fn get_cached_expression(key: &'static str) -> Option<Expression> {
COMMON_EXPRESSIONS.with(|cache| cache.borrow().get(key).cloned())
}
pub fn build_cached_function(base: &str, suffix: &str, args: Vec<Expression>) -> Expression {
let name = get_cached_function_name(base, suffix);
Expression::function(name, args)
}
pub fn clear_caches() {
FUNCTION_NAME_CACHE.with(|cache| cache.borrow_mut().clear());
EXPR_LIST_BUFFER.with(|buffer| buffer.borrow_mut().clear());
COMMON_EXPRESSIONS.with(|cache| cache.borrow_mut().clear());
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub function_name_cache_size: usize,
pub expr_list_buffer_capacity: usize,
pub common_expressions_size: usize,
}
pub fn get_cache_stats() -> CacheStats {
let function_name_cache_size = FUNCTION_NAME_CACHE.with(|cache| cache.borrow().len());
let expr_list_buffer_capacity = EXPR_LIST_BUFFER.with(|buffer| buffer.borrow().capacity());
let common_expressions_size = COMMON_EXPRESSIONS.with(|cache| cache.borrow().len());
CacheStats {
function_name_cache_size,
expr_list_buffer_capacity,
common_expressions_size,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_function_name_caching() {
clear_caches();
let name1 = get_cached_function_name("test", "function");
let name2 = get_cached_function_name("test", "function");
assert_eq!(name1, name2);
assert_eq!(name1, "test_function");
let stats = get_cache_stats();
assert_eq!(stats.function_name_cache_size, 1);
}
#[test]
fn test_expr_list_building() {
let exprs = vec![
Expression::integer(1),
Expression::integer(2),
Expression::integer(3),
];
let result = build_expr_list(exprs.clone());
assert_eq!(result.len(), 3);
assert_eq!(result, exprs);
}
#[test]
fn test_cached_expressions() {
let zero = get_cached_expression("0");
let pi = get_cached_expression("pi");
let unknown = get_cached_expression("unknown");
assert!(zero.is_some());
assert!(pi.is_some());
assert!(unknown.is_none());
}
#[test]
fn test_cached_function_building() {
let args = vec![Expression::integer(1), Expression::symbol("x")];
let func = build_cached_function("bessel", "j", args.clone());
match func {
Expression::Function {
name,
args: func_args,
} => {
assert_eq!(name.as_ref(), "bessel_j");
assert_eq!(*func_args, args);
}
_ => panic!("Expected function expression"),
}
}
#[test]
fn test_cache_clearing() {
let _name = get_cached_function_name("test", "clear");
let _expr = build_expr_list(vec![Expression::integer(1)]);
let stats_before = get_cache_stats();
assert!(stats_before.function_name_cache_size > 0);
clear_caches();
let stats_after = get_cache_stats();
assert_eq!(stats_after.function_name_cache_size, 0);
}
}