#[cfg(feature = "pivot")]
use polars_core::utils::try_get_supertype;
use super::*;
use crate::constants::get_len_name;
impl FunctionIR {
pub(crate) fn clear_cached_schema(&self) {
use FunctionIR::*;
#[allow(clippy::single_match)]
match self {
#[cfg(feature = "pivot")]
Unpivot { schema, .. } => {
let mut guard = schema.lock().unwrap();
*guard = None;
},
RowIndex { schema, .. } | Explode { schema, .. } => {
let mut guard = schema.lock().unwrap();
*guard = None;
},
_ => {},
}
}
pub fn schema<'a>(&self, input_schema: &'a SchemaRef) -> PolarsResult<Cow<'a, SchemaRef>> {
use FunctionIR::*;
match self {
Opaque { schema, .. } => match schema {
None => Ok(Cow::Borrowed(input_schema)),
Some(schema_fn) => {
let output_schema = schema_fn.get_schema(input_schema)?;
Ok(Cow::Owned(output_schema))
},
},
#[cfg(feature = "python")]
OpaquePython(OpaquePythonUdf { schema, .. }) => Ok(schema
.as_ref()
.map(|schema| Cow::Owned(schema.clone()))
.unwrap_or_else(|| Cow::Borrowed(input_schema))),
FastCount { alias, .. } => {
let mut schema: Schema = Schema::with_capacity(1);
let name = alias.clone().unwrap_or_else(get_len_name);
schema.insert_at_index(0, name, IDX_DTYPE)?;
Ok(Cow::Owned(Arc::new(schema)))
},
Rechunk => Ok(Cow::Borrowed(input_schema)),
Unnest { columns, separator } => {
#[cfg(feature = "dtype-struct")]
{
let mut new_schema = Schema::with_capacity(input_schema.len() * 2);
for (name, dtype) in input_schema.iter() {
if columns.iter().any(|item| item == name) {
match dtype {
DataType::Struct(flds) => {
for fld in flds {
let fld_name = match separator {
None => fld.name().clone(),
Some(sep) => {
polars_utils::format_pl_smallstr!(
"{name}{sep}{}",
fld.name()
)
},
};
new_schema.with_column(fld_name, fld.dtype().clone());
}
},
DataType::Unknown(_) => {
},
_ => {
polars_bail!(
SchemaMismatch: "expected struct dtype, got: `{}`", dtype
);
},
}
} else {
new_schema.with_column(name.clone(), dtype.clone());
}
}
Ok(Cow::Owned(Arc::new(new_schema)))
}
#[cfg(not(feature = "dtype-struct"))]
{
panic!("activate feature 'dtype-struct'")
}
},
RowIndex { schema, name, .. } => Ok(Cow::Owned(row_index_schema(
schema,
input_schema,
name.clone(),
))),
Explode {
schema,
options: _,
columns,
} => explode_schema(schema, input_schema, columns),
#[cfg(feature = "pivot")]
Unpivot { schema, args } => unpivot_schema(args, schema, input_schema),
Hint(_) => Ok(Cow::Borrowed(input_schema)),
}
}
}
fn row_index_schema(
cached_schema: &CachedSchema,
input_schema: &SchemaRef,
name: PlSmallStr,
) -> SchemaRef {
let mut guard = cached_schema.lock().unwrap();
if let Some(schema) = &*guard {
return schema.clone();
}
let mut schema = (**input_schema).clone();
schema.insert_at_index(0, name, IDX_DTYPE).unwrap();
let schema_ref = Arc::new(schema);
*guard = Some(schema_ref.clone());
schema_ref
}
fn explode_schema<'a>(
cached_schema: &CachedSchema,
schema: &'a Schema,
columns: &[PlSmallStr],
) -> PolarsResult<Cow<'a, SchemaRef>> {
let mut guard = cached_schema.lock().unwrap();
if let Some(schema) = &*guard {
return Ok(Cow::Owned(schema.clone()));
}
let mut schema = schema.clone();
columns.iter().try_for_each(|name| {
match schema.try_get(name)? {
DataType::List(inner) => {
schema.with_column(name.clone(), inner.as_ref().clone());
},
#[cfg(feature = "dtype-array")]
DataType::Array(inner, _) => {
schema.with_column(name.clone(), inner.as_ref().clone());
},
_ => {},
}
PolarsResult::Ok(())
})?;
let schema = Arc::new(schema);
*guard = Some(schema.clone());
Ok(Cow::Owned(schema))
}
#[cfg(feature = "pivot")]
fn unpivot_schema<'a>(
args: &UnpivotArgsIR,
cached_schema: &CachedSchema,
input_schema: &'a Schema,
) -> PolarsResult<Cow<'a, SchemaRef>> {
let mut guard = cached_schema.lock().unwrap();
if let Some(schema) = &*guard {
return Ok(Cow::Owned(schema.clone()));
}
let mut new_schema = args
.index
.iter()
.map(|id| Ok(Field::new(id.clone(), input_schema.try_get(id)?.clone())))
.collect::<PolarsResult<Schema>>()?;
new_schema.with_column(args.variable_name.clone(), DataType::String);
let mut supertype = DataType::Null;
for name in &args.on {
let dtype = input_schema.try_get(name)?;
supertype = try_get_supertype(&supertype, dtype)?;
}
new_schema.with_column(args.value_name.clone(), supertype);
let schema = Arc::new(new_schema);
*guard = Some(schema.clone());
Ok(Cow::Owned(schema))
}