1use crate::extend::ExtendNode;
2use crate::join::{SparqlJoinNode, SparqlJoinType, compute_sparql_join_columns};
3use crate::logical_plan_builder_context::RdfFusionLogicalPlanBuilderContext;
4use crate::minus::MinusNode;
5use crate::patterns::PatternNode;
6use crate::{RdfFusionExprBuilder, RdfFusionExprBuilderContext};
7use datafusion::arrow::datatypes::DataType;
8use datafusion::common::{Column, DFSchemaRef};
9use datafusion::logical_expr::{
10 Expr, ExprSchemable, Extension, LogicalPlan, LogicalPlanBuilder, Sort, SortExpr,
11 UserDefinedLogicalNode, col,
12};
13use rdf_fusion_encoding::EncodingName;
14use rdf_fusion_model::Variable;
15use rdf_fusion_model::{DFResult, TermPattern};
16use std::collections::{HashMap, HashSet};
17use std::sync::Arc;
18
19#[derive(Debug, Clone)]
68pub struct RdfFusionLogicalPlanBuilder {
69 plan_builder: LogicalPlanBuilder,
74 context: RdfFusionLogicalPlanBuilderContext,
76}
77
78impl RdfFusionLogicalPlanBuilder {
79 pub(crate) fn new(
81 context: RdfFusionLogicalPlanBuilderContext,
82 plan: Arc<LogicalPlan>,
83 ) -> Self {
84 let plan_builder = LogicalPlanBuilder::new_from_arc(plan);
85 Self {
86 plan_builder,
87 context,
88 }
89 }
90
91 pub fn project(self, variables: &[Variable]) -> DFResult<Self> {
93 let plan_builder = self.plan_builder.project(
94 variables
95 .iter()
96 .map(|v| col(Column::new_unqualified(v.as_str()))),
97 )?;
98 Ok(Self {
99 context: self.context.clone(),
100 plan_builder,
101 })
102 }
103
104 pub fn filter(self, expression: Expr) -> DFResult<RdfFusionLogicalPlanBuilder> {
115 let (datatype, _) = expression.data_type_and_nullable(self.schema())?;
116 let expression = match datatype {
117 DataType::Boolean => expression,
119 _ => self
121 .expr_builder(expression)?
122 .build_effective_boolean_value()?,
123 };
124
125 Ok(Self {
126 context: self.context.clone(),
127 plan_builder: self.plan_builder.filter(expression)?,
128 })
129 }
130
131 pub fn extend(
133 self,
134 variable: Variable,
135 expr: Expr,
136 ) -> DFResult<RdfFusionLogicalPlanBuilder> {
137 let inner = self.plan_builder.build()?;
138 let extend_node = ExtendNode::try_new(inner, variable, expr)?;
139 Ok(Self {
140 context: self.context.clone(),
141 plan_builder: create_extension_plan(extend_node),
142 })
143 }
144
145 pub fn join(
150 self,
151 rhs: LogicalPlan,
152 join_type: SparqlJoinType,
153 filter: Option<Expr>,
154 ) -> DFResult<RdfFusionLogicalPlanBuilder> {
155 let context = self.context.clone();
156
157 let (lhs, rhs) = self.align_encodings_of_common_columns(rhs)?;
158 let join_node = SparqlJoinNode::try_new(
159 context.encodings().clone(),
160 lhs.build()?,
161 rhs,
162 filter,
163 join_type,
164 )?;
165 Ok(Self {
166 context,
167 plan_builder: LogicalPlanBuilder::new(LogicalPlan::Extension(Extension {
168 node: Arc::new(join_node),
169 })),
170 })
171 }
172
173 pub fn slice(
175 self,
176 start: usize,
177 length: Option<usize>,
178 ) -> DFResult<RdfFusionLogicalPlanBuilder> {
179 Ok(Self {
180 context: self.context.clone(),
181 plan_builder: self.plan_builder.limit(start, length)?,
182 })
183 }
184
185 pub fn order_by(self, exprs: &[SortExpr]) -> DFResult<RdfFusionLogicalPlanBuilder> {
187 let exprs = exprs
188 .iter()
189 .map(|sort| self.ensure_sortable(sort))
190 .collect::<DFResult<Vec<_>>>()?;
191
192 let context = self.context.clone();
193 let plan = LogicalPlan::Sort(Sort {
194 input: Arc::new(self.build()?),
195 expr: exprs,
196 fetch: None,
197 });
198
199 Ok(Self {
200 context,
201 plan_builder: LogicalPlanBuilder::new(plan),
202 })
203 }
204
205 fn ensure_sortable(&self, e: &SortExpr) -> DFResult<SortExpr> {
207 let expr = self
208 .expr_builder(e.expr.clone())?
209 .with_encoding(EncodingName::Sortable)?
210 .build()?;
211 Ok(SortExpr::new(expr, e.asc, e.nulls_first))
212 }
213
214 pub fn union(self, rhs: LogicalPlan) -> DFResult<RdfFusionLogicalPlanBuilder> {
216 let context = self.context.clone();
217
218 let (lhs, rhs) = self.align_encodings_of_common_columns(rhs)?;
219 Ok(Self {
220 context,
221 plan_builder: lhs.plan_builder.union_by_name(rhs)?,
222 })
223 }
224
225 pub fn minus(self, rhs: LogicalPlan) -> DFResult<RdfFusionLogicalPlanBuilder> {
227 let minus_node = MinusNode::new(self.plan_builder.build()?, rhs);
228 Ok(Self {
229 context: self.context,
230 plan_builder: create_extension_plan(minus_node),
231 })
232 }
233
234 pub fn group(
236 self,
237 variables: &[Variable],
238 aggregates: &[(Variable, Expr)],
239 ) -> DFResult<RdfFusionLogicalPlanBuilder> {
240 let group_expr = variables
241 .iter()
242 .map(|v| self.create_group_expr(v))
243 .collect::<DFResult<Vec<_>>>()?;
244 let aggr_expr = aggregates
245 .iter()
246 .map(|(v, e)| e.clone().alias(v.as_str()))
247 .collect::<Vec<_>>();
248
249 Ok(Self {
250 context: self.context,
251 plan_builder: self.plan_builder.aggregate(group_expr, aggr_expr)?,
252 })
253 }
254
255 fn create_group_expr(&self, v: &Variable) -> DFResult<Expr> {
258 Ok(self
259 .expr_builder_root()
260 .variable(v.as_ref())?
261 .with_any_encoding(&[EncodingName::PlainTerm, EncodingName::ObjectId])?
262 .build()?
263 .alias(v.as_str()))
264 }
265
266 pub fn distinct(self) -> DFResult<RdfFusionLogicalPlanBuilder> {
268 self.distinct_with_sort(Vec::new())
269 }
270
271 pub fn distinct_with_sort(
273 self,
274 sorts: Vec<SortExpr>,
275 ) -> DFResult<RdfFusionLogicalPlanBuilder> {
276 let schema = self.plan_builder.schema();
277 let (on_expr, sorts) =
278 create_distinct_on_expressions(self.expr_builder_root(), sorts)?;
279 let select_expr = schema.columns().into_iter().map(col).collect();
280 let sorts = if sorts.is_empty() { None } else { Some(sorts) };
281
282 Ok(Self {
283 context: self.context,
284 plan_builder: self.plan_builder.distinct_on(on_expr, select_expr, sorts)?,
285 })
286 }
287
288 pub fn pattern(
290 self,
291 pattern: Vec<Option<TermPattern>>,
292 ) -> DFResult<RdfFusionLogicalPlanBuilder> {
293 let pattern_node = PatternNode::try_new(self.plan_builder.build()?, pattern)?;
294 Ok(Self {
295 context: self.context,
296 plan_builder: LogicalPlanBuilder::from(LogicalPlan::Extension(Extension {
297 node: Arc::new(pattern_node),
298 })),
299 })
300 }
301
302 pub fn with_plain_terms(self) -> DFResult<RdfFusionLogicalPlanBuilder> {
304 let with_correct_encoding = self
305 .schema()
306 .columns()
307 .into_iter()
308 .map(|c| {
309 let name = c.name().to_owned();
310 let expr = self
311 .expr_builder(col(c))?
312 .with_encoding(EncodingName::PlainTerm)?
313 .build()?
314 .alias(name);
315 Ok(expr)
316 })
317 .collect::<DFResult<Vec<_>>>()?;
318 Ok(Self {
319 context: self.context,
320 plan_builder: self.plan_builder.project(with_correct_encoding)?,
321 })
322 }
323
324 pub fn schema(&self) -> &DFSchemaRef {
326 self.plan_builder.schema()
327 }
328
329 pub fn context(&self) -> &RdfFusionLogicalPlanBuilderContext {
331 &self.context
332 }
333
334 pub fn into_inner(self) -> LogicalPlanBuilder {
336 self.plan_builder
337 }
338
339 pub fn build(self) -> DFResult<LogicalPlan> {
341 self.plan_builder.build()
342 }
343
344 pub fn expr_builder_root(&self) -> RdfFusionExprBuilderContext<'_> {
346 let schema = self.schema().as_ref();
347 self.context.expr_builder_context_with_schema(schema)
348 }
349
350 pub fn expr_builder(&self, expr: Expr) -> DFResult<RdfFusionExprBuilder<'_>> {
352 self.expr_builder_root().try_create_builder(expr)
353 }
354
355 fn align_encodings_of_common_columns(
358 self,
359 rhs: LogicalPlan,
360 ) -> DFResult<(Self, LogicalPlan)> {
361 let join_columns = compute_sparql_join_columns(
362 self.context.encodings(),
363 self.schema().as_ref(),
364 rhs.schema().as_ref(),
365 )?;
366
367 if join_columns.is_empty() {
368 return Ok((self, rhs));
369 }
370
371 let lhs_expr_builder =
372 self.context.expr_builder_context_with_schema(self.schema());
373 let rhs_expr_builder =
374 self.context.expr_builder_context_with_schema(rhs.schema());
375
376 let lhs_projections =
377 build_projections_for_encoding_alignment(lhs_expr_builder, &join_columns)?;
378 let lhs = match lhs_projections {
379 None => self.plan_builder.build()?,
380 Some(projections) => self.plan_builder.project(projections)?.build()?,
381 };
382
383 let rhs_projections =
384 build_projections_for_encoding_alignment(rhs_expr_builder, &join_columns)?;
385 let rhs = match rhs_projections {
386 None => rhs,
387 Some(projections) => {
388 LogicalPlanBuilder::new(rhs).project(projections)?.build()?
389 }
390 };
391
392 let context = self.context.clone();
393 Ok((Self::new(context, Arc::new(lhs)), rhs))
394 }
395}
396
397fn build_projections_for_encoding_alignment(
401 expr_builder_root: RdfFusionExprBuilderContext<'_>,
402 join_columns: &HashMap<String, HashSet<EncodingName>>,
403) -> DFResult<Option<Vec<Expr>>> {
404 let projections = expr_builder_root
405 .schema()
406 .fields()
407 .iter()
408 .map(|f| {
409 if let Some(encodings) = join_columns.get(f.name()) {
410 let expr = col(Column::new_unqualified(f.name()));
411
412 if encodings.len() > 1 {
413 let expr = expr_builder_root.try_create_builder(expr)?;
414 Ok(expr
415 .with_encoding(EncodingName::PlainTerm)?
416 .build()?
417 .alias(f.name()))
418 } else {
419 Ok(expr)
420 }
421 } else {
422 Ok(col(Column::new_unqualified(f.name())))
423 }
424 })
425 .collect::<DFResult<Vec<_>>>()?;
426
427 if projections.iter().all(|e| matches!(e, Expr::Column(_))) {
428 Ok(None)
429 } else {
430 Ok(Some(projections))
431 }
432}
433
434fn create_distinct_on_expressions(
435 expr_builder_root: RdfFusionExprBuilderContext<'_>,
436 mut sort_expr: Vec<SortExpr>,
437) -> DFResult<(Vec<Expr>, Vec<SortExpr>)> {
438 let mut on_expr = sort_expr
439 .iter()
440 .map(|se| se.expr.clone())
441 .collect::<Vec<_>>();
442
443 for column in expr_builder_root.schema().columns() {
444 let expr = col(column.clone());
445 let sortable_expr = expr_builder_root
446 .try_create_builder(expr.clone())?
447 .with_encoding(EncodingName::Sortable)?
448 .build()?;
449
450 if !on_expr.contains(&sortable_expr) {
452 on_expr.push(expr.clone());
453 sort_expr.push(SortExpr::new(expr, true, true))
454 }
455 }
456
457 Ok((on_expr, sort_expr))
458}
459
460fn create_extension_plan(
462 node: impl UserDefinedLogicalNode + 'static,
463) -> LogicalPlanBuilder {
464 LogicalPlanBuilder::new(LogicalPlan::Extension(Extension {
465 node: Arc::new(node),
466 }))
467}