Skip to main content

fraiseql_core/compiler/window_functions/
planner.rs

1use super::{
2    FactTableMetadata, FraiseQLError, FrameBoundary, FrameExclusion, FrameType, OrderByClause,
3    OrderDirection, Result, SelectColumn, WhereClause, WindowExecutionPlan, WindowFrame,
4    WindowFunction, WindowFunctionType,
5};
6use crate::compiler::window_allowlist::WindowAllowlist;
7
8/// Window function plan generator
9pub struct WindowFunctionPlanner;
10
11/// Validate a SQL column/expression string for safe embedding.
12///
13/// Allows identifiers, JSONB path operators (`->`, `->>`), quoted string keys
14/// (single-quote), periods, spaces, and parentheses for function-style paths.
15/// Rejects characters that could enable SQL injection (`;`, `--`, `/*`, `=`,
16/// `\`, nul bytes, etc.).
17fn validate_sql_expression(value: &str, context: &str) -> Result<()> {
18    let safe = value.chars().all(|c| {
19        c.is_alphanumeric() || matches!(c, '_' | '-' | '>' | '\'' | '.' | ' ' | '(' | ')')
20    });
21    if safe {
22        Ok(())
23    } else {
24        Err(FraiseQLError::Validation {
25            message: format!(
26                "Unsafe characters in window function {context}: {value:?}. \
27                 Only identifiers, JSONB path operators (-> ->>), and quoted keys are allowed."
28            ),
29            path:    None,
30        })
31    }
32}
33
34impl WindowFunctionPlanner {
35    /// Generate window function execution plan from JSON query
36    ///
37    /// # Example Query Format
38    ///
39    /// ```json
40    /// {
41    ///   "table": "tf_sales",
42    ///   "select": ["revenue", "category"],
43    ///   "windows": [
44    ///     {
45    ///       "function": {"type": "row_number"},
46    ///       "alias": "rank",
47    ///       "partitionBy": ["data->>'category'"],
48    ///       "orderBy": [{"field": "revenue", "direction": "DESC"}]
49    ///     }
50    ///   ],
51    ///   "limit": 10
52    /// }
53    /// ```
54    ///
55    /// # Errors
56    ///
57    /// Returns `FraiseQLError::Validation` if the window function specification is invalid
58    /// (e.g., missing required fields, disallowed characters, or unsupported function names).
59    pub fn plan(
60        query: &serde_json::Value,
61        metadata: &FactTableMetadata,
62    ) -> Result<WindowExecutionPlan> {
63        // Build schema-based allowlist from metadata (defence-in-depth on top of
64        // character-level validation).  Empty metadata → empty allowlist → no
65        // schema-constraint enforcement (character validation still applies).
66        let allowlist = WindowAllowlist::from_metadata(metadata);
67
68        // Parse table name
69        let table = query["table"]
70            .as_str()
71            .ok_or_else(|| FraiseQLError::validation("Missing 'table' field"))?
72            .to_string();
73
74        // Parse SELECT columns
75        let select = Self::parse_select_columns(query)?;
76
77        // Parse window functions
78        let windows = Self::parse_window_functions(query, &allowlist)?;
79
80        // Parse WHERE clause (placeholder - full implementation would parse actual conditions)
81        let where_clause = query.get("where").map(|_| WhereClause::And(vec![]));
82
83        // Parse ORDER BY
84        let order_by = query
85            .get("orderBy")
86            .and_then(|v| v.as_array())
87            .map(|arr| {
88                arr.iter()
89                    .filter_map(|item| {
90                        let direction = match item.get("direction").and_then(|d| d.as_str()) {
91                            Some("DESC") => OrderDirection::Desc,
92                            _ => OrderDirection::Asc,
93                        };
94                        Some(OrderByClause::new(item["field"].as_str()?.to_string(), direction))
95                    })
96                    .collect()
97            })
98            .unwrap_or_default();
99
100        // Parse LIMIT/OFFSET
101        let limit = query
102            .get("limit")
103            .and_then(|v| v.as_u64())
104            .map(|n| u32::try_from(n).unwrap_or(u32::MAX));
105        let offset = query
106            .get("offset")
107            .and_then(|v| v.as_u64())
108            .map(|n| u32::try_from(n).unwrap_or(u32::MAX));
109
110        Ok(WindowExecutionPlan {
111            table,
112            select,
113            windows,
114            where_clause,
115            order_by,
116            limit,
117            offset,
118        })
119    }
120
121    fn parse_select_columns(query: &serde_json::Value) -> Result<Vec<SelectColumn>> {
122        let default_array = vec![];
123        let select = query.get("select").and_then(|s| s.as_array()).unwrap_or(&default_array);
124
125        let columns = select
126            .iter()
127            .filter_map(|col| {
128                col.as_str().map(|col_str| SelectColumn {
129                    expression: col_str.to_string(),
130                    alias:      col_str.to_string(),
131                })
132            })
133            .collect();
134
135        Ok(columns)
136    }
137
138    fn parse_window_functions(
139        query: &serde_json::Value,
140        allowlist: &WindowAllowlist,
141    ) -> Result<Vec<WindowFunction>> {
142        let default_array = vec![];
143        let windows = query.get("windows").and_then(|w| w.as_array()).unwrap_or(&default_array);
144
145        windows.iter().map(|w| Self::parse_single_window(w, allowlist)).collect()
146    }
147
148    fn parse_single_window(
149        window: &serde_json::Value,
150        allowlist: &WindowAllowlist,
151    ) -> Result<WindowFunction> {
152        let function = Self::parse_window_function_type(&window["function"])?;
153        let alias = window["alias"]
154            .as_str()
155            .ok_or_else(|| FraiseQLError::validation("Missing 'alias' in window function"))?
156            .to_string();
157
158        let partition_by = window
159            .get("partitionBy")
160            .and_then(|p| p.as_array())
161            .map(|arr| -> Result<Vec<String>> {
162                arr.iter()
163                    .filter_map(|v| v.as_str())
164                    .map(|col| {
165                        // Layer 1: character-level validation (rejects SQL injection chars)
166                        validate_sql_expression(col, "partitionBy")?;
167                        // Layer 2: schema-based allowlist (defence-in-depth)
168                        allowlist.validate(col, "PARTITION BY")?;
169                        Ok(col.to_string())
170                    })
171                    .collect()
172            })
173            .transpose()?
174            .unwrap_or_default();
175
176        let order_by = window
177            .get("orderBy")
178            .and_then(|o| o.as_array())
179            .map(|arr| -> Result<Vec<OrderByClause>> {
180                arr.iter()
181                    .filter_map(|item| {
182                        let field = item["field"].as_str()?;
183                        let direction = match item.get("direction").and_then(|d| d.as_str()) {
184                            Some("DESC") => OrderDirection::Desc,
185                            _ => OrderDirection::Asc,
186                        };
187                        Some((field, direction))
188                    })
189                    .map(|(field, direction)| {
190                        // Layer 1: character-level validation
191                        validate_sql_expression(field, "orderBy.field")?;
192                        // Layer 2: schema-based allowlist (defence-in-depth)
193                        allowlist.validate(field, "ORDER BY")?;
194                        Ok(OrderByClause::new(field.to_string(), direction))
195                    })
196                    .collect()
197            })
198            .transpose()?
199            .unwrap_or_default();
200
201        let frame = window.get("frame").map(Self::parse_window_frame).transpose()?;
202
203        Ok(WindowFunction {
204            function,
205            alias,
206            partition_by,
207            order_by,
208            frame,
209        })
210    }
211
212    fn parse_window_function_type(func: &serde_json::Value) -> Result<WindowFunctionType> {
213        serde_json::from_value(func.clone()).map_err(|e| {
214            FraiseQLError::validation(format!("Unknown or invalid window function: {e}"))
215        })
216    }
217
218    fn parse_window_frame(frame: &serde_json::Value) -> Result<WindowFrame> {
219        let frame_type = match frame["frame_type"].as_str() {
220            Some("ROWS") => FrameType::Rows,
221            Some("RANGE") => FrameType::Range,
222            Some("GROUPS") => FrameType::Groups,
223            _ => return Err(FraiseQLError::validation("Invalid or missing 'frame_type'")),
224        };
225
226        let start = Self::parse_frame_boundary(&frame["start"])?;
227        let end = Self::parse_frame_boundary(&frame["end"])?;
228        let exclusion = frame.get("exclusion").map(|e| match e.as_str() {
229            Some("current_row") => FrameExclusion::CurrentRow,
230            Some("group") => FrameExclusion::Group,
231            Some("ties") => FrameExclusion::Ties,
232            // "no_others" and unrecognised values default to NoOthers
233            _ => FrameExclusion::NoOthers,
234        });
235
236        Ok(WindowFrame {
237            frame_type,
238            start,
239            end,
240            exclusion,
241        })
242    }
243
244    fn parse_frame_boundary(boundary: &serde_json::Value) -> Result<FrameBoundary> {
245        match boundary["type"].as_str() {
246            Some("unbounded_preceding") => Ok(FrameBoundary::UnboundedPreceding),
247            Some("n_preceding") => {
248                let n = u32::try_from(
249                    boundary["n"]
250                        .as_u64()
251                        .ok_or_else(|| FraiseQLError::validation("Missing 'n' in N PRECEDING"))?,
252                )
253                .unwrap_or(u32::MAX);
254                Ok(FrameBoundary::NPreceding { n })
255            },
256            Some("current_row") => Ok(FrameBoundary::CurrentRow),
257            Some("n_following") => {
258                let n = u32::try_from(
259                    boundary["n"]
260                        .as_u64()
261                        .ok_or_else(|| FraiseQLError::validation("Missing 'n' in N FOLLOWING"))?,
262                )
263                .unwrap_or(u32::MAX);
264                Ok(FrameBoundary::NFollowing { n })
265            },
266            Some("unbounded_following") => Ok(FrameBoundary::UnboundedFollowing),
267            _ => Err(FraiseQLError::validation("Invalid frame boundary type")),
268        }
269    }
270
271    /// Validate window function plan
272    ///
273    /// # Errors
274    ///
275    /// Returns `FraiseQLError::Validation` if the plan uses features unsupported by the
276    /// target database (e.g., RANGE frames on MySQL/SQLite).
277    pub fn validate(
278        plan: &WindowExecutionPlan,
279        _metadata: &FactTableMetadata,
280        database_target: crate::db::types::DatabaseType,
281    ) -> Result<()> {
282        use crate::db::types::DatabaseType;
283
284        // Validate frame type supported by database
285        for window in &plan.windows {
286            if let Some(frame) = &window.frame {
287                if frame.frame_type == FrameType::Groups
288                    && !matches!(database_target, DatabaseType::PostgreSQL)
289                {
290                    return Err(FraiseQLError::validation(
291                        "GROUPS frame type only supported on PostgreSQL",
292                    ));
293                }
294
295                // Validate frame exclusion (PostgreSQL only)
296                if frame.exclusion.is_some() && !matches!(database_target, DatabaseType::PostgreSQL)
297                {
298                    return Err(FraiseQLError::validation(
299                        "Frame exclusion only supported on PostgreSQL",
300                    ));
301                }
302            }
303
304            // Validate PERCENT_RANK and CUME_DIST (not in SQLite)
305            match window.function {
306                WindowFunctionType::PercentRank | WindowFunctionType::CumeDist => {
307                    if matches!(database_target, DatabaseType::SQLite) {
308                        return Err(FraiseQLError::validation(
309                            "PERCENT_RANK and CUME_DIST not supported on SQLite",
310                        ));
311                    }
312                },
313                _ => {},
314            }
315        }
316
317        Ok(())
318    }
319}