1use std::collections::BTreeSet;
2use std::collections::HashMap;
3use std::collections::HashSet;
4use std::convert::TryInto;
5
6use indexmap::IndexSet;
7use rv::data::CategoricalSuffStat;
8use rv::dist::Categorical;
9use rv::dist::SymmetricDirichlet;
10use serde::Deserialize;
11use serde::Serialize;
12
13use super::error::InsertDataError;
14use crate::cc::feature::ColModel;
15use crate::cc::feature::Column;
16use crate::cc::feature::FType;
17use crate::codebook::Codebook;
18use crate::codebook::ColMetadataList;
19use crate::codebook::ColType;
20use crate::codebook::ValueMap;
21use crate::codebook::ValueMapExtension;
22use crate::codebook::ValueMapExtensionError;
23use crate::data::Category;
24use crate::data::Datum;
25use crate::data::SparseContainer;
26use crate::interface::HasCodebook;
27use crate::ColumnIndex;
28use crate::Engine;
29use crate::HasStates;
30use crate::OracleT;
31use crate::RowIndex;
32
33#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
35#[serde(rename_all = "snake_case")]
36pub enum OverwriteMode {
37 Allow,
39 Deny,
42 MissingOnly,
45}
46
47#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
49#[serde(rename_all = "snake_case")]
50pub enum InsertMode {
51 Unrestricted,
53 DenyNewRows,
55 DenyNewColumns,
57 DenyNewRowsAndColumns,
59}
60
61#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
63#[serde(rename_all = "snake_case")]
64#[derive(Default)]
65pub enum AppendStrategy {
66 #[default]
68 None,
69 Window,
71 Trench {
74 max_n_rows: usize,
76 trench_ix: usize,
78 },
79}
80
81#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
105#[serde(rename_all = "snake_case")]
106pub struct WriteMode {
107 pub insert: InsertMode,
110 pub overwrite: OverwriteMode,
113 #[serde(default)]
117 pub allow_extend_support: bool,
118 #[serde(default)]
120 pub append_strategy: AppendStrategy,
121}
122
123impl WriteMode {
124 #[inline]
127 pub fn new() -> Self {
128 Self {
129 insert: InsertMode::Unrestricted,
130 overwrite: OverwriteMode::Deny,
131 allow_extend_support: false,
132 append_strategy: AppendStrategy::None,
133 }
134 }
135
136 #[inline]
137 pub fn unrestricted() -> Self {
138 Self {
139 insert: InsertMode::Unrestricted,
140 overwrite: OverwriteMode::Allow,
141 allow_extend_support: true,
142 append_strategy: AppendStrategy::None,
143 }
144 }
145}
146
147impl Default for WriteMode {
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
155pub struct Value<C: ColumnIndex> {
156 pub col_ix: C,
158 pub value: Datum,
160}
161
162impl<C: ColumnIndex> From<(C, Datum)> for Value<C> {
163 fn from(value: (C, Datum)) -> Self {
164 Self {
165 col_ix: value.0,
166 value: value.1,
167 }
168 }
169}
170
171#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
213pub struct Row<R: RowIndex, C: ColumnIndex> {
214 pub row_ix: R,
216 pub values: Vec<Value<C>>,
218}
219
220impl<R, C> From<(R, Vec<(C, Datum)>)> for Row<R, C>
221where
222 R: RowIndex,
223 C: ColumnIndex,
224{
225 fn from(mut row: (R, Vec<(C, Datum)>)) -> Self {
226 Self {
227 row_ix: row.0,
228 values: row.1.drain(..).map(Value::from).collect(),
229 }
230 }
231}
232
233impl<R: RowIndex, C: ColumnIndex> From<(R, Vec<Value<C>>)> for Row<R, C> {
234 fn from(row: (R, Vec<Value<C>>)) -> Self {
235 Self {
236 row_ix: row.0,
237 values: row.1,
238 }
239 }
240}
241
242impl<R: RowIndex, C: ColumnIndex> Row<R, C> {
243 #[inline]
245 pub fn len(&self) -> usize {
246 self.values.len()
247 }
248
249 #[inline]
251 pub fn is_empty(&self) -> bool {
252 self.values.is_empty()
253 }
254}
255
256#[derive(Debug, PartialEq)]
258pub(crate) struct IndexValue {
259 pub col_ix: usize,
260 pub value: Datum,
261}
262
263#[derive(Debug, PartialEq)]
264pub(crate) struct IndexRow {
265 pub row_ix: usize,
266 pub values: Vec<IndexValue>,
267}
268
269#[derive(Clone, Debug, Serialize, Deserialize)]
271pub enum SupportExtension {
272 Categorical {
273 col_ix: usize,
275 col_name: String,
277 value_map_extension: ValueMapExtension,
279 },
280}
281
282#[derive(Clone, Debug, Serialize, Deserialize)]
284pub struct InsertDataActions {
285 pub(crate) new_rows: IndexSet<String>,
287 pub(crate) new_cols: HashSet<String>,
288 pub(crate) support_extensions: Vec<SupportExtension>,
289}
290
291impl Default for InsertDataActions {
292 fn default() -> Self {
293 Self::new()
294 }
295}
296
297impl InsertDataActions {
298 pub fn new() -> Self {
299 Self {
300 new_rows: IndexSet::new(),
301 new_cols: HashSet::new(),
302 support_extensions: Vec::new(),
303 }
304 }
305
306 pub fn new_rows(&self) -> Option<&IndexSet<String>> {
308 if self.new_rows.is_empty() {
309 None
310 } else {
311 Some(&self.new_rows)
312 }
313 }
314
315 pub fn new_cols(&self) -> Option<&HashSet<String>> {
317 if self.new_cols.is_empty() {
318 None
319 } else {
320 Some(&self.new_cols)
321 }
322 }
323
324 pub fn support_extensions(&self) -> Option<&Vec<SupportExtension>> {
327 if self.support_extensions.is_empty() {
328 None
329 } else {
330 Some(&self.support_extensions)
331 }
332 }
333}
334
335#[derive(Debug)]
337pub(crate) struct InsertDataTasks {
338 pub new_rows: IndexSet<String>,
341 pub new_cols: HashSet<String>,
343 pub overwrite_missing: bool,
346 pub overwrite_present: bool,
348}
349
350impl InsertDataTasks {
351 fn new() -> Self {
352 Self {
353 new_rows: IndexSet::new(),
354 new_cols: HashSet::new(),
355 overwrite_missing: false,
356 overwrite_present: false,
357 }
358 }
359
360 pub(crate) fn validate_insert_mode(
361 &self,
362 mode: WriteMode,
363 ) -> Result<(), InsertDataError> {
364 match mode.overwrite {
365 OverwriteMode::Deny => {
366 if self.overwrite_present || self.overwrite_missing {
367 Err(InsertDataError::ModeForbidsOverwrite)
368 } else {
369 Ok(())
370 }
371 }
372 OverwriteMode::MissingOnly => {
373 if self.overwrite_present {
374 Err(InsertDataError::ModeForbidsOverwrite)
375 } else {
376 Ok(())
377 }
378 }
379 OverwriteMode::Allow => Ok(()),
380 }
381 .and_then(|_| match mode.insert {
382 InsertMode::DenyNewRows => {
383 if !self.new_rows.is_empty() {
384 Err(InsertDataError::ModeForbidsNewRows)
385 } else {
386 Ok(())
387 }
388 }
389 InsertMode::DenyNewColumns => {
390 if !self.new_cols.is_empty() {
391 Err(InsertDataError::ModeForbidsNewColumns)
392 } else {
393 Ok(())
394 }
395 }
396 InsertMode::DenyNewRowsAndColumns => {
397 if !(self.new_rows.is_empty() && self.new_cols.is_empty()) {
398 Err(InsertDataError::ModeForbidsNewRowsOrColumns)
399 } else {
400 Ok(())
401 }
402 }
403 _ => Ok(()),
404 })
405 }
406}
407
408#[inline]
409fn ix_lookup_from_codebook(
410 col_metadata: &Option<ColMetadataList>,
411) -> Option<HashMap<&str, usize>> {
412 col_metadata.as_ref().map(|colmds| {
413 colmds
414 .iter()
415 .enumerate()
416 .map(|(ix, md)| (md.name.as_str(), ix))
417 .collect()
418 })
419}
420
421#[inline]
422fn col_ix_from_lookup(
423 col: &str,
424 lookup: &Option<HashMap<&str, usize>>,
425) -> Result<usize, InsertDataError> {
426 match lookup {
427 Some(lkp) => lkp
428 .get(col)
429 .ok_or_else(|| {
430 InsertDataError::NewColumnNotInColumnMetadata(col.to_owned())
431 })
432 .copied(),
433 None => Err(InsertDataError::NewColumnNotInColumnMetadata(
434 String::from(col),
435 )),
436 }
437}
438
439pub(crate) fn append_empty_columns(
442 tasks: &InsertDataTasks,
443 col_metadata: Option<ColMetadataList>,
444 engine: &mut Engine,
445) -> Result<(), InsertDataError> {
446 match col_metadata {
447 Some(colmds) if !tasks.new_cols.is_empty() => {
449 tasks.new_cols.iter().try_for_each(|col| {
452 if colmds.contains_key(col) {
453 Ok(())
454 } else {
455 Err(InsertDataError::NewColumnNotInColumnMetadata(
456 col.clone(),
457 ))
458 }
459 })?;
460
461 if colmds.len() != tasks.new_cols.len() {
462 Err(InsertDataError::WrongNumberOfColumnMetadataEntries {
465 ncolmd: colmds.len(),
466 nnew: tasks.new_cols.len(),
467 })
468 } else {
469 let shape = (engine.n_rows(), engine.n_cols());
472 create_new_columns(&colmds, shape, &mut engine.rng).map(
473 |col_models| {
474 let mut rng = &mut engine.rng;
478 engine.states.iter_mut().for_each(|state| {
479 state.append_blank_features(
480 col_models.clone(),
481 &mut rng,
482 );
483 });
484
485 engine.codebook.append_col_metadata(colmds).unwrap();
489 },
490 )
491 }
492 }
493 None if !tasks.new_cols.is_empty() => {
495 Err(InsertDataError::WrongNumberOfColumnMetadataEntries {
496 ncolmd: 0,
497 nnew: tasks.new_cols.len(),
498 })
499 }
500 _ => Ok(()),
502 }
503}
504
505fn validate_new_col_ftype(
506 new_metadata: &Option<ColMetadataList>,
507 value: &Value<&str>,
508) -> Result<(), InsertDataError> {
509 let col_ftype = new_metadata
510 .as_ref()
511 .ok_or_else(|| {
512 InsertDataError::NewColumnNotInColumnMetadata(value.col_ix.into())
513 })?
514 .get(value.col_ix)
515 .ok_or_else(|| {
516 InsertDataError::NewColumnNotInColumnMetadata(value.col_ix.into())
517 })
518 .map(|(_, md)| FType::from_coltype(&md.coltype))?;
519
520 let (is_compat, compat_info) = col_ftype.datum_compatible(&value.value);
521
522 let bad_continuous_value = match value.value {
523 Datum::Continuous(ref x) => !x.is_finite(),
524 _ => false,
525 };
526
527 if is_compat {
528 if bad_continuous_value {
529 Err(InsertDataError::NonFiniteContinuousValue {
530 col: value.col_ix.to_owned(),
531 value: value.value.to_f64_opt().unwrap(),
532 })
533 } else {
534 Ok(())
535 }
536 } else {
537 Err(InsertDataError::DatumIncompatibleWithColumn {
538 col: value.col_ix.to_owned(),
539 ftype: compat_info.ftype,
540 ftype_req: compat_info.ftype_req,
541 })
542 }
543}
544
545fn validate_row_values<R: RowIndex, C: ColumnIndex>(
546 row: &Row<R, C>,
547 row_ix: usize,
548 row_exists: bool,
549 col_metadata: &Option<ColMetadataList>,
550 col_ix_lookup: &Option<HashMap<&str, usize>>,
551 insert_tasks: &mut InsertDataTasks,
552 engine: &Engine,
553) -> Result<IndexRow, InsertDataError> {
554 let n_cols = engine.n_cols();
555
556 let mut index_row = IndexRow {
557 row_ix,
558 values: vec![],
559 };
560
561 row.values.iter().try_for_each(|value| {
562 match value.col_ix.col_ix(engine.codebook()) {
563 Ok(col_ix) => {
564 if row_exists {
566 if engine.datum(row_ix, col_ix).unwrap().is_missing() {
567 insert_tasks.overwrite_missing = true;
568 } else {
569 insert_tasks.overwrite_present = true;
570 }
571 }
572
573 let ftype_compat = engine
576 .ftype(col_ix)
577 .unwrap()
578 .datum_compatible(&value.value);
579
580 let bad_continuous_value = match value.value {
581 Datum::Continuous(ref x) => !x.is_finite(),
582 _ => false,
583 };
584
585 if ftype_compat.0 {
586 if bad_continuous_value {
587 let col = &engine.codebook.col_metadata[col_ix].name;
588 Err(InsertDataError::NonFiniteContinuousValue {
589 col: col.clone(),
590 value: value.value.to_f64_opt().unwrap(),
591 })
592 } else {
593 Ok(col_ix)
594 }
595 } else {
596 let col = &engine.codebook.col_metadata[col_ix].name;
597 Err(InsertDataError::DatumIncompatibleWithColumn {
598 col: col.clone(),
599 ftype_req: ftype_compat.1.ftype_req,
600 ftype: ftype_compat.1.ftype,
601 })
602 }
603 }
604 Err(_) => {
605 value
606 .col_ix
607 .col_str()
608 .ok_or_else(|| {
609 InsertDataError::IntegerIndexNewColumn(
610 value
611 .col_ix
612 .col_usize()
613 .expect("Column index does not have a string or usize representation")
614 )
615 })
616 .and_then(|name| {
617 let new_val = Value {
619 col_ix: name,
620 value: value.value.clone(),
621 };
622 validate_new_col_ftype(col_metadata, &new_val).and_then(
623 |_| {
624 insert_tasks.new_cols.insert(name.to_owned());
625 col_ix_from_lookup(name, col_ix_lookup)
626 .map(|ix| ix + n_cols)
627 },
628 )
629 })
630 }
631 }
632 .map(|col_ix| {
633 index_row.values.push(IndexValue {
634 col_ix,
635 value: value.value.clone(),
636 });
637 })
638 })?;
639 Ok(index_row)
640}
641
642pub(crate) fn insert_data_tasks<R: RowIndex, C: ColumnIndex>(
644 rows: &[Row<R, C>],
645 col_metadata: &Option<ColMetadataList>,
646 engine: &Engine,
647) -> Result<(InsertDataTasks, Vec<IndexRow>), InsertDataError> {
648 const EXISTING_ROW: bool = true;
649 const NEW_ROW: bool = false;
650
651 let col_ix_lookup = ix_lookup_from_codebook(col_metadata);
653
654 let n_rows = engine.n_rows();
657
658 let mut tasks = InsertDataTasks::new();
659
660 let index_rows: Vec<IndexRow> = rows
661 .iter()
662 .map(|row| match row.row_ix.row_ix(engine.codebook()) {
663 Ok(row_ix) => {
664 if row.is_empty() {
665 let name = engine.codebook.row_names.name(row_ix).unwrap();
666 Err(InsertDataError::EmptyRow(name.clone()))
667 } else {
668 validate_row_values(
669 row,
670 row_ix,
671 EXISTING_ROW,
672 col_metadata,
673 &col_ix_lookup,
674 &mut tasks,
675 engine,
676 )
677 }
678 }
679 Err(_) => {
680 if row.is_empty() {
682 Err(InsertDataError::EmptyRow(format!("{:?}", row.row_ix)))
683 } else {
684 validate_row_values(
685 row,
686 {
687 let n = tasks.new_rows.len();
688 row.row_ix
689 .row_str()
690 .ok_or_else(|| {
691 let ix = row
692 .row_ix
693 .row_usize()
694 .expect("Index doesn't have a string or usize representation");
695 InsertDataError::IntegerIndexNewRow(ix)
696 })
697 .map(|row_name| {
698 tasks
699 .new_rows
700 .insert(String::from(row_name));
701 })?;
702 n_rows + n
703 },
704 NEW_ROW,
705 col_metadata,
706 &col_ix_lookup,
707 &mut tasks,
708 engine,
709 )
710 }
711 }
712 })
713 .collect::<Result<Vec<IndexRow>, InsertDataError>>()?;
714 Ok((tasks, index_rows))
715}
716
717pub(crate) fn maybe_add_categories<R: RowIndex, C: ColumnIndex>(
718 rows: &[Row<R, C>],
719 engine: &mut Engine,
720 mode: WriteMode,
721) -> Result<Vec<SupportExtension>, InsertDataError> {
722 let mut extended_value_map: HashMap<usize, ValueMapExtension> =
723 HashMap::new();
724
725 rows.iter().try_for_each(|row| {
729 row.values.iter().try_for_each(|new_value| {
730 match new_value.col_ix.col_ix(engine.codebook()) {
733 Err(_) => Ok(()), Ok(col_ix) => {
735 let col_metadata = &engine.codebook.col_metadata[col_ix];
736 let col_name = col_metadata.name.as_str();
737
738 match &col_metadata.coltype {
739 ColType::Categorical { value_map, .. } => {
740 match (&new_value.value, value_map) {
741 (Datum::Categorical(Category::String(x)), ValueMap::String(vm)) => {
742 if !vm.contains_cat(x) {
743 let ext_vm = extended_value_map.entry(col_ix).or_insert_with(ValueMapExtension::new_string);
744 ext_vm.extend(Category::String(x.clone())).map_err(
745 |e| match e {
746 ValueMapExtensionError::ExtensionOfDifferingType(a, b) => {
747 InsertDataError::WrongCategoryAndType(a, b, col_name.to_string())
748 }
749 })?;
750 };
751 Ok(())
752 },
753 (Datum::Categorical(Category::UInt(x)), ValueMap::UInt(old_max)) => {
754 if *old_max as u32 <= *x {
755 let ext_max = extended_value_map.entry(col_ix).or_insert_with(ValueMapExtension::new_uint);
756 ext_max.extend(Category::UInt(*x)).map_err(
757 |e| match e {
758 ValueMapExtensionError::ExtensionOfDifferingType(a, b) => {
759 InsertDataError::WrongCategoryAndType(a, b, col_name.to_string())
760 }
761 })?;
762 }
763 Ok(())
764 },
765 (Datum::Missing, _) |
766 (Datum::Categorical(Category::Bool(_)), ValueMap::Bool) => {
767 Ok(())
768 },
769 _ => {
770 Err(
771 InsertDataError::DatumIncompatibleWithColumn {
772 col: (*col_name).into(),
773 ftype_req: FType::Categorical,
774 ftype: (&new_value.value).try_into().unwrap(),
778 },
779 )}
780 }
781 }
782 _ => Ok(())
783 }
784 }
785 }
786 })
787 })?;
788
789 if !mode.allow_extend_support && !extended_value_map.is_empty() {
790 return Err(InsertDataError::ModeForbidsCategoryExtension);
792 }
793
794 let mut cols_extended: Vec<SupportExtension> = Vec::new();
795
796 for (col_ix, value_map_extension) in extended_value_map.drain() {
802 incr_column_categories(engine, col_ix, &value_map_extension)?;
803 cols_extended.push(SupportExtension::Categorical {
804 col_ix,
805 col_name: engine.codebook.col_metadata[col_ix].name.clone(),
806 value_map_extension,
807 })
808 }
809
810 Ok(cols_extended)
811}
812
813fn incr_category_in_codebook(
814 codebook: &mut Codebook,
815 col_ix: usize,
816 value_map_extension: &ValueMapExtension,
817) -> Result<(), InsertDataError> {
818 let col_name = codebook.col_metadata[col_ix].name.clone();
819 match codebook.col_metadata[col_ix].coltype {
820 ColType::Categorical {
821 ref mut k,
822 ref mut value_map,
823 ..
824 } => {
825 value_map.extend(value_map_extension.clone()).map_err(
827 |e| match e {
828 ValueMapExtensionError::ExtensionOfDifferingType(a, b) => {
829 InsertDataError::WrongCategoryAndType(a, b, col_name)
830 }
831 },
832 )?;
833 *k = value_map.len();
834 Ok(())
835 }
836 _ => panic!("Tried to change cardinality of non-categorical column"),
837 }
838}
839
840fn incr_column_categories(
841 engine: &mut Engine,
842 col_ix: usize,
843 extended_value_map: &ValueMapExtension,
844) -> Result<(), InsertDataError> {
845 incr_category_in_codebook(
847 &mut engine.codebook,
848 col_ix,
849 extended_value_map,
850 )?;
851
852 let n_cats_req = match engine.codebook.col_metadata[col_ix].coltype {
853 ColType::Categorical { k, .. } => k,
854 _ => panic!("Requested non-categorical column"),
855 };
856
857 engine.states.iter_mut().for_each(|state| {
859 match state.feature_mut(col_ix) {
860 ColModel::Categorical(column) => {
861 column.prior = SymmetricDirichlet::new_unchecked(
862 column.prior.alpha(),
863 n_cats_req,
864 );
865 column.components.iter_mut().for_each(|cpnt| {
866 cpnt.stat = CategoricalSuffStat::from_parts_unchecked(
867 cpnt.stat.n(),
868 {
869 let mut counts = cpnt.stat.counts().clone();
870 counts.resize(n_cats_req, 0.0);
871 counts
872 },
873 );
874
875 cpnt.fx = Categorical::new_unchecked({
876 let mut ln_weights = cpnt.fx.ln_weights().clone();
877 ln_weights.resize(n_cats_req, f64::NEG_INFINITY);
878 ln_weights
879 });
880 })
881 }
882 _ => panic!("Requested non-categorical column"),
883 }
884 });
885 Ok(())
886}
887
888macro_rules! new_col_arm {
889 (
890 $coltype: ident,
891 $htype: ty,
892 $errvar: ident,
893 $colmd: ident,
894 $hyper: ident,
895 $prior: ident,
896 $n_rows: ident,
897 $id: ident,
898 $xtype: ty,
899 $rng: ident
900 ) => {{
901 let data: SparseContainer<$xtype> =
902 SparseContainer::all_missing($n_rows);
903
904 match ($hyper, $prior) {
905 (Some(h), _) => {
906 let pr = if let Some(pr) = $prior {
907 pr.clone()
908 } else {
909 h.draw(&mut $rng)
910 };
911 let column = Column::new($id, data, pr, h.clone());
912 Ok(ColModel::$coltype(column))
913 }
914 (None, Some(pr)) => {
915 let mut column =
917 Column::new($id, data, pr.clone(), <$htype>::default());
918 column.ignore_hyper = true;
919 Ok(ColModel::$coltype(column))
920 }
921 (None, None) => Err(InsertDataError::NoGaussianHyperForNewColumn(
922 $colmd.name.clone(),
923 )),
924 }
925 }};
926}
927
928pub(crate) fn create_new_columns<R: rand::Rng>(
929 col_metadata: &ColMetadataList,
930 state_shape: (usize, usize),
931 mut rng: &mut R,
932) -> Result<Vec<ColModel>, InsertDataError> {
933 let (n_rows, n_cols) = state_shape;
934 col_metadata
935 .iter()
936 .enumerate()
937 .map(|(i, colmd)| {
938 let id = i + n_cols;
939 match &colmd.coltype {
940 ColType::Continuous { hyper, prior } => new_col_arm!(
941 Continuous,
942 crate::stats::prior::nix::NixHyper,
943 NoGaussianHyperForNewColumn,
944 colmd,
945 hyper,
946 prior,
947 n_rows,
948 id,
949 f64,
950 rng
951 ),
952 ColType::Count { hyper, prior } => new_col_arm!(
953 Count,
954 crate::stats::prior::pg::PgHyper,
955 NoPoissonHyperForNewColumn,
956 colmd,
957 hyper,
958 prior,
959 n_rows,
960 id,
961 u32,
962 rng
963 ),
964 ColType::Categorical {
965 k, hyper, prior, ..
966 } => {
967 let data: SparseContainer<u32> =
968 SparseContainer::all_missing(n_rows);
969
970 let id = i + n_cols;
971 match (hyper, prior) {
972 (Some(h), _) => {
973 let pr = if let Some(pr) = prior {
974 pr.clone()
975 } else {
976 h.draw(*k, &mut rng)
977 };
978 let column = Column::new(id, data, pr, h.clone());
979 Ok(ColModel::Categorical(column))
980 }
981 (None, Some(pr)) => {
982 use crate::stats::prior::csd::CsdHyper;
983 let mut column = Column::new(
984 id,
985 data,
986 pr.clone(),
987 CsdHyper::default(),
988 );
989 column.ignore_hyper = true;
990 Ok(ColModel::Categorical(column))
991 }
992 (None, None) => Err(
993 InsertDataError::NoCategoricalHyperForNewColumn(
994 colmd.name.clone(),
995 ),
996 ),
997 }
998 }
999 }
1000 })
1001 .collect()
1002}
1003
1004pub(crate) fn remove_cell(engine: &mut Engine, row_ix: usize, col_ix: usize) {
1005 engine.states.iter_mut().for_each(|state| {
1006 state.remove_datum(row_ix, col_ix);
1007 })
1008}
1009
1010pub(crate) fn remove_col(engine: &mut Engine, col_ix: usize) {
1011 engine.codebook.col_metadata.remove_by_index(col_ix);
1013 let mut rng = engine.rng.clone();
1014 engine.states.iter_mut().for_each(|state| {
1015 state.del_col(col_ix, &mut rng);
1017 });
1018}
1019
1020pub(crate) fn check_if_removes_col(
1021 engine: &Engine,
1022 rm_rows: &BTreeSet<usize>,
1023 mut rm_cell_cols: HashMap<usize, i64>,
1024) -> BTreeSet<usize> {
1025 let mut to_rm: BTreeSet<usize> = BTreeSet::new();
1026 rm_cell_cols.iter_mut().for_each(|(col_ix, val)| {
1028 let mut present_count = 0_i64;
1029 let mut remove = true;
1030 for row_ix in 0..engine.n_rows() {
1031 if present_count > *val {
1032 remove = false;
1033 break;
1034 }
1035 if !rm_rows.contains(&row_ix)
1036 && !engine.datum(row_ix, *col_ix).unwrap().is_missing()
1037 {
1038 present_count += 1;
1039 }
1040 }
1041 if remove {
1042 to_rm.insert(*col_ix);
1043 }
1044 });
1045 to_rm
1046}
1047
1048pub(crate) fn check_if_removes_row(
1049 engine: &Engine,
1050 rm_cols: &BTreeSet<usize>,
1051 mut rm_cell_rows: HashMap<usize, i64>,
1052) -> BTreeSet<usize> {
1053 let mut to_rm: BTreeSet<usize> = BTreeSet::new();
1054 rm_cell_rows.iter_mut().for_each(|(row_ix, val)| {
1055 let mut present_count = 0_i64;
1056 let mut remove = true;
1057 for col_ix in 0..engine.n_cols() {
1058 if present_count > *val {
1059 remove = false;
1060 break;
1061 }
1062 if !rm_cols.contains(&col_ix)
1063 && !engine.datum(*row_ix, col_ix).unwrap().is_missing()
1064 {
1065 present_count += 1;
1066 }
1067 }
1068 if remove {
1069 to_rm.insert(*row_ix);
1070 }
1071 });
1072 to_rm
1073}
1074
1075#[cfg(test)]
1076mod tests {
1077 use rand::SeedableRng;
1078
1079 use super::*;
1080 use crate::codebook::ColMetadata;
1081 use crate::codebook::ColType;
1082 use crate::codebook::ValueMap;
1083 use crate::data::data_source;
1084 use crate::stats::prior::csd::CsdHyper;
1085
1086 #[cfg(feature = "examples")]
1087 mod requiring_examples {
1088 use super::*;
1089 use crate::examples::Example;
1090
1091 #[test]
1092 fn errors_when_no_col_metadata_when_new_columns() {
1093 let engine = Example::Animals.engine().unwrap();
1094 let moose_updates = Row::<String, String> {
1095 row_ix: "moose".into(),
1096 values: vec![
1097 Value {
1098 col_ix: "does+taxes".into(),
1099 value: Datum::Categorical(1_u32.into()),
1100 },
1101 Value {
1102 col_ix: "flys".into(),
1103 value: Datum::Categorical(1_u32.into()),
1104 },
1105 ],
1106 };
1107
1108 let result = insert_data_tasks(&[moose_updates], &None, &engine);
1109
1110 assert!(result.is_err());
1111 match result {
1112 Err(InsertDataError::NewColumnNotInColumnMetadata(s)) => {
1113 assert_eq!(s, String::from("does+taxes"))
1114 }
1115 Err(err) => panic!("wrong error: {:?}", err),
1116 Ok(_) => panic!("failed to fail"),
1117 }
1118 }
1119
1120 #[test]
1121 fn errors_when_new_column_not_in_col_metadata() {
1122 let engine = Example::Animals.engine().unwrap();
1123 let moose_updates = Row::<String, String> {
1124 row_ix: "moose".into(),
1125 values: vec![
1126 Value {
1127 col_ix: "does+taxes".into(),
1128 value: Datum::Categorical(1_u32.into()),
1129 },
1130 Value {
1131 col_ix: "flys".into(),
1132 value: Datum::Categorical(1_u32.into()),
1133 },
1134 ],
1135 };
1136
1137 let col_metadata = ColMetadataList::new(vec![ColMetadata {
1138 name: "dances".into(),
1139 coltype: ColType::Categorical {
1140 k: 2,
1141 hyper: None,
1142 prior: None,
1143 value_map: ValueMap::UInt(2),
1144 },
1145 notes: None,
1146 missing_not_at_random: false,
1147 }])
1148 .unwrap();
1149
1150 let result = insert_data_tasks(
1151 &[moose_updates],
1152 &Some(col_metadata),
1153 &engine,
1154 );
1155
1156 assert!(result.is_err());
1157 assert_eq!(
1158 result.unwrap_err(),
1159 InsertDataError::NewColumnNotInColumnMetadata(
1160 "does+taxes".into()
1161 )
1162 );
1163 }
1164
1165 #[test]
1166 fn tasks_on_one_existing_row() {
1167 let engine = Example::Animals.engine().unwrap();
1168 let moose_updates = Row::<String, String> {
1169 row_ix: "moose".into(),
1170 values: vec![
1171 Value {
1172 col_ix: "swims".into(),
1173 value: Datum::Categorical(1_u32.into()),
1174 },
1175 Value {
1176 col_ix: "flys".into(),
1177 value: Datum::Categorical(1_u32.into()),
1178 },
1179 ],
1180 };
1181 let rows = vec![moose_updates];
1182 let (tasks, ixrows) =
1183 insert_data_tasks(&rows, &None, &engine).unwrap();
1184
1185 assert!(tasks.new_rows.is_empty());
1186 assert!(tasks.new_cols.is_empty());
1187 assert!(!tasks.overwrite_missing);
1188 assert!(tasks.overwrite_present);
1189
1190 assert_eq!(
1191 ixrows,
1192 vec![IndexRow {
1193 row_ix: 15,
1194 values: vec![
1195 IndexValue {
1196 col_ix: 36,
1197 value: Datum::Categorical(1_u32.into())
1198 },
1199 IndexValue {
1200 col_ix: 34,
1201 value: Datum::Categorical(1_u32.into())
1202 },
1203 ]
1204 }]
1205 );
1206 }
1207
1208 #[test]
1209 fn tasks_on_one_new_row() {
1210 let engine = Example::Animals.engine().unwrap();
1211 let pegasus = Row::<String, String> {
1212 row_ix: "pegasus".into(),
1213 values: vec![
1214 Value {
1215 col_ix: "swims".into(),
1216 value: Datum::Categorical(1_u32.into()),
1217 },
1218 Value {
1219 col_ix: "flys".into(),
1220 value: Datum::Categorical(1_u32.into()),
1221 },
1222 ],
1223 };
1224 let rows = vec![pegasus];
1225 let (tasks, ixrows) =
1226 insert_data_tasks(&rows, &None, &engine).unwrap();
1227
1228 assert_eq!(tasks.new_rows.len(), 1);
1229 assert!(tasks.new_rows.contains("pegasus"));
1230 assert!(tasks.new_cols.is_empty());
1231 assert!(!tasks.overwrite_missing);
1232 assert!(!tasks.overwrite_present);
1233
1234 assert_eq!(
1235 ixrows,
1236 vec![IndexRow {
1237 row_ix: 50,
1238 values: vec![
1239 IndexValue {
1240 col_ix: 36,
1241 value: Datum::Categorical(1_u32.into())
1242 },
1243 IndexValue {
1244 col_ix: 34,
1245 value: Datum::Categorical(1_u32.into())
1246 },
1247 ]
1248 }]
1249 );
1250 }
1251
1252 #[test]
1253 fn tasks_on_two_new_rows() {
1254 let engine = Example::Animals.engine().unwrap();
1255 let pegasus = Row::<String, String> {
1256 row_ix: "pegasus".into(),
1257 values: vec![
1258 Value {
1259 col_ix: "swims".into(),
1260 value: Datum::Categorical(1_u32.into()),
1261 },
1262 Value {
1263 col_ix: "flys".into(),
1264 value: Datum::Categorical(1_u32.into()),
1265 },
1266 ],
1267 };
1268
1269 let man = Row::<String, String> {
1270 row_ix: "man".into(),
1271 values: vec![
1272 Value {
1273 col_ix: "smart".into(),
1274 value: Datum::Categorical(1_u32.into()),
1275 },
1276 Value {
1277 col_ix: "hunter".into(),
1278 value: Datum::Categorical(0_u32.into()),
1279 },
1280 ],
1281 };
1282 let rows = vec![pegasus, man];
1283 let (tasks, ixrows) =
1284 insert_data_tasks(&rows, &None, &engine).unwrap();
1285
1286 assert_eq!(tasks.new_rows.len(), 2);
1287 assert!(tasks.new_rows.contains("pegasus"));
1288 assert!(tasks.new_rows.contains("man"));
1289
1290 assert!(tasks.new_cols.is_empty());
1291 assert!(!tasks.overwrite_missing);
1292 assert!(!tasks.overwrite_present);
1293
1294 assert_eq!(
1295 ixrows,
1296 vec![
1297 IndexRow {
1298 row_ix: 50,
1299 values: vec![
1300 IndexValue {
1301 col_ix: 36,
1302 value: Datum::Categorical(1_u32.into())
1303 },
1304 IndexValue {
1305 col_ix: 34,
1306 value: Datum::Categorical(1_u32.into())
1307 },
1308 ]
1309 },
1310 IndexRow {
1311 row_ix: 51,
1312 values: vec![
1313 IndexValue {
1314 col_ix: 80,
1315 value: Datum::Categorical(1_u32.into())
1316 },
1317 IndexValue {
1318 col_ix: 58,
1319 value: Datum::Categorical(0_u32.into())
1320 },
1321 ]
1322 }
1323 ]
1324 );
1325 }
1326
1327 #[test]
1328 fn tasks_on_one_new_and_one_existing_row() {
1329 let engine = Example::Animals.engine().unwrap();
1330 let pegasus = Row::<String, String> {
1331 row_ix: "pegasus".into(),
1332 values: vec![
1333 Value {
1334 col_ix: "swims".into(),
1335 value: Datum::Categorical(1_u32.into()),
1336 },
1337 Value {
1338 col_ix: "flys".into(),
1339 value: Datum::Categorical(1_u32.into()),
1340 },
1341 ],
1342 };
1343
1344 let moose = Row::<String, String> {
1345 row_ix: "moose".into(),
1346 values: vec![
1347 Value {
1348 col_ix: "smart".into(),
1349 value: Datum::Categorical(1_u32.into()),
1350 },
1351 Value {
1352 col_ix: "hunter".into(),
1353 value: Datum::Categorical(0_u32.into()),
1354 },
1355 ],
1356 };
1357 let rows = vec![pegasus, moose];
1358 let (tasks, ixrows) =
1359 insert_data_tasks(&rows, &None, &engine).unwrap();
1360
1361 assert_eq!(tasks.new_rows.len(), 1);
1362 assert!(tasks.new_rows.contains("pegasus"));
1363
1364 assert!(tasks.new_cols.is_empty());
1365 assert!(!tasks.overwrite_missing);
1366 assert!(tasks.overwrite_present);
1367
1368 assert_eq!(
1369 ixrows,
1370 vec![
1371 IndexRow {
1372 row_ix: 50,
1373 values: vec![
1374 IndexValue {
1375 col_ix: 36,
1376 value: Datum::Categorical(1_u32.into())
1377 },
1378 IndexValue {
1379 col_ix: 34,
1380 value: Datum::Categorical(1_u32.into())
1381 },
1382 ]
1383 },
1384 IndexRow {
1385 row_ix: 15,
1386 values: vec![
1387 IndexValue {
1388 col_ix: 80,
1389 value: Datum::Categorical(1_u32.into())
1390 },
1391 IndexValue {
1392 col_ix: 58,
1393 value: Datum::Categorical(0_u32.into())
1394 },
1395 ]
1396 }
1397 ]
1398 );
1399 }
1400
1401 #[test]
1402 fn tasks_on_one_new_col_in_existing_row() {
1403 let engine = Example::Animals.engine().unwrap();
1404 let col_metadata = ColMetadataList::new(vec![ColMetadata {
1405 name: "dances".into(),
1406 coltype: ColType::Categorical {
1407 k: 2,
1408 hyper: None,
1409 prior: None,
1410 value_map: ValueMap::UInt(2),
1411 },
1412 notes: None,
1413 missing_not_at_random: false,
1414 }])
1415 .unwrap();
1416 let moose_updates = Row::<String, String> {
1417 row_ix: "moose".into(),
1418 values: vec![
1419 Value {
1420 col_ix: "dances".into(),
1421 value: Datum::Categorical(1_u32.into()),
1422 },
1423 Value {
1424 col_ix: "flys".into(),
1425 value: Datum::Categorical(1_u32.into()),
1426 },
1427 ],
1428 };
1429 let rows = vec![moose_updates];
1430 let (tasks, ixrows) =
1431 insert_data_tasks(&rows, &Some(col_metadata), &engine).unwrap();
1432
1433 assert!(tasks.new_rows.is_empty());
1434 assert_eq!(tasks.new_cols.len(), 1);
1435 assert!(tasks.new_cols.contains("dances"));
1436
1437 assert!(!tasks.overwrite_missing);
1438 assert!(tasks.overwrite_present);
1439
1440 assert_eq!(
1441 ixrows,
1442 vec![IndexRow {
1443 row_ix: 15,
1444 values: vec![
1445 IndexValue {
1446 col_ix: 85,
1447 value: Datum::Categorical(1_u32.into())
1448 },
1449 IndexValue {
1450 col_ix: 34,
1451 value: Datum::Categorical(1_u32.into())
1452 },
1453 ]
1454 }]
1455 );
1456 }
1457
1458 #[test]
1459 fn tasks_on_one_new_col_in_new_row() {
1460 let engine = Example::Animals.engine().unwrap();
1461
1462 let col_metadata = ColMetadataList::new(vec![ColMetadata {
1463 name: "dances".into(),
1464 coltype: ColType::Categorical {
1465 k: 2,
1466 hyper: None,
1467 prior: None,
1468 value_map: ValueMap::UInt(2),
1469 },
1470 notes: None,
1471 missing_not_at_random: false,
1472 }])
1473 .unwrap();
1474
1475 let peanut = Row::<String, String> {
1476 row_ix: "peanut".into(),
1477 values: vec![
1478 Value {
1479 col_ix: "dances".into(),
1480 value: Datum::Categorical(1_u32.into()),
1481 },
1482 Value {
1483 col_ix: "flys".into(),
1484 value: Datum::Categorical(0_u32.into()),
1485 },
1486 ],
1487 };
1488 let rows = vec![peanut];
1489 let (tasks, ixrows) =
1490 insert_data_tasks(&rows, &Some(col_metadata), &engine).unwrap();
1491
1492 assert_eq!(tasks.new_rows.len(), 1);
1493 assert!(tasks.new_rows.contains("peanut"));
1494
1495 assert_eq!(tasks.new_cols.len(), 1);
1496 assert!(tasks.new_cols.contains("dances"));
1497
1498 assert!(!tasks.overwrite_missing);
1499 assert!(!tasks.overwrite_present);
1500
1501 assert_eq!(
1502 ixrows,
1503 vec![IndexRow {
1504 row_ix: 50,
1505 values: vec![
1506 IndexValue {
1507 col_ix: 85,
1508 value: Datum::Categorical(1_u32.into())
1509 },
1510 IndexValue {
1511 col_ix: 34,
1512 value: Datum::Categorical(0_u32.into())
1513 },
1514 ]
1515 }]
1516 );
1517 }
1518
1519 #[test]
1520 fn tasks_on_two_new_cols_in_existing_row() {
1521 let engine = Example::Animals.engine().unwrap();
1522 let col_metadata = ColMetadataList::new(vec![
1523 ColMetadata {
1524 name: "dances".into(),
1525 coltype: ColType::Categorical {
1526 k: 2,
1527 hyper: None,
1528 prior: None,
1529 value_map: ValueMap::UInt(2),
1530 },
1531 notes: None,
1532 missing_not_at_random: false,
1533 },
1534 ColMetadata {
1535 name: "eats+figs".into(),
1536 coltype: ColType::Categorical {
1537 k: 2,
1538 hyper: None,
1539 prior: None,
1540 value_map: ValueMap::UInt(2),
1541 },
1542 notes: None,
1543 missing_not_at_random: false,
1544 },
1545 ])
1546 .unwrap();
1547
1548 let moose_updates = Row::<String, String> {
1549 row_ix: "moose".into(),
1550 values: vec![
1551 Value {
1552 col_ix: "flys".into(),
1553 value: Datum::Categorical(1_u32.into()),
1554 },
1555 Value {
1556 col_ix: "eats+figs".into(),
1557 value: Datum::Categorical(0_u32.into()),
1558 },
1559 Value {
1560 col_ix: "dances".into(),
1561 value: Datum::Categorical(1_u32.into()),
1562 },
1563 ],
1564 };
1565 let rows = vec![moose_updates];
1566 let (tasks, ixrows) =
1567 insert_data_tasks(&rows, &Some(col_metadata), &engine).unwrap();
1568
1569 assert!(tasks.new_rows.is_empty());
1570 assert_eq!(tasks.new_cols.len(), 2);
1571 assert!(tasks.new_cols.contains("dances"));
1572 assert!(tasks.new_cols.contains("eats+figs"));
1573
1574 assert!(!tasks.overwrite_missing);
1575 assert!(tasks.overwrite_present);
1576
1577 assert_eq!(
1578 ixrows,
1579 vec![IndexRow {
1580 row_ix: 15,
1581 values: vec![
1582 IndexValue {
1583 col_ix: 34,
1584 value: Datum::Categorical(1_u32.into())
1585 },
1586 IndexValue {
1587 col_ix: 86,
1588 value: Datum::Categorical(0_u32.into())
1589 },
1590 IndexValue {
1591 col_ix: 85,
1592 value: Datum::Categorical(1_u32.into())
1593 },
1594 ]
1595 }]
1596 );
1597 }
1598 }
1599
1600 fn quick_codebook() -> Codebook {
1601 let coltype = ColType::Categorical {
1602 k: 2,
1603 hyper: None,
1604 prior: None,
1605 value_map: ValueMap::try_from(vec![
1606 String::from("red"),
1607 String::from("green"),
1608 ])
1609 .unwrap(),
1610 };
1611 let md0 = ColMetadata {
1612 name: "0".to_string(),
1613 coltype: coltype.clone(),
1614 notes: None,
1615 missing_not_at_random: false,
1616 };
1617 let md1 = ColMetadata {
1618 name: "1".to_string(),
1619 coltype,
1620 notes: None,
1621 missing_not_at_random: false,
1622 };
1623 let md2 = ColMetadata {
1624 name: "2".to_string(),
1625 coltype: ColType::Categorical {
1626 k: 3,
1627 hyper: None,
1628 prior: None,
1629 value_map: ValueMap::UInt(3),
1630 },
1631 notes: None,
1632 missing_not_at_random: false,
1633 };
1634
1635 let col_metadata = ColMetadataList::new(vec![md0, md1, md2]).unwrap();
1636 Codebook::new("table".to_string(), col_metadata)
1637 }
1638
1639 #[test]
1640 fn incr_cats_in_codebook_without_suppl_metadata_for_no_valmap_col() {
1641 let mut codebook = quick_codebook();
1642
1643 let n_cats_before = match &codebook.col_metadata[2].coltype {
1644 ColType::Categorical {
1645 k,
1646 value_map: ValueMap::UInt(3),
1647 ..
1648 } => *k,
1649 ColType::Categorical { value_map, .. } => {
1650 panic!(
1651 "starting value_map should have been U32(3), was {:?}",
1652 value_map
1653 );
1654 }
1655 _ => panic!("should've been categorical"),
1656 };
1657
1658 assert_eq!(n_cats_before, 3);
1659
1660 let mut extension = ValueMapExtension::new_uint();
1661 extension.extend(Category::UInt(3)).unwrap();
1662
1663 let result = incr_category_in_codebook(&mut codebook, 2, &extension);
1664 result.unwrap();
1665
1666 let n_cats_after = match &codebook.col_metadata[2].coltype {
1667 ColType::Categorical {
1668 k,
1669 value_map: ValueMap::UInt(4),
1670 ..
1671 } => *k,
1672 ColType::Categorical { value_map, .. } => {
1673 panic!("value_map should be U32(4), was: {:?}", value_map)
1674 }
1675 _ => panic!("should've been categorical"),
1676 };
1677
1678 assert_eq!(n_cats_after, 4);
1679 }
1680
1681 #[test]
1682 fn incr_cats_in_codebook_with_suppl_metadata_for_valmap_col() {
1683 let mut codebook = quick_codebook();
1684
1685 match &codebook.col_metadata[0].coltype {
1686 ColType::Categorical {
1687 k, value_map: vm, ..
1688 } => {
1689 assert_eq!(*k, 2);
1690 assert_eq!(vm.len(), 2);
1691 }
1692 _ => panic!("should've been categorical with valmap"),
1693 };
1694
1695 let mut extension = ValueMapExtension::new_string();
1696 extension
1697 .extend(Category::String("blue".to_string()))
1698 .unwrap();
1699
1700 let result = incr_category_in_codebook(&mut codebook, 0, &extension);
1701
1702 assert!(result.is_ok());
1703
1704 match &codebook.col_metadata[0].coltype {
1705 ColType::Categorical {
1706 k, value_map: vm, ..
1707 } => {
1708 assert_eq!(vm.len(), 3);
1709 assert_eq!(*k, 3);
1710 }
1711 _ => panic!("should've been categorical with valmap"),
1712 };
1713 }
1714
1715 #[test]
1716 fn append_bool() {
1717 let coltype = ColType::Categorical {
1718 k: 2,
1719 hyper: Some(CsdHyper::default()),
1720 prior: None,
1721 value_map: ValueMap::Bool,
1722 };
1723 let md0 = ColMetadata {
1724 name: "bool_col".to_string(),
1725 coltype: coltype.clone(),
1726 notes: None,
1727 missing_not_at_random: false,
1728 };
1729
1730 let mut engine = Engine::new(
1731 1,
1732 Codebook::new(
1733 "test".to_string(),
1734 ColMetadataList::new(vec![]).unwrap(),
1735 ),
1736 data_source::DataSource::Empty,
1737 0,
1738 rand_xoshiro::Xoshiro256Plus::seed_from_u64(0x1234),
1739 )
1740 .unwrap();
1741
1742 engine
1744 .insert_data(
1745 vec![(
1746 "abc",
1747 vec![(
1748 "bool_col",
1749 Datum::Categorical(Category::Bool(false)),
1750 )],
1751 )
1752 .into()],
1753 Some(ColMetadataList::new(vec![md0]).unwrap()),
1754 WriteMode::unrestricted(),
1755 )
1756 .unwrap();
1757
1758 engine
1760 .insert_data(
1761 vec![(
1762 "def",
1763 vec![(
1764 "bool_col",
1765 Datum::Categorical(Category::Bool(false)),
1766 )],
1767 )
1768 .into()],
1769 None,
1770 WriteMode::unrestricted(),
1771 )
1772 .unwrap();
1773 }
1774}