oracle_sql_tools/statements/
insert_utils.rs

1use std::{fmt::Display, sync::Arc, thread::{self, JoinHandle}};
2use indicatif::ProgressBar;
3use oracle::{Batch, Connection};
4
5use crate::{format_data::FormattedData, types::{errors::OracleSqlToolsError, BatchPrep, CellProperties, DatatypeIndexes, GridProperties}, utils::remove_invalid_chars};
6use super::mutate_grid::MutateGrid;
7
8impl BatchPrep {
9    pub(crate) fn split_batch_by_threads(mut self) -> Result<Arc<Connection>, OracleSqlToolsError> {
10        // wrapping these variables in an Arc because they're going to be passed to multiple threads
11        let conn: Arc<Connection> = Arc::new(self.conn);
12        let insert_stmt: Arc<String> = Arc::new(self.insert_stmt);
13        let datatype_indexes: Arc<DatatypeIndexes> = Arc::new(self.data_indexes);
14
15        // divides the length of the data by the number of threads on the host CPU
16        let len = self.data.len();
17        let nthreads = num_cpus::get();
18        let num = (len / nthreads + if len % nthreads == 0 { 0 } else { 1 }) as f32;
19
20        // initialize the progress bar
21        let pb = ProgressBar::new(len as u64);
22        let progress_bar = Arc::new(pb);
23
24        // captures the spawned threads into a vector
25        let mut handles: Vec<JoinHandle<Result<(), OracleSqlToolsError>>> = Vec::new();
26        // iterates as many times as there are threads
27        for n in 0..nthreads {
28            // each thread needs to have its own clone of the data
29            let conn = Arc::clone(&conn);
30            let insert = Arc::clone(&insert_stmt);
31            let datatype_indexes = Arc::clone(&datatype_indexes);
32            let progress_bar = Arc::clone(&progress_bar);
33            let arc_data: Arc<Vec<Vec<FormattedData>>>;
34            if n + 1 < nthreads {
35                // splits up the 2d vector to have an even amount per thread
36                let d = self.data.divide(num);
37                // new ARC per each split vector
38                arc_data = Arc::new(d);
39            // collecting the remaining data
40            } else { arc_data = Arc::new(self.data.to_owned()); }
41            handles.push(thread::spawn(move || {
42                // creates a unique Batch per thread
43                let mut batch: Batch<'_> = conn
44                    .batch(&insert.as_str(), arc_data.len())
45                    .build()?;
46                // each thread iterates over their slice of the data
47                GridProperties {
48                    data: arc_data,
49                    num: (num.ceil() as usize - 1) * n,
50                    datatype_indexes,
51                }.get_cell_props(&mut batch, progress_bar)
52            }));
53        }
54        // executes all threads
55        for handle in handles {
56            handle.join().unwrap()?;
57        }
58        Ok(conn)
59    }
60
61    pub(crate) fn single_thread_batch(self) -> Result<Arc<Connection>, OracleSqlToolsError> {
62        let body_len = &self.data.len();
63        // initialize the progress bar
64        let pb = ProgressBar::new(*body_len as u64);
65        let progress_bar = Arc::new(pb);
66        let conn: Arc<Connection> = Arc::new(self.conn);
67        let conn_clone = Arc::clone(&conn);
68        let mut batch: Batch<'_> = conn_clone
69            .batch(&self.insert_stmt.as_str(), body_len.to_owned())
70            .build()?;
71        // CellProperties expects an Arc<DatatypeIndexes>
72        let datatype_indexes = Arc::new(self.data_indexes);
73        GridProperties {
74            data: self.data.into(),
75            num: 0usize,
76            datatype_indexes,
77        }.get_cell_props(&mut batch, progress_bar)?;
78        Ok(conn)
79    }
80}
81
82impl GridProperties {
83    fn get_cell_props(self, batch: &mut Batch<'_>, progress_bar: Arc<ProgressBar>) -> Result<(), OracleSqlToolsError> {
84        self.data.iter().enumerate().try_for_each(|(y, row)| 
85        -> Result<(), OracleSqlToolsError> {
86            row.iter().enumerate().try_for_each(|(x, cell)| 
87            -> Result<(), OracleSqlToolsError> {
88                CellProperties {
89                    cell,
90                    datatype_indexes: &self.datatype_indexes,
91                    x_ind: x,
92                    y_ind: self.num + y,
93                }.bind_cell_to_batch(batch)
94            })?;
95            batch.append_row(&[])?;
96            progress_bar.inc(1u64);
97            Ok(())
98        })?;
99
100        batch.execute()?;
101        Ok(())
102    }
103}
104
105macro_rules! empty_batch_set {
106    ($cell_props:ident, $data_type:ty, $batch:ident) => {
107        match $batch.set($cell_props.x_ind + 1, &None::<$data_type>) {
108            Ok(_) => return Ok(()),
109            Err(e) => return Err(OracleSqlToolsError::CellPropertyError { 
110                error_message: e, 
111                cell_value: "NULL".to_string(),
112                x_index: $cell_props.x_ind, 
113                y_index: $cell_props.y_ind 
114            }),
115        }
116    };
117}
118
119impl<'props> CellProperties<'props> {
120    fn bind_cell_to_batch(self, batch: &mut Batch<'_>) -> Result<(), OracleSqlToolsError> {
121        match &self.cell {
122            FormattedData::STRING(val) => batch_set(self, batch, val.to_string()),
123            FormattedData::INT(val) => match self.datatype_indexes.is_varchar.contains(&self.x_ind) {
124                true => batch_set(self, batch, val.to_string()),
125                false => match self.datatype_indexes.is_float.contains(&self.x_ind) {
126                    true => batch_set(self, batch, val.to_string().parse::<f64>().unwrap()),
127                    false => batch_set(self, batch, *val),
128                },
129            },
130            FormattedData::FLOAT(val) => match self.datatype_indexes.is_varchar.contains(&self.x_ind) {
131                true => batch_set(self, batch, val.to_string()),
132                false => batch_set(self, batch, *val),
133            },
134            FormattedData::DATE(val) => {
135                match self.datatype_indexes {
136                    ind if ind.is_varchar.contains(&self.x_ind) => batch_set(self, batch, val.to_string()),
137                    ind if ind.is_date.contains(&self.x_ind) => batch_set(self, batch, *val),
138                    ind if ind.is_int.contains(&self.x_ind) => {
139                        let to_num = remove_invalid_chars(&val.to_string());
140                        batch_set(self, batch, to_num.parse::<i64>().unwrap())
141                    },
142                    ind if ind.is_float.contains(&self.x_ind) => {
143                        let to_num = remove_invalid_chars(&val.to_string());
144                        batch_set(self, batch, to_num.parse::<f64>().unwrap())
145                    },
146                    _ => batch_set(self, batch, *val),
147                }
148            },
149            FormattedData::TIMESTAMP(val) => {
150                match self.datatype_indexes {
151                    ind if ind.is_varchar.contains(&self.x_ind) => batch_set(self, batch, val.to_string()),
152                    ind if ind.is_date.contains(&self.x_ind) => batch_set(self, batch, *val),
153                    ind if ind.is_int.contains(&self.x_ind) => {
154                        let to_num = remove_invalid_chars(&val.to_string());
155                        batch_set(self, batch, to_num.parse::<i64>().unwrap())
156                    },
157                    ind if ind.is_float.contains(&self.x_ind) => {
158                        let to_num = remove_invalid_chars(&val.to_string());
159                        batch_set(self, batch, to_num.parse::<f64>().unwrap())
160                    },
161                    _ => batch_set(self, batch, *val),
162                }
163            },
164            FormattedData::EMPTY => {
165                match self.datatype_indexes {
166                    ind if ind.is_varchar.contains(&self.x_ind) => empty_batch_set!(self, String, batch),
167                    ind if ind.is_date.contains(&self.x_ind) => empty_batch_set!(self, chrono::NaiveDateTime, batch),
168                    ind if ind.is_int.contains(&self.x_ind) => empty_batch_set!(self, i8, batch),
169                    ind if ind.is_float.contains(&self.x_ind) => empty_batch_set!(self, f32, batch),
170                    _ => empty_batch_set!(self, String, batch),
171                }
172            },
173        }
174    }
175}
176
177fn batch_set<T> (cell_props: CellProperties, batch: &mut Batch<'_>, value: T) 
178-> Result<(), OracleSqlToolsError>
179where T: oracle::sql_type::ToSql + Display {
180    match batch.set(cell_props.x_ind + 1, &value) {
181        Ok(_) => Ok(()),
182        Err(e) => return Err(OracleSqlToolsError::CellPropertyError { 
183            error_message: e, 
184            cell_value: value.to_string(),
185            x_index: cell_props.x_ind, 
186            y_index: cell_props.y_ind 
187        }),
188    }
189}