use std::future::Future;
use std::sync::Arc;
use std::sync::OnceLock;
use async_lock::Mutex as AsyncMutex;
use vortex_error::SharedVortexResult;
use vortex_error::VortexResult;
use crate::ArrayRef;
use crate::Canonical;
use crate::IntoArray;
use crate::dtype::DType;
use crate::stats::ArrayStats;
#[derive(Debug, Clone)]
pub struct SharedArray {
source: ArrayRef,
cached: Arc<OnceLock<SharedVortexResult<ArrayRef>>>,
async_compute_lock: Arc<AsyncMutex<()>>,
pub(super) dtype: DType,
pub(super) stats: ArrayStats,
}
impl SharedArray {
pub fn new(source: ArrayRef) -> Self {
Self {
dtype: source.dtype().clone(),
source,
cached: Arc::new(OnceLock::new()),
async_compute_lock: Arc::new(AsyncMutex::new(())),
stats: ArrayStats::default(),
}
}
pub(super) fn current_array_ref(&self) -> &ArrayRef {
match self.cached.get() {
Some(Ok(arr)) => arr,
_ => &self.source,
}
}
pub fn get_or_compute(
&self,
f: impl FnOnce(&ArrayRef) -> VortexResult<Canonical>,
) -> VortexResult<ArrayRef> {
let result = self
.cached
.get_or_init(|| f(&self.source).map(|c| c.into_array()).map_err(Arc::new));
result.clone().map_err(Into::into)
}
pub async fn get_or_compute_async<F, Fut>(&self, f: F) -> VortexResult<ArrayRef>
where
F: FnOnce(ArrayRef) -> Fut,
Fut: Future<Output = VortexResult<Canonical>>,
{
if let Some(result) = self.cached.get() {
return result.clone().map_err(Into::into);
}
let _guard = self.async_compute_lock.lock().await;
if let Some(result) = self.cached.get() {
return result.clone().map_err(Into::into);
}
let computed = f(self.source.clone())
.await
.map(|c| c.into_array())
.map_err(Arc::new);
let result = self.cached.get_or_init(|| computed);
result.clone().map_err(Into::into)
}
pub(super) fn set_source(&mut self, source: ArrayRef) {
self.dtype = source.dtype().clone();
self.source = source;
self.cached = Arc::new(OnceLock::new());
self.async_compute_lock = Arc::new(AsyncMutex::new(()));
}
}