fraiseql_core/compiler/window_functions/
planner.rs1use super::{
2 FactTableMetadata, FraiseQLError, FrameBoundary, FrameExclusion, FrameType, OrderByClause,
3 OrderDirection, Result, SelectColumn, WhereClause, WindowExecutionPlan, WindowFrame,
4 WindowFunction, WindowFunctionType,
5};
6use crate::compiler::window_allowlist::WindowAllowlist;
7
8pub struct WindowFunctionPlanner;
10
11fn validate_sql_expression(value: &str, context: &str) -> Result<()> {
18 let safe = value.chars().all(|c| {
19 c.is_alphanumeric() || matches!(c, '_' | '-' | '>' | '\'' | '.' | ' ' | '(' | ')')
20 });
21 if safe {
22 Ok(())
23 } else {
24 Err(FraiseQLError::Validation {
25 message: format!(
26 "Unsafe characters in window function {context}: {value:?}. \
27 Only identifiers, JSONB path operators (-> ->>), and quoted keys are allowed."
28 ),
29 path: None,
30 })
31 }
32}
33
34impl WindowFunctionPlanner {
35 pub fn plan(
60 query: &serde_json::Value,
61 metadata: &FactTableMetadata,
62 ) -> Result<WindowExecutionPlan> {
63 let allowlist = WindowAllowlist::from_metadata(metadata);
67
68 let table = query["table"]
70 .as_str()
71 .ok_or_else(|| FraiseQLError::validation("Missing 'table' field"))?
72 .to_string();
73
74 let select = Self::parse_select_columns(query)?;
76
77 let windows = Self::parse_window_functions(query, &allowlist)?;
79
80 let where_clause = query.get("where").map(|_| WhereClause::And(vec![]));
82
83 let order_by = query
85 .get("orderBy")
86 .and_then(|v| v.as_array())
87 .map(|arr| {
88 arr.iter()
89 .filter_map(|item| {
90 let direction = match item.get("direction").and_then(|d| d.as_str()) {
91 Some("DESC") => OrderDirection::Desc,
92 _ => OrderDirection::Asc,
93 };
94 Some(OrderByClause {
95 field: item["field"].as_str()?.to_string(),
96 direction,
97 })
98 })
99 .collect()
100 })
101 .unwrap_or_default();
102
103 let limit = query
105 .get("limit")
106 .and_then(|v| v.as_u64())
107 .map(|n| u32::try_from(n).unwrap_or(u32::MAX));
108 let offset = query
109 .get("offset")
110 .and_then(|v| v.as_u64())
111 .map(|n| u32::try_from(n).unwrap_or(u32::MAX));
112
113 Ok(WindowExecutionPlan {
114 table,
115 select,
116 windows,
117 where_clause,
118 order_by,
119 limit,
120 offset,
121 })
122 }
123
124 fn parse_select_columns(query: &serde_json::Value) -> Result<Vec<SelectColumn>> {
125 let default_array = vec![];
126 let select = query.get("select").and_then(|s| s.as_array()).unwrap_or(&default_array);
127
128 let columns = select
129 .iter()
130 .filter_map(|col| {
131 col.as_str().map(|col_str| SelectColumn {
132 expression: col_str.to_string(),
133 alias: col_str.to_string(),
134 })
135 })
136 .collect();
137
138 Ok(columns)
139 }
140
141 fn parse_window_functions(
142 query: &serde_json::Value,
143 allowlist: &WindowAllowlist,
144 ) -> Result<Vec<WindowFunction>> {
145 let default_array = vec![];
146 let windows = query.get("windows").and_then(|w| w.as_array()).unwrap_or(&default_array);
147
148 windows.iter().map(|w| Self::parse_single_window(w, allowlist)).collect()
149 }
150
151 fn parse_single_window(
152 window: &serde_json::Value,
153 allowlist: &WindowAllowlist,
154 ) -> Result<WindowFunction> {
155 let function = Self::parse_window_function_type(&window["function"])?;
156 let alias = window["alias"]
157 .as_str()
158 .ok_or_else(|| FraiseQLError::validation("Missing 'alias' in window function"))?
159 .to_string();
160
161 let partition_by = window
162 .get("partitionBy")
163 .and_then(|p| p.as_array())
164 .map(|arr| -> Result<Vec<String>> {
165 arr.iter()
166 .filter_map(|v| v.as_str())
167 .map(|col| {
168 validate_sql_expression(col, "partitionBy")?;
170 allowlist.validate(col, "PARTITION BY")?;
172 Ok(col.to_string())
173 })
174 .collect()
175 })
176 .transpose()?
177 .unwrap_or_default();
178
179 let order_by = window
180 .get("orderBy")
181 .and_then(|o| o.as_array())
182 .map(|arr| -> Result<Vec<OrderByClause>> {
183 arr.iter()
184 .filter_map(|item| {
185 let field = item["field"].as_str()?;
186 let direction = match item.get("direction").and_then(|d| d.as_str()) {
187 Some("DESC") => OrderDirection::Desc,
188 _ => OrderDirection::Asc,
189 };
190 Some((field, direction))
191 })
192 .map(|(field, direction)| {
193 validate_sql_expression(field, "orderBy.field")?;
195 allowlist.validate(field, "ORDER BY")?;
197 Ok(OrderByClause {
198 field: field.to_string(),
199 direction,
200 })
201 })
202 .collect()
203 })
204 .transpose()?
205 .unwrap_or_default();
206
207 let frame = window.get("frame").map(Self::parse_window_frame).transpose()?;
208
209 Ok(WindowFunction {
210 function,
211 alias,
212 partition_by,
213 order_by,
214 frame,
215 })
216 }
217
218 fn parse_window_function_type(func: &serde_json::Value) -> Result<WindowFunctionType> {
219 serde_json::from_value(func.clone()).map_err(|e| {
220 FraiseQLError::validation(format!("Unknown or invalid window function: {e}"))
221 })
222 }
223
224 fn parse_window_frame(frame: &serde_json::Value) -> Result<WindowFrame> {
225 let frame_type = match frame["frame_type"].as_str() {
226 Some("ROWS") => FrameType::Rows,
227 Some("RANGE") => FrameType::Range,
228 Some("GROUPS") => FrameType::Groups,
229 _ => return Err(FraiseQLError::validation("Invalid or missing 'frame_type'")),
230 };
231
232 let start = Self::parse_frame_boundary(&frame["start"])?;
233 let end = Self::parse_frame_boundary(&frame["end"])?;
234 let exclusion = frame.get("exclusion").map(|e| match e.as_str() {
235 Some("current_row") => FrameExclusion::CurrentRow,
236 Some("group") => FrameExclusion::Group,
237 Some("ties") => FrameExclusion::Ties,
238 _ => FrameExclusion::NoOthers,
240 });
241
242 Ok(WindowFrame {
243 frame_type,
244 start,
245 end,
246 exclusion,
247 })
248 }
249
250 fn parse_frame_boundary(boundary: &serde_json::Value) -> Result<FrameBoundary> {
251 match boundary["type"].as_str() {
252 Some("unbounded_preceding") => Ok(FrameBoundary::UnboundedPreceding),
253 Some("n_preceding") => {
254 let n = u32::try_from(
255 boundary["n"]
256 .as_u64()
257 .ok_or_else(|| FraiseQLError::validation("Missing 'n' in N PRECEDING"))?,
258 )
259 .unwrap_or(u32::MAX);
260 Ok(FrameBoundary::NPreceding { n })
261 },
262 Some("current_row") => Ok(FrameBoundary::CurrentRow),
263 Some("n_following") => {
264 let n = u32::try_from(
265 boundary["n"]
266 .as_u64()
267 .ok_or_else(|| FraiseQLError::validation("Missing 'n' in N FOLLOWING"))?,
268 )
269 .unwrap_or(u32::MAX);
270 Ok(FrameBoundary::NFollowing { n })
271 },
272 Some("unbounded_following") => Ok(FrameBoundary::UnboundedFollowing),
273 _ => Err(FraiseQLError::validation("Invalid frame boundary type")),
274 }
275 }
276
277 pub fn validate(
284 plan: &WindowExecutionPlan,
285 _metadata: &FactTableMetadata,
286 database_target: crate::db::types::DatabaseType,
287 ) -> Result<()> {
288 use crate::db::types::DatabaseType;
289
290 for window in &plan.windows {
292 if let Some(frame) = &window.frame {
293 if frame.frame_type == FrameType::Groups
294 && !matches!(database_target, DatabaseType::PostgreSQL)
295 {
296 return Err(FraiseQLError::validation(
297 "GROUPS frame type only supported on PostgreSQL",
298 ));
299 }
300
301 if frame.exclusion.is_some() && !matches!(database_target, DatabaseType::PostgreSQL)
303 {
304 return Err(FraiseQLError::validation(
305 "Frame exclusion only supported on PostgreSQL",
306 ));
307 }
308 }
309
310 match window.function {
312 WindowFunctionType::PercentRank | WindowFunctionType::CumeDist => {
313 if matches!(database_target, DatabaseType::SQLite) {
314 return Err(FraiseQLError::validation(
315 "PERCENT_RANK and CUME_DIST not supported on SQLite",
316 ));
317 }
318 },
319 _ => {},
320 }
321 }
322
323 Ok(())
324 }
325}