polars_plan/dsl/function_expr/
struct_.rs

1use polars_core::utils::slice_offsets;
2use polars_utils::format_pl_smallstr;
3
4use super::*;
5use crate::{map, map_as_slice};
6
7#[derive(Clone, Eq, PartialEq, Hash, Debug)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9pub enum StructFunction {
10    FieldByIndex(i64),
11    FieldByName(PlSmallStr),
12    RenameFields(Arc<[PlSmallStr]>),
13    PrefixFields(PlSmallStr),
14    SuffixFields(PlSmallStr),
15    #[cfg(feature = "json")]
16    JsonEncode,
17    WithFields,
18    MultipleFields(Arc<[PlSmallStr]>),
19}
20
21impl StructFunction {
22    pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
23        use StructFunction::*;
24
25        match self {
26            FieldByIndex(index) => mapper.try_map_field(|field| {
27                let (index, _) = slice_offsets(*index, 0, mapper.get_fields_lens());
28                if let DataType::Struct(ref fields) = field.dtype {
29                    fields.get(index).cloned().ok_or_else(
30                        || polars_err!(ComputeError: "index out of bounds in `struct.field`"),
31                    )
32                } else {
33                    polars_bail!(
34                        ComputeError: "expected struct dtype, got: `{}`", &field.dtype
35                    )
36                }
37            }),
38            FieldByName(name) => mapper.try_map_field(|field| {
39                if let DataType::Struct(ref fields) = field.dtype {
40                    let fld = fields
41                        .iter()
42                        .find(|fld| fld.name() == name)
43                        .ok_or_else(|| polars_err!(StructFieldNotFound: "{}", name))?;
44                    Ok(fld.clone())
45                } else {
46                    polars_bail!(StructFieldNotFound: "{}", name);
47                }
48            }),
49            RenameFields(names) => mapper.map_dtype(|dt| match dt {
50                DataType::Struct(fields) => {
51                    let fields = fields
52                        .iter()
53                        .zip(names.as_ref())
54                        .map(|(fld, name)| Field::new(name.clone(), fld.dtype().clone()))
55                        .collect();
56                    DataType::Struct(fields)
57                },
58                // The types will be incorrect, but its better than nothing
59                // we can get an incorrect type with python lambdas, because we only know return type when running
60                // the query
61                dt => DataType::Struct(
62                    names
63                        .iter()
64                        .map(|name| Field::new(name.clone(), dt.clone()))
65                        .collect(),
66                ),
67            }),
68            PrefixFields(prefix) => mapper.try_map_dtype(|dt| match dt {
69                DataType::Struct(fields) => {
70                    let fields = fields
71                        .iter()
72                        .map(|fld| {
73                            let name = fld.name();
74                            Field::new(format_pl_smallstr!("{prefix}{name}"), fld.dtype().clone())
75                        })
76                        .collect();
77                    Ok(DataType::Struct(fields))
78                },
79                _ => polars_bail!(op = "prefix_fields", got = dt, expected = "Struct"),
80            }),
81            SuffixFields(suffix) => mapper.try_map_dtype(|dt| match dt {
82                DataType::Struct(fields) => {
83                    let fields = fields
84                        .iter()
85                        .map(|fld| {
86                            let name = fld.name();
87                            Field::new(format_pl_smallstr!("{name}{suffix}"), fld.dtype().clone())
88                        })
89                        .collect();
90                    Ok(DataType::Struct(fields))
91                },
92                _ => polars_bail!(op = "suffix_fields", got = dt, expected = "Struct"),
93            }),
94            #[cfg(feature = "json")]
95            JsonEncode => mapper.with_dtype(DataType::String),
96            WithFields => {
97                let args = mapper.args();
98                let struct_ = &args[0];
99
100                if let DataType::Struct(fields) = struct_.dtype() {
101                    let mut name_2_dtype = PlIndexMap::with_capacity(fields.len() * 2);
102
103                    for field in fields {
104                        name_2_dtype.insert(field.name(), field.dtype());
105                    }
106                    for arg in &args[1..] {
107                        name_2_dtype.insert(arg.name(), arg.dtype());
108                    }
109                    let dtype = DataType::Struct(
110                        name_2_dtype
111                            .iter()
112                            .map(|(&name, &dtype)| Field::new(name.clone(), dtype.clone()))
113                            .collect(),
114                    );
115                    let mut out = struct_.clone();
116                    out.coerce(dtype);
117                    Ok(out)
118                } else {
119                    let dt = struct_.dtype();
120                    polars_bail!(op = "with_fields", got = dt, expected = "Struct")
121                }
122            },
123            MultipleFields(_) => panic!("should be expanded"),
124        }
125    }
126
127    pub fn function_options(&self) -> FunctionOptions {
128        use StructFunction as S;
129        match self {
130            S::FieldByIndex(_) | S::FieldByName(_) => {
131                FunctionOptions::elementwise().with_flags(|f| f | FunctionFlags::ALLOW_RENAME)
132            },
133            S::RenameFields(_) | S::PrefixFields(_) | S::SuffixFields(_) => {
134                FunctionOptions::elementwise()
135            },
136            #[cfg(feature = "json")]
137            S::JsonEncode => FunctionOptions::elementwise(),
138            S::WithFields => FunctionOptions::elementwise().with_flags(|f| {
139                f | FunctionFlags::INPUT_WILDCARD_EXPANSION | FunctionFlags::PASS_NAME_TO_APPLY
140            }),
141            S::MultipleFields(_) => {
142                FunctionOptions::elementwise().with_flags(|f| f | FunctionFlags::ALLOW_RENAME)
143            },
144        }
145    }
146}
147
148impl Display for StructFunction {
149    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
150        use StructFunction::*;
151        match self {
152            FieldByIndex(index) => write!(f, "struct.field_by_index({index})"),
153            FieldByName(name) => write!(f, "struct.field_by_name({name})"),
154            RenameFields(names) => write!(f, "struct.rename_fields({:?})", names),
155            PrefixFields(_) => write!(f, "name.prefix_fields"),
156            SuffixFields(_) => write!(f, "name.suffixFields"),
157            #[cfg(feature = "json")]
158            JsonEncode => write!(f, "struct.to_json"),
159            WithFields => write!(f, "with_fields"),
160            MultipleFields(_) => write!(f, "multiple_fields"),
161        }
162    }
163}
164
165impl From<StructFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
166    fn from(func: StructFunction) -> Self {
167        use StructFunction::*;
168        match func {
169            FieldByIndex(_) => panic!("should be replaced"),
170            FieldByName(name) => map!(get_by_name, &name),
171            RenameFields(names) => map!(rename_fields, names.clone()),
172            PrefixFields(prefix) => map!(prefix_fields, prefix.as_str()),
173            SuffixFields(suffix) => map!(suffix_fields, suffix.as_str()),
174            #[cfg(feature = "json")]
175            JsonEncode => map!(to_json),
176            WithFields => map_as_slice!(with_fields),
177            MultipleFields(_) => unimplemented!(),
178        }
179    }
180}
181
182pub(super) fn get_by_name(s: &Column, name: &str) -> PolarsResult<Column> {
183    let ca = s.struct_()?;
184    ca.field_by_name(name).map(Column::from)
185}
186
187pub(super) fn rename_fields(s: &Column, names: Arc<[PlSmallStr]>) -> PolarsResult<Column> {
188    let ca = s.struct_()?;
189    let fields = ca
190        .fields_as_series()
191        .iter()
192        .zip(names.as_ref())
193        .map(|(s, name)| {
194            let mut s = s.clone();
195            s.rename(name.clone());
196            s
197        })
198        .collect::<Vec<_>>();
199    let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())?;
200    out.zip_outer_validity(ca);
201    Ok(out.into_column())
202}
203
204pub(super) fn prefix_fields(s: &Column, prefix: &str) -> PolarsResult<Column> {
205    let ca = s.struct_()?;
206    let fields = ca
207        .fields_as_series()
208        .iter()
209        .map(|s| {
210            let mut s = s.clone();
211            let name = s.name();
212            s.rename(format_pl_smallstr!("{prefix}{name}"));
213            s
214        })
215        .collect::<Vec<_>>();
216    let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())?;
217    out.zip_outer_validity(ca);
218    Ok(out.into_column())
219}
220
221pub(super) fn suffix_fields(s: &Column, suffix: &str) -> PolarsResult<Column> {
222    let ca = s.struct_()?;
223    let fields = ca
224        .fields_as_series()
225        .iter()
226        .map(|s| {
227            let mut s = s.clone();
228            let name = s.name();
229            s.rename(format_pl_smallstr!("{name}{suffix}"));
230            s
231        })
232        .collect::<Vec<_>>();
233    let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())?;
234    out.zip_outer_validity(ca);
235    Ok(out.into_column())
236}
237
238#[cfg(feature = "json")]
239pub(super) fn to_json(s: &Column) -> PolarsResult<Column> {
240    let ca = s.struct_()?;
241    let dtype = ca.dtype().to_arrow(CompatLevel::newest());
242
243    let iter = ca.chunks().iter().map(|arr| {
244        let arr = polars_compute::cast::cast_unchecked(arr.as_ref(), &dtype).unwrap();
245        polars_json::json::write::serialize_to_utf8(arr.as_ref())
246    });
247
248    Ok(StringChunked::from_chunk_iter(ca.name().clone(), iter).into_column())
249}
250
251pub(super) fn with_fields(args: &[Column]) -> PolarsResult<Column> {
252    let s = &args[0];
253
254    let ca = s.struct_()?;
255    let current = ca.fields_as_series();
256
257    let mut fields = PlIndexMap::with_capacity(current.len() + s.len() - 1);
258
259    for field in current.iter() {
260        fields.insert(field.name(), field);
261    }
262
263    for field in &args[1..] {
264        fields.insert(field.name(), field.as_materialized_series());
265    }
266
267    let new_fields = fields.into_values().cloned().collect::<Vec<_>>();
268    let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), new_fields.iter())?;
269    out.zip_outer_validity(ca);
270    Ok(out.into_column())
271}