use std::collections::{BTreeMap, HashSet};
use std::sync::Arc;
use arrow_array::RecordBatch;
use arrow_schema::Schema as ArrowSchema;
use object_store::path::Path;
use snafu::{location, Location};
use tempfile::TempDir;
use crate::datatypes::Schema;
use crate::io::reader::batches_stream;
use crate::io::FileReader;
use crate::{
error::{Error, Result},
io::{FileWriter, ObjectStore, RecordBatchStream},
};
const BUFFER_FILE_NAME: &str = "buffer.lance";
#[allow(dead_code)]
pub struct ShufflerBuilder {
buffer: BTreeMap<u32, Vec<RecordBatch>>,
flush_size: usize,
parted_groups: BTreeMap<u32, Vec<u32>>,
temp_dir: Arc<TempDir>,
writer: FileWriter,
}
fn lance_buffer_path(dir: &TempDir) -> Result<Path> {
let tmp_dir_path = Path::from_filesystem_path(dir.path()).map_err(|e| Error::IO {
message: format!("failed to get buffer path in shuffler: {}", e),
location: location!(),
})?;
Ok(tmp_dir_path.child(BUFFER_FILE_NAME))
}
impl ShufflerBuilder {
#[allow(dead_code)]
pub async fn try_new(schema: &ArrowSchema, flush_threshold: usize) -> Result<Self> {
let temp_dir = Arc::new(tempfile::tempdir()?);
let object_store = ObjectStore::local();
let path = lance_buffer_path(&temp_dir)?;
let writer = object_store.create(&path).await?;
let lance_schema = Schema::try_from(schema)?;
Ok(Self {
buffer: BTreeMap::new(),
flush_size: flush_threshold, temp_dir,
parted_groups: BTreeMap::new(),
writer: FileWriter::with_object_writer(writer, lance_schema, &Default::default())?,
})
}
#[allow(dead_code)]
pub async fn insert(&mut self, key: u32, batch: RecordBatch) -> Result<()> {
let batches = self.buffer.entry(key).or_default();
batches.push(batch);
let total = batches.iter().map(|b| b.num_rows()).sum::<usize>();
if total >= self.flush_size {
self.parted_groups
.entry(key)
.or_default()
.push(self.writer.next_batch_id() as u32);
self.writer.write(batches).await?;
batches.clear();
};
Ok(())
}
#[allow(dead_code)]
pub async fn finish(&mut self) -> Result<Shuffler> {
for (key, batches) in self.buffer.iter() {
if !batches.is_empty() {
self.parted_groups
.entry(*key)
.or_default()
.push(self.writer.next_batch_id() as u32);
self.writer.write(batches.as_slice()).await?;
}
}
self.writer.finish().await?;
Ok(Shuffler::new(&self.parted_groups, self.temp_dir.clone()))
}
}
pub struct Shuffler {
parted_groups: BTreeMap<u32, Vec<u32>>,
temp_dir: Arc<TempDir>,
}
impl Shuffler {
fn new(parted_groups: &BTreeMap<u32, Vec<u32>>, temp_dir: Arc<TempDir>) -> Self {
Self {
parted_groups: parted_groups.clone(),
temp_dir,
}
}
pub async fn key_iter(&self, key: u32) -> Result<Option<impl RecordBatchStream + '_>> {
if !self.parted_groups.contains_key(&key) {
return Ok(None);
}
let object_store = ObjectStore::local();
let path = lance_buffer_path(self.temp_dir.as_ref())?;
let reader = FileReader::try_new(&object_store, &path)
.await
.map_err(|e| Error::IO {
message: format!("failed to open shuffler buffer file: {}, {}", path, e),
location: location!(),
})?;
let schema = reader.schema().clone();
let group_ids = self
.parted_groups
.get(&key)
.unwrap() .iter()
.copied()
.collect::<HashSet<_>>();
let stream = batches_stream(reader, schema, move |id| group_ids.contains(&(*id as u32)));
Ok(Some(stream))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use arrow_array::UInt32Array;
use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
#[tokio::test]
async fn test_shuffler() {
let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
let mut shuffler = ShufflerBuilder::try_new(&schema, 4).await.unwrap();
for i in 0..20 {
shuffler
.insert(
i % 3,
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(UInt32Array::from(vec![i]))],
)
.unwrap(),
)
.await
.unwrap();
}
let reader = shuffler.finish().await.unwrap();
for i in 0..3 {
let stream = reader.key_iter(i).await.unwrap().expect("key exists");
let batches = stream.try_collect::<Vec<_>>().await.unwrap();
assert_eq!(batches.len(), 2, "key {} has {} batches", i, batches.len());
}
assert!(reader.key_iter(5).await.unwrap().is_none())
}
}