polars_schema/
schema.rs

1use core::fmt::{Debug, Formatter};
2use core::hash::{Hash, Hasher};
3
4use indexmap::map::MutableKeys;
5use polars_error::{PolarsError, PolarsResult, polars_bail, polars_ensure, polars_err};
6use polars_utils::aliases::{InitHashMaps, PlIndexMap};
7use polars_utils::pl_str::PlSmallStr;
8
9#[derive(Clone, Default)]
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
12pub struct Schema<D> {
13    fields: PlIndexMap<PlSmallStr, D>,
14}
15
16impl<D: Eq> Eq for Schema<D> {}
17
18impl<D> Schema<D> {
19    pub fn with_capacity(capacity: usize) -> Self {
20        let fields = PlIndexMap::with_capacity(capacity);
21        Self { fields }
22    }
23
24    /// Reserve `additional` memory spaces in the schema.
25    pub fn reserve(&mut self, additional: usize) {
26        self.fields.reserve(additional);
27    }
28
29    /// The number of fields in the schema.
30    #[inline]
31    pub fn len(&self) -> usize {
32        self.fields.len()
33    }
34
35    #[inline]
36    pub fn is_empty(&self) -> bool {
37        self.fields.is_empty()
38    }
39
40    /// Rename field `old` to `new`, and return the (owned) old name.
41    ///
42    /// If `old` is not present in the schema, the schema is not modified and `None` is returned. Otherwise the schema
43    /// is updated and `Some(old_name)` is returned.
44    pub fn rename(&mut self, old: &str, new: PlSmallStr) -> Option<PlSmallStr> {
45        // Remove `old`, get the corresponding index and dtype, and move the last item in the map to that position
46        let (old_index, old_name, dtype) = self.fields.swap_remove_full(old)?;
47        // Insert the same dtype under the new name at the end of the map and store that index
48        let (new_index, _) = self.fields.insert_full(new, dtype);
49        // Swap the two indices to move the originally last element back to the end and to move the new element back to
50        // its original position
51        self.fields.swap_indices(old_index, new_index);
52
53        Some(old_name)
54    }
55
56    pub fn insert(&mut self, key: PlSmallStr, value: D) -> Option<D> {
57        self.fields.insert(key, value)
58    }
59
60    /// Insert a field with `name` and `dtype` at the given `index` into this schema.
61    ///
62    /// If a field named `name` already exists, it is updated with the new dtype. Regardless, the field named `name` is
63    /// always moved to the given index. Valid indices range from `0` (front of the schema) to `self.len()` (after the
64    /// end of the schema).
65    ///
66    /// For a non-mutating version that clones the schema, see [`new_inserting_at_index`][Self::new_inserting_at_index].
67    ///
68    /// Runtime: **O(n)** where `n` is the number of fields in the schema.
69    ///
70    /// Returns:
71    /// - If index is out of bounds, `Err(PolarsError)`
72    /// - Else if `name` was already in the schema, `Ok(Some(old_dtype))`
73    /// - Else `Ok(None)`
74    pub fn insert_at_index(
75        &mut self,
76        mut index: usize,
77        name: PlSmallStr,
78        dtype: D,
79    ) -> PolarsResult<Option<D>> {
80        polars_ensure!(
81            index <= self.len(),
82            OutOfBounds:
83                "index {} is out of bounds for schema with length {} (the max index allowed is self.len())",
84                    index,
85                    self.len()
86        );
87
88        let (old_index, old_dtype) = self.fields.insert_full(name, dtype);
89
90        // If we're moving an existing field, one-past-the-end will actually be out of bounds. Also, self.len() won't
91        // have changed after inserting, so `index == self.len()` is the same as it was before inserting.
92        if old_dtype.is_some() && index == self.len() {
93            index -= 1;
94        }
95        self.fields.move_index(old_index, index);
96        Ok(old_dtype)
97    }
98
99    /// Get a reference to the dtype of the field named `name`, or `None` if the field doesn't exist.
100    pub fn get(&self, name: &str) -> Option<&D> {
101        self.fields.get(name)
102    }
103
104    /// Get a mutable reference to the dtype of the field named `name`, or `None` if the field doesn't exist.
105    pub fn get_mut(&mut self, name: &str) -> Option<&mut D> {
106        self.fields.get_mut(name)
107    }
108
109    /// Get a reference to the dtype of the field named `name`, or `Err(PolarsErr)` if the field doesn't exist.
110    pub fn try_get(&self, name: &str) -> PolarsResult<&D> {
111        self.get(name)
112            .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name))
113    }
114
115    /// Get a mutable reference to the dtype of the field named `name`, or `Err(PolarsErr)` if the field doesn't exist.
116    pub fn try_get_mut(&mut self, name: &str) -> PolarsResult<&mut D> {
117        self.fields
118            .get_mut(name)
119            .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name))
120    }
121
122    /// Return all data about the field named `name`: its index in the schema, its name, and its dtype.
123    ///
124    /// Returns `Some((index, &name, &dtype))` if the field exists, `None` if it doesn't.
125    pub fn get_full(&self, name: &str) -> Option<(usize, &PlSmallStr, &D)> {
126        self.fields.get_full(name)
127    }
128
129    /// Return all data about the field named `name`: its index in the schema, its name, and its dtype.
130    ///
131    /// Returns `Ok((index, &name, &dtype))` if the field exists, `Err(PolarsErr)` if it doesn't.
132    pub fn try_get_full(&self, name: &str) -> PolarsResult<(usize, &PlSmallStr, &D)> {
133        self.fields
134            .get_full(name)
135            .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name))
136    }
137
138    /// Get references to the name and dtype of the field at `index`.
139    ///
140    /// If `index` is inbounds, returns `Some((&name, &dtype))`, else `None`. See
141    /// [`get_at_index_mut`][Self::get_at_index_mut] for a mutable version.
142    pub fn get_at_index(&self, index: usize) -> Option<(&PlSmallStr, &D)> {
143        self.fields.get_index(index)
144    }
145
146    pub fn try_get_at_index(&self, index: usize) -> PolarsResult<(&PlSmallStr, &D)> {
147        self.fields.get_index(index).ok_or_else(|| polars_err!(ComputeError: "index {index} out of bounds with 'schema' of len: {}", self.len()))
148    }
149
150    /// Get mutable references to the name and dtype of the field at `index`.
151    ///
152    /// If `index` is inbounds, returns `Some((&mut name, &mut dtype))`, else `None`. See
153    /// [`get_at_index`][Self::get_at_index] for an immutable version.
154    pub fn get_at_index_mut(&mut self, index: usize) -> Option<(&mut PlSmallStr, &mut D)> {
155        self.fields.get_index_mut2(index)
156    }
157
158    /// Swap-remove a field by name and, if the field existed, return its dtype.
159    ///
160    /// If the field does not exist, the schema is not modified and `None` is returned.
161    ///
162    /// This method does a `swap_remove`, which is O(1) but **changes the order of the schema**: the field named `name`
163    /// is replaced by the last field, which takes its position. For a slower, but order-preserving, method, use
164    /// [`shift_remove`][Self::shift_remove].
165    pub fn remove(&mut self, name: &str) -> Option<D> {
166        self.fields.swap_remove(name)
167    }
168
169    /// Remove a field by name, preserving order, and, if the field existed, return its dtype.
170    ///
171    /// If the field does not exist, the schema is not modified and `None` is returned.
172    ///
173    /// This method does a `shift_remove`, which preserves the order of the fields in the schema but **is O(n)**. For a
174    /// faster, but not order-preserving, method, use [`remove`][Self::remove].
175    pub fn shift_remove(&mut self, name: &str) -> Option<D> {
176        self.fields.shift_remove(name)
177    }
178
179    /// Remove a field by name, preserving order, and, if the field existed, return its dtype.
180    ///
181    /// If the field does not exist, the schema is not modified and `None` is returned.
182    ///
183    /// This method does a `shift_remove`, which preserves the order of the fields in the schema but **is O(n)**. For a
184    /// faster, but not order-preserving, method, use [`remove`][Self::remove].
185    pub fn shift_remove_index(&mut self, index: usize) -> Option<(PlSmallStr, D)> {
186        self.fields.shift_remove_index(index)
187    }
188
189    /// Whether the schema contains a field named `name`.
190    pub fn contains(&self, name: &str) -> bool {
191        self.get(name).is_some()
192    }
193
194    /// Change the field named `name` to the given `dtype` and return the previous dtype.
195    ///
196    /// If `name` doesn't already exist in the schema, the schema is not modified and `None` is returned. Otherwise
197    /// returns `Some(old_dtype)`.
198    ///
199    /// This method only ever modifies an existing field and never adds a new field to the schema. To add a new field,
200    /// use [`with_column`][Self::with_column] or [`insert_at_index`][Self::insert_at_index].
201    pub fn set_dtype(&mut self, name: &str, dtype: D) -> Option<D> {
202        let old_dtype = self.fields.get_mut(name)?;
203        Some(std::mem::replace(old_dtype, dtype))
204    }
205
206    /// Change the field at the given index to the given `dtype` and return the previous dtype.
207    ///
208    /// If the index is out of bounds, the schema is not modified and `None` is returned. Otherwise returns
209    /// `Some(old_dtype)`.
210    ///
211    /// This method only ever modifies an existing index and never adds a new field to the schema. To add a new field,
212    /// use [`with_column`][Self::with_column] or [`insert_at_index`][Self::insert_at_index].
213    pub fn set_dtype_at_index(&mut self, index: usize, dtype: D) -> Option<D> {
214        let (_, old_dtype) = self.fields.get_index_mut(index)?;
215        Some(std::mem::replace(old_dtype, dtype))
216    }
217
218    /// Insert a column into the [`Schema`].
219    ///
220    /// If the schema already has this column, this instead updates it with the new value and
221    /// returns the old one. Otherwise, the column is inserted at the end.
222    ///
223    /// To enforce the index of the resulting field, use [`insert_at_index`][Self::insert_at_index].
224    pub fn with_column(&mut self, name: PlSmallStr, dtype: D) -> Option<D> {
225        self.fields.insert(name, dtype)
226    }
227
228    /// Raises DuplicateError if this column already exists in the schema.
229    pub fn try_insert(&mut self, name: PlSmallStr, value: D) -> PolarsResult<()> {
230        if self.fields.contains_key(&name) {
231            polars_bail!(Duplicate: "column '{}' is duplicate", name)
232        }
233
234        self.fields.insert(name, value);
235
236        Ok(())
237    }
238
239    /// Performs [`Schema::try_insert`] for every column.
240    ///
241    /// Raises DuplicateError if a column already exists in the schema.
242    pub fn hstack_mut(
243        &mut self,
244        columns: impl IntoIterator<Item = impl Into<(PlSmallStr, D)>>,
245    ) -> PolarsResult<()> {
246        for v in columns {
247            let (k, v) = v.into();
248            self.try_insert(k, v)?;
249        }
250
251        Ok(())
252    }
253
254    /// Performs [`Schema::try_insert`] for every column.
255    ///
256    /// Raises DuplicateError if a column already exists in the schema.
257    pub fn hstack(
258        mut self,
259        columns: impl IntoIterator<Item = impl Into<(PlSmallStr, D)>>,
260    ) -> PolarsResult<Self> {
261        self.hstack_mut(columns)?;
262        Ok(self)
263    }
264
265    /// Merge `other` into `self`.
266    ///
267    /// Merging logic:
268    /// - Fields that occur in `self` but not `other` are unmodified
269    /// - Fields that occur in `other` but not `self` are appended, in order, to the end of `self`
270    /// - Fields that occur in both `self` and `other` are updated with the dtype from `other`, but keep their original
271    ///   index
272    pub fn merge(&mut self, other: Self) {
273        self.fields.extend(other.fields)
274    }
275
276    /// Iterates over the `(&name, &dtype)` pairs in this schema.
277    ///
278    /// For an owned version, use [`iter_fields`][Self::iter_fields], which clones the data to iterate owned `Field`s
279    pub fn iter(&self) -> impl ExactSizeIterator<Item = (&PlSmallStr, &D)> + '_ {
280        self.fields.iter()
281    }
282
283    pub fn iter_mut(&mut self) -> impl ExactSizeIterator<Item = (&PlSmallStr, &mut D)> + '_ {
284        self.fields.iter_mut()
285    }
286
287    /// Iterates over references to the names in this schema.
288    pub fn iter_names(&self) -> impl '_ + ExactSizeIterator<Item = &PlSmallStr> {
289        self.fields.iter().map(|(name, _dtype)| name)
290    }
291
292    pub fn iter_names_cloned(&self) -> impl '_ + ExactSizeIterator<Item = PlSmallStr> {
293        self.iter_names().cloned()
294    }
295
296    /// Iterates over references to the dtypes in this schema.
297    pub fn iter_values(&self) -> impl '_ + ExactSizeIterator<Item = &D> {
298        self.fields.iter().map(|(_name, dtype)| dtype)
299    }
300
301    pub fn into_iter_values(self) -> impl ExactSizeIterator<Item = D> {
302        self.fields.into_values()
303    }
304
305    /// Iterates over mut references to the dtypes in this schema.
306    pub fn iter_values_mut(&mut self) -> impl '_ + ExactSizeIterator<Item = &mut D> {
307        self.fields.iter_mut().map(|(_name, dtype)| dtype)
308    }
309
310    pub fn index_of(&self, name: &str) -> Option<usize> {
311        self.fields.get_index_of(name)
312    }
313
314    pub fn try_index_of(&self, name: &str) -> PolarsResult<usize> {
315        let Some(i) = self.fields.get_index_of(name) else {
316            polars_bail!(
317                ColumnNotFound:
318                "unable to find column {:?}; valid columns: {:?}",
319                name, self.iter_names().collect::<Vec<_>>(),
320            )
321        };
322
323        Ok(i)
324    }
325
326    /// Compare the fields between two schema returning the additional columns that each schema has.
327    pub fn field_compare<'a, 'b>(
328        &'a self,
329        other: &'b Self,
330        self_extra: &mut Vec<(usize, (&'a PlSmallStr, &'a D))>,
331        other_extra: &mut Vec<(usize, (&'b PlSmallStr, &'b D))>,
332    ) {
333        self_extra.extend(
334            self.iter()
335                .enumerate()
336                .filter(|(_, (n, _))| !other.contains(n)),
337        );
338        other_extra.extend(
339            other
340                .iter()
341                .enumerate()
342                .filter(|(_, (n, _))| !self.contains(n)),
343        );
344    }
345}
346
347impl<D> Schema<D>
348where
349    D: Clone + Default,
350{
351    /// Create a new schema from this one, inserting a field with `name` and `dtype` at the given `index`.
352    ///
353    /// If a field named `name` already exists, it is updated with the new dtype. Regardless, the field named `name` is
354    /// always moved to the given index. Valid indices range from `0` (front of the schema) to `self.len()` (after the
355    /// end of the schema).
356    ///
357    /// For a mutating version that doesn't clone, see [`insert_at_index`][Self::insert_at_index].
358    ///
359    /// Runtime: **O(m * n)** where `m` is the (average) length of the field names and `n` is the number of fields in
360    /// the schema. This method clones every field in the schema.
361    ///
362    /// Returns: `Ok(new_schema)` if `index <= self.len()`, else `Err(PolarsError)`
363    pub fn new_inserting_at_index(
364        &self,
365        index: usize,
366        name: PlSmallStr,
367        field: D,
368    ) -> PolarsResult<Self> {
369        polars_ensure!(
370            index <= self.len(),
371            OutOfBounds:
372                "index {} is out of bounds for schema with length {} (the max index allowed is self.len())",
373                    index,
374                    self.len()
375        );
376
377        let mut new = Self::default();
378        let mut iter = self.fields.iter().filter_map(|(fld_name, dtype)| {
379            (fld_name != &name).then_some((fld_name.clone(), dtype.clone()))
380        });
381        new.fields.extend(iter.by_ref().take(index));
382        new.fields.insert(name.clone(), field);
383        new.fields.extend(iter);
384        Ok(new)
385    }
386
387    /// Merge borrowed `other` into `self`.
388    ///
389    /// Merging logic:
390    /// - Fields that occur in `self` but not `other` are unmodified
391    /// - Fields that occur in `other` but not `self` are appended, in order, to the end of `self`
392    /// - Fields that occur in both `self` and `other` are updated with the dtype from `other`, but keep their original
393    ///   index
394    pub fn merge_from_ref(&mut self, other: &Self) {
395        self.fields.extend(
396            other
397                .iter()
398                .map(|(column, field)| (column.clone(), field.clone())),
399        )
400    }
401
402    /// Generates another schema with just the specified columns selected from this one.
403    pub fn try_project<I>(&self, columns: I) -> PolarsResult<Self>
404    where
405        I: IntoIterator,
406        I::Item: AsRef<str>,
407    {
408        let schema = columns
409            .into_iter()
410            .map(|c| {
411                let name = c.as_ref();
412                let (_, name, dtype) = self
413                    .fields
414                    .get_full(name)
415                    .ok_or_else(|| polars_err!(col_not_found = name))?;
416                PolarsResult::Ok((name.clone(), dtype.clone()))
417            })
418            .collect::<PolarsResult<PlIndexMap<PlSmallStr, _>>>()?;
419        Ok(Self::from(schema))
420    }
421
422    pub fn try_project_indices(&self, indices: &[usize]) -> PolarsResult<Self> {
423        let fields = indices
424            .iter()
425            .map(|&i| {
426                let Some((k, v)) = self.fields.get_index(i) else {
427                    polars_bail!(
428                        SchemaFieldNotFound:
429                        "projection index {} is out of bounds for schema of length {}",
430                        i, self.fields.len()
431                    );
432                };
433
434                Ok((k.clone(), v.clone()))
435            })
436            .collect::<PolarsResult<PlIndexMap<_, _>>>()?;
437
438        Ok(Self { fields })
439    }
440
441    /// Returns a new [`Schema`] with a subset of all fields whose `predicate`
442    /// evaluates to true.
443    pub fn filter<F: Fn(usize, &D) -> bool>(self, predicate: F) -> Self {
444        let fields = self
445            .fields
446            .into_iter()
447            .enumerate()
448            .filter_map(|(index, (name, d))| {
449                if (predicate)(index, &d) {
450                    Some((name, d))
451                } else {
452                    None
453                }
454            })
455            .collect();
456
457        Self { fields }
458    }
459
460    pub fn from_iter_check_duplicates<I, F>(iter: I) -> PolarsResult<Self>
461    where
462        I: IntoIterator<Item = F>,
463        F: Into<(PlSmallStr, D)>,
464    {
465        let iter = iter.into_iter();
466        let mut slf = Self::with_capacity(iter.size_hint().1.unwrap_or(0));
467
468        for v in iter {
469            let (name, d) = v.into();
470
471            if slf.contains(&name) {
472                return Err(err_msg(&name));
473
474                fn err_msg(name: &str) -> PolarsError {
475                    polars_err!(Duplicate: "duplicate name when building schema '{}'", &name)
476                }
477            }
478
479            slf.fields.insert(name, d);
480        }
481
482        Ok(slf)
483    }
484}
485
486pub fn ensure_matching_schema_names<D>(lhs: &Schema<D>, rhs: &Schema<D>) -> PolarsResult<()> {
487    let lhs_names = lhs.iter_names();
488    let rhs_names = rhs.iter_names();
489
490    if !(lhs_names.len() == rhs_names.len() && lhs_names.zip(rhs_names).all(|(l, r)| l == r)) {
491        polars_bail!(
492            SchemaMismatch:
493            "lhs: {:?} rhs: {:?}",
494            lhs.iter_names().collect::<Vec<_>>(), rhs.iter_names().collect::<Vec<_>>()
495        )
496    }
497
498    Ok(())
499}
500
501impl<D: Debug> Debug for Schema<D> {
502    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
503        writeln!(f, "Schema:")?;
504        for (name, field) in self.fields.iter() {
505            writeln!(f, "name: {name}, field: {field:?}")?;
506        }
507        Ok(())
508    }
509}
510
511impl<D: Hash> Hash for Schema<D> {
512    fn hash<H: Hasher>(&self, state: &mut H) {
513        self.fields.iter().for_each(|v| v.hash(state))
514    }
515}
516
517// Schemas will only compare equal if they have the same fields in the same order. We can't use `self.inner ==
518// other.inner` because [`IndexMap`] ignores order when checking equality, but we don't want to ignore it.
519impl<D: PartialEq> PartialEq for Schema<D> {
520    fn eq(&self, other: &Self) -> bool {
521        self.fields.len() == other.fields.len()
522            && self
523                .fields
524                .iter()
525                .zip(other.fields.iter())
526                .all(|(a, b)| a == b)
527    }
528}
529
530impl<D> From<PlIndexMap<PlSmallStr, D>> for Schema<D> {
531    fn from(fields: PlIndexMap<PlSmallStr, D>) -> Self {
532        Self { fields }
533    }
534}
535
536impl<F, D> FromIterator<F> for Schema<D>
537where
538    F: Into<(PlSmallStr, D)>,
539{
540    fn from_iter<I: IntoIterator<Item = F>>(iter: I) -> Self {
541        let fields = PlIndexMap::from_iter(iter.into_iter().map(|x| x.into()));
542        Self { fields }
543    }
544}
545
546impl<F, D> Extend<F> for Schema<D>
547where
548    F: Into<(PlSmallStr, D)>,
549{
550    fn extend<T: IntoIterator<Item = F>>(&mut self, iter: T) {
551        self.fields.extend(iter.into_iter().map(|x| x.into()))
552    }
553}
554
555impl<D> IntoIterator for Schema<D> {
556    type IntoIter = <PlIndexMap<PlSmallStr, D> as IntoIterator>::IntoIter;
557    type Item = (PlSmallStr, D);
558
559    fn into_iter(self) -> Self::IntoIter {
560        self.fields.into_iter()
561    }
562}