Skip to main content

lace/interface/engine/
data.rs

1use std::collections::BTreeSet;
2use std::collections::HashMap;
3use std::collections::HashSet;
4use std::convert::TryInto;
5
6use indexmap::IndexSet;
7use rv::data::CategoricalSuffStat;
8use rv::dist::Categorical;
9use rv::dist::SymmetricDirichlet;
10use serde::Deserialize;
11use serde::Serialize;
12
13use super::error::InsertDataError;
14use crate::cc::feature::ColModel;
15use crate::cc::feature::Column;
16use crate::cc::feature::FType;
17use crate::codebook::Codebook;
18use crate::codebook::ColMetadataList;
19use crate::codebook::ColType;
20use crate::codebook::ValueMap;
21use crate::codebook::ValueMapExtension;
22use crate::codebook::ValueMapExtensionError;
23use crate::data::Category;
24use crate::data::Datum;
25use crate::data::SparseContainer;
26use crate::interface::HasCodebook;
27use crate::ColumnIndex;
28use crate::Engine;
29use crate::HasStates;
30use crate::OracleT;
31use crate::RowIndex;
32
33/// Defines which data may be overwritten
34#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
35#[serde(rename_all = "snake_case")]
36pub enum OverwriteMode {
37    /// Overwrite anything
38    Allow,
39    /// Do not overwrite any existing cells. Only allow data in new rows or
40    /// columns.
41    Deny,
42    /// Same as deny, but also allow existing cells that are empty to be
43    /// overwritten.
44    MissingOnly,
45}
46
47/// Defines insert data behavior -- where data may be inserted.
48#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
49#[serde(rename_all = "snake_case")]
50pub enum InsertMode {
51    /// Can add new rows or column
52    Unrestricted,
53    /// Cannot add new rows, but can add new columns
54    DenyNewRows,
55    /// Cannot add new columns, but can add new rows
56    DenyNewColumns,
57    /// No adding new rows or columns
58    DenyNewRowsAndColumns,
59}
60
61/// Defines the behavior of the data table when new rows are appended
62#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
63#[serde(rename_all = "snake_case")]
64#[derive(Default)]
65pub enum AppendStrategy {
66    /// New rows will be appended and the rest of the table will be unchanged
67    #[default]
68    None,
69    /// If `n` rows are added, the top `n` rows will be removed
70    Window,
71    /// For each row added that exceeds `max_n_rows`, the row at `tench_ix` will
72    /// be removed.
73    Trench {
74        /// The max number of rows allowed
75        max_n_rows: usize,
76        /// The index to remove data from
77        trench_ix: usize,
78    },
79}
80
81/// Defines how/where data may be inserted, which day may and may not be
82/// overwritten, and whether data may extend the domain
83///
84/// # Example
85///
86/// Default `WriteMode` only allows appending supported values to new rows or
87/// columns
88/// ```
89/// use lace::{WriteMode, InsertMode, OverwriteMode, AppendStrategy};
90/// let mode_new = WriteMode::new();
91/// let mode_def = WriteMode::default();
92///
93/// assert_eq!(
94///     mode_new,
95///     WriteMode {
96///         insert: InsertMode::Unrestricted,
97///         overwrite: OverwriteMode::Deny,
98///         allow_extend_support: false,
99///         append_strategy: AppendStrategy::None,
100///     }
101/// );
102/// assert_eq!(mode_def, mode_new);
103/// ```
104#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
105#[serde(rename_all = "snake_case")]
106pub struct WriteMode {
107    /// Determines whether new rows or columns can be appended or if data may
108    /// be entered into existing cells.
109    pub insert: InsertMode,
110    /// Determines if existing cells may or may not be overwritten or whether
111    /// only missing cells may be overwritten.
112    pub overwrite: OverwriteMode,
113    /// If `true`, allow column support to be extended to accommodate new data
114    /// that fall outside the range. For example, a binary column extends to
115    /// ternary after the user inserts `Datum::Categorical(2)`.
116    #[serde(default)]
117    pub allow_extend_support: bool,
118    /// The behavior of the table when new rows are appended
119    #[serde(default)]
120    pub append_strategy: AppendStrategy,
121}
122
123impl WriteMode {
124    /// Allows new data to be appended only to new rows/columns. No overwriting
125    /// and no support extension.
126    #[inline]
127    pub fn new() -> Self {
128        Self {
129            insert: InsertMode::Unrestricted,
130            overwrite: OverwriteMode::Deny,
131            allow_extend_support: false,
132            append_strategy: AppendStrategy::None,
133        }
134    }
135
136    #[inline]
137    pub fn unrestricted() -> Self {
138        Self {
139            insert: InsertMode::Unrestricted,
140            overwrite: OverwriteMode::Allow,
141            allow_extend_support: true,
142            append_strategy: AppendStrategy::None,
143        }
144    }
145}
146
147impl Default for WriteMode {
148    fn default() -> Self {
149        Self::new()
150    }
151}
152
153/// A datum for insertion into a certain column
154#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
155pub struct Value<C: ColumnIndex> {
156    /// Name of the column
157    pub col_ix: C,
158    /// The value of the cell
159    pub value: Datum,
160}
161
162impl<C: ColumnIndex> From<(C, Datum)> for Value<C> {
163    fn from(value: (C, Datum)) -> Self {
164        Self {
165            col_ix: value.0,
166            value: value.1,
167        }
168    }
169}
170
171/// A list of data for insertion into a certain row
172///
173/// # Example
174///
175/// ```
176/// # use lace::Row;
177/// use lace::Value;
178/// use lace::data::Datum;
179///
180/// let row = Row::<&str, &str> {
181///     row_ix: "vampire",
182///     values: vec![
183///         Value {
184///             col_ix: "sucks_blood",
185///             value: Datum::Categorical(1_u32.into()),
186///         },
187///         Value {
188///             col_ix: "drinks_wine",
189///             value: Datum::Categorical(0_u32.into()),
190///         },
191///     ],
192/// };
193///
194/// assert_eq!(row.len(), 2);
195/// ```
196///
197/// There are converters for convenience.
198///
199/// ```
200/// # use lace::Row;
201/// # use lace::data::Datum;
202/// let row: Row<&str, &str>  = (
203///     "vampire",
204///     vec![
205///         ("sucks_blood", Datum::Categorical(1_u32.into())),
206///         ("drinks_wine", Datum::Categorical(0_u32.into())),
207///     ]
208/// ).into();
209///
210/// assert_eq!(row.len(), 2);
211/// ```
212#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
213pub struct Row<R: RowIndex, C: ColumnIndex> {
214    /// The name of the row
215    pub row_ix: R,
216    /// The cells and values to fill in
217    pub values: Vec<Value<C>>,
218}
219
220impl<R, C> From<(R, Vec<(C, Datum)>)> for Row<R, C>
221where
222    R: RowIndex,
223    C: ColumnIndex,
224{
225    fn from(mut row: (R, Vec<(C, Datum)>)) -> Self {
226        Self {
227            row_ix: row.0,
228            values: row.1.drain(..).map(Value::from).collect(),
229        }
230    }
231}
232
233impl<R: RowIndex, C: ColumnIndex> From<(R, Vec<Value<C>>)> for Row<R, C> {
234    fn from(row: (R, Vec<Value<C>>)) -> Self {
235        Self {
236            row_ix: row.0,
237            values: row.1,
238        }
239    }
240}
241
242impl<R: RowIndex, C: ColumnIndex> Row<R, C> {
243    /// The number of values in the Row
244    #[inline]
245    pub fn len(&self) -> usize {
246        self.values.len()
247    }
248
249    /// Return true if there are no values in the Row
250    #[inline]
251    pub fn is_empty(&self) -> bool {
252        self.values.is_empty()
253    }
254}
255
256// Because lace uses integer indices for rows and columns
257#[derive(Debug, PartialEq)]
258pub(crate) struct IndexValue {
259    pub col_ix: usize,
260    pub value: Datum,
261}
262
263#[derive(Debug, PartialEq)]
264pub(crate) struct IndexRow {
265    pub row_ix: usize,
266    pub values: Vec<IndexValue>,
267}
268
269/// Describes the support extension action taken
270#[derive(Clone, Debug, Serialize, Deserialize)]
271pub enum SupportExtension {
272    Categorical {
273        /// The index of the column
274        col_ix: usize,
275        /// The name of the column
276        col_name: String,
277        /// New mapped values
278        value_map_extension: ValueMapExtension,
279    },
280}
281
282/// Describes table-extending actions taken when inserting data
283#[derive(Clone, Debug, Serialize, Deserialize)]
284pub struct InsertDataActions {
285    // the types of the members match the types in InsertDataTasks
286    pub(crate) new_rows: IndexSet<String>,
287    pub(crate) new_cols: HashSet<String>,
288    pub(crate) support_extensions: Vec<SupportExtension>,
289}
290
291impl Default for InsertDataActions {
292    fn default() -> Self {
293        Self::new()
294    }
295}
296
297impl InsertDataActions {
298    pub fn new() -> Self {
299        Self {
300            new_rows: IndexSet::new(),
301            new_cols: HashSet::new(),
302            support_extensions: Vec::new(),
303        }
304    }
305
306    /// If any new rows were appended, returns their names and order
307    pub fn new_rows(&self) -> Option<&IndexSet<String>> {
308        if self.new_rows.is_empty() {
309            None
310        } else {
311            Some(&self.new_rows)
312        }
313    }
314
315    /// If any new columns were appended, returns their names
316    pub fn new_cols(&self) -> Option<&HashSet<String>> {
317        if self.new_cols.is_empty() {
318            None
319        } else {
320            Some(&self.new_cols)
321        }
322    }
323
324    // The any columns had their supports extended, returns the support
325    // actions taken
326    pub fn support_extensions(&self) -> Option<&Vec<SupportExtension>> {
327        if self.support_extensions.is_empty() {
328            None
329        } else {
330            Some(&self.support_extensions)
331        }
332    }
333}
334
335/// A summary of the tasks required to insert certain data into an `Engine`
336#[derive(Debug)]
337pub(crate) struct InsertDataTasks {
338    /// The names of new rows to be created. The order of the items is the
339    /// order in the which the rows are inserted.
340    pub new_rows: IndexSet<String>,
341    /// The names of new columns to be created
342    pub new_cols: HashSet<String>,
343    /// True if the operation would insert a value into an empty cell in the
344    /// existing table
345    pub overwrite_missing: bool,
346    /// True if the operation would overwrite an existing (non-missing) value
347    pub overwrite_present: bool,
348}
349
350impl InsertDataTasks {
351    fn new() -> Self {
352        Self {
353            new_rows: IndexSet::new(),
354            new_cols: HashSet::new(),
355            overwrite_missing: false,
356            overwrite_present: false,
357        }
358    }
359
360    pub(crate) fn validate_insert_mode(
361        &self,
362        mode: WriteMode,
363    ) -> Result<(), InsertDataError> {
364        match mode.overwrite {
365            OverwriteMode::Deny => {
366                if self.overwrite_present || self.overwrite_missing {
367                    Err(InsertDataError::ModeForbidsOverwrite)
368                } else {
369                    Ok(())
370                }
371            }
372            OverwriteMode::MissingOnly => {
373                if self.overwrite_present {
374                    Err(InsertDataError::ModeForbidsOverwrite)
375                } else {
376                    Ok(())
377                }
378            }
379            OverwriteMode::Allow => Ok(()),
380        }
381        .and_then(|_| match mode.insert {
382            InsertMode::DenyNewRows => {
383                if !self.new_rows.is_empty() {
384                    Err(InsertDataError::ModeForbidsNewRows)
385                } else {
386                    Ok(())
387                }
388            }
389            InsertMode::DenyNewColumns => {
390                if !self.new_cols.is_empty() {
391                    Err(InsertDataError::ModeForbidsNewColumns)
392                } else {
393                    Ok(())
394                }
395            }
396            InsertMode::DenyNewRowsAndColumns => {
397                if !(self.new_rows.is_empty() && self.new_cols.is_empty()) {
398                    Err(InsertDataError::ModeForbidsNewRowsOrColumns)
399                } else {
400                    Ok(())
401                }
402            }
403            _ => Ok(()),
404        })
405    }
406}
407
408#[inline]
409fn ix_lookup_from_codebook(
410    col_metadata: &Option<ColMetadataList>,
411) -> Option<HashMap<&str, usize>> {
412    col_metadata.as_ref().map(|colmds| {
413        colmds
414            .iter()
415            .enumerate()
416            .map(|(ix, md)| (md.name.as_str(), ix))
417            .collect()
418    })
419}
420
421#[inline]
422fn col_ix_from_lookup(
423    col: &str,
424    lookup: &Option<HashMap<&str, usize>>,
425) -> Result<usize, InsertDataError> {
426    match lookup {
427        Some(lkp) => lkp
428            .get(col)
429            .ok_or_else(|| {
430                InsertDataError::NewColumnNotInColumnMetadata(col.to_owned())
431            })
432            .copied(),
433        None => Err(InsertDataError::NewColumnNotInColumnMetadata(
434            String::from(col),
435        )),
436    }
437}
438
439/// Determine whether we need to add new columns to the Engine and then add
440/// them.
441pub(crate) fn append_empty_columns(
442    tasks: &InsertDataTasks,
443    col_metadata: Option<ColMetadataList>,
444    engine: &mut Engine,
445) -> Result<(), InsertDataError> {
446    match col_metadata {
447        // There is partial codebook and there are new columns to add
448        Some(colmds) if !tasks.new_cols.is_empty() => {
449            // make sure that each of the new columns to be added is listed in
450            // the column metadata
451            tasks.new_cols.iter().try_for_each(|col| {
452                if colmds.contains_key(col) {
453                    Ok(())
454                } else {
455                    Err(InsertDataError::NewColumnNotInColumnMetadata(
456                        col.clone(),
457                    ))
458                }
459            })?;
460
461            if colmds.len() != tasks.new_cols.len() {
462                // There are more columns in the partial codebook than are
463                // in the inserted data.
464                Err(InsertDataError::WrongNumberOfColumnMetadataEntries {
465                    ncolmd: colmds.len(),
466                    nnew: tasks.new_cols.len(),
467                })
468            } else {
469                // create blank (data-less) columns and insert them into
470                // the States
471                let shape = (engine.n_rows(), engine.n_cols());
472                create_new_columns(&colmds, shape, &mut engine.rng).map(
473                    |col_models| {
474                        // Inserts blank columns into random existing views.
475                        // It is assumed that another reassignment transition
476                        // will be run after the data are inserted.
477                        let mut rng = &mut engine.rng;
478                        engine.states.iter_mut().for_each(|state| {
479                            state.append_blank_features(
480                                col_models.clone(),
481                                &mut rng,
482                            );
483                        });
484
485                        // Combine the codebooks
486                        // NOTE: if a panic happens here its our fault.
487                        // TODO: only append the ones that are new
488                        engine.codebook.append_col_metadata(colmds).unwrap();
489                    },
490                )
491            }
492        }
493        // There are new columns, but no partial codebook
494        None if !tasks.new_cols.is_empty() => {
495            Err(InsertDataError::WrongNumberOfColumnMetadataEntries {
496                ncolmd: 0,
497                nnew: tasks.new_cols.len(),
498            })
499        }
500        // Can ignore other cases (no new columns)
501        _ => Ok(()),
502    }
503}
504
505fn validate_new_col_ftype(
506    new_metadata: &Option<ColMetadataList>,
507    value: &Value<&str>,
508) -> Result<(), InsertDataError> {
509    let col_ftype = new_metadata
510        .as_ref()
511        .ok_or_else(|| {
512            InsertDataError::NewColumnNotInColumnMetadata(value.col_ix.into())
513        })?
514        .get(value.col_ix)
515        .ok_or_else(|| {
516            InsertDataError::NewColumnNotInColumnMetadata(value.col_ix.into())
517        })
518        .map(|(_, md)| FType::from_coltype(&md.coltype))?;
519
520    let (is_compat, compat_info) = col_ftype.datum_compatible(&value.value);
521
522    let bad_continuous_value = match value.value {
523        Datum::Continuous(ref x) => !x.is_finite(),
524        _ => false,
525    };
526
527    if is_compat {
528        if bad_continuous_value {
529            Err(InsertDataError::NonFiniteContinuousValue {
530                col: value.col_ix.to_owned(),
531                value: value.value.to_f64_opt().unwrap(),
532            })
533        } else {
534            Ok(())
535        }
536    } else {
537        Err(InsertDataError::DatumIncompatibleWithColumn {
538            col: value.col_ix.to_owned(),
539            ftype: compat_info.ftype,
540            ftype_req: compat_info.ftype_req,
541        })
542    }
543}
544
545fn validate_row_values<R: RowIndex, C: ColumnIndex>(
546    row: &Row<R, C>,
547    row_ix: usize,
548    row_exists: bool,
549    col_metadata: &Option<ColMetadataList>,
550    col_ix_lookup: &Option<HashMap<&str, usize>>,
551    insert_tasks: &mut InsertDataTasks,
552    engine: &Engine,
553) -> Result<IndexRow, InsertDataError> {
554    let n_cols = engine.n_cols();
555
556    let mut index_row = IndexRow {
557        row_ix,
558        values: vec![],
559    };
560
561    row.values.iter().try_for_each(|value| {
562        match value.col_ix.col_ix(engine.codebook()) {
563            Ok(col_ix) => {
564                // check whether the datum is missing.
565                if row_exists {
566                    if engine.datum(row_ix, col_ix).unwrap().is_missing() {
567                        insert_tasks.overwrite_missing = true;
568                    } else {
569                        insert_tasks.overwrite_present = true;
570                    }
571                }
572
573                // determine whether the value is compatible
574                // with the FType of the column
575                let ftype_compat = engine
576                    .ftype(col_ix)
577                    .unwrap()
578                    .datum_compatible(&value.value);
579
580                let bad_continuous_value = match value.value {
581                    Datum::Continuous(ref x) => !x.is_finite(),
582                    _ => false,
583                };
584
585                if ftype_compat.0 {
586                    if bad_continuous_value {
587                        let col = &engine.codebook.col_metadata[col_ix].name;
588                        Err(InsertDataError::NonFiniteContinuousValue {
589                            col: col.clone(),
590                            value: value.value.to_f64_opt().unwrap(),
591                        })
592                    } else {
593                        Ok(col_ix)
594                    }
595                } else {
596                    let col = &engine.codebook.col_metadata[col_ix].name;
597                    Err(InsertDataError::DatumIncompatibleWithColumn {
598                        col: col.clone(),
599                        ftype_req: ftype_compat.1.ftype_req,
600                        ftype: ftype_compat.1.ftype,
601                    })
602                }
603            }
604            Err(_) => {
605                value
606                    .col_ix
607                    .col_str()
608                    .ok_or_else(|| {
609                        InsertDataError::IntegerIndexNewColumn(
610                            value
611                                .col_ix
612                                .col_usize()
613                                .expect("Column index does not have a string or usize representation")
614                        )
615                    })
616                    .and_then(|name| {
617                        // TODO: get rid of this clone
618                        let new_val = Value {
619                            col_ix: name,
620                            value: value.value.clone(),
621                        };
622                        validate_new_col_ftype(col_metadata, &new_val).and_then(
623                            |_| {
624                                insert_tasks.new_cols.insert(name.to_owned());
625                                col_ix_from_lookup(name, col_ix_lookup)
626                                    .map(|ix| ix + n_cols)
627                            },
628                        )
629                    })
630            }
631        }
632        .map(|col_ix| {
633            index_row.values.push(IndexValue {
634                col_ix,
635                value: value.value.clone(),
636            });
637        })
638    })?;
639    Ok(index_row)
640}
641
642/// Get a summary of the tasks required to insert `rows` into `Engine`.
643pub(crate) fn insert_data_tasks<R: RowIndex, C: ColumnIndex>(
644    rows: &[Row<R, C>],
645    col_metadata: &Option<ColMetadataList>,
646    engine: &Engine,
647) -> Result<(InsertDataTasks, Vec<IndexRow>), InsertDataError> {
648    const EXISTING_ROW: bool = true;
649    const NEW_ROW: bool = false;
650
651    // Get a map into the new column indices if they exist
652    let col_ix_lookup = ix_lookup_from_codebook(col_metadata);
653
654    // Get a list of all the row names. The row names must be included in the
655    // codebook in order to insert data.
656    let n_rows = engine.n_rows();
657
658    let mut tasks = InsertDataTasks::new();
659
660    let index_rows: Vec<IndexRow> = rows
661        .iter()
662        .map(|row| match row.row_ix.row_ix(engine.codebook()) {
663            Ok(row_ix) => {
664                if row.is_empty() {
665                    let name = engine.codebook.row_names.name(row_ix).unwrap();
666                    Err(InsertDataError::EmptyRow(name.clone()))
667                } else {
668                    validate_row_values(
669                        row,
670                        row_ix,
671                        EXISTING_ROW,
672                        col_metadata,
673                        &col_ix_lookup,
674                        &mut tasks,
675                        engine,
676                    )
677                }
678            }
679            Err(_) => {
680                // row index is either out of bounds or does not exist in the codebook
681                if row.is_empty() {
682                    Err(InsertDataError::EmptyRow(format!("{:?}", row.row_ix)))
683                } else {
684                    validate_row_values(
685                        row,
686                        {
687                            let n = tasks.new_rows.len();
688                            row.row_ix
689                                .row_str()
690                                .ok_or_else(|| {
691                                    let ix = row
692                                        .row_ix
693                                        .row_usize()
694                                        .expect("Index doesn't have a string or usize representation");
695                                    InsertDataError::IntegerIndexNewRow(ix)
696                                })
697                                .map(|row_name| {
698                                    tasks
699                                        .new_rows
700                                        .insert(String::from(row_name));
701                                })?;
702                            n_rows + n
703                        },
704                        NEW_ROW,
705                        col_metadata,
706                        &col_ix_lookup,
707                        &mut tasks,
708                        engine,
709                    )
710                }
711            }
712        })
713        .collect::<Result<Vec<IndexRow>, InsertDataError>>()?;
714    Ok((tasks, index_rows))
715}
716
717pub(crate) fn maybe_add_categories<R: RowIndex, C: ColumnIndex>(
718    rows: &[Row<R, C>],
719    engine: &mut Engine,
720    mode: WriteMode,
721) -> Result<Vec<SupportExtension>, InsertDataError> {
722    let mut extended_value_map: HashMap<usize, ValueMapExtension> =
723        HashMap::new();
724
725    // This code gets all the supports for all the categorical columns for
726    // which data are to be inserted.
727    // For each value (cell) in each row...
728    rows.iter().try_for_each(|row| {
729        row.values.iter().try_for_each(|new_value| {
730            // if the column is categorical, see if we need to add support,
731            // otherwise carry on.
732            match new_value.col_ix.col_ix(engine.codebook()) {
733                Err(_) => Ok(()), // IndexError means new column
734                Ok(col_ix) => {
735                    let col_metadata = &engine.codebook.col_metadata[col_ix];
736                    let col_name = col_metadata.name.as_str();
737
738                    match &col_metadata.coltype {
739                        ColType::Categorical { value_map, .. } => {
740                            match (&new_value.value, value_map) {
741                                (Datum::Categorical(Category::String(x)), ValueMap::String(vm)) => {
742                                    if !vm.contains_cat(x) {
743                                        let ext_vm = extended_value_map.entry(col_ix).or_insert_with(ValueMapExtension::new_string);
744                                        ext_vm.extend(Category::String(x.clone())).map_err(
745                                            |e| match e {
746                                                ValueMapExtensionError::ExtensionOfDifferingType(a, b) => {
747                                                    InsertDataError::WrongCategoryAndType(a, b, col_name.to_string())
748                                                }
749                                            })?;
750                                    };
751                                    Ok(())
752                                },
753                                (Datum::Categorical(Category::UInt(x)), ValueMap::UInt(old_max)) => {
754                                    if *old_max as u32 <= *x {
755                                        let ext_max = extended_value_map.entry(col_ix).or_insert_with(ValueMapExtension::new_uint);
756                                        ext_max.extend(Category::UInt(*x)).map_err(
757                                            |e| match e {
758                                                ValueMapExtensionError::ExtensionOfDifferingType(a, b) => {
759                                                    InsertDataError::WrongCategoryAndType(a, b, col_name.to_string())
760                                                }
761                                            })?;
762                                    }
763                                    Ok(())
764                                },
765                                (Datum::Missing, _) |
766                                (Datum::Categorical(Category::Bool(_)), ValueMap::Bool) => {
767                                    Ok(())
768                                },
769                                _ => {
770                                    Err(
771                                    InsertDataError::DatumIncompatibleWithColumn {
772                                        col: (*col_name).into(),
773                                        ftype_req: FType::Categorical,
774                                        // this should never fail because TryFrom only
775                                        // fails for Datum::Missing, and that case is
776                                        // handled above
777                                        ftype: (&new_value.value).try_into().unwrap(),
778                                    },
779                                )}
780                            }
781                        }
782                        _ => Ok(())
783                    }
784                }
785            }
786        })
787    })?;
788
789    if !mode.allow_extend_support && !extended_value_map.is_empty() {
790        // support extension not allowed
791        return Err(InsertDataError::ModeForbidsCategoryExtension);
792    }
793
794    let mut cols_extended: Vec<SupportExtension> = Vec::new();
795
796    // Here we loop through all the categorical insertions generated above and
797    // determine whether we need to extend categorical support by comparing the
798    // existing support (n_cats, or k) for each column with the maximum value
799    // requested to be inserted into that column. If the value exceeds the
800    // support of that column, we extend the support.
801    for (col_ix, value_map_extension) in extended_value_map.drain() {
802        incr_column_categories(engine, col_ix, &value_map_extension)?;
803        cols_extended.push(SupportExtension::Categorical {
804            col_ix,
805            col_name: engine.codebook.col_metadata[col_ix].name.clone(),
806            value_map_extension,
807        })
808    }
809
810    Ok(cols_extended)
811}
812
813fn incr_category_in_codebook(
814    codebook: &mut Codebook,
815    col_ix: usize,
816    value_map_extension: &ValueMapExtension,
817) -> Result<(), InsertDataError> {
818    let col_name = codebook.col_metadata[col_ix].name.clone();
819    match codebook.col_metadata[col_ix].coltype {
820        ColType::Categorical {
821            ref mut k,
822            ref mut value_map,
823            ..
824        } => {
825            // TODO: Does this capture any errors from a user or just errors from within lace?
826            value_map.extend(value_map_extension.clone()).map_err(
827                |e| match e {
828                    ValueMapExtensionError::ExtensionOfDifferingType(a, b) => {
829                        InsertDataError::WrongCategoryAndType(a, b, col_name)
830                    }
831                },
832            )?;
833            *k = value_map.len();
834            Ok(())
835        }
836        _ => panic!("Tried to change cardinality of non-categorical column"),
837    }
838}
839
840fn incr_column_categories(
841    engine: &mut Engine,
842    col_ix: usize,
843    extended_value_map: &ValueMapExtension,
844) -> Result<(), InsertDataError> {
845    // Adjust in codebook
846    incr_category_in_codebook(
847        &mut engine.codebook,
848        col_ix,
849        extended_value_map,
850    )?;
851
852    let n_cats_req = match engine.codebook.col_metadata[col_ix].coltype {
853        ColType::Categorical { k, .. } => k,
854        _ => panic!("Requested non-categorical column"),
855    };
856
857    // Adjust component models, priors, suffstats
858    engine.states.iter_mut().for_each(|state| {
859        match state.feature_mut(col_ix) {
860            ColModel::Categorical(column) => {
861                column.prior = SymmetricDirichlet::new_unchecked(
862                    column.prior.alpha(),
863                    n_cats_req,
864                );
865                column.components.iter_mut().for_each(|cpnt| {
866                    cpnt.stat = CategoricalSuffStat::from_parts_unchecked(
867                        cpnt.stat.n(),
868                        {
869                            let mut counts = cpnt.stat.counts().clone();
870                            counts.resize(n_cats_req, 0.0);
871                            counts
872                        },
873                    );
874
875                    cpnt.fx = Categorical::new_unchecked({
876                        let mut ln_weights = cpnt.fx.ln_weights().clone();
877                        ln_weights.resize(n_cats_req, f64::NEG_INFINITY);
878                        ln_weights
879                    });
880                })
881            }
882            _ => panic!("Requested non-categorical column"),
883        }
884    });
885    Ok(())
886}
887
888macro_rules! new_col_arm {
889    (
890        $coltype: ident,
891        $htype: ty,
892        $errvar: ident,
893        $colmd: ident,
894        $hyper: ident,
895        $prior: ident,
896        $n_rows: ident,
897        $id: ident,
898        $xtype: ty,
899        $rng: ident
900    ) => {{
901        let data: SparseContainer<$xtype> =
902            SparseContainer::all_missing($n_rows);
903
904        match ($hyper, $prior) {
905            (Some(h), _) => {
906                let pr = if let Some(pr) = $prior {
907                    pr.clone()
908                } else {
909                    h.draw(&mut $rng)
910                };
911                let column = Column::new($id, data, pr, h.clone());
912                Ok(ColModel::$coltype(column))
913            }
914            (None, Some(pr)) => {
915                // use a dummy hyper, we're going to ignore it
916                let mut column =
917                    Column::new($id, data, pr.clone(), <$htype>::default());
918                column.ignore_hyper = true;
919                Ok(ColModel::$coltype(column))
920            }
921            (None, None) => Err(InsertDataError::NoGaussianHyperForNewColumn(
922                $colmd.name.clone(),
923            )),
924        }
925    }};
926}
927
928pub(crate) fn create_new_columns<R: rand::Rng>(
929    col_metadata: &ColMetadataList,
930    state_shape: (usize, usize),
931    mut rng: &mut R,
932) -> Result<Vec<ColModel>, InsertDataError> {
933    let (n_rows, n_cols) = state_shape;
934    col_metadata
935        .iter()
936        .enumerate()
937        .map(|(i, colmd)| {
938            let id = i + n_cols;
939            match &colmd.coltype {
940                ColType::Continuous { hyper, prior } => new_col_arm!(
941                    Continuous,
942                    crate::stats::prior::nix::NixHyper,
943                    NoGaussianHyperForNewColumn,
944                    colmd,
945                    hyper,
946                    prior,
947                    n_rows,
948                    id,
949                    f64,
950                    rng
951                ),
952                ColType::Count { hyper, prior } => new_col_arm!(
953                    Count,
954                    crate::stats::prior::pg::PgHyper,
955                    NoPoissonHyperForNewColumn,
956                    colmd,
957                    hyper,
958                    prior,
959                    n_rows,
960                    id,
961                    u32,
962                    rng
963                ),
964                ColType::Categorical {
965                    k, hyper, prior, ..
966                } => {
967                    let data: SparseContainer<u32> =
968                        SparseContainer::all_missing(n_rows);
969
970                    let id = i + n_cols;
971                    match (hyper, prior) {
972                        (Some(h), _) => {
973                            let pr = if let Some(pr) = prior {
974                                pr.clone()
975                            } else {
976                                h.draw(*k, &mut rng)
977                            };
978                            let column = Column::new(id, data, pr, h.clone());
979                            Ok(ColModel::Categorical(column))
980                        }
981                        (None, Some(pr)) => {
982                            use crate::stats::prior::csd::CsdHyper;
983                            let mut column = Column::new(
984                                id,
985                                data,
986                                pr.clone(),
987                                CsdHyper::default(),
988                            );
989                            column.ignore_hyper = true;
990                            Ok(ColModel::Categorical(column))
991                        }
992                        (None, None) => Err(
993                            InsertDataError::NoCategoricalHyperForNewColumn(
994                                colmd.name.clone(),
995                            ),
996                        ),
997                    }
998                }
999            }
1000        })
1001        .collect()
1002}
1003
1004pub(crate) fn remove_cell(engine: &mut Engine, row_ix: usize, col_ix: usize) {
1005    engine.states.iter_mut().for_each(|state| {
1006        state.remove_datum(row_ix, col_ix);
1007    })
1008}
1009
1010pub(crate) fn remove_col(engine: &mut Engine, col_ix: usize) {
1011    // remove the column from the codebook and re-index
1012    engine.codebook.col_metadata.remove_by_index(col_ix);
1013    let mut rng = engine.rng.clone();
1014    engine.states.iter_mut().for_each(|state| {
1015        // deletes the column and re-indexes
1016        state.del_col(col_ix, &mut rng);
1017    });
1018}
1019
1020pub(crate) fn check_if_removes_col(
1021    engine: &Engine,
1022    rm_rows: &BTreeSet<usize>,
1023    mut rm_cell_cols: HashMap<usize, i64>,
1024) -> BTreeSet<usize> {
1025    let mut to_rm: BTreeSet<usize> = BTreeSet::new();
1026    // rm_cell_cols.values_mut().for_each(|val| {*val -= rm_rows.len() as i64});
1027    rm_cell_cols.iter_mut().for_each(|(col_ix, val)| {
1028        let mut present_count = 0_i64;
1029        let mut remove = true;
1030        for row_ix in 0..engine.n_rows() {
1031            if present_count > *val {
1032                remove = false;
1033                break;
1034            }
1035            if !rm_rows.contains(&row_ix)
1036                && !engine.datum(row_ix, *col_ix).unwrap().is_missing()
1037            {
1038                present_count += 1;
1039            }
1040        }
1041        if remove {
1042            to_rm.insert(*col_ix);
1043        }
1044    });
1045    to_rm
1046}
1047
1048pub(crate) fn check_if_removes_row(
1049    engine: &Engine,
1050    rm_cols: &BTreeSet<usize>,
1051    mut rm_cell_rows: HashMap<usize, i64>,
1052) -> BTreeSet<usize> {
1053    let mut to_rm: BTreeSet<usize> = BTreeSet::new();
1054    rm_cell_rows.iter_mut().for_each(|(row_ix, val)| {
1055        let mut present_count = 0_i64;
1056        let mut remove = true;
1057        for col_ix in 0..engine.n_cols() {
1058            if present_count > *val {
1059                remove = false;
1060                break;
1061            }
1062            if !rm_cols.contains(&col_ix)
1063                && !engine.datum(*row_ix, col_ix).unwrap().is_missing()
1064            {
1065                present_count += 1;
1066            }
1067        }
1068        if remove {
1069            to_rm.insert(*row_ix);
1070        }
1071    });
1072    to_rm
1073}
1074
1075#[cfg(test)]
1076mod tests {
1077    use rand::SeedableRng;
1078
1079    use super::*;
1080    use crate::codebook::ColMetadata;
1081    use crate::codebook::ColType;
1082    use crate::codebook::ValueMap;
1083    use crate::data::data_source;
1084    use crate::stats::prior::csd::CsdHyper;
1085
1086    #[cfg(feature = "examples")]
1087    mod requiring_examples {
1088        use super::*;
1089        use crate::examples::Example;
1090
1091        #[test]
1092        fn errors_when_no_col_metadata_when_new_columns() {
1093            let engine = Example::Animals.engine().unwrap();
1094            let moose_updates = Row::<String, String> {
1095                row_ix: "moose".into(),
1096                values: vec![
1097                    Value {
1098                        col_ix: "does+taxes".into(),
1099                        value: Datum::Categorical(1_u32.into()),
1100                    },
1101                    Value {
1102                        col_ix: "flys".into(),
1103                        value: Datum::Categorical(1_u32.into()),
1104                    },
1105                ],
1106            };
1107
1108            let result = insert_data_tasks(&[moose_updates], &None, &engine);
1109
1110            assert!(result.is_err());
1111            match result {
1112                Err(InsertDataError::NewColumnNotInColumnMetadata(s)) => {
1113                    assert_eq!(s, String::from("does+taxes"))
1114                }
1115                Err(err) => panic!("wrong error: {:?}", err),
1116                Ok(_) => panic!("failed to fail"),
1117            }
1118        }
1119
1120        #[test]
1121        fn errors_when_new_column_not_in_col_metadata() {
1122            let engine = Example::Animals.engine().unwrap();
1123            let moose_updates = Row::<String, String> {
1124                row_ix: "moose".into(),
1125                values: vec![
1126                    Value {
1127                        col_ix: "does+taxes".into(),
1128                        value: Datum::Categorical(1_u32.into()),
1129                    },
1130                    Value {
1131                        col_ix: "flys".into(),
1132                        value: Datum::Categorical(1_u32.into()),
1133                    },
1134                ],
1135            };
1136
1137            let col_metadata = ColMetadataList::new(vec![ColMetadata {
1138                name: "dances".into(),
1139                coltype: ColType::Categorical {
1140                    k: 2,
1141                    hyper: None,
1142                    prior: None,
1143                    value_map: ValueMap::UInt(2),
1144                },
1145                notes: None,
1146                missing_not_at_random: false,
1147            }])
1148            .unwrap();
1149
1150            let result = insert_data_tasks(
1151                &[moose_updates],
1152                &Some(col_metadata),
1153                &engine,
1154            );
1155
1156            assert!(result.is_err());
1157            assert_eq!(
1158                result.unwrap_err(),
1159                InsertDataError::NewColumnNotInColumnMetadata(
1160                    "does+taxes".into()
1161                )
1162            );
1163        }
1164
1165        #[test]
1166        fn tasks_on_one_existing_row() {
1167            let engine = Example::Animals.engine().unwrap();
1168            let moose_updates = Row::<String, String> {
1169                row_ix: "moose".into(),
1170                values: vec![
1171                    Value {
1172                        col_ix: "swims".into(),
1173                        value: Datum::Categorical(1_u32.into()),
1174                    },
1175                    Value {
1176                        col_ix: "flys".into(),
1177                        value: Datum::Categorical(1_u32.into()),
1178                    },
1179                ],
1180            };
1181            let rows = vec![moose_updates];
1182            let (tasks, ixrows) =
1183                insert_data_tasks(&rows, &None, &engine).unwrap();
1184
1185            assert!(tasks.new_rows.is_empty());
1186            assert!(tasks.new_cols.is_empty());
1187            assert!(!tasks.overwrite_missing);
1188            assert!(tasks.overwrite_present);
1189
1190            assert_eq!(
1191                ixrows,
1192                vec![IndexRow {
1193                    row_ix: 15,
1194                    values: vec![
1195                        IndexValue {
1196                            col_ix: 36,
1197                            value: Datum::Categorical(1_u32.into())
1198                        },
1199                        IndexValue {
1200                            col_ix: 34,
1201                            value: Datum::Categorical(1_u32.into())
1202                        },
1203                    ]
1204                }]
1205            );
1206        }
1207
1208        #[test]
1209        fn tasks_on_one_new_row() {
1210            let engine = Example::Animals.engine().unwrap();
1211            let pegasus = Row::<String, String> {
1212                row_ix: "pegasus".into(),
1213                values: vec![
1214                    Value {
1215                        col_ix: "swims".into(),
1216                        value: Datum::Categorical(1_u32.into()),
1217                    },
1218                    Value {
1219                        col_ix: "flys".into(),
1220                        value: Datum::Categorical(1_u32.into()),
1221                    },
1222                ],
1223            };
1224            let rows = vec![pegasus];
1225            let (tasks, ixrows) =
1226                insert_data_tasks(&rows, &None, &engine).unwrap();
1227
1228            assert_eq!(tasks.new_rows.len(), 1);
1229            assert!(tasks.new_rows.contains("pegasus"));
1230            assert!(tasks.new_cols.is_empty());
1231            assert!(!tasks.overwrite_missing);
1232            assert!(!tasks.overwrite_present);
1233
1234            assert_eq!(
1235                ixrows,
1236                vec![IndexRow {
1237                    row_ix: 50,
1238                    values: vec![
1239                        IndexValue {
1240                            col_ix: 36,
1241                            value: Datum::Categorical(1_u32.into())
1242                        },
1243                        IndexValue {
1244                            col_ix: 34,
1245                            value: Datum::Categorical(1_u32.into())
1246                        },
1247                    ]
1248                }]
1249            );
1250        }
1251
1252        #[test]
1253        fn tasks_on_two_new_rows() {
1254            let engine = Example::Animals.engine().unwrap();
1255            let pegasus = Row::<String, String> {
1256                row_ix: "pegasus".into(),
1257                values: vec![
1258                    Value {
1259                        col_ix: "swims".into(),
1260                        value: Datum::Categorical(1_u32.into()),
1261                    },
1262                    Value {
1263                        col_ix: "flys".into(),
1264                        value: Datum::Categorical(1_u32.into()),
1265                    },
1266                ],
1267            };
1268
1269            let man = Row::<String, String> {
1270                row_ix: "man".into(),
1271                values: vec![
1272                    Value {
1273                        col_ix: "smart".into(),
1274                        value: Datum::Categorical(1_u32.into()),
1275                    },
1276                    Value {
1277                        col_ix: "hunter".into(),
1278                        value: Datum::Categorical(0_u32.into()),
1279                    },
1280                ],
1281            };
1282            let rows = vec![pegasus, man];
1283            let (tasks, ixrows) =
1284                insert_data_tasks(&rows, &None, &engine).unwrap();
1285
1286            assert_eq!(tasks.new_rows.len(), 2);
1287            assert!(tasks.new_rows.contains("pegasus"));
1288            assert!(tasks.new_rows.contains("man"));
1289
1290            assert!(tasks.new_cols.is_empty());
1291            assert!(!tasks.overwrite_missing);
1292            assert!(!tasks.overwrite_present);
1293
1294            assert_eq!(
1295                ixrows,
1296                vec![
1297                    IndexRow {
1298                        row_ix: 50,
1299                        values: vec![
1300                            IndexValue {
1301                                col_ix: 36,
1302                                value: Datum::Categorical(1_u32.into())
1303                            },
1304                            IndexValue {
1305                                col_ix: 34,
1306                                value: Datum::Categorical(1_u32.into())
1307                            },
1308                        ]
1309                    },
1310                    IndexRow {
1311                        row_ix: 51,
1312                        values: vec![
1313                            IndexValue {
1314                                col_ix: 80,
1315                                value: Datum::Categorical(1_u32.into())
1316                            },
1317                            IndexValue {
1318                                col_ix: 58,
1319                                value: Datum::Categorical(0_u32.into())
1320                            },
1321                        ]
1322                    }
1323                ]
1324            );
1325        }
1326
1327        #[test]
1328        fn tasks_on_one_new_and_one_existing_row() {
1329            let engine = Example::Animals.engine().unwrap();
1330            let pegasus = Row::<String, String> {
1331                row_ix: "pegasus".into(),
1332                values: vec![
1333                    Value {
1334                        col_ix: "swims".into(),
1335                        value: Datum::Categorical(1_u32.into()),
1336                    },
1337                    Value {
1338                        col_ix: "flys".into(),
1339                        value: Datum::Categorical(1_u32.into()),
1340                    },
1341                ],
1342            };
1343
1344            let moose = Row::<String, String> {
1345                row_ix: "moose".into(),
1346                values: vec![
1347                    Value {
1348                        col_ix: "smart".into(),
1349                        value: Datum::Categorical(1_u32.into()),
1350                    },
1351                    Value {
1352                        col_ix: "hunter".into(),
1353                        value: Datum::Categorical(0_u32.into()),
1354                    },
1355                ],
1356            };
1357            let rows = vec![pegasus, moose];
1358            let (tasks, ixrows) =
1359                insert_data_tasks(&rows, &None, &engine).unwrap();
1360
1361            assert_eq!(tasks.new_rows.len(), 1);
1362            assert!(tasks.new_rows.contains("pegasus"));
1363
1364            assert!(tasks.new_cols.is_empty());
1365            assert!(!tasks.overwrite_missing);
1366            assert!(tasks.overwrite_present);
1367
1368            assert_eq!(
1369                ixrows,
1370                vec![
1371                    IndexRow {
1372                        row_ix: 50,
1373                        values: vec![
1374                            IndexValue {
1375                                col_ix: 36,
1376                                value: Datum::Categorical(1_u32.into())
1377                            },
1378                            IndexValue {
1379                                col_ix: 34,
1380                                value: Datum::Categorical(1_u32.into())
1381                            },
1382                        ]
1383                    },
1384                    IndexRow {
1385                        row_ix: 15,
1386                        values: vec![
1387                            IndexValue {
1388                                col_ix: 80,
1389                                value: Datum::Categorical(1_u32.into())
1390                            },
1391                            IndexValue {
1392                                col_ix: 58,
1393                                value: Datum::Categorical(0_u32.into())
1394                            },
1395                        ]
1396                    }
1397                ]
1398            );
1399        }
1400
1401        #[test]
1402        fn tasks_on_one_new_col_in_existing_row() {
1403            let engine = Example::Animals.engine().unwrap();
1404            let col_metadata = ColMetadataList::new(vec![ColMetadata {
1405                name: "dances".into(),
1406                coltype: ColType::Categorical {
1407                    k: 2,
1408                    hyper: None,
1409                    prior: None,
1410                    value_map: ValueMap::UInt(2),
1411                },
1412                notes: None,
1413                missing_not_at_random: false,
1414            }])
1415            .unwrap();
1416            let moose_updates = Row::<String, String> {
1417                row_ix: "moose".into(),
1418                values: vec![
1419                    Value {
1420                        col_ix: "dances".into(),
1421                        value: Datum::Categorical(1_u32.into()),
1422                    },
1423                    Value {
1424                        col_ix: "flys".into(),
1425                        value: Datum::Categorical(1_u32.into()),
1426                    },
1427                ],
1428            };
1429            let rows = vec![moose_updates];
1430            let (tasks, ixrows) =
1431                insert_data_tasks(&rows, &Some(col_metadata), &engine).unwrap();
1432
1433            assert!(tasks.new_rows.is_empty());
1434            assert_eq!(tasks.new_cols.len(), 1);
1435            assert!(tasks.new_cols.contains("dances"));
1436
1437            assert!(!tasks.overwrite_missing);
1438            assert!(tasks.overwrite_present);
1439
1440            assert_eq!(
1441                ixrows,
1442                vec![IndexRow {
1443                    row_ix: 15,
1444                    values: vec![
1445                        IndexValue {
1446                            col_ix: 85,
1447                            value: Datum::Categorical(1_u32.into())
1448                        },
1449                        IndexValue {
1450                            col_ix: 34,
1451                            value: Datum::Categorical(1_u32.into())
1452                        },
1453                    ]
1454                }]
1455            );
1456        }
1457
1458        #[test]
1459        fn tasks_on_one_new_col_in_new_row() {
1460            let engine = Example::Animals.engine().unwrap();
1461
1462            let col_metadata = ColMetadataList::new(vec![ColMetadata {
1463                name: "dances".into(),
1464                coltype: ColType::Categorical {
1465                    k: 2,
1466                    hyper: None,
1467                    prior: None,
1468                    value_map: ValueMap::UInt(2),
1469                },
1470                notes: None,
1471                missing_not_at_random: false,
1472            }])
1473            .unwrap();
1474
1475            let peanut = Row::<String, String> {
1476                row_ix: "peanut".into(),
1477                values: vec![
1478                    Value {
1479                        col_ix: "dances".into(),
1480                        value: Datum::Categorical(1_u32.into()),
1481                    },
1482                    Value {
1483                        col_ix: "flys".into(),
1484                        value: Datum::Categorical(0_u32.into()),
1485                    },
1486                ],
1487            };
1488            let rows = vec![peanut];
1489            let (tasks, ixrows) =
1490                insert_data_tasks(&rows, &Some(col_metadata), &engine).unwrap();
1491
1492            assert_eq!(tasks.new_rows.len(), 1);
1493            assert!(tasks.new_rows.contains("peanut"));
1494
1495            assert_eq!(tasks.new_cols.len(), 1);
1496            assert!(tasks.new_cols.contains("dances"));
1497
1498            assert!(!tasks.overwrite_missing);
1499            assert!(!tasks.overwrite_present);
1500
1501            assert_eq!(
1502                ixrows,
1503                vec![IndexRow {
1504                    row_ix: 50,
1505                    values: vec![
1506                        IndexValue {
1507                            col_ix: 85,
1508                            value: Datum::Categorical(1_u32.into())
1509                        },
1510                        IndexValue {
1511                            col_ix: 34,
1512                            value: Datum::Categorical(0_u32.into())
1513                        },
1514                    ]
1515                }]
1516            );
1517        }
1518
1519        #[test]
1520        fn tasks_on_two_new_cols_in_existing_row() {
1521            let engine = Example::Animals.engine().unwrap();
1522            let col_metadata = ColMetadataList::new(vec![
1523                ColMetadata {
1524                    name: "dances".into(),
1525                    coltype: ColType::Categorical {
1526                        k: 2,
1527                        hyper: None,
1528                        prior: None,
1529                        value_map: ValueMap::UInt(2),
1530                    },
1531                    notes: None,
1532                    missing_not_at_random: false,
1533                },
1534                ColMetadata {
1535                    name: "eats+figs".into(),
1536                    coltype: ColType::Categorical {
1537                        k: 2,
1538                        hyper: None,
1539                        prior: None,
1540                        value_map: ValueMap::UInt(2),
1541                    },
1542                    notes: None,
1543                    missing_not_at_random: false,
1544                },
1545            ])
1546            .unwrap();
1547
1548            let moose_updates = Row::<String, String> {
1549                row_ix: "moose".into(),
1550                values: vec![
1551                    Value {
1552                        col_ix: "flys".into(),
1553                        value: Datum::Categorical(1_u32.into()),
1554                    },
1555                    Value {
1556                        col_ix: "eats+figs".into(),
1557                        value: Datum::Categorical(0_u32.into()),
1558                    },
1559                    Value {
1560                        col_ix: "dances".into(),
1561                        value: Datum::Categorical(1_u32.into()),
1562                    },
1563                ],
1564            };
1565            let rows = vec![moose_updates];
1566            let (tasks, ixrows) =
1567                insert_data_tasks(&rows, &Some(col_metadata), &engine).unwrap();
1568
1569            assert!(tasks.new_rows.is_empty());
1570            assert_eq!(tasks.new_cols.len(), 2);
1571            assert!(tasks.new_cols.contains("dances"));
1572            assert!(tasks.new_cols.contains("eats+figs"));
1573
1574            assert!(!tasks.overwrite_missing);
1575            assert!(tasks.overwrite_present);
1576
1577            assert_eq!(
1578                ixrows,
1579                vec![IndexRow {
1580                    row_ix: 15,
1581                    values: vec![
1582                        IndexValue {
1583                            col_ix: 34,
1584                            value: Datum::Categorical(1_u32.into())
1585                        },
1586                        IndexValue {
1587                            col_ix: 86,
1588                            value: Datum::Categorical(0_u32.into())
1589                        },
1590                        IndexValue {
1591                            col_ix: 85,
1592                            value: Datum::Categorical(1_u32.into())
1593                        },
1594                    ]
1595                }]
1596            );
1597        }
1598    }
1599
1600    fn quick_codebook() -> Codebook {
1601        let coltype = ColType::Categorical {
1602            k: 2,
1603            hyper: None,
1604            prior: None,
1605            value_map: ValueMap::try_from(vec![
1606                String::from("red"),
1607                String::from("green"),
1608            ])
1609            .unwrap(),
1610        };
1611        let md0 = ColMetadata {
1612            name: "0".to_string(),
1613            coltype: coltype.clone(),
1614            notes: None,
1615            missing_not_at_random: false,
1616        };
1617        let md1 = ColMetadata {
1618            name: "1".to_string(),
1619            coltype,
1620            notes: None,
1621            missing_not_at_random: false,
1622        };
1623        let md2 = ColMetadata {
1624            name: "2".to_string(),
1625            coltype: ColType::Categorical {
1626                k: 3,
1627                hyper: None,
1628                prior: None,
1629                value_map: ValueMap::UInt(3),
1630            },
1631            notes: None,
1632            missing_not_at_random: false,
1633        };
1634
1635        let col_metadata = ColMetadataList::new(vec![md0, md1, md2]).unwrap();
1636        Codebook::new("table".to_string(), col_metadata)
1637    }
1638
1639    #[test]
1640    fn incr_cats_in_codebook_without_suppl_metadata_for_no_valmap_col() {
1641        let mut codebook = quick_codebook();
1642
1643        let n_cats_before = match &codebook.col_metadata[2].coltype {
1644            ColType::Categorical {
1645                k,
1646                value_map: ValueMap::UInt(3),
1647                ..
1648            } => *k,
1649            ColType::Categorical { value_map, .. } => {
1650                panic!(
1651                    "starting value_map should have been U32(3), was {:?}",
1652                    value_map
1653                );
1654            }
1655            _ => panic!("should've been categorical"),
1656        };
1657
1658        assert_eq!(n_cats_before, 3);
1659
1660        let mut extension = ValueMapExtension::new_uint();
1661        extension.extend(Category::UInt(3)).unwrap();
1662
1663        let result = incr_category_in_codebook(&mut codebook, 2, &extension);
1664        result.unwrap();
1665
1666        let n_cats_after = match &codebook.col_metadata[2].coltype {
1667            ColType::Categorical {
1668                k,
1669                value_map: ValueMap::UInt(4),
1670                ..
1671            } => *k,
1672            ColType::Categorical { value_map, .. } => {
1673                panic!("value_map should be U32(4), was: {:?}", value_map)
1674            }
1675            _ => panic!("should've been categorical"),
1676        };
1677
1678        assert_eq!(n_cats_after, 4);
1679    }
1680
1681    #[test]
1682    fn incr_cats_in_codebook_with_suppl_metadata_for_valmap_col() {
1683        let mut codebook = quick_codebook();
1684
1685        match &codebook.col_metadata[0].coltype {
1686            ColType::Categorical {
1687                k, value_map: vm, ..
1688            } => {
1689                assert_eq!(*k, 2);
1690                assert_eq!(vm.len(), 2);
1691            }
1692            _ => panic!("should've been categorical with valmap"),
1693        };
1694
1695        let mut extension = ValueMapExtension::new_string();
1696        extension
1697            .extend(Category::String("blue".to_string()))
1698            .unwrap();
1699
1700        let result = incr_category_in_codebook(&mut codebook, 0, &extension);
1701
1702        assert!(result.is_ok());
1703
1704        match &codebook.col_metadata[0].coltype {
1705            ColType::Categorical {
1706                k, value_map: vm, ..
1707            } => {
1708                assert_eq!(vm.len(), 3);
1709                assert_eq!(*k, 3);
1710            }
1711            _ => panic!("should've been categorical with valmap"),
1712        };
1713    }
1714
1715    #[test]
1716    fn append_bool() {
1717        let coltype = ColType::Categorical {
1718            k: 2,
1719            hyper: Some(CsdHyper::default()),
1720            prior: None,
1721            value_map: ValueMap::Bool,
1722        };
1723        let md0 = ColMetadata {
1724            name: "bool_col".to_string(),
1725            coltype: coltype.clone(),
1726            notes: None,
1727            missing_not_at_random: false,
1728        };
1729
1730        let mut engine = Engine::new(
1731            1,
1732            Codebook::new(
1733                "test".to_string(),
1734                ColMetadataList::new(vec![]).unwrap(),
1735            ),
1736            data_source::DataSource::Empty,
1737            0,
1738            rand_xoshiro::Xoshiro256Plus::seed_from_u64(0x1234),
1739        )
1740        .unwrap();
1741
1742        // Insert once with specific metadata.
1743        engine
1744            .insert_data(
1745                vec![(
1746                    "abc",
1747                    vec![(
1748                        "bool_col",
1749                        Datum::Categorical(Category::Bool(false)),
1750                    )],
1751                )
1752                    .into()],
1753                Some(ColMetadataList::new(vec![md0]).unwrap()),
1754                WriteMode::unrestricted(),
1755            )
1756            .unwrap();
1757
1758        // Insert again without metadata for the bool column.
1759        engine
1760            .insert_data(
1761                vec![(
1762                    "def",
1763                    vec![(
1764                        "bool_col",
1765                        Datum::Categorical(Category::Bool(false)),
1766                    )],
1767                )
1768                    .into()],
1769                None,
1770                WriteMode::unrestricted(),
1771            )
1772            .unwrap();
1773    }
1774}