polars_plan/dsl/function_expr/
struct_.rs1use polars_core::utils::slice_offsets;
2use polars_utils::format_pl_smallstr;
3
4use super::*;
5use crate::{map, map_as_slice};
6
7#[derive(Clone, Eq, PartialEq, Hash, Debug)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9pub enum StructFunction {
10 FieldByIndex(i64),
11 FieldByName(PlSmallStr),
12 RenameFields(Arc<[PlSmallStr]>),
13 PrefixFields(PlSmallStr),
14 SuffixFields(PlSmallStr),
15 #[cfg(feature = "json")]
16 JsonEncode,
17 WithFields,
18 MultipleFields(Arc<[PlSmallStr]>),
19}
20
21impl StructFunction {
22 pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
23 use StructFunction::*;
24
25 match self {
26 FieldByIndex(index) => mapper.try_map_field(|field| {
27 let (index, _) = slice_offsets(*index, 0, mapper.get_fields_lens());
28 if let DataType::Struct(ref fields) = field.dtype {
29 fields.get(index).cloned().ok_or_else(
30 || polars_err!(ComputeError: "index out of bounds in `struct.field`"),
31 )
32 } else {
33 polars_bail!(
34 ComputeError: "expected struct dtype, got: `{}`", &field.dtype
35 )
36 }
37 }),
38 FieldByName(name) => mapper.try_map_field(|field| {
39 if let DataType::Struct(ref fields) = field.dtype {
40 let fld = fields
41 .iter()
42 .find(|fld| fld.name() == name)
43 .ok_or_else(|| polars_err!(StructFieldNotFound: "{}", name))?;
44 Ok(fld.clone())
45 } else {
46 polars_bail!(StructFieldNotFound: "{}", name);
47 }
48 }),
49 RenameFields(names) => mapper.map_dtype(|dt| match dt {
50 DataType::Struct(fields) => {
51 let fields = fields
52 .iter()
53 .zip(names.as_ref())
54 .map(|(fld, name)| Field::new(name.clone(), fld.dtype().clone()))
55 .collect();
56 DataType::Struct(fields)
57 },
58 dt => DataType::Struct(
62 names
63 .iter()
64 .map(|name| Field::new(name.clone(), dt.clone()))
65 .collect(),
66 ),
67 }),
68 PrefixFields(prefix) => mapper.try_map_dtype(|dt| match dt {
69 DataType::Struct(fields) => {
70 let fields = fields
71 .iter()
72 .map(|fld| {
73 let name = fld.name();
74 Field::new(format_pl_smallstr!("{prefix}{name}"), fld.dtype().clone())
75 })
76 .collect();
77 Ok(DataType::Struct(fields))
78 },
79 _ => polars_bail!(op = "prefix_fields", got = dt, expected = "Struct"),
80 }),
81 SuffixFields(suffix) => mapper.try_map_dtype(|dt| match dt {
82 DataType::Struct(fields) => {
83 let fields = fields
84 .iter()
85 .map(|fld| {
86 let name = fld.name();
87 Field::new(format_pl_smallstr!("{name}{suffix}"), fld.dtype().clone())
88 })
89 .collect();
90 Ok(DataType::Struct(fields))
91 },
92 _ => polars_bail!(op = "suffix_fields", got = dt, expected = "Struct"),
93 }),
94 #[cfg(feature = "json")]
95 JsonEncode => mapper.with_dtype(DataType::String),
96 WithFields => {
97 let args = mapper.args();
98 let struct_ = &args[0];
99
100 if let DataType::Struct(fields) = struct_.dtype() {
101 let mut name_2_dtype = PlIndexMap::with_capacity(fields.len() * 2);
102
103 for field in fields {
104 name_2_dtype.insert(field.name(), field.dtype());
105 }
106 for arg in &args[1..] {
107 name_2_dtype.insert(arg.name(), arg.dtype());
108 }
109 let dtype = DataType::Struct(
110 name_2_dtype
111 .iter()
112 .map(|(&name, &dtype)| Field::new(name.clone(), dtype.clone()))
113 .collect(),
114 );
115 let mut out = struct_.clone();
116 out.coerce(dtype);
117 Ok(out)
118 } else {
119 let dt = struct_.dtype();
120 polars_bail!(op = "with_fields", got = dt, expected = "Struct")
121 }
122 },
123 MultipleFields(_) => panic!("should be expanded"),
124 }
125 }
126
127 pub fn function_options(&self) -> FunctionOptions {
128 use StructFunction as S;
129 match self {
130 S::FieldByIndex(_) | S::FieldByName(_) => {
131 FunctionOptions::elementwise().with_flags(|f| f | FunctionFlags::ALLOW_RENAME)
132 },
133 S::RenameFields(_) | S::PrefixFields(_) | S::SuffixFields(_) => {
134 FunctionOptions::elementwise()
135 },
136 #[cfg(feature = "json")]
137 S::JsonEncode => FunctionOptions::elementwise(),
138 S::WithFields => FunctionOptions::elementwise().with_flags(|f| {
139 f | FunctionFlags::INPUT_WILDCARD_EXPANSION | FunctionFlags::PASS_NAME_TO_APPLY
140 }),
141 S::MultipleFields(_) => {
142 FunctionOptions::elementwise().with_flags(|f| f | FunctionFlags::ALLOW_RENAME)
143 },
144 }
145 }
146}
147
148impl Display for StructFunction {
149 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
150 use StructFunction::*;
151 match self {
152 FieldByIndex(index) => write!(f, "struct.field_by_index({index})"),
153 FieldByName(name) => write!(f, "struct.field_by_name({name})"),
154 RenameFields(names) => write!(f, "struct.rename_fields({:?})", names),
155 PrefixFields(_) => write!(f, "name.prefix_fields"),
156 SuffixFields(_) => write!(f, "name.suffixFields"),
157 #[cfg(feature = "json")]
158 JsonEncode => write!(f, "struct.to_json"),
159 WithFields => write!(f, "with_fields"),
160 MultipleFields(_) => write!(f, "multiple_fields"),
161 }
162 }
163}
164
165impl From<StructFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
166 fn from(func: StructFunction) -> Self {
167 use StructFunction::*;
168 match func {
169 FieldByIndex(_) => panic!("should be replaced"),
170 FieldByName(name) => map!(get_by_name, &name),
171 RenameFields(names) => map!(rename_fields, names.clone()),
172 PrefixFields(prefix) => map!(prefix_fields, prefix.as_str()),
173 SuffixFields(suffix) => map!(suffix_fields, suffix.as_str()),
174 #[cfg(feature = "json")]
175 JsonEncode => map!(to_json),
176 WithFields => map_as_slice!(with_fields),
177 MultipleFields(_) => unimplemented!(),
178 }
179 }
180}
181
182pub(super) fn get_by_name(s: &Column, name: &str) -> PolarsResult<Column> {
183 let ca = s.struct_()?;
184 ca.field_by_name(name).map(Column::from)
185}
186
187pub(super) fn rename_fields(s: &Column, names: Arc<[PlSmallStr]>) -> PolarsResult<Column> {
188 let ca = s.struct_()?;
189 let fields = ca
190 .fields_as_series()
191 .iter()
192 .zip(names.as_ref())
193 .map(|(s, name)| {
194 let mut s = s.clone();
195 s.rename(name.clone());
196 s
197 })
198 .collect::<Vec<_>>();
199 let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())?;
200 out.zip_outer_validity(ca);
201 Ok(out.into_column())
202}
203
204pub(super) fn prefix_fields(s: &Column, prefix: &str) -> PolarsResult<Column> {
205 let ca = s.struct_()?;
206 let fields = ca
207 .fields_as_series()
208 .iter()
209 .map(|s| {
210 let mut s = s.clone();
211 let name = s.name();
212 s.rename(format_pl_smallstr!("{prefix}{name}"));
213 s
214 })
215 .collect::<Vec<_>>();
216 let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())?;
217 out.zip_outer_validity(ca);
218 Ok(out.into_column())
219}
220
221pub(super) fn suffix_fields(s: &Column, suffix: &str) -> PolarsResult<Column> {
222 let ca = s.struct_()?;
223 let fields = ca
224 .fields_as_series()
225 .iter()
226 .map(|s| {
227 let mut s = s.clone();
228 let name = s.name();
229 s.rename(format_pl_smallstr!("{name}{suffix}"));
230 s
231 })
232 .collect::<Vec<_>>();
233 let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())?;
234 out.zip_outer_validity(ca);
235 Ok(out.into_column())
236}
237
238#[cfg(feature = "json")]
239pub(super) fn to_json(s: &Column) -> PolarsResult<Column> {
240 let ca = s.struct_()?;
241 let dtype = ca.dtype().to_arrow(CompatLevel::newest());
242
243 let iter = ca.chunks().iter().map(|arr| {
244 let arr = polars_compute::cast::cast_unchecked(arr.as_ref(), &dtype).unwrap();
245 polars_json::json::write::serialize_to_utf8(arr.as_ref())
246 });
247
248 Ok(StringChunked::from_chunk_iter(ca.name().clone(), iter).into_column())
249}
250
251pub(super) fn with_fields(args: &[Column]) -> PolarsResult<Column> {
252 let s = &args[0];
253
254 let ca = s.struct_()?;
255 let current = ca.fields_as_series();
256
257 let mut fields = PlIndexMap::with_capacity(current.len() + s.len() - 1);
258
259 for field in current.iter() {
260 fields.insert(field.name(), field);
261 }
262
263 for field in &args[1..] {
264 fields.insert(field.name(), field.as_materialized_series());
265 }
266
267 let new_fields = fields.into_values().cloned().collect::<Vec<_>>();
268 let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), new_fields.iter())?;
269 out.zip_outer_validity(ca);
270 Ok(out.into_column())
271}