use crate::{
execution::{context::SessionState, session_state::SessionStateBuilder},
object_store::{
Error, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta,
ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult,
memory::InMemory, path::Path,
},
prelude::SessionContext,
};
use futures::{FutureExt, stream::BoxStream};
use object_store::{CopyOptions, ObjectStoreExt};
use std::{
fmt::{Debug, Display, Formatter},
sync::Arc,
};
use tokio::{
sync::Barrier,
time::{Duration, timeout},
};
use url::Url;
pub fn register_test_store(ctx: &SessionContext, files: &[(&str, u64)]) {
let url = Url::parse("test://").unwrap();
ctx.register_object_store(&url, make_test_store_and_state(files).0);
}
pub fn make_test_store_and_state(files: &[(&str, u64)]) -> (Arc<InMemory>, SessionState) {
let memory = InMemory::new();
for (name, size) in files {
memory
.put(&Path::from(*name), vec![0; *size as usize].into())
.now_or_never()
.unwrap()
.unwrap();
}
(
Arc::new(memory),
SessionStateBuilder::new().with_default_features().build(),
)
}
pub fn local_unpartitioned_file(path: impl AsRef<std::path::Path>) -> ObjectMeta {
let location = Path::from_filesystem_path(path.as_ref()).unwrap();
let metadata = std::fs::metadata(path).expect("Local file metadata");
ObjectMeta {
location,
last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(),
size: metadata.len(),
e_tag: None,
version: None,
}
}
pub fn ensure_head_concurrency(
object_store: Arc<dyn ObjectStore>,
concurrency: usize,
) -> Arc<dyn ObjectStore> {
Arc::new(BlockingObjectStore::new(object_store, concurrency))
}
#[derive(Debug)]
struct BlockingObjectStore {
inner: Arc<dyn ObjectStore>,
barrier: Arc<Barrier>,
}
impl BlockingObjectStore {
const NAME: &'static str = "BlockingObjectStore";
fn new(inner: Arc<dyn ObjectStore>, expected_concurrency: usize) -> Self {
Self {
inner,
barrier: Arc::new(Barrier::new(expected_concurrency)),
}
}
}
impl Display for BlockingObjectStore {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.inner, f)
}
}
#[async_trait::async_trait]
impl ObjectStore for BlockingObjectStore {
async fn put_opts(
&self,
location: &Path,
payload: PutPayload,
opts: PutOptions,
) -> object_store::Result<PutResult> {
self.inner.put_opts(location, payload, opts).await
}
async fn put_multipart_opts(
&self,
location: &Path,
opts: PutMultipartOptions,
) -> object_store::Result<Box<dyn MultipartUpload>> {
self.inner.put_multipart_opts(location, opts).await
}
async fn get_opts(
&self,
location: &Path,
options: GetOptions,
) -> object_store::Result<GetResult> {
if options.head {
println!(
"{} received head call for {location}",
BlockingObjectStore::NAME
);
let wait_result = timeout(Duration::from_secs(1), self.barrier.wait()).await;
match wait_result {
Ok(_) => println!(
"{} barrier reached for {location}",
BlockingObjectStore::NAME
),
Err(_) => {
let error_message = format!(
"{} barrier wait timed out for {location}",
BlockingObjectStore::NAME
);
log::error!("{error_message}");
return Err(Error::Generic {
store: BlockingObjectStore::NAME,
source: error_message.into(),
});
}
}
}
self.inner.get_opts(location, options).await
}
fn delete_stream(
&self,
locations: BoxStream<'static, object_store::Result<Path>>,
) -> BoxStream<'static, object_store::Result<Path>> {
self.inner.delete_stream(locations)
}
fn list(
&self,
prefix: Option<&Path>,
) -> BoxStream<'static, object_store::Result<ObjectMeta>> {
self.inner.list(prefix)
}
async fn list_with_delimiter(
&self,
prefix: Option<&Path>,
) -> object_store::Result<ListResult> {
self.inner.list_with_delimiter(prefix).await
}
async fn copy_opts(
&self,
from: &Path,
to: &Path,
options: CopyOptions,
) -> object_store::Result<()> {
self.inner.copy_opts(from, to, options).await
}
}