generic_bnp/generic_master_problem/
column_pool.rs

1use std::marker::PhantomData;
2use std::slice::Iter;
3use std::time::Instant;
4
5#[derive(Clone, Copy, Debug, PartialEq, Eq)]
6pub struct ColumnId(pub u32);
7
8impl std::hash::Hash for ColumnId {
9    fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
10        hasher.write_u32(self.0)
11    }
12}
13
14impl nohash_hasher::IsEnabled for ColumnId {}
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
17/// Column ticket is used to ensure that no columns
18/// are generated twice within a multi-threaded environment
19/// Each worker records the time it started generated columns.
20/// When submitting new columns it uses the ticket to identify
21/// columns added by other workers in the meantime.
22pub struct ColumnTicket(pub usize);
23
24#[derive(Clone, Debug, PartialEq)]
25pub struct Column<ColumnType> {
26    pub id: ColumnId,
27    pub data: ColumnType,
28}
29
30/// Holds all columns generated so far
31pub struct ColumnPool<ColumnType, FilterType> {
32    local_column_counter: u32,
33    columns: Vec<Column<ColumnType>>,
34    phantom: PhantomData<FilterType>,
35}
36
37impl<ColumnType, FilterType> ColumnPool<ColumnType, FilterType> {
38    pub fn new() -> Self {
39        ColumnPool {
40            local_column_counter: 0,
41            columns: Vec::new(),
42            phantom: PhantomData,
43        }
44    }
45}
46
47/// Trait to be implemented by the user.
48/// Given a list of branching filters, return only those
49/// columns that remain valid.
50pub trait ColumnPoolFilter<ColumnType, FilterType> {
51    fn get_columns(&self, filters: &[FilterType], ticket : Option<ColumnTicket>) -> (ColumnTicket, Vec<&Column<ColumnType>>);
52}
53
54impl<ColumnType, FilterType> ColumnPool<ColumnType, FilterType> {
55
56    /// Total number of columns in pool
57    pub fn count(&self) -> usize {
58        self.columns.len()
59    }
60
61    /// Returns a specific column from the pool
62    pub fn get_column(&self, id: ColumnId) -> &Column<ColumnType> {
63        let column_at_index = &self.columns[id.0 as usize];
64
65        // make sure that our assumption that columns are ordered in the column pool is correct
66        // i.e. ids of lookup key and value match
67        debug_assert_eq!(column_at_index.id.0, id.0);
68
69        column_at_index
70    }
71
72    /// Get Iterator of all columns in column pool and associated ticket
73    pub fn get_all_columns(&self) -> (ColumnTicket,Iter<Column<ColumnType>>) {
74        (ColumnTicket(self.columns.len()), self.columns.iter())
75    }
76
77
78    /// Adds a column to the column pool
79    /// The `column_ticket` that was handed out during get_columns 
80    /// must be returned, to ensure column pool consistency in a 
81    /// multithreaded environment
82    pub fn add_column(&mut self, column_data: ColumnType, column_ticket: ColumnTicket) -> bool
83    where
84        ColumnType: PartialEq,
85    {
86
87
88
89        // in an multi threaded environment, another thread
90        // might have added columns inbetween
91        // we got a ticket="num_cols known" when retrieving the columns
92        // we then only need to check those columns > ticket_no for potential conflicts
93        // if there is a conflict, return the previously added variable instead
94
95        let existing_column = self
96            .columns
97            .iter()
98            .enumerate() // uses enumerate and indecies due to lifetime constraints
99            .skip(column_ticket.0)
100            .filter_map(|( i ,c)| if c.data == column_data { Some(i)} else { None})
101            .next();
102
103
104        if let Some(column) = existing_column {
105            // there was the same column within the pool
106            return false;
107
108        } else {
109
110
111            #[cfg(feature = "validity_assertions")]
112            {
113                // assert that we never regenerate the same column twice
114                let existing_columns  = self
115                    .columns
116                    .iter()
117                    .filter(|c| c.data == column_data)
118                    .next();
119
120                    assert!(existing_columns.is_none());
121            }
122
123
124            let column = Column {
125                id: ColumnId(self.local_column_counter),
126                data: column_data,
127            };
128
129            self.local_column_counter += 1;
130            self.columns.push(column);
131
132            return true
133
134        }
135    }
136}
137
138impl<ColumnType, FilterType> Clone for ColumnPool<ColumnType, FilterType>
139where
140    ColumnType: Clone,
141{
142    fn clone(&self) -> Self {
143        Self {
144            local_column_counter: self.local_column_counter,
145            columns: self.columns.clone(),
146            phantom: PhantomData,
147        }
148    }
149}