rdf_fusion_functions/aggregates/
group_concat.rs1use datafusion::arrow::array::{ArrayRef, AsArray};
2use datafusion::arrow::datatypes::DataType;
3use datafusion::common::plan_err;
4use datafusion::logical_expr::expr::AggregateFunction;
5use datafusion::logical_expr::function::{
6 AccumulatorArgs, AggregateFunctionSimplification,
7};
8use datafusion::logical_expr::{
9 AggregateUDF, AggregateUDFImpl, Expr, Signature, Volatility,
10};
11use datafusion::scalar::ScalarValue;
12use datafusion::{error::Result, physical_plan::Accumulator};
13use rdf_fusion_encoding::typed_value::TYPED_VALUE_ENCODING;
14use rdf_fusion_encoding::typed_value::decoders::{
15 DefaultTypedValueDecoder, StringLiteralRefTermValueDecoder,
16};
17use rdf_fusion_encoding::typed_value::encoders::StringLiteralRefTermValueEncoder;
18use rdf_fusion_encoding::{TermDecoder, TermEncoder, TermEncoding};
19use rdf_fusion_extensions::functions::BuiltinName;
20use rdf_fusion_model::DFResult;
21use rdf_fusion_model::{StringLiteralRef, ThinError, TypedValueRef};
22use std::any::Any;
23use std::sync::Arc;
24
25pub fn group_concat_typed_value() -> AggregateUDF {
26 AggregateUDF::new_from_impl(SparqlGroupConcat::new())
27}
28
29#[derive(Debug, PartialEq, Eq, Hash)]
34pub struct SparqlGroupConcat {
35 name: String,
36 signature: Signature,
37}
38
39impl SparqlGroupConcat {
40 pub fn new() -> Self {
42 let name = BuiltinName::GroupConcat.to_string();
43 let signature = Signature::uniform(
44 2,
45 vec![TYPED_VALUE_ENCODING.data_type()],
46 Volatility::Stable,
47 );
48 SparqlGroupConcat { name, signature }
49 }
50}
51
52impl Default for SparqlGroupConcat {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl AggregateUDFImpl for SparqlGroupConcat {
59 fn as_any(&self) -> &dyn Any {
60 self
61 }
62
63 fn name(&self) -> &str {
64 &self.name
65 }
66
67 fn signature(&self) -> &Signature {
68 &self.signature
69 }
70
71 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
72 Ok(TYPED_VALUE_ENCODING.data_type())
73 }
74
75 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
76 unreachable!("GROUP_CONCAT should have been simplified by the optimizer")
77 }
78
79 fn simplify(&self) -> Option<AggregateFunctionSimplification> {
80 Some(Box::new(|function, _info| {
81 debug_assert!(
82 function.params.args.len() == 2,
83 "Separator should be the second argument"
84 );
85
86 let separator_expr = &function.params.args[1];
87 let separator = match separator_expr {
88 Expr::Literal(value, _) => {
89 let scalar = TYPED_VALUE_ENCODING.try_new_scalar(value.clone())?;
90 let term = DefaultTypedValueDecoder::decode_term(&scalar);
91 match term {
92 Ok(TypedValueRef::SimpleLiteral(literal)) => {
93 literal.value.to_owned()
94 }
95 Err(_) => " ".to_owned(),
96 _ => return plan_err!("Separator should be a simple literal"),
97 }
98 }
99 _ => return plan_err!("Separator should be a literal"),
100 };
101
102 Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
103 AggregateUDF::new_from_impl(SparqlGroupConcatWithSeparator::new(
104 separator,
105 ))
106 .into(),
107 vec![function.params.args[0].clone()],
108 function.params.distinct,
109 function.params.filter.clone(),
110 function.params.order_by.clone(),
111 function.params.null_treatment,
112 )))
113 }))
114 }
115}
116
117#[derive(Debug, PartialEq, Eq, Hash)]
118struct SparqlGroupConcatWithSeparator {
119 name: String,
120 signature: Signature,
121 separator: String,
122}
123
124impl SparqlGroupConcatWithSeparator {
125 pub fn new(separator: String) -> Self {
127 let name = BuiltinName::GroupConcat.to_string();
128 let signature =
129 Signature::exact(vec![TYPED_VALUE_ENCODING.data_type()], Volatility::Stable);
130 SparqlGroupConcatWithSeparator {
131 name,
132 signature,
133 separator,
134 }
135 }
136}
137
138impl AggregateUDFImpl for SparqlGroupConcatWithSeparator {
139 fn as_any(&self) -> &dyn Any {
140 self
141 }
142
143 fn name(&self) -> &str {
144 &self.name
145 }
146
147 fn signature(&self) -> &Signature {
148 &self.signature
149 }
150
151 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
152 Ok(TYPED_VALUE_ENCODING.data_type())
153 }
154
155 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
156 Ok(Box::new(SparqlGroupConcatAccumulator::new(
157 self.separator.clone(),
158 )))
159 }
160}
161
162#[derive(Debug)]
163struct SparqlGroupConcatAccumulator {
164 separator: String,
165 error: bool,
166 value: Option<String>,
167 language_error: bool,
168 language: Option<String>,
169}
170
171impl SparqlGroupConcatAccumulator {
172 pub fn new(separator: String) -> Self {
173 SparqlGroupConcatAccumulator {
174 separator,
175 error: false,
176 value: None,
177 language_error: false,
178 language: None,
179 }
180 }
181}
182
183impl Accumulator for SparqlGroupConcatAccumulator {
184 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
185 if self.error || values.is_empty() {
186 return Ok(());
187 }
188
189 let mut value_exists = self.value.is_some();
190 let mut value = self.value.take().unwrap_or_default();
191
192 let arr = TYPED_VALUE_ENCODING.try_new_array(Arc::clone(&values[0]))?;
193 for string in StringLiteralRefTermValueDecoder::decode_terms(&arr) {
194 if let Ok(string) = string {
195 if value_exists {
196 value += self.separator.as_str();
197 }
198 value += string.0;
199 value_exists = true;
200 if let Some(lang) = &self.language {
201 if Some(lang.as_str()) != string.1 {
202 self.language_error = true;
203 self.language = None;
204 }
205 } else {
206 self.language = string.1.map(ToOwned::to_owned);
207 }
208 } else {
209 self.error = true;
210 self.value = None;
211 return Ok(());
212 }
213 }
214
215 self.value = Some(value);
216 Ok(())
217 }
218
219 fn evaluate(&mut self) -> DFResult<ScalarValue> {
220 if self.error {
221 return StringLiteralRefTermValueEncoder::encode_term(ThinError::expected())
222 .map(rdf_fusion_encoding::EncodingScalar::into_scalar_value);
223 }
224
225 let value = self.value.as_deref().unwrap_or("");
226 let literal = StringLiteralRef(value, self.language.as_deref());
227 StringLiteralRefTermValueEncoder::encode_term(Ok(literal))
228 .map(rdf_fusion_encoding::EncodingScalar::into_scalar_value)
229 }
230
231 fn size(&self) -> usize {
232 size_of_val(self)
233 }
234
235 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
236 Ok(vec![
237 ScalarValue::Boolean(Some(self.error)),
238 ScalarValue::Utf8(self.value.clone()),
239 ScalarValue::Boolean(Some(self.language_error)),
240 ScalarValue::Utf8(self.language.clone()),
241 ])
242 }
243
244 #[allow(clippy::missing_asserts_for_indexing)]
245 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
246 let error = states[0].as_boolean().iter().any(|e| e == Some(true));
247 if error {
248 self.error = true;
249 self.value = None;
250 self.language = None;
251 return Ok(());
252 }
253
254 let old_values = states[1].as_string::<i32>();
255 for old_value in old_values.iter().flatten() {
256 self.value = match self.value.take() {
257 None => Some(old_value.to_owned()),
258 Some(value) => Some(value + self.separator.as_str() + old_value),
259 };
260 }
261
262 let existing_language_error =
263 states[2].as_boolean().iter().any(|e| e == Some(true));
264 if existing_language_error {
265 self.language_error = true;
266 self.language = None;
267 return Ok(());
268 }
269
270 let old_languages = states[3].as_string::<i32>();
271 for old_language in old_languages {
272 self.language = match (self.language.take(), old_language) {
273 (None, other) => other.map(ToOwned::to_owned),
274 (other, None) => other,
275 (Some(language), Some(old_language)) => {
276 if language.as_str() == old_language {
277 Some(language)
278 } else {
279 self.language_error = true;
280 None
281 }
282 }
283 };
284 }
285
286 Ok(())
287 }
288}