use crate::backend::ScalarType;
use crate::data::utils::{array_major_minor_index_default, cs_major_minor_index2};
use crate::data::{DataFrameIndex, DynCsrMatrix};
use crate::{AnnDataOp, ArrayElemOp};
use anyhow::{ensure, Result};
use indexmap::IndexSet;
use itertools::Itertools;
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::pattern::SparsityPattern;
use polars::chunked_array::builder::CategoricalChunkedBuilder;
use polars::frame::DataFrame;
use polars::prelude::{AnyValue, Categorical32Type, Column, DataType, IntoLazy, NamedFrom};
use polars::series::{IntoSeries, Series};
use crate::data::{ArrayData, DynArray};
#[derive(Debug, Clone, Copy)]
pub enum JoinType {
Inner,
Outer,
}
pub fn concat<A, O, S>(
adatas: &[A],
join: JoinType,
label: Option<&str>,
keys: Option<&[S]>,
out: &O,
) -> Result<()>
where
A: AnnDataOp,
O: AnnDataOp,
S: ToString,
{
let common_vars = adatas
.iter()
.map(|x| x.var_names().into_iter().collect::<IndexSet<_>>());
let common_vars: IndexSet<String> = match join {
JoinType::Inner => common_vars.reduce(|a, b| a.intersection(&b).cloned().collect()),
JoinType::Outer => common_vars.reduce(|a, b| a.union(&b).cloned().collect()),
}
.unwrap();
out.set_var_names(common_vars.iter().cloned().collect())?;
{
let df_var = adatas
.iter()
.map(|adata| {
let var = adata.read_var().unwrap();
let var_names = adata.var_names();
let columns = var
.get_columns()
.iter()
.map(|s| align_series(s, &var_names, &common_vars))
.collect::<Result<Vec<_>>>()?;
Ok(DataFrame::new(columns)?)
})
.reduce(|a, b| {
let mut a = a?;
merge_df(&mut a, &b?)?;
anyhow::Ok(a)
})
.unwrap()?;
out.set_var(df_var)?;
}
{
let obs_names = adatas.iter().flat_map(|adata| adata.obs_names()).collect();
out.set_obs_names(obs_names)?;
let mut dfs = adatas
.iter()
.map(|adata| adata.read_obs().unwrap())
.collect::<Vec<_>>();
if let Some(keys) = keys {
dfs.iter_mut().zip_eq(keys.iter()).for_each(|(df, key)| {
let s = Series::new(
label.unwrap_or("label").into(),
vec![key.to_string(); df.height()],
);
df.insert_column(0, s).unwrap();
});
}
let dfs = dfs.into_iter().map(|df| df.lazy()).collect::<Vec<_>>();
let mut args = polars::prelude::UnionArgs::default();
match join {
JoinType::Inner => args.diagonal = false,
JoinType::Outer => args.diagonal = true,
};
let df_obs = polars::prelude::concat(&dfs, args)?.collect()?;
out.set_obs(df_obs)?;
}
{
if adatas.iter().any(|adata| !adata.x().is_none()) {
let dtype = adatas
.iter()
.flat_map(|x| x.x().dtype().and_then(|d| d.scalar_type()))
.next()
.unwrap();
let x_arr = adatas.iter().map(|adata| {
let n_obs = adata.n_obs();
let n_vars = adata.n_vars();
let var_names = adata.var_names();
macro_rules! fun {
($variant:ident) => {
CsrMatrix::<$variant>::zeros(n_obs, n_vars).into()
};
}
adata
.x()
.get()
.unwrap()
.map(|arr| {
index_array(
arr,
&(0..adata.n_obs())
.into_iter()
.map(|x| Some(x))
.collect::<Vec<_>>(),
&common_vars
.iter()
.map(|x| var_names.get_index(x))
.collect::<Vec<_>>(),
)
})
.unwrap_or_else(|| crate::macros::dyn_match!(dtype, ScalarType, fun))
});
out.set_x_from_iter(x_arr)?;
}
}
Ok(())
}
fn merge_df(this: &mut DataFrame, other: &DataFrame) -> Result<()> {
if other.is_empty() {
return Ok(());
}
ensure!(
this.height() == other.height(),
"DataFrames must have the same number of rows"
);
other.get_columns().iter().try_for_each(|other_s| {
let name = other_s.name();
if let Some(i) = this.get_column_index(name) {
let this_s = this.column(name)?;
let new_column = this_s
.as_series()
.unwrap()
.iter()
.zip(other_s.as_series().unwrap().iter())
.map(|(this_v, other_v)| {
if other_v.is_null() {
this_v.clone()
} else {
other_v.clone()
}
})
.collect::<Vec<_>>();
let dtype = match (this_s.dtype(), other_s.dtype()) {
(DataType::Categorical(_, _), _) => this_s.dtype(),
(_, DataType::Categorical(_, _)) => other_s.dtype(),
_ => this_s.dtype(),
};
let new_column = match dtype {
DataType::Categorical(_, _) => {
let mut builder: CategoricalChunkedBuilder<Categorical32Type> = CategoricalChunkedBuilder::new(name.clone(), dtype.clone());
new_column
.iter()
.for_each(|x| {
if let Some(x) = x.get_str() {
builder.append_str(x).unwrap();
} else {
builder.append_null();
}
});
builder.finish().into_series()
}
_ => {
Series::from_any_values_and_dtype(name.clone(), &new_column, dtype, false)?
}
};
this.replace_column(i, new_column)?;
} else {
this.insert_column(this.width(), other_s.clone())?;
}
anyhow::Ok(())
})?;
Ok(())
}
fn align_series(
series: &Column,
row_names: &DataFrameIndex,
new_row_names: &IndexSet<String>,
) -> Result<Column> {
let name = series.name();
let dtype = series.dtype();
let new_series = match dtype {
DataType::Categorical(_, _) => {
let mut builder: CategoricalChunkedBuilder<Categorical32Type> = CategoricalChunkedBuilder::new(name.clone(), dtype.clone());
new_row_names
.iter()
.for_each(|key| {
let item = row_names.get_index(key).map(|i| series.get(i).unwrap());
if let Some(s) = item.as_ref().and_then(|x| x.get_str()) {
builder.append_str(s).unwrap();
} else {
builder.append_null();
}
});
builder.finish().into_series()
}
_ => {
let values: Result<Vec<_>> = new_row_names
.iter()
.map(|key| {
if let Some(i) = row_names.get_index(key) {
Ok(series.get(i)?)
} else {
Ok(AnyValue::Null)
}
})
.collect();
Series::from_any_values_and_dtype(name.clone(), &values?, &dtype, false)?
}
};
Ok(new_series.into())
}
fn index_array(
arr: ArrayData,
row_indices: &[Option<usize>],
col_indices: &[Option<usize>],
) -> ArrayData {
macro_rules! fun_array {
($variant:ident, $value:expr) => {
array_major_minor_index_default(
row_indices,
col_indices,
&$value.into_dimensionality().unwrap(),
)
.into()
};
}
macro_rules! fun_csr {
($variant:ident, $value:expr) => {{
let (offsets, indices, data) = $value.csr_data();
let (new_row_offsets, new_col_indices, new_data) = cs_major_minor_index2(
row_indices,
col_indices,
$value.ncols(),
offsets,
indices,
data,
);
let pattern = unsafe {
SparsityPattern::from_offset_and_indices_unchecked(
row_indices.len(),
col_indices.len(),
new_row_offsets,
new_col_indices,
)
};
CsrMatrix::try_from_pattern_and_values(pattern, new_data)
.unwrap()
.into()
}};
}
match arr {
ArrayData::Array(x) => crate::macros::dyn_map!(x, DynArray, fun_array),
ArrayData::CsrMatrix(x) => crate::macros::dyn_map!(x, DynCsrMatrix, fun_csr),
_ => todo!(),
}
}