polars_plan/dsl/function_expr/
struct_.rs

1use polars_core::utils::slice_offsets;
2use polars_utils::format_pl_smallstr;
3
4use super::*;
5use crate::{map, map_as_slice};
6
7#[derive(Clone, Eq, PartialEq, Hash, Debug)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9pub enum StructFunction {
10    FieldByIndex(i64),
11    FieldByName(PlSmallStr),
12    RenameFields(Arc<[PlSmallStr]>),
13    PrefixFields(PlSmallStr),
14    SuffixFields(PlSmallStr),
15    #[cfg(feature = "json")]
16    JsonEncode,
17    WithFields,
18    MultipleFields(Arc<[PlSmallStr]>),
19}
20
21impl StructFunction {
22    pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
23        use StructFunction::*;
24
25        match self {
26            FieldByIndex(index) => mapper.try_map_field(|field| {
27                let (index, _) = slice_offsets(*index, 0, mapper.get_fields_lens());
28                if let DataType::Struct(ref fields) = field.dtype {
29                    fields.get(index).cloned().ok_or_else(
30                        || polars_err!(ComputeError: "index out of bounds in `struct.field`"),
31                    )
32                } else {
33                    polars_bail!(
34                        ComputeError: "expected struct dtype, got: `{}`", &field.dtype
35                    )
36                }
37            }),
38            FieldByName(name) => mapper.try_map_field(|field| {
39                if let DataType::Struct(ref fields) = field.dtype {
40                    let fld = fields
41                        .iter()
42                        .find(|fld| fld.name() == name)
43                        .ok_or_else(|| polars_err!(StructFieldNotFound: "{}", name))?;
44                    Ok(fld.clone())
45                } else {
46                    polars_bail!(StructFieldNotFound: "{}", name);
47                }
48            }),
49            RenameFields(names) => mapper.map_dtype(|dt| match dt {
50                DataType::Struct(fields) => {
51                    let fields = fields
52                        .iter()
53                        .zip(names.as_ref())
54                        .map(|(fld, name)| Field::new(name.clone(), fld.dtype().clone()))
55                        .collect();
56                    DataType::Struct(fields)
57                },
58                // The types will be incorrect, but its better than nothing
59                // we can get an incorrect type with python lambdas, because we only know return type when running
60                // the query
61                dt => DataType::Struct(
62                    names
63                        .iter()
64                        .map(|name| Field::new(name.clone(), dt.clone()))
65                        .collect(),
66                ),
67            }),
68            PrefixFields(prefix) => mapper.try_map_dtype(|dt| match dt {
69                DataType::Struct(fields) => {
70                    let fields = fields
71                        .iter()
72                        .map(|fld| {
73                            let name = fld.name();
74                            Field::new(format_pl_smallstr!("{prefix}{name}"), fld.dtype().clone())
75                        })
76                        .collect();
77                    Ok(DataType::Struct(fields))
78                },
79                _ => polars_bail!(op = "prefix_fields", got = dt, expected = "Struct"),
80            }),
81            SuffixFields(suffix) => mapper.try_map_dtype(|dt| match dt {
82                DataType::Struct(fields) => {
83                    let fields = fields
84                        .iter()
85                        .map(|fld| {
86                            let name = fld.name();
87                            Field::new(format_pl_smallstr!("{name}{suffix}"), fld.dtype().clone())
88                        })
89                        .collect();
90                    Ok(DataType::Struct(fields))
91                },
92                _ => polars_bail!(op = "suffix_fields", got = dt, expected = "Struct"),
93            }),
94            #[cfg(feature = "json")]
95            JsonEncode => mapper.with_dtype(DataType::String),
96            WithFields => {
97                let args = mapper.args();
98                let struct_ = &args[0];
99
100                if let DataType::Struct(fields) = struct_.dtype() {
101                    let mut name_2_dtype = PlIndexMap::with_capacity(fields.len() * 2);
102
103                    for field in fields {
104                        name_2_dtype.insert(field.name(), field.dtype());
105                    }
106                    for arg in &args[1..] {
107                        name_2_dtype.insert(arg.name(), arg.dtype());
108                    }
109                    let dtype = DataType::Struct(
110                        name_2_dtype
111                            .iter()
112                            .map(|(&name, &dtype)| Field::new(name.clone(), dtype.clone()))
113                            .collect(),
114                    );
115                    let mut out = struct_.clone();
116                    out.coerce(dtype);
117                    Ok(out)
118                } else {
119                    let dt = struct_.dtype();
120                    polars_bail!(op = "with_fields", got = dt, expected = "Struct")
121                }
122            },
123            MultipleFields(_) => panic!("should be expanded"),
124        }
125    }
126}
127
128impl Display for StructFunction {
129    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
130        use StructFunction::*;
131        match self {
132            FieldByIndex(index) => write!(f, "struct.field_by_index({index})"),
133            FieldByName(name) => write!(f, "struct.field_by_name({name})"),
134            RenameFields(names) => write!(f, "struct.rename_fields({:?})", names),
135            PrefixFields(_) => write!(f, "name.prefix_fields"),
136            SuffixFields(_) => write!(f, "name.suffixFields"),
137            #[cfg(feature = "json")]
138            JsonEncode => write!(f, "struct.to_json"),
139            WithFields => write!(f, "with_fields"),
140            MultipleFields(_) => write!(f, "multiple_fields"),
141        }
142    }
143}
144
145impl From<StructFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
146    fn from(func: StructFunction) -> Self {
147        use StructFunction::*;
148        match func {
149            FieldByIndex(_) => panic!("should be replaced"),
150            FieldByName(name) => map!(get_by_name, &name),
151            RenameFields(names) => map!(rename_fields, names.clone()),
152            PrefixFields(prefix) => map!(prefix_fields, prefix.as_str()),
153            SuffixFields(suffix) => map!(suffix_fields, suffix.as_str()),
154            #[cfg(feature = "json")]
155            JsonEncode => map!(to_json),
156            WithFields => map_as_slice!(with_fields),
157            MultipleFields(_) => unimplemented!(),
158        }
159    }
160}
161
162pub(super) fn get_by_name(s: &Column, name: &str) -> PolarsResult<Column> {
163    let ca = s.struct_()?;
164    ca.field_by_name(name).map(Column::from)
165}
166
167pub(super) fn rename_fields(s: &Column, names: Arc<[PlSmallStr]>) -> PolarsResult<Column> {
168    let ca = s.struct_()?;
169    let fields = ca
170        .fields_as_series()
171        .iter()
172        .zip(names.as_ref())
173        .map(|(s, name)| {
174            let mut s = s.clone();
175            s.rename(name.clone());
176            s
177        })
178        .collect::<Vec<_>>();
179    let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())?;
180    out.zip_outer_validity(ca);
181    Ok(out.into_column())
182}
183
184pub(super) fn prefix_fields(s: &Column, prefix: &str) -> PolarsResult<Column> {
185    let ca = s.struct_()?;
186    let fields = ca
187        .fields_as_series()
188        .iter()
189        .map(|s| {
190            let mut s = s.clone();
191            let name = s.name();
192            s.rename(format_pl_smallstr!("{prefix}{name}"));
193            s
194        })
195        .collect::<Vec<_>>();
196    let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())?;
197    out.zip_outer_validity(ca);
198    Ok(out.into_column())
199}
200
201pub(super) fn suffix_fields(s: &Column, suffix: &str) -> PolarsResult<Column> {
202    let ca = s.struct_()?;
203    let fields = ca
204        .fields_as_series()
205        .iter()
206        .map(|s| {
207            let mut s = s.clone();
208            let name = s.name();
209            s.rename(format_pl_smallstr!("{name}{suffix}"));
210            s
211        })
212        .collect::<Vec<_>>();
213    let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())?;
214    out.zip_outer_validity(ca);
215    Ok(out.into_column())
216}
217
218#[cfg(feature = "json")]
219pub(super) fn to_json(s: &Column) -> PolarsResult<Column> {
220    let ca = s.struct_()?;
221    let dtype = ca.dtype().to_arrow(CompatLevel::newest());
222
223    let iter = ca.chunks().iter().map(|arr| {
224        let arr = polars_compute::cast::cast_unchecked(arr.as_ref(), &dtype).unwrap();
225        polars_json::json::write::serialize_to_utf8(arr.as_ref())
226    });
227
228    Ok(StringChunked::from_chunk_iter(ca.name().clone(), iter).into_column())
229}
230
231pub(super) fn with_fields(args: &[Column]) -> PolarsResult<Column> {
232    let s = &args[0];
233
234    let ca = s.struct_()?;
235    let current = ca.fields_as_series();
236
237    let mut fields = PlIndexMap::with_capacity(current.len() + s.len() - 1);
238
239    for field in current.iter() {
240        fields.insert(field.name(), field);
241    }
242
243    for field in &args[1..] {
244        fields.insert(field.name(), field.as_materialized_series());
245    }
246
247    let new_fields = fields.into_values().cloned().collect::<Vec<_>>();
248    let mut out =
249        StructChunked::from_series(ca.name().clone(), new_fields[0].len(), new_fields.iter())?;
250    out.zip_outer_validity(ca);
251    Ok(out.into_column())
252}