use std::sync::{Mutex, MutexGuard, OnceLock};
use ahash::random_state::RandomState;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyString};
use crate::string_decoder::StringOutput;
#[derive(Debug, Clone, Copy, Default)]
pub enum StringCacheMode {
#[default]
All,
Keys,
None,
}
impl<'py> FromPyObject<'_, 'py> for StringCacheMode {
type Error = PyErr;
fn extract(ob: Borrowed<'_, 'py, PyAny>) -> PyResult<StringCacheMode> {
if let Ok(bool_mode) = ob.cast::<PyBool>() {
Ok(bool_mode.is_true().into())
} else if let Ok(str_mode) = ob.extract::<&str>() {
match str_mode {
"all" => Ok(Self::All),
"keys" => Ok(Self::Keys),
"none" => Ok(Self::None),
_ => Err(PyValueError::new_err(
"Invalid string cache mode, should be `'all'`, '`keys`', `'none`' or a `bool`",
)),
}
} else {
Err(PyTypeError::new_err(
"Invalid string cache mode, should be `'all'`, '`keys`', `'none`' or a `bool`",
))
}
}
}
impl From<bool> for StringCacheMode {
fn from(mode: bool) -> Self {
if mode { Self::All } else { Self::None }
}
}
pub trait StringMaybeCache {
fn get_key<'py>(py: Python<'py>, string_output: StringOutput<'_, '_>) -> Bound<'py, PyString>;
fn get_value<'py>(py: Python<'py>, string_output: StringOutput<'_, '_>) -> Bound<'py, PyString> {
Self::get_key(py, string_output)
}
}
pub struct StringCacheAll;
impl StringMaybeCache for StringCacheAll {
fn get_key<'py>(py: Python<'py>, string_output: StringOutput<'_, '_>) -> Bound<'py, PyString> {
unsafe { cached_py_string_maybe_ascii(py, string_output.as_str(), string_output.ascii_only()) }
}
}
pub struct StringCacheKeys;
impl StringMaybeCache for StringCacheKeys {
fn get_key<'py>(py: Python<'py>, string_output: StringOutput<'_, '_>) -> Bound<'py, PyString> {
unsafe { cached_py_string_maybe_ascii(py, string_output.as_str(), string_output.ascii_only()) }
}
fn get_value<'py>(py: Python<'py>, string_output: StringOutput<'_, '_>) -> Bound<'py, PyString> {
unsafe { pystring_fast_new_maybe_ascii(py, string_output.as_str(), string_output.ascii_only()) }
}
}
pub struct StringNoCache;
impl StringMaybeCache for StringNoCache {
fn get_key<'py>(py: Python<'py>, string_output: StringOutput<'_, '_>) -> Bound<'py, PyString> {
unsafe { pystring_fast_new_maybe_ascii(py, string_output.as_str(), string_output.ascii_only()) }
}
}
static STRING_CACHE: OnceLock<Mutex<PyStringCache>> = OnceLock::new();
#[inline]
fn get_string_cache() -> MutexGuard<'static, PyStringCache> {
match STRING_CACHE.get_or_init(|| Mutex::new(PyStringCache::default())).lock() {
Ok(cache) => cache,
Err(poisoned) => {
let mut cache = poisoned.into_inner();
cache.clear();
cache
}
}
}
pub fn cache_usage() -> usize {
get_string_cache().usage()
}
pub fn cache_clear() {
get_string_cache().clear();
}
#[inline]
pub fn cached_py_string<'py>(py: Python<'py>, s: &str) -> Bound<'py, PyString> {
unsafe { cached_py_string_maybe_ascii(py, s, false) }
}
#[inline]
pub unsafe fn cached_py_string_ascii<'py>(py: Python<'py>, s: &str) -> Bound<'py, PyString> {
unsafe { cached_py_string_maybe_ascii(py, s, true) }
}
unsafe fn cached_py_string_maybe_ascii<'py>(py: Python<'py>, s: &str, ascii_only: bool) -> Bound<'py, PyString> {
unsafe {
if (2..64).contains(&s.len()) {
get_string_cache().get_or_insert(py, s, ascii_only)
} else {
pystring_fast_new_maybe_ascii(py, s, ascii_only)
}
}
}
const CAPACITY: usize = 16_384;
type Entry = Option<(u64, Py<PyString>)>;
#[derive(Debug)]
struct PyStringCache {
entries: Box<[Entry; CAPACITY]>,
hash_builder: RandomState,
}
const ARRAY_REPEAT_VALUE: Entry = None;
impl Default for PyStringCache {
fn default() -> Self {
Self {
entries: std::iter::repeat_with(|| ARRAY_REPEAT_VALUE)
.take(CAPACITY)
.collect::<Vec<_>>()
.into_boxed_slice()
.try_into()
.unwrap(),
hash_builder: RandomState::default(),
}
}
}
impl PyStringCache {
unsafe fn get_or_insert<'py>(&mut self, py: Python<'py>, s: &str, ascii_only: bool) -> Bound<'py, PyString> {
let hash = self.hash_builder.hash_one(s);
let hash_index = hash as usize % CAPACITY;
let set_entry = |entry: &mut Entry| {
let py_str = unsafe { pystring_fast_new_maybe_ascii(py, s, ascii_only) };
if let Some((_, old_py_str)) = entry.replace((hash, py_str.clone().unbind())) {
drop(old_py_str.into_bound(py));
}
py_str
};
for index in hash_index..hash_index.wrapping_add(5) {
if let Some(entry) = self.entries.get_mut(index) {
if let Some((entry_hash, py_str_ob)) = entry {
if *entry_hash == hash {
if py_str_ob.bind(py) == s {
return py_str_ob.bind(py).to_owned();
}
}
} else {
return set_entry(entry);
}
} else {
break;
}
}
let entry = self.entries.get_mut(hash_index).unwrap();
set_entry(entry)
}
fn usage(&self) -> usize {
self.entries.iter().filter(|e| e.is_some()).count()
}
fn clear(&mut self) {
self.entries.fill_with(|| None);
}
}
unsafe fn pystring_fast_new_maybe_ascii<'py>(py: Python<'py>, s: &str, ascii_only: bool) -> Bound<'py, PyString> {
if ascii_only {
unsafe { pystring_ascii_new(py, s) }
} else {
PyString::new(py, s)
}
}
pub unsafe fn pystring_ascii_new<'py>(py: Python<'py>, s: &str) -> Bound<'py, PyString> {
unsafe {
#[cfg(not(any(PyPy, GraalPy, Py_LIMITED_API)))]
{
let ptr = pyo3::ffi::PyUnicode_New(s.len() as isize, 127);
debug_assert_eq!(pyo3::ffi::PyUnicode_KIND(ptr), pyo3::ffi::PyUnicode_1BYTE_KIND);
let data_ptr = pyo3::ffi::PyUnicode_DATA(ptr).cast();
core::ptr::copy_nonoverlapping(s.as_ptr(), data_ptr, s.len());
core::ptr::write(data_ptr.add(s.len()), 0);
Bound::from_owned_ptr(py, ptr).cast_into_unchecked()
}
#[cfg(any(PyPy, GraalPy, Py_LIMITED_API))]
{
PyString::new(py, s)
}
}
}