rdf_fusion_functions/aggregates/
group_concat.rs

1use datafusion::arrow::array::{ArrayRef, AsArray};
2use datafusion::arrow::datatypes::DataType;
3use datafusion::common::plan_err;
4use datafusion::logical_expr::expr::AggregateFunction;
5use datafusion::logical_expr::function::{
6    AccumulatorArgs, AggregateFunctionSimplification,
7};
8use datafusion::logical_expr::{
9    AggregateUDF, AggregateUDFImpl, Expr, Signature, Volatility,
10};
11use datafusion::scalar::ScalarValue;
12use datafusion::{error::Result, physical_plan::Accumulator};
13use rdf_fusion_encoding::typed_value::TYPED_VALUE_ENCODING;
14use rdf_fusion_encoding::typed_value::decoders::{
15    DefaultTypedValueDecoder, StringLiteralRefTermValueDecoder,
16};
17use rdf_fusion_encoding::typed_value::encoders::StringLiteralRefTermValueEncoder;
18use rdf_fusion_encoding::{TermDecoder, TermEncoder, TermEncoding};
19use rdf_fusion_extensions::functions::BuiltinName;
20use rdf_fusion_model::DFResult;
21use rdf_fusion_model::{StringLiteralRef, ThinError, TypedValueRef};
22use std::any::Any;
23use std::sync::Arc;
24
25pub fn group_concat_typed_value() -> AggregateUDF {
26    AggregateUDF::new_from_impl(SparqlGroupConcat::new())
27}
28
29/// Concatenates the strings in a set with a given separator.
30///
31/// Relevant Resources:
32/// - [SPARQL 1.1 - GROUP CONCAT](https://www.w3.org/TR/sparql11-query/#defn_aggGroupConcat)
33#[derive(Debug, PartialEq, Eq, Hash)]
34pub struct SparqlGroupConcat {
35    name: String,
36    signature: Signature,
37}
38
39impl SparqlGroupConcat {
40    /// Creates a new [SparqlGroupConcat] aggregate UDF.
41    pub fn new() -> Self {
42        let name = BuiltinName::GroupConcat.to_string();
43        let signature = Signature::uniform(
44            2,
45            vec![TYPED_VALUE_ENCODING.data_type()],
46            Volatility::Stable,
47        );
48        SparqlGroupConcat { name, signature }
49    }
50}
51
52impl Default for SparqlGroupConcat {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl AggregateUDFImpl for SparqlGroupConcat {
59    fn as_any(&self) -> &dyn Any {
60        self
61    }
62
63    fn name(&self) -> &str {
64        &self.name
65    }
66
67    fn signature(&self) -> &Signature {
68        &self.signature
69    }
70
71    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
72        Ok(TYPED_VALUE_ENCODING.data_type())
73    }
74
75    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
76        unreachable!("GROUP_CONCAT should have been simplified by the optimizer")
77    }
78
79    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
80        Some(Box::new(|function, _info| {
81            debug_assert!(
82                function.params.args.len() == 2,
83                "Separator should be the second argument"
84            );
85
86            let separator_expr = &function.params.args[1];
87            let separator = match separator_expr {
88                Expr::Literal(value, _) => {
89                    let scalar = TYPED_VALUE_ENCODING.try_new_scalar(value.clone())?;
90                    let term = DefaultTypedValueDecoder::decode_term(&scalar);
91                    match term {
92                        Ok(TypedValueRef::SimpleLiteral(literal)) => {
93                            literal.value.to_owned()
94                        }
95                        Err(_) => " ".to_owned(),
96                        _ => return plan_err!("Separator should be a simple literal"),
97                    }
98                }
99                _ => return plan_err!("Separator should be a literal"),
100            };
101
102            Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
103                AggregateUDF::new_from_impl(SparqlGroupConcatWithSeparator::new(
104                    separator,
105                ))
106                .into(),
107                vec![function.params.args[0].clone()],
108                function.params.distinct,
109                function.params.filter.clone(),
110                function.params.order_by.clone(),
111                function.params.null_treatment,
112            )))
113        }))
114    }
115}
116
117#[derive(Debug, PartialEq, Eq, Hash)]
118struct SparqlGroupConcatWithSeparator {
119    name: String,
120    signature: Signature,
121    separator: String,
122}
123
124impl SparqlGroupConcatWithSeparator {
125    /// Creates a new [SparqlGroupConcatWithSeparator] aggregate UDF.
126    pub fn new(separator: String) -> Self {
127        let name = BuiltinName::GroupConcat.to_string();
128        let signature =
129            Signature::exact(vec![TYPED_VALUE_ENCODING.data_type()], Volatility::Stable);
130        SparqlGroupConcatWithSeparator {
131            name,
132            signature,
133            separator,
134        }
135    }
136}
137
138impl AggregateUDFImpl for SparqlGroupConcatWithSeparator {
139    fn as_any(&self) -> &dyn Any {
140        self
141    }
142
143    fn name(&self) -> &str {
144        &self.name
145    }
146
147    fn signature(&self) -> &Signature {
148        &self.signature
149    }
150
151    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
152        Ok(TYPED_VALUE_ENCODING.data_type())
153    }
154
155    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
156        Ok(Box::new(SparqlGroupConcatAccumulator::new(
157            self.separator.clone(),
158        )))
159    }
160}
161
162#[derive(Debug)]
163struct SparqlGroupConcatAccumulator {
164    separator: String,
165    error: bool,
166    value: Option<String>,
167    language_error: bool,
168    language: Option<String>,
169}
170
171impl SparqlGroupConcatAccumulator {
172    pub fn new(separator: String) -> Self {
173        SparqlGroupConcatAccumulator {
174            separator,
175            error: false,
176            value: None,
177            language_error: false,
178            language: None,
179        }
180    }
181}
182
183impl Accumulator for SparqlGroupConcatAccumulator {
184    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
185        if self.error || values.is_empty() {
186            return Ok(());
187        }
188
189        let mut value_exists = self.value.is_some();
190        let mut value = self.value.take().unwrap_or_default();
191
192        let arr = TYPED_VALUE_ENCODING.try_new_array(Arc::clone(&values[0]))?;
193        for string in StringLiteralRefTermValueDecoder::decode_terms(&arr) {
194            if let Ok(string) = string {
195                if value_exists {
196                    value += self.separator.as_str();
197                }
198                value += string.0;
199                value_exists = true;
200                if let Some(lang) = &self.language {
201                    if Some(lang.as_str()) != string.1 {
202                        self.language_error = true;
203                        self.language = None;
204                    }
205                } else {
206                    self.language = string.1.map(ToOwned::to_owned);
207                }
208            } else {
209                self.error = true;
210                self.value = None;
211                return Ok(());
212            }
213        }
214
215        self.value = Some(value);
216        Ok(())
217    }
218
219    fn evaluate(&mut self) -> DFResult<ScalarValue> {
220        if self.error {
221            return StringLiteralRefTermValueEncoder::encode_term(ThinError::expected())
222                .map(rdf_fusion_encoding::EncodingScalar::into_scalar_value);
223        }
224
225        let value = self.value.as_deref().unwrap_or("");
226        let literal = StringLiteralRef(value, self.language.as_deref());
227        StringLiteralRefTermValueEncoder::encode_term(Ok(literal))
228            .map(rdf_fusion_encoding::EncodingScalar::into_scalar_value)
229    }
230
231    fn size(&self) -> usize {
232        size_of_val(self)
233    }
234
235    fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
236        Ok(vec![
237            ScalarValue::Boolean(Some(self.error)),
238            ScalarValue::Utf8(self.value.clone()),
239            ScalarValue::Boolean(Some(self.language_error)),
240            ScalarValue::Utf8(self.language.clone()),
241        ])
242    }
243
244    #[allow(clippy::missing_asserts_for_indexing)]
245    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
246        let error = states[0].as_boolean().iter().any(|e| e == Some(true));
247        if error {
248            self.error = true;
249            self.value = None;
250            self.language = None;
251            return Ok(());
252        }
253
254        let old_values = states[1].as_string::<i32>();
255        for old_value in old_values.iter().flatten() {
256            self.value = match self.value.take() {
257                None => Some(old_value.to_owned()),
258                Some(value) => Some(value + self.separator.as_str() + old_value),
259            };
260        }
261
262        let existing_language_error =
263            states[2].as_boolean().iter().any(|e| e == Some(true));
264        if existing_language_error {
265            self.language_error = true;
266            self.language = None;
267            return Ok(());
268        }
269
270        let old_languages = states[3].as_string::<i32>();
271        for old_language in old_languages {
272            self.language = match (self.language.take(), old_language) {
273                (None, other) => other.map(ToOwned::to_owned),
274                (other, None) => other,
275                (Some(language), Some(old_language)) => {
276                    if language.as_str() == old_language {
277                        Some(language)
278                    } else {
279                        self.language_error = true;
280                        None
281                    }
282                }
283            };
284        }
285
286        Ok(())
287    }
288}