use std::{collections::HashMap, sync::Arc};
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
use datafusion_expr::col;
use futures::TryStreamExt;
use lance_core::ROW_ID;
use lance_datafusion::exec::SessionContextExt;
use crate::{
Error, Result, Table,
arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream},
connect,
database::{CreateTableRequest, Database},
dataloader::permutation::{
shuffle::{Shuffler, ShufflerConfig},
split::{SPLIT_ID_COLUMN, SplitStrategy, Splitter},
util::{TemporaryDirectory, rename_column},
},
query::{ExecutableQuery, QueryBase, Select},
};
pub const SRC_ROW_ID_COL: &str = "row_id";
pub const SPLIT_NAMES_CONFIG_KEY: &str = "split_names";
pub const DEFAULT_MEMORY_LIMIT: usize = 100 * 1024 * 1024;
#[derive(Debug, Clone, Default)]
enum PermutationDestination {
#[default]
Temporary,
Permanent(Arc<dyn Database>, String),
}
#[derive(Debug, Default)]
pub struct PermutationConfig {
split_strategy: SplitStrategy,
split_names: Option<Vec<String>>,
shuffle_strategy: ShuffleStrategy,
filter: Option<String>,
temp_dir: TemporaryDirectory,
destination: PermutationDestination,
}
#[derive(Debug, Clone, Default)]
pub enum ShuffleStrategy {
Random {
seed: Option<u64>,
clump_size: Option<u64>,
},
#[default]
None,
}
pub struct PermutationBuilder {
config: PermutationConfig,
base_table: Table,
}
impl PermutationBuilder {
pub fn new(base_table: Table) -> Self {
Self {
config: PermutationConfig::default(),
base_table,
}
}
pub fn with_split_strategy(
mut self,
split_strategy: SplitStrategy,
split_names: Option<Vec<String>>,
) -> Self {
self.config.split_strategy = split_strategy;
self.config.split_names = split_names;
self
}
pub fn with_shuffle_strategy(mut self, shuffle_strategy: ShuffleStrategy) -> Self {
self.config.shuffle_strategy = shuffle_strategy;
self
}
pub fn with_filter(mut self, filter: String) -> Self {
self.config.filter = Some(filter);
self
}
pub fn with_temp_dir(mut self, temp_dir: TemporaryDirectory) -> Self {
self.config.temp_dir = temp_dir;
self
}
pub fn persist(mut self, database: Arc<dyn Database>, table_name: String) -> Self {
self.config.destination = PermutationDestination::Permanent(database, table_name);
self
}
async fn sort_by_split_id(
&self,
data: SendableRecordBatchStream,
) -> Result<SendableRecordBatchStream> {
let memory_limit = std::env::var("LANCEDB_PERM_BUILDER_MEMORY_LIMIT")
.unwrap_or_else(|_| DEFAULT_MEMORY_LIMIT.to_string())
.parse::<usize>()
.unwrap_or_else(|_| {
log::error!(
"Failed to parse LANCEDB_PERM_BUILDER_MEMORY_LIMIT, using default: {}",
DEFAULT_MEMORY_LIMIT
);
DEFAULT_MEMORY_LIMIT
});
let ctx = SessionContext::new_with_config_rt(
SessionConfig::default(),
RuntimeEnvBuilder::new()
.with_memory_limit(memory_limit, 1.0)
.with_disk_manager_builder(
DiskManagerBuilder::default()
.with_mode(self.config.temp_dir.to_disk_manager_mode()),
)
.build_arc()
.unwrap(),
);
let df = ctx
.read_one_shot(data.into_df_stream())
.map_err(|e| Error::Other {
message: format!("Failed to setup sort by split id: {}", e),
source: Some(e.into()),
})?;
let df_stream = df
.sort_by(vec![col(SPLIT_ID_COLUMN)])
.map_err(|e| Error::Other {
message: format!("Failed to plan sort by split id: {}", e),
source: Some(e.into()),
})?
.execute_stream()
.await
.map_err(|e| Error::Other {
message: format!("Failed to sort by split id: {}", e),
source: Some(e.into()),
})?;
let schema = df_stream.schema();
let stream = df_stream.map_err(|e| Error::Other {
message: format!("Failed to execute sort by split id: {}", e),
source: Some(e.into()),
});
Ok(Box::pin(SimpleRecordBatchStream { schema, stream }))
}
fn add_split_names(
data: SendableRecordBatchStream,
split_names: &[String],
) -> Result<SendableRecordBatchStream> {
let schema = data
.schema()
.as_ref()
.clone()
.with_metadata(HashMap::from([(
SPLIT_NAMES_CONFIG_KEY.to_string(),
serde_json::to_string(split_names).map_err(|e| Error::Other {
message: format!("Failed to serialize split names: {}", e),
source: Some(e.into()),
})?,
)]));
let schema = Arc::new(schema);
let schema_clone = schema.clone();
let stream = data.map_ok(move |batch| batch.with_schema(schema.clone()).unwrap());
Ok(Box::pin(SimpleRecordBatchStream {
schema: schema_clone,
stream,
}))
}
pub async fn build(self) -> Result<Table> {
let mut rows = self.base_table.query().select(Select::columns(&[ROW_ID]));
if let Some(filter) = &self.config.filter {
rows = rows.only_if(filter);
}
let splitter = Splitter::new(
self.config.temp_dir.clone(),
self.config.split_strategy.clone(),
);
let mut needs_sort = !splitter.orders_by_split_id();
rows = splitter.project(rows);
let num_rows = self
.base_table
.count_rows(self.config.filter.clone())
.await? as u64;
let rows = rows.execute().await?;
let split_data = splitter.apply(rows, num_rows).await?;
let shuffled = match self.config.shuffle_strategy {
ShuffleStrategy::None => split_data,
ShuffleStrategy::Random { seed, clump_size } => {
let shuffler = Shuffler::new(ShufflerConfig {
seed,
clump_size,
temp_dir: self.config.temp_dir.clone(),
max_rows_per_file: 10 * 1024 * 1024,
});
shuffler.shuffle(split_data, num_rows).await?
}
};
needs_sort |= !matches!(self.config.shuffle_strategy, ShuffleStrategy::None);
let sorted = if needs_sort {
self.sort_by_split_id(shuffled).await?
} else {
shuffled
};
let renamed = rename_column(sorted, ROW_ID, SRC_ROW_ID_COL)?;
let streaming_data = if let Some(split_names) = &self.config.split_names {
Self::add_split_names(renamed, split_names)?
} else {
renamed
};
let (name, database) = match &self.config.destination {
PermutationDestination::Permanent(database, table_name) => {
(table_name.as_str(), database.clone())
}
PermutationDestination::Temporary => {
let conn = connect("memory:///").execute().await?;
("permutation", conn.database().clone())
}
};
let create_table_request =
CreateTableRequest::new(name.to_string(), Box::new(streaming_data));
let table = database.create_table(create_table_request).await?;
Ok(Table::new(table, database))
}
}
#[cfg(test)]
mod tests {
use arrow::datatypes::Int32Type;
use lance_datagen::{BatchCount, RowCount};
use crate::{arrow::LanceDbDatagenExt, connect, dataloader::permutation::split::SplitSizes};
use super::*;
#[tokio::test]
async fn test_permutation_table_only_stores_row_id_and_split_id() {
let temp_dir = tempfile::tempdir().unwrap();
let db = connect(temp_dir.path().to_str().unwrap())
.execute()
.await
.unwrap();
let initial_data = lance_datagen::gen_batch()
.col("col_a", lance_datagen::array::step::<Int32Type>())
.col("col_b", lance_datagen::array::step::<Int32Type>())
.into_ldb_stream(RowCount::from(100), BatchCount::from(10));
let data_table = db
.create_table("base_tbl", initial_data)
.execute()
.await
.unwrap();
let permutation_table = PermutationBuilder::new(data_table.clone())
.with_split_strategy(
SplitStrategy::Sequential {
sizes: SplitSizes::Percentages(vec![0.5, 0.5]),
},
None,
)
.with_filter("col_a > 57".to_string())
.build()
.await
.unwrap();
let schema = permutation_table.schema().await.unwrap();
let field_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
assert_eq!(
field_names,
vec!["row_id", "split_id"],
"Permutation table should only contain row_id and split_id columns, but found: {:?}",
field_names,
);
}
#[tokio::test]
async fn test_permutation_builder() {
let temp_dir = tempfile::tempdir().unwrap();
let db = connect(temp_dir.path().to_str().unwrap())
.execute()
.await
.unwrap();
let initial_data = lance_datagen::gen_batch()
.col("some_value", lance_datagen::array::step::<Int32Type>())
.into_ldb_stream(RowCount::from(100), BatchCount::from(10));
let data_table = db
.create_table("mytbl", initial_data)
.execute()
.await
.unwrap();
let permutation_table = PermutationBuilder::new(data_table.clone())
.with_filter("some_value > 57".to_string())
.with_split_strategy(
SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
},
None,
)
.build()
.await
.unwrap();
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330);
assert_eq!(
permutation_table
.count_rows(Some("split_id = 0".to_string()))
.await
.unwrap(),
47
);
assert_eq!(
permutation_table
.count_rows(Some("split_id = 1".to_string()))
.await
.unwrap(),
283
);
}
}