use std::iter::repeat_n;
use std::sync::{Arc, LazyLock, Mutex};
use arrow::array::{Array, ArrayRef, PrimitiveArray, new_null_array};
use arrow::datatypes::{
ArrowDictionaryKeyType, DataType, Int8Type, Int16Type, Int32Type, Int64Type,
UInt8Type, UInt16Type, UInt32Type, UInt64Type,
};
const MAX_CACHE_SIZE: usize = 1024 * 1024;
#[derive(Debug)]
struct KeyArrayCache<K: ArrowDictionaryKeyType> {
cache: Option<(usize, bool, PrimitiveArray<K>)>, }
impl<K: ArrowDictionaryKeyType> Default for KeyArrayCache<K> {
fn default() -> Self {
Self { cache: None }
}
}
impl<K: ArrowDictionaryKeyType> KeyArrayCache<K> {
fn get_or_create(&mut self, num_rows: usize, is_null: bool) -> PrimitiveArray<K> {
if num_rows > MAX_CACHE_SIZE {
return self.create_key_array(num_rows, is_null);
}
match &self.cache {
Some((cached_num_rows, cached_is_null, cached_array))
if *cached_num_rows == num_rows && *cached_is_null == is_null =>
{
cached_array.clone()
}
_ => {
let key_array = self.create_key_array(num_rows, is_null);
self.cache = Some((num_rows, is_null, key_array.clone()));
key_array
}
}
}
fn create_key_array(&self, num_rows: usize, is_null: bool) -> PrimitiveArray<K> {
let key_array: PrimitiveArray<K> = repeat_n(
if is_null {
None
} else {
Some(K::default_value())
},
num_rows,
)
.collect();
key_array
}
}
#[derive(Debug, Default)]
struct NullArrayCache {
cache: Option<(usize, ArrayRef)>, }
impl NullArrayCache {
fn get_or_create(&mut self, num_rows: usize) -> ArrayRef {
if num_rows > MAX_CACHE_SIZE {
return new_null_array(&DataType::Null, num_rows);
}
match &self.cache {
Some((cached_num_rows, cached_array)) if *cached_num_rows == num_rows => {
Arc::clone(cached_array)
}
_ => {
let null_array = new_null_array(&DataType::Null, num_rows);
self.cache = Some((num_rows, Arc::clone(&null_array)));
null_array
}
}
}
}
#[derive(Debug, Default)]
struct ArrayCaches {
cache_i8: KeyArrayCache<Int8Type>,
cache_i16: KeyArrayCache<Int16Type>,
cache_i32: KeyArrayCache<Int32Type>,
cache_i64: KeyArrayCache<Int64Type>,
cache_u8: KeyArrayCache<UInt8Type>,
cache_u16: KeyArrayCache<UInt16Type>,
cache_u32: KeyArrayCache<UInt32Type>,
cache_u64: KeyArrayCache<UInt64Type>,
null_cache: NullArrayCache,
}
static ARRAY_CACHES: LazyLock<Mutex<ArrayCaches>> =
LazyLock::new(|| Mutex::new(ArrayCaches::default()));
fn get_array_caches() -> &'static Mutex<ArrayCaches> {
&ARRAY_CACHES
}
pub(crate) fn get_or_create_cached_null_array(num_rows: usize) -> ArrayRef {
let cache = get_array_caches();
let mut caches = cache.lock().unwrap();
caches.null_cache.get_or_create(num_rows)
}
pub(crate) fn get_or_create_cached_key_array<K: ArrowDictionaryKeyType>(
num_rows: usize,
is_null: bool,
) -> PrimitiveArray<K> {
let cache = get_array_caches();
let mut caches = cache.lock().unwrap();
match K::DATA_TYPE {
DataType::Int8 => {
let array = caches.cache_i8.get_or_create(num_rows, is_null);
let array_data = array.to_data();
PrimitiveArray::<K>::from(array_data)
}
DataType::Int16 => {
let array = caches.cache_i16.get_or_create(num_rows, is_null);
let array_data = array.to_data();
PrimitiveArray::<K>::from(array_data)
}
DataType::Int32 => {
let array = caches.cache_i32.get_or_create(num_rows, is_null);
let array_data = array.to_data();
PrimitiveArray::<K>::from(array_data)
}
DataType::Int64 => {
let array = caches.cache_i64.get_or_create(num_rows, is_null);
let array_data = array.to_data();
PrimitiveArray::<K>::from(array_data)
}
DataType::UInt8 => {
let array = caches.cache_u8.get_or_create(num_rows, is_null);
let array_data = array.to_data();
PrimitiveArray::<K>::from(array_data)
}
DataType::UInt16 => {
let array = caches.cache_u16.get_or_create(num_rows, is_null);
let array_data = array.to_data();
PrimitiveArray::<K>::from(array_data)
}
DataType::UInt32 => {
let array = caches.cache_u32.get_or_create(num_rows, is_null);
let array_data = array.to_data();
PrimitiveArray::<K>::from(array_data)
}
DataType::UInt64 => {
let array = caches.cache_u64.get_or_create(num_rows, is_null);
let array_data = array.to_data();
PrimitiveArray::<K>::from(array_data)
}
_ => {
let key_array: PrimitiveArray<K> = repeat_n(
if is_null {
None
} else {
Some(K::default_value())
},
num_rows,
)
.collect();
key_array
}
}
}