fraiseql_core/compiler/window_functions/
planner.rs1use super::{
2 FactTableMetadata, FraiseQLError, FrameBoundary, FrameExclusion, FrameType, OrderByClause,
3 OrderDirection, Result, SelectColumn, WhereClause, WindowExecutionPlan, WindowFrame,
4 WindowFunction, WindowFunctionType,
5};
6use crate::compiler::window_allowlist::WindowAllowlist;
7
8pub struct WindowFunctionPlanner;
10
11fn validate_sql_expression(value: &str, context: &str) -> Result<()> {
18 let safe = value.chars().all(|c| {
19 c.is_alphanumeric() || matches!(c, '_' | '-' | '>' | '\'' | '.' | ' ' | '(' | ')')
20 });
21 if safe {
22 Ok(())
23 } else {
24 Err(FraiseQLError::Validation {
25 message: format!(
26 "Unsafe characters in window function {context}: {value:?}. \
27 Only identifiers, JSONB path operators (-> ->>), and quoted keys are allowed."
28 ),
29 path: None,
30 })
31 }
32}
33
34impl WindowFunctionPlanner {
35 pub fn plan(
60 query: &serde_json::Value,
61 metadata: &FactTableMetadata,
62 ) -> Result<WindowExecutionPlan> {
63 let allowlist = WindowAllowlist::from_metadata(metadata);
67
68 let table = query["table"]
70 .as_str()
71 .ok_or_else(|| FraiseQLError::validation("Missing 'table' field"))?
72 .to_string();
73
74 let select = Self::parse_select_columns(query)?;
76
77 let windows = Self::parse_window_functions(query, &allowlist)?;
79
80 let where_clause = query.get("where").map(|_| WhereClause::And(vec![]));
82
83 let order_by = query
85 .get("orderBy")
86 .and_then(|v| v.as_array())
87 .map(|arr| {
88 arr.iter()
89 .filter_map(|item| {
90 let direction = match item.get("direction").and_then(|d| d.as_str()) {
91 Some("DESC") => OrderDirection::Desc,
92 _ => OrderDirection::Asc,
93 };
94 Some(OrderByClause::new(item["field"].as_str()?.to_string(), direction))
95 })
96 .collect()
97 })
98 .unwrap_or_default();
99
100 let limit = query
102 .get("limit")
103 .and_then(|v| v.as_u64())
104 .map(|n| u32::try_from(n).unwrap_or(u32::MAX));
105 let offset = query
106 .get("offset")
107 .and_then(|v| v.as_u64())
108 .map(|n| u32::try_from(n).unwrap_or(u32::MAX));
109
110 Ok(WindowExecutionPlan {
111 table,
112 select,
113 windows,
114 where_clause,
115 order_by,
116 limit,
117 offset,
118 })
119 }
120
121 fn parse_select_columns(query: &serde_json::Value) -> Result<Vec<SelectColumn>> {
122 let default_array = vec![];
123 let select = query.get("select").and_then(|s| s.as_array()).unwrap_or(&default_array);
124
125 let columns = select
126 .iter()
127 .filter_map(|col| {
128 col.as_str().map(|col_str| SelectColumn {
129 expression: col_str.to_string(),
130 alias: col_str.to_string(),
131 })
132 })
133 .collect();
134
135 Ok(columns)
136 }
137
138 fn parse_window_functions(
139 query: &serde_json::Value,
140 allowlist: &WindowAllowlist,
141 ) -> Result<Vec<WindowFunction>> {
142 let default_array = vec![];
143 let windows = query.get("windows").and_then(|w| w.as_array()).unwrap_or(&default_array);
144
145 windows.iter().map(|w| Self::parse_single_window(w, allowlist)).collect()
146 }
147
148 fn parse_single_window(
149 window: &serde_json::Value,
150 allowlist: &WindowAllowlist,
151 ) -> Result<WindowFunction> {
152 let function = Self::parse_window_function_type(&window["function"])?;
153 let alias = window["alias"]
154 .as_str()
155 .ok_or_else(|| FraiseQLError::validation("Missing 'alias' in window function"))?
156 .to_string();
157
158 let partition_by = window
159 .get("partitionBy")
160 .and_then(|p| p.as_array())
161 .map(|arr| -> Result<Vec<String>> {
162 arr.iter()
163 .filter_map(|v| v.as_str())
164 .map(|col| {
165 validate_sql_expression(col, "partitionBy")?;
167 allowlist.validate(col, "PARTITION BY")?;
169 Ok(col.to_string())
170 })
171 .collect()
172 })
173 .transpose()?
174 .unwrap_or_default();
175
176 let order_by = window
177 .get("orderBy")
178 .and_then(|o| o.as_array())
179 .map(|arr| -> Result<Vec<OrderByClause>> {
180 arr.iter()
181 .filter_map(|item| {
182 let field = item["field"].as_str()?;
183 let direction = match item.get("direction").and_then(|d| d.as_str()) {
184 Some("DESC") => OrderDirection::Desc,
185 _ => OrderDirection::Asc,
186 };
187 Some((field, direction))
188 })
189 .map(|(field, direction)| {
190 validate_sql_expression(field, "orderBy.field")?;
192 allowlist.validate(field, "ORDER BY")?;
194 Ok(OrderByClause::new(field.to_string(), direction))
195 })
196 .collect()
197 })
198 .transpose()?
199 .unwrap_or_default();
200
201 let frame = window.get("frame").map(Self::parse_window_frame).transpose()?;
202
203 Ok(WindowFunction {
204 function,
205 alias,
206 partition_by,
207 order_by,
208 frame,
209 })
210 }
211
212 fn parse_window_function_type(func: &serde_json::Value) -> Result<WindowFunctionType> {
213 serde_json::from_value(func.clone()).map_err(|e| {
214 FraiseQLError::validation(format!("Unknown or invalid window function: {e}"))
215 })
216 }
217
218 fn parse_window_frame(frame: &serde_json::Value) -> Result<WindowFrame> {
219 let frame_type = match frame["frame_type"].as_str() {
220 Some("ROWS") => FrameType::Rows,
221 Some("RANGE") => FrameType::Range,
222 Some("GROUPS") => FrameType::Groups,
223 _ => return Err(FraiseQLError::validation("Invalid or missing 'frame_type'")),
224 };
225
226 let start = Self::parse_frame_boundary(&frame["start"])?;
227 let end = Self::parse_frame_boundary(&frame["end"])?;
228 let exclusion = frame.get("exclusion").map(|e| match e.as_str() {
229 Some("current_row") => FrameExclusion::CurrentRow,
230 Some("group") => FrameExclusion::Group,
231 Some("ties") => FrameExclusion::Ties,
232 _ => FrameExclusion::NoOthers,
234 });
235
236 Ok(WindowFrame {
237 frame_type,
238 start,
239 end,
240 exclusion,
241 })
242 }
243
244 fn parse_frame_boundary(boundary: &serde_json::Value) -> Result<FrameBoundary> {
245 match boundary["type"].as_str() {
246 Some("unbounded_preceding") => Ok(FrameBoundary::UnboundedPreceding),
247 Some("n_preceding") => {
248 let n = u32::try_from(
249 boundary["n"]
250 .as_u64()
251 .ok_or_else(|| FraiseQLError::validation("Missing 'n' in N PRECEDING"))?,
252 )
253 .unwrap_or(u32::MAX);
254 Ok(FrameBoundary::NPreceding { n })
255 },
256 Some("current_row") => Ok(FrameBoundary::CurrentRow),
257 Some("n_following") => {
258 let n = u32::try_from(
259 boundary["n"]
260 .as_u64()
261 .ok_or_else(|| FraiseQLError::validation("Missing 'n' in N FOLLOWING"))?,
262 )
263 .unwrap_or(u32::MAX);
264 Ok(FrameBoundary::NFollowing { n })
265 },
266 Some("unbounded_following") => Ok(FrameBoundary::UnboundedFollowing),
267 _ => Err(FraiseQLError::validation("Invalid frame boundary type")),
268 }
269 }
270
271 pub fn validate(
278 plan: &WindowExecutionPlan,
279 _metadata: &FactTableMetadata,
280 database_target: crate::db::types::DatabaseType,
281 ) -> Result<()> {
282 use crate::db::types::DatabaseType;
283
284 for window in &plan.windows {
286 if let Some(frame) = &window.frame {
287 if frame.frame_type == FrameType::Groups
288 && !matches!(database_target, DatabaseType::PostgreSQL)
289 {
290 return Err(FraiseQLError::validation(
291 "GROUPS frame type only supported on PostgreSQL",
292 ));
293 }
294
295 if frame.exclusion.is_some() && !matches!(database_target, DatabaseType::PostgreSQL)
297 {
298 return Err(FraiseQLError::validation(
299 "Frame exclusion only supported on PostgreSQL",
300 ));
301 }
302 }
303
304 match window.function {
306 WindowFunctionType::PercentRank | WindowFunctionType::CumeDist => {
307 if matches!(database_target, DatabaseType::SQLite) {
308 return Err(FraiseQLError::validation(
309 "PERCENT_RANK and CUME_DIST not supported on SQLite",
310 ));
311 }
312 },
313 _ => {},
314 }
315 }
316
317 Ok(())
318 }
319}