use std::sync::Arc;
use arrow_array::{cast::AsArray, FixedSizeListArray};
use arrow_schema::Schema as ArrowSchema;
use arrow_select::concat::concat_batches;
use futures::stream::TryStreamExt;
use snafu::{location, Location};
use crate::dataset::Dataset;
use crate::{Error, Result};
pub async fn maybe_sample_training_data(
dataset: &Dataset,
column: &str,
sample_size_hint: usize,
) -> Result<FixedSizeListArray> {
let num_rows = dataset.count_rows().await?;
let projection = dataset.schema().project(&[column])?;
let batch = if num_rows > sample_size_hint {
dataset.sample(sample_size_hint, &projection).await?
} else {
let mut scanner = dataset.scan();
scanner.project(&[column])?;
let batches = scanner
.try_into_stream()
.await?
.try_collect::<Vec<_>>()
.await?;
concat_batches(&Arc::new(ArrowSchema::from(&projection)), &batches)?
};
let array = batch.column_by_name(column).ok_or(Error::Index {
message: format!(
"Sample training data: column {} does not exist in return",
column
),
location: location!(),
})?;
Ok(array.as_fixed_size_list().clone())
}