Skip to main content

fraiseql_core/compiler/window_functions/
planner.rs

1use super::{
2    FactTableMetadata, FraiseQLError, FrameBoundary, FrameExclusion, FrameType, OrderByClause,
3    OrderDirection, Result, SelectColumn, WhereClause, WindowExecutionPlan, WindowFrame,
4    WindowFunction, WindowFunctionType,
5};
6use crate::compiler::window_allowlist::WindowAllowlist;
7
8/// Window function plan generator
9pub struct WindowFunctionPlanner;
10
11/// Validate a SQL column/expression string for safe embedding.
12///
13/// Allows identifiers, JSONB path operators (`->`, `->>`), quoted string keys
14/// (single-quote), periods, spaces, and parentheses for function-style paths.
15/// Rejects characters that could enable SQL injection (`;`, `--`, `/*`, `=`,
16/// `\`, nul bytes, etc.).
17fn validate_sql_expression(value: &str, context: &str) -> Result<()> {
18    let safe = value.chars().all(|c| {
19        c.is_alphanumeric() || matches!(c, '_' | '-' | '>' | '\'' | '.' | ' ' | '(' | ')')
20    });
21    if safe {
22        Ok(())
23    } else {
24        Err(FraiseQLError::Validation {
25            message: format!(
26                "Unsafe characters in window function {context}: {value:?}. \
27                 Only identifiers, JSONB path operators (-> ->>), and quoted keys are allowed."
28            ),
29            path:    None,
30        })
31    }
32}
33
34impl WindowFunctionPlanner {
35    /// Generate window function execution plan from JSON query
36    ///
37    /// # Example Query Format
38    ///
39    /// ```json
40    /// {
41    ///   "table": "tf_sales",
42    ///   "select": ["revenue", "category"],
43    ///   "windows": [
44    ///     {
45    ///       "function": {"type": "row_number"},
46    ///       "alias": "rank",
47    ///       "partitionBy": ["data->>'category'"],
48    ///       "orderBy": [{"field": "revenue", "direction": "DESC"}]
49    ///     }
50    ///   ],
51    ///   "limit": 10
52    /// }
53    /// ```
54    ///
55    /// # Errors
56    ///
57    /// Returns `FraiseQLError::Validation` if the window function specification is invalid
58    /// (e.g., missing required fields, disallowed characters, or unsupported function names).
59    pub fn plan(
60        query: &serde_json::Value,
61        metadata: &FactTableMetadata,
62    ) -> Result<WindowExecutionPlan> {
63        // Build schema-based allowlist from metadata (defence-in-depth on top of
64        // character-level validation).  Empty metadata → empty allowlist → no
65        // schema-constraint enforcement (character validation still applies).
66        let allowlist = WindowAllowlist::from_metadata(metadata);
67
68        // Parse table name
69        let table = query["table"]
70            .as_str()
71            .ok_or_else(|| FraiseQLError::validation("Missing 'table' field"))?
72            .to_string();
73
74        // Parse SELECT columns
75        let select = Self::parse_select_columns(query)?;
76
77        // Parse window functions
78        let windows = Self::parse_window_functions(query, &allowlist)?;
79
80        // Parse WHERE clause (placeholder - full implementation would parse actual conditions)
81        let where_clause = query.get("where").map(|_| WhereClause::And(vec![]));
82
83        // Parse ORDER BY
84        let order_by = query
85            .get("orderBy")
86            .and_then(|v| v.as_array())
87            .map(|arr| {
88                arr.iter()
89                    .filter_map(|item| {
90                        let direction = match item.get("direction").and_then(|d| d.as_str()) {
91                            Some("DESC") => OrderDirection::Desc,
92                            _ => OrderDirection::Asc,
93                        };
94                        Some(OrderByClause {
95                            field: item["field"].as_str()?.to_string(),
96                            direction,
97                        })
98                    })
99                    .collect()
100            })
101            .unwrap_or_default();
102
103        // Parse LIMIT/OFFSET
104        let limit = query
105            .get("limit")
106            .and_then(|v| v.as_u64())
107            .map(|n| u32::try_from(n).unwrap_or(u32::MAX));
108        let offset = query
109            .get("offset")
110            .and_then(|v| v.as_u64())
111            .map(|n| u32::try_from(n).unwrap_or(u32::MAX));
112
113        Ok(WindowExecutionPlan {
114            table,
115            select,
116            windows,
117            where_clause,
118            order_by,
119            limit,
120            offset,
121        })
122    }
123
124    fn parse_select_columns(query: &serde_json::Value) -> Result<Vec<SelectColumn>> {
125        let default_array = vec![];
126        let select = query.get("select").and_then(|s| s.as_array()).unwrap_or(&default_array);
127
128        let columns = select
129            .iter()
130            .filter_map(|col| {
131                col.as_str().map(|col_str| SelectColumn {
132                    expression: col_str.to_string(),
133                    alias:      col_str.to_string(),
134                })
135            })
136            .collect();
137
138        Ok(columns)
139    }
140
141    fn parse_window_functions(
142        query: &serde_json::Value,
143        allowlist: &WindowAllowlist,
144    ) -> Result<Vec<WindowFunction>> {
145        let default_array = vec![];
146        let windows = query.get("windows").and_then(|w| w.as_array()).unwrap_or(&default_array);
147
148        windows.iter().map(|w| Self::parse_single_window(w, allowlist)).collect()
149    }
150
151    fn parse_single_window(
152        window: &serde_json::Value,
153        allowlist: &WindowAllowlist,
154    ) -> Result<WindowFunction> {
155        let function = Self::parse_window_function_type(&window["function"])?;
156        let alias = window["alias"]
157            .as_str()
158            .ok_or_else(|| FraiseQLError::validation("Missing 'alias' in window function"))?
159            .to_string();
160
161        let partition_by = window
162            .get("partitionBy")
163            .and_then(|p| p.as_array())
164            .map(|arr| -> Result<Vec<String>> {
165                arr.iter()
166                    .filter_map(|v| v.as_str())
167                    .map(|col| {
168                        // Layer 1: character-level validation (rejects SQL injection chars)
169                        validate_sql_expression(col, "partitionBy")?;
170                        // Layer 2: schema-based allowlist (defence-in-depth)
171                        allowlist.validate(col, "PARTITION BY")?;
172                        Ok(col.to_string())
173                    })
174                    .collect()
175            })
176            .transpose()?
177            .unwrap_or_default();
178
179        let order_by = window
180            .get("orderBy")
181            .and_then(|o| o.as_array())
182            .map(|arr| -> Result<Vec<OrderByClause>> {
183                arr.iter()
184                    .filter_map(|item| {
185                        let field = item["field"].as_str()?;
186                        let direction = match item.get("direction").and_then(|d| d.as_str()) {
187                            Some("DESC") => OrderDirection::Desc,
188                            _ => OrderDirection::Asc,
189                        };
190                        Some((field, direction))
191                    })
192                    .map(|(field, direction)| {
193                        // Layer 1: character-level validation
194                        validate_sql_expression(field, "orderBy.field")?;
195                        // Layer 2: schema-based allowlist (defence-in-depth)
196                        allowlist.validate(field, "ORDER BY")?;
197                        Ok(OrderByClause {
198                            field: field.to_string(),
199                            direction,
200                        })
201                    })
202                    .collect()
203            })
204            .transpose()?
205            .unwrap_or_default();
206
207        let frame = window.get("frame").map(Self::parse_window_frame).transpose()?;
208
209        Ok(WindowFunction {
210            function,
211            alias,
212            partition_by,
213            order_by,
214            frame,
215        })
216    }
217
218    fn parse_window_function_type(func: &serde_json::Value) -> Result<WindowFunctionType> {
219        serde_json::from_value(func.clone()).map_err(|e| {
220            FraiseQLError::validation(format!("Unknown or invalid window function: {e}"))
221        })
222    }
223
224    fn parse_window_frame(frame: &serde_json::Value) -> Result<WindowFrame> {
225        let frame_type = match frame["frame_type"].as_str() {
226            Some("ROWS") => FrameType::Rows,
227            Some("RANGE") => FrameType::Range,
228            Some("GROUPS") => FrameType::Groups,
229            _ => return Err(FraiseQLError::validation("Invalid or missing 'frame_type'")),
230        };
231
232        let start = Self::parse_frame_boundary(&frame["start"])?;
233        let end = Self::parse_frame_boundary(&frame["end"])?;
234        let exclusion = frame.get("exclusion").map(|e| match e.as_str() {
235            Some("current_row") => FrameExclusion::CurrentRow,
236            Some("group") => FrameExclusion::Group,
237            Some("ties") => FrameExclusion::Ties,
238            // "no_others" and unrecognised values default to NoOthers
239            _ => FrameExclusion::NoOthers,
240        });
241
242        Ok(WindowFrame {
243            frame_type,
244            start,
245            end,
246            exclusion,
247        })
248    }
249
250    fn parse_frame_boundary(boundary: &serde_json::Value) -> Result<FrameBoundary> {
251        match boundary["type"].as_str() {
252            Some("unbounded_preceding") => Ok(FrameBoundary::UnboundedPreceding),
253            Some("n_preceding") => {
254                let n = u32::try_from(
255                    boundary["n"]
256                        .as_u64()
257                        .ok_or_else(|| FraiseQLError::validation("Missing 'n' in N PRECEDING"))?,
258                )
259                .unwrap_or(u32::MAX);
260                Ok(FrameBoundary::NPreceding { n })
261            },
262            Some("current_row") => Ok(FrameBoundary::CurrentRow),
263            Some("n_following") => {
264                let n = u32::try_from(
265                    boundary["n"]
266                        .as_u64()
267                        .ok_or_else(|| FraiseQLError::validation("Missing 'n' in N FOLLOWING"))?,
268                )
269                .unwrap_or(u32::MAX);
270                Ok(FrameBoundary::NFollowing { n })
271            },
272            Some("unbounded_following") => Ok(FrameBoundary::UnboundedFollowing),
273            _ => Err(FraiseQLError::validation("Invalid frame boundary type")),
274        }
275    }
276
277    /// Validate window function plan
278    ///
279    /// # Errors
280    ///
281    /// Returns `FraiseQLError::Validation` if the plan uses features unsupported by the
282    /// target database (e.g., RANGE frames on MySQL/SQLite).
283    pub fn validate(
284        plan: &WindowExecutionPlan,
285        _metadata: &FactTableMetadata,
286        database_target: crate::db::types::DatabaseType,
287    ) -> Result<()> {
288        use crate::db::types::DatabaseType;
289
290        // Validate frame type supported by database
291        for window in &plan.windows {
292            if let Some(frame) = &window.frame {
293                if frame.frame_type == FrameType::Groups
294                    && !matches!(database_target, DatabaseType::PostgreSQL)
295                {
296                    return Err(FraiseQLError::validation(
297                        "GROUPS frame type only supported on PostgreSQL",
298                    ));
299                }
300
301                // Validate frame exclusion (PostgreSQL only)
302                if frame.exclusion.is_some() && !matches!(database_target, DatabaseType::PostgreSQL)
303                {
304                    return Err(FraiseQLError::validation(
305                        "Frame exclusion only supported on PostgreSQL",
306                    ));
307                }
308            }
309
310            // Validate PERCENT_RANK and CUME_DIST (not in SQLite)
311            match window.function {
312                WindowFunctionType::PercentRank | WindowFunctionType::CumeDist => {
313                    if matches!(database_target, DatabaseType::SQLite) {
314                        return Err(FraiseQLError::validation(
315                            "PERCENT_RANK and CUME_DIST not supported on SQLite",
316                        ));
317                    }
318                },
319                _ => {},
320            }
321        }
322
323        Ok(())
324    }
325}