use polars_utils::itertools::Itertools;
use recursive::recursive;
use super::*;
impl IR {
pub(crate) fn scan_schema(&self) -> &SchemaRef {
use IR::*;
match self {
Scan { file_info, .. } => &file_info.schema,
#[cfg(feature = "python")]
PythonScan { options, .. } => &options.schema,
_ => unreachable!(),
}
}
pub fn name(&self) -> &'static str {
use IR::*;
match self {
Scan { scan_type, .. } => (&**scan_type).into(),
#[cfg(feature = "python")]
PythonScan { .. } => "python_scan",
Slice { .. } => "slice",
Filter { .. } => "filter",
DataFrameScan { .. } => "df",
Select { .. } => "projection",
Sort { .. } => "sort",
Cache { .. } => "cache",
GroupBy { .. } => "aggregate",
Join { .. } => "join",
Gather { .. } => "gather",
HStack { .. } => "hstack",
Distinct { .. } => "distinct",
MapFunction { .. } => "map_function",
Union { .. } => "union",
HConcat { .. } => "hconcat",
ExtContext { .. } => "ext_context",
Sink { payload, .. } => match payload {
SinkTypeIR::Memory => "sink (memory)",
SinkTypeIR::Callback(..) => "sink (callback)",
SinkTypeIR::File { .. } => "sink (file)",
SinkTypeIR::Partitioned { .. } => "sink (partition)",
},
SinkMultiple { .. } => "sink multiple",
SimpleProjection { .. } => "simple_projection",
#[cfg(feature = "merge_sorted")]
MergeSorted { .. } => "merge_sorted",
UnoptimizedDispatch { .. } => "unoptimized_dispatch",
Invalid => "invalid",
}
}
pub fn input_schema<'a>(&'a self, arena: &'a Arena<IR>) -> Option<Cow<'a, SchemaRef>> {
use IR::*;
let schema = match self {
#[cfg(feature = "python")]
PythonScan { options } => &options.schema,
DataFrameScan { schema, .. } => schema,
Scan { file_info, .. } => &file_info.schema,
node => {
let input = node.get_input()?;
return Some(arena.get(input).schema(arena));
},
};
Some(Cow::Borrowed(schema))
}
#[recursive]
pub fn schema<'a>(&'a self, arena: &'a Arena<IR>) -> Cow<'a, SchemaRef> {
use IR::*;
let schema = match self {
#[cfg(feature = "python")]
PythonScan { options } => options.output_schema.as_ref().unwrap_or(&options.schema),
Union { inputs, .. } => return arena.get(inputs[0]).schema(arena),
HConcat { schema, .. } => schema,
Cache { input, .. } => return arena.get(*input).schema(arena),
Sort { input, .. } => return arena.get(*input).schema(arena),
Gather { input, .. } => return arena.get(*input).schema(arena),
Scan {
output_schema,
file_info,
..
} => output_schema.as_ref().unwrap_or(&file_info.schema),
DataFrameScan {
schema,
output_schema,
..
} => output_schema.as_ref().unwrap_or(schema),
Filter { input, .. } => return arena.get(*input).schema(arena),
Select { schema, .. } => schema,
SimpleProjection { columns, .. } => columns,
GroupBy { schema, .. } => schema,
Join { schema, .. } => schema,
HStack { schema, .. } => schema,
Distinct { input, .. }
| Sink {
input,
payload: SinkTypeIR::Memory,
} => return arena.get(*input).schema(arena),
Sink { .. } | SinkMultiple { .. } => return Cow::Owned(Arc::new(Schema::default())),
Slice { input, .. } => return arena.get(*input).schema(arena),
MapFunction { input, function } => {
let input_schema = arena.get(*input).schema(arena);
return match input_schema {
Cow::Owned(schema) => {
Cow::Owned(function.schema(&schema).unwrap().into_owned())
},
Cow::Borrowed(schema) => function.schema(schema).unwrap(),
};
},
ExtContext { schema, .. } => schema,
#[cfg(feature = "merge_sorted")]
MergeSorted { input_left, .. } => return arena.get(*input_left).schema(arena),
UnoptimizedDispatch {
inputs,
arg_map,
operation,
} => {
let input_schemas = inputs
.iter()
.map(|input| arena.get(*input).schema(arena).into_owned())
.collect_vec();
return Cow::Owned(operation.schema(&input_schemas, arg_map));
},
Invalid => unreachable!(),
};
Cow::Borrowed(schema)
}
#[recursive]
pub fn schema_with_cache<'a>(
node: Node,
arena: &'a Arena<IR>,
cache: &mut PlHashMap<Node, Arc<Schema>>,
) -> Arc<Schema> {
use IR::*;
if let Some(schema) = cache.get(&node) {
return schema.clone();
}
let schema = match arena.get(node) {
#[cfg(feature = "python")]
PythonScan { options } => options
.output_schema
.as_ref()
.unwrap_or(&options.schema)
.clone(),
Union { inputs, .. } => IR::schema_with_cache(inputs[0], arena, cache),
HConcat { schema, .. } => schema.clone(),
Cache { input, .. }
| Sort { input, .. }
| Filter { input, .. }
| Distinct { input, .. }
| Sink {
input,
payload: SinkTypeIR::Memory,
}
| Slice { input, .. }
| Gather { input, .. } => IR::schema_with_cache(*input, arena, cache),
Sink { .. } | SinkMultiple { .. } => Arc::new(Schema::default()),
Scan {
output_schema,
file_info,
..
} => output_schema.as_ref().unwrap_or(&file_info.schema).clone(),
DataFrameScan {
schema,
output_schema,
..
} => output_schema.as_ref().unwrap_or(schema).clone(),
Select { schema, .. }
| GroupBy { schema, .. }
| Join { schema, .. }
| HStack { schema, .. }
| ExtContext { schema, .. }
| SimpleProjection {
columns: schema, ..
} => schema.clone(),
MapFunction { input, function } => {
let input_schema = IR::schema_with_cache(*input, arena, cache);
function.schema(&input_schema).unwrap().into_owned()
},
#[cfg(feature = "merge_sorted")]
MergeSorted { input_left, .. } => IR::schema_with_cache(*input_left, arena, cache),
UnoptimizedDispatch {
inputs,
arg_map,
operation,
} => {
let input_schemas = inputs
.iter()
.map(|input| IR::schema_with_cache(*input, arena, cache))
.collect_vec();
operation.schema(&input_schemas, arg_map)
},
Invalid => unreachable!(),
};
cache.insert(node, schema.clone());
schema
}
}