1use core::fmt::{Debug, Formatter};
2use core::hash::{Hash, Hasher};
3
4use indexmap::map::MutableKeys;
5use polars_error::{PolarsError, PolarsResult, polars_bail, polars_ensure, polars_err};
6use polars_utils::aliases::{InitHashMaps, PlIndexMap};
7use polars_utils::pl_str::PlSmallStr;
8
9#[derive(Clone, Default)]
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
12pub struct Schema<D> {
13 fields: PlIndexMap<PlSmallStr, D>,
14}
15
16impl<D: Eq> Eq for Schema<D> {}
17
18impl<D> Schema<D> {
19 pub fn with_capacity(capacity: usize) -> Self {
20 let fields = PlIndexMap::with_capacity(capacity);
21 Self { fields }
22 }
23
24 pub fn reserve(&mut self, additional: usize) {
26 self.fields.reserve(additional);
27 }
28
29 #[inline]
31 pub fn len(&self) -> usize {
32 self.fields.len()
33 }
34
35 #[inline]
36 pub fn is_empty(&self) -> bool {
37 self.fields.is_empty()
38 }
39
40 pub fn rename(&mut self, old: &str, new: PlSmallStr) -> Option<PlSmallStr> {
45 let (old_index, old_name, dtype) = self.fields.swap_remove_full(old)?;
47 let (new_index, _) = self.fields.insert_full(new, dtype);
49 self.fields.swap_indices(old_index, new_index);
52
53 Some(old_name)
54 }
55
56 pub fn insert(&mut self, key: PlSmallStr, value: D) -> Option<D> {
57 self.fields.insert(key, value)
58 }
59
60 pub fn insert_at_index(
75 &mut self,
76 mut index: usize,
77 name: PlSmallStr,
78 dtype: D,
79 ) -> PolarsResult<Option<D>> {
80 polars_ensure!(
81 index <= self.len(),
82 OutOfBounds:
83 "index {} is out of bounds for schema with length {} (the max index allowed is self.len())",
84 index,
85 self.len()
86 );
87
88 let (old_index, old_dtype) = self.fields.insert_full(name, dtype);
89
90 if old_dtype.is_some() && index == self.len() {
93 index -= 1;
94 }
95 self.fields.move_index(old_index, index);
96 Ok(old_dtype)
97 }
98
99 pub fn get(&self, name: &str) -> Option<&D> {
101 self.fields.get(name)
102 }
103
104 pub fn get_mut(&mut self, name: &str) -> Option<&mut D> {
106 self.fields.get_mut(name)
107 }
108
109 pub fn try_get(&self, name: &str) -> PolarsResult<&D> {
111 self.get(name)
112 .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name))
113 }
114
115 pub fn try_get_mut(&mut self, name: &str) -> PolarsResult<&mut D> {
117 self.fields
118 .get_mut(name)
119 .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name))
120 }
121
122 pub fn get_full(&self, name: &str) -> Option<(usize, &PlSmallStr, &D)> {
126 self.fields.get_full(name)
127 }
128
129 pub fn try_get_full(&self, name: &str) -> PolarsResult<(usize, &PlSmallStr, &D)> {
133 self.fields
134 .get_full(name)
135 .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name))
136 }
137
138 pub fn get_at_index(&self, index: usize) -> Option<(&PlSmallStr, &D)> {
143 self.fields.get_index(index)
144 }
145
146 pub fn try_get_at_index(&self, index: usize) -> PolarsResult<(&PlSmallStr, &D)> {
147 self.fields.get_index(index).ok_or_else(|| polars_err!(ComputeError: "index {index} out of bounds with 'schema' of len: {}", self.len()))
148 }
149
150 pub fn get_at_index_mut(&mut self, index: usize) -> Option<(&mut PlSmallStr, &mut D)> {
155 self.fields.get_index_mut2(index)
156 }
157
158 pub fn remove(&mut self, name: &str) -> Option<D> {
166 self.fields.swap_remove(name)
167 }
168
169 pub fn shift_remove(&mut self, name: &str) -> Option<D> {
176 self.fields.shift_remove(name)
177 }
178
179 pub fn shift_remove_index(&mut self, index: usize) -> Option<(PlSmallStr, D)> {
186 self.fields.shift_remove_index(index)
187 }
188
189 pub fn contains(&self, name: &str) -> bool {
191 self.get(name).is_some()
192 }
193
194 pub fn set_dtype(&mut self, name: &str, dtype: D) -> Option<D> {
202 let old_dtype = self.fields.get_mut(name)?;
203 Some(std::mem::replace(old_dtype, dtype))
204 }
205
206 pub fn set_dtype_at_index(&mut self, index: usize, dtype: D) -> Option<D> {
214 let (_, old_dtype) = self.fields.get_index_mut(index)?;
215 Some(std::mem::replace(old_dtype, dtype))
216 }
217
218 pub fn with_column(&mut self, name: PlSmallStr, dtype: D) -> Option<D> {
225 self.fields.insert(name, dtype)
226 }
227
228 pub fn try_insert(&mut self, name: PlSmallStr, value: D) -> PolarsResult<()> {
230 if self.fields.contains_key(&name) {
231 polars_bail!(Duplicate: "column '{}' is duplicate", name)
232 }
233
234 self.fields.insert(name, value);
235
236 Ok(())
237 }
238
239 pub fn hstack_mut(
243 &mut self,
244 columns: impl IntoIterator<Item = impl Into<(PlSmallStr, D)>>,
245 ) -> PolarsResult<()> {
246 for v in columns {
247 let (k, v) = v.into();
248 self.try_insert(k, v)?;
249 }
250
251 Ok(())
252 }
253
254 pub fn hstack(
258 mut self,
259 columns: impl IntoIterator<Item = impl Into<(PlSmallStr, D)>>,
260 ) -> PolarsResult<Self> {
261 self.hstack_mut(columns)?;
262 Ok(self)
263 }
264
265 pub fn merge(&mut self, other: Self) {
273 self.fields.extend(other.fields)
274 }
275
276 pub fn iter(&self) -> impl ExactSizeIterator<Item = (&PlSmallStr, &D)> + '_ {
280 self.fields.iter()
281 }
282
283 pub fn iter_mut(&mut self) -> impl ExactSizeIterator<Item = (&PlSmallStr, &mut D)> + '_ {
284 self.fields.iter_mut()
285 }
286
287 pub fn iter_names(&self) -> impl '_ + ExactSizeIterator<Item = &PlSmallStr> {
289 self.fields.iter().map(|(name, _dtype)| name)
290 }
291
292 pub fn iter_names_cloned(&self) -> impl '_ + ExactSizeIterator<Item = PlSmallStr> {
293 self.iter_names().cloned()
294 }
295
296 pub fn iter_values(&self) -> impl '_ + ExactSizeIterator<Item = &D> {
298 self.fields.iter().map(|(_name, dtype)| dtype)
299 }
300
301 pub fn into_iter_values(self) -> impl ExactSizeIterator<Item = D> {
302 self.fields.into_values()
303 }
304
305 pub fn iter_values_mut(&mut self) -> impl '_ + ExactSizeIterator<Item = &mut D> {
307 self.fields.iter_mut().map(|(_name, dtype)| dtype)
308 }
309
310 pub fn index_of(&self, name: &str) -> Option<usize> {
311 self.fields.get_index_of(name)
312 }
313
314 pub fn try_index_of(&self, name: &str) -> PolarsResult<usize> {
315 let Some(i) = self.fields.get_index_of(name) else {
316 polars_bail!(
317 ColumnNotFound:
318 "unable to find column {:?}; valid columns: {:?}",
319 name, self.iter_names().collect::<Vec<_>>(),
320 )
321 };
322
323 Ok(i)
324 }
325
326 pub fn field_compare<'a, 'b>(
328 &'a self,
329 other: &'b Self,
330 self_extra: &mut Vec<(usize, (&'a PlSmallStr, &'a D))>,
331 other_extra: &mut Vec<(usize, (&'b PlSmallStr, &'b D))>,
332 ) {
333 self_extra.extend(
334 self.iter()
335 .enumerate()
336 .filter(|(_, (n, _))| !other.contains(n)),
337 );
338 other_extra.extend(
339 other
340 .iter()
341 .enumerate()
342 .filter(|(_, (n, _))| !self.contains(n)),
343 );
344 }
345}
346
347impl<D> Schema<D>
348where
349 D: Clone + Default,
350{
351 pub fn new_inserting_at_index(
364 &self,
365 index: usize,
366 name: PlSmallStr,
367 field: D,
368 ) -> PolarsResult<Self> {
369 polars_ensure!(
370 index <= self.len(),
371 OutOfBounds:
372 "index {} is out of bounds for schema with length {} (the max index allowed is self.len())",
373 index,
374 self.len()
375 );
376
377 let mut new = Self::default();
378 let mut iter = self.fields.iter().filter_map(|(fld_name, dtype)| {
379 (fld_name != &name).then_some((fld_name.clone(), dtype.clone()))
380 });
381 new.fields.extend(iter.by_ref().take(index));
382 new.fields.insert(name.clone(), field);
383 new.fields.extend(iter);
384 Ok(new)
385 }
386
387 pub fn merge_from_ref(&mut self, other: &Self) {
395 self.fields.extend(
396 other
397 .iter()
398 .map(|(column, field)| (column.clone(), field.clone())),
399 )
400 }
401
402 pub fn try_project<I>(&self, columns: I) -> PolarsResult<Self>
404 where
405 I: IntoIterator,
406 I::Item: AsRef<str>,
407 {
408 let schema = columns
409 .into_iter()
410 .map(|c| {
411 let name = c.as_ref();
412 let (_, name, dtype) = self
413 .fields
414 .get_full(name)
415 .ok_or_else(|| polars_err!(col_not_found = name))?;
416 PolarsResult::Ok((name.clone(), dtype.clone()))
417 })
418 .collect::<PolarsResult<PlIndexMap<PlSmallStr, _>>>()?;
419 Ok(Self::from(schema))
420 }
421
422 pub fn try_project_indices(&self, indices: &[usize]) -> PolarsResult<Self> {
423 let fields = indices
424 .iter()
425 .map(|&i| {
426 let Some((k, v)) = self.fields.get_index(i) else {
427 polars_bail!(
428 SchemaFieldNotFound:
429 "projection index {} is out of bounds for schema of length {}",
430 i, self.fields.len()
431 );
432 };
433
434 Ok((k.clone(), v.clone()))
435 })
436 .collect::<PolarsResult<PlIndexMap<_, _>>>()?;
437
438 Ok(Self { fields })
439 }
440
441 pub fn filter<F: Fn(usize, &D) -> bool>(self, predicate: F) -> Self {
444 let fields = self
445 .fields
446 .into_iter()
447 .enumerate()
448 .filter_map(|(index, (name, d))| {
449 if (predicate)(index, &d) {
450 Some((name, d))
451 } else {
452 None
453 }
454 })
455 .collect();
456
457 Self { fields }
458 }
459
460 pub fn from_iter_check_duplicates<I, F>(iter: I) -> PolarsResult<Self>
461 where
462 I: IntoIterator<Item = F>,
463 F: Into<(PlSmallStr, D)>,
464 {
465 let iter = iter.into_iter();
466 let mut slf = Self::with_capacity(iter.size_hint().1.unwrap_or(0));
467
468 for v in iter {
469 let (name, d) = v.into();
470
471 if slf.contains(&name) {
472 return Err(err_msg(&name));
473
474 fn err_msg(name: &str) -> PolarsError {
475 polars_err!(Duplicate: "duplicate name when building schema '{}'", &name)
476 }
477 }
478
479 slf.fields.insert(name, d);
480 }
481
482 Ok(slf)
483 }
484}
485
486pub fn ensure_matching_schema_names<D>(lhs: &Schema<D>, rhs: &Schema<D>) -> PolarsResult<()> {
487 let lhs_names = lhs.iter_names();
488 let rhs_names = rhs.iter_names();
489
490 if !(lhs_names.len() == rhs_names.len() && lhs_names.zip(rhs_names).all(|(l, r)| l == r)) {
491 polars_bail!(
492 SchemaMismatch:
493 "lhs: {:?} rhs: {:?}",
494 lhs.iter_names().collect::<Vec<_>>(), rhs.iter_names().collect::<Vec<_>>()
495 )
496 }
497
498 Ok(())
499}
500
501impl<D: Debug> Debug for Schema<D> {
502 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
503 writeln!(f, "Schema:")?;
504 for (name, field) in self.fields.iter() {
505 writeln!(f, "name: {name}, field: {field:?}")?;
506 }
507 Ok(())
508 }
509}
510
511impl<D: Hash> Hash for Schema<D> {
512 fn hash<H: Hasher>(&self, state: &mut H) {
513 self.fields.iter().for_each(|v| v.hash(state))
514 }
515}
516
517impl<D: PartialEq> PartialEq for Schema<D> {
520 fn eq(&self, other: &Self) -> bool {
521 self.fields.len() == other.fields.len()
522 && self
523 .fields
524 .iter()
525 .zip(other.fields.iter())
526 .all(|(a, b)| a == b)
527 }
528}
529
530impl<D> From<PlIndexMap<PlSmallStr, D>> for Schema<D> {
531 fn from(fields: PlIndexMap<PlSmallStr, D>) -> Self {
532 Self { fields }
533 }
534}
535
536impl<F, D> FromIterator<F> for Schema<D>
537where
538 F: Into<(PlSmallStr, D)>,
539{
540 fn from_iter<I: IntoIterator<Item = F>>(iter: I) -> Self {
541 let fields = PlIndexMap::from_iter(iter.into_iter().map(|x| x.into()));
542 Self { fields }
543 }
544}
545
546impl<F, D> Extend<F> for Schema<D>
547where
548 F: Into<(PlSmallStr, D)>,
549{
550 fn extend<T: IntoIterator<Item = F>>(&mut self, iter: T) {
551 self.fields.extend(iter.into_iter().map(|x| x.into()))
552 }
553}
554
555impl<D> IntoIterator for Schema<D> {
556 type IntoIter = <PlIndexMap<PlSmallStr, D> as IntoIterator>::IntoIter;
557 type Item = (PlSmallStr, D);
558
559 fn into_iter(self) -> Self::IntoIter {
560 self.fields.into_iter()
561 }
562}