1use recursive::recursive;
2
3use super::*;
4
5impl IR {
6 pub(crate) fn scan_schema(&self) -> &SchemaRef {
9 use IR::*;
10 match self {
11 Scan { file_info, .. } => &file_info.schema,
12 #[cfg(feature = "python")]
13 PythonScan { options, .. } => &options.schema,
14 _ => unreachable!(),
15 }
16 }
17
18 pub fn name(&self) -> &'static str {
19 use IR::*;
20 match self {
21 Scan { scan_type, .. } => scan_type.into(),
22 #[cfg(feature = "python")]
23 PythonScan { .. } => "python_scan",
24 Slice { .. } => "slice",
25 Filter { .. } => "selection",
26 DataFrameScan { .. } => "df",
27 Select { .. } => "projection",
28 Sort { .. } => "sort",
29 Cache { .. } => "cache",
30 GroupBy { .. } => "aggregate",
31 Join { .. } => "join",
32 HStack { .. } => "hstack",
33 Distinct { .. } => "distinct",
34 MapFunction { .. } => "map_function",
35 Union { .. } => "union",
36 HConcat { .. } => "hconcat",
37 ExtContext { .. } => "ext_context",
38 Sink { payload, .. } => match payload {
39 SinkType::Memory => "sink (memory)",
40 SinkType::File { .. } => "sink (file)",
41 },
42 SimpleProjection { .. } => "simple_projection",
43 Invalid => "invalid",
44 }
45 }
46
47 pub fn input_schema<'a>(&'a self, arena: &'a Arena<IR>) -> Option<Cow<'a, SchemaRef>> {
48 use IR::*;
49 let schema = match self {
50 #[cfg(feature = "python")]
51 PythonScan { options } => &options.schema,
52 DataFrameScan { schema, .. } => schema,
53 Scan { file_info, .. } => &file_info.schema,
54 node => {
55 let input = node.get_input()?;
56 return Some(arena.get(input).schema(arena));
57 },
58 };
59 Some(Cow::Borrowed(schema))
60 }
61
62 #[recursive]
64 pub fn schema<'a>(&'a self, arena: &'a Arena<IR>) -> Cow<'a, SchemaRef> {
65 use IR::*;
66 let schema = match self {
67 #[cfg(feature = "python")]
68 PythonScan { options } => options.output_schema.as_ref().unwrap_or(&options.schema),
69 Union { inputs, .. } => return arena.get(inputs[0]).schema(arena),
70 HConcat { schema, .. } => schema,
71 Cache { input, .. } => return arena.get(*input).schema(arena),
72 Sort { input, .. } => return arena.get(*input).schema(arena),
73 Scan {
74 output_schema,
75 file_info,
76 ..
77 } => output_schema.as_ref().unwrap_or(&file_info.schema),
78 DataFrameScan {
79 schema,
80 output_schema,
81 ..
82 } => output_schema.as_ref().unwrap_or(schema),
83 Filter { input, .. } => return arena.get(*input).schema(arena),
84 Select { schema, .. } => schema,
85 SimpleProjection { columns, .. } => columns,
86 GroupBy { schema, .. } => schema,
87 Join { schema, .. } => schema,
88 HStack { schema, .. } => schema,
89 Distinct { input, .. } | Sink { input, .. } => return arena.get(*input).schema(arena),
90 Slice { input, .. } => return arena.get(*input).schema(arena),
91 MapFunction { input, function } => {
92 let input_schema = arena.get(*input).schema(arena);
93
94 return match input_schema {
95 Cow::Owned(schema) => {
96 Cow::Owned(function.schema(&schema).unwrap().into_owned())
97 },
98 Cow::Borrowed(schema) => function.schema(schema).unwrap(),
99 };
100 },
101 ExtContext { schema, .. } => schema,
102 Invalid => unreachable!(),
103 };
104 Cow::Borrowed(schema)
105 }
106
107 #[recursive]
109 pub fn schema_with_cache<'a>(
110 node: Node,
111 arena: &'a Arena<IR>,
112 cache: &mut PlHashMap<Node, Arc<Schema>>,
113 ) -> Arc<Schema> {
114 use IR::*;
115 if let Some(schema) = cache.get(&node) {
116 return schema.clone();
117 }
118
119 let schema = match arena.get(node) {
120 #[cfg(feature = "python")]
121 PythonScan { options } => options
122 .output_schema
123 .as_ref()
124 .unwrap_or(&options.schema)
125 .clone(),
126 Union { inputs, .. } => IR::schema_with_cache(inputs[0], arena, cache),
127 HConcat { schema, .. } => schema.clone(),
128 Cache { input, .. }
129 | Sort { input, .. }
130 | Filter { input, .. }
131 | Distinct { input, .. }
132 | Sink { input, .. }
133 | Slice { input, .. } => IR::schema_with_cache(*input, arena, cache),
134 Scan {
135 output_schema,
136 file_info,
137 ..
138 } => output_schema.as_ref().unwrap_or(&file_info.schema).clone(),
139 DataFrameScan {
140 schema,
141 output_schema,
142 ..
143 } => output_schema.as_ref().unwrap_or(schema).clone(),
144 Select { schema, .. }
145 | GroupBy { schema, .. }
146 | Join { schema, .. }
147 | HStack { schema, .. }
148 | ExtContext { schema, .. }
149 | SimpleProjection {
150 columns: schema, ..
151 } => schema.clone(),
152 MapFunction { input, function } => {
153 let input_schema = IR::schema_with_cache(*input, arena, cache);
154 function.schema(&input_schema).unwrap().into_owned()
155 },
156 Invalid => unreachable!(),
157 };
158 cache.insert(node, schema.clone());
159 schema
160 }
161}