db_mover/databases/mysql/
mod.rs

1use std::collections::HashMap;
2
3use anyhow::Context;
4use itertools::Itertools;
5use mysql::prelude::Queryable;
6use mysql::{Conn, Opts, params};
7use tracing::debug;
8pub use value::MysqlTypeOptions;
9
10use crate::databases::table::{Row, Value};
11use crate::databases::traits::{DBInfoProvider, DBReader};
12
13use super::table::{Column, ColumnType, TableInfo};
14use super::traits::{DBWriter, ReaderIterator, WriterError};
15
16mod value;
17
18pub struct MysqlDB {
19    uri: String,
20    connection: Conn,
21    is_mariadb: bool,
22    type_options: MysqlTypeOptions,
23    stmt_cache: HashMap<(String, usize, usize), mysql::Statement>,
24}
25
26impl MysqlDB {
27    pub fn new(uri: &str, type_options: MysqlTypeOptions) -> anyhow::Result<Self> {
28        let mut connection = Self::connect(uri)?;
29        debug!("Connected to mysql {uri}");
30        let version: String = connection
31            .query_first("SELECT VERSION()")
32            .context("Unable to fetch database version")?
33            .unwrap();
34        return Ok(Self {
35            uri: uri.to_string(),
36            connection,
37            is_mariadb: version.contains("MariaDB"),
38            type_options,
39            stmt_cache: HashMap::new(),
40        });
41    }
42
43    fn connect(uri: &str) -> Result<Conn, anyhow::Error> {
44        let opts = Opts::from_url(uri)?;
45        let mut conn = Conn::new(opts)?;
46        conn.query_drop("SET time_zone = 'UTC'")
47            .context("Failed to set UTC timezone")?;
48        return Ok(conn);
49    }
50
51    fn get_num_rows(&mut self, table: &str) -> anyhow::Result<u64> {
52        let count_query = format!("SELECT count(1) FROM {table}");
53        return self
54            .connection
55            .query_first(count_query)?
56            .context("Unable to get count of rows for table");
57    }
58
59    fn get_stmt(
60        &mut self,
61        table_name: &str,
62        values_per_row: usize,
63        rows: usize,
64    ) -> anyhow::Result<mysql::Statement> {
65        let key = (table_name.to_owned(), values_per_row, rows);
66        return match self.stmt_cache.get(&key) {
67            Some(stmt) => Ok(stmt.to_owned()),
68            None => {
69                let placeholder = generate_placeholders(values_per_row, rows);
70                let stmt = self
71                    .connection
72                    .prep(format!("INSERT INTO {table_name} VALUES {placeholder}"))
73                    .context("Unable to prepare insert query")?;
74                self.stmt_cache.insert(key, stmt.clone());
75                Ok(stmt)
76            }
77        };
78    }
79}
80
81impl DBInfoProvider for MysqlDB {
82    fn get_table_info(&mut self, table: &str, no_count: bool) -> anyhow::Result<TableInfo> {
83        let mut num_rows = None;
84        if !no_count {
85            num_rows = Some(
86                self.get_num_rows(table)
87                    .context("Failed to get number of rows in the table")?,
88            );
89        }
90
91        let info_rows: Vec<mysql::Row> = self.connection.exec(r"SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE 
92                                                                FROM INFORMATION_SCHEMA.COLUMNS 
93                                                                WHERE table_name = :table AND TABLE_SCHEMA = database()
94                                                                ORDER BY ORDINAL_POSITION", params! {table})?;
95        let mut columns = Vec::with_capacity(info_rows.len());
96        for row in info_rows {
97            let name = row
98                .get_opt(0)
99                .context("Value expected")?
100                .context("Couldn't parse column name")?;
101            let mut column_type: String = row
102                .get_opt(1)
103                .context("Value expected")?
104                .context("Couldn't parse column type")?;
105            let nullable: String = row
106                .get_opt(2)
107                .context("Value expected")?
108                .context("Couldn't parse column nullable")?;
109            if column_type == "longtext" && self.is_mariadb {
110                let num_json_constraints: usize = self.connection.exec_first(
111                    r"SELECT count(1) FROM INFORMATION_SCHEMA.check_constraints
112                    WHERE CONSTRAINT_SCHEMA = database() AND TABLE_NAME = :table AND CHECK_CLAUSE = :clause",
113                    params! {table, "clause" => format!("json_valid(`{name}`)")},
114                ).context("Failed to check json constraint")?.unwrap();
115                if num_json_constraints > 0 {
116                    column_type = String::from("json");
117                }
118            }
119            columns.push(Column {
120                name,
121                column_type: ColumnType::try_from_mysql_type(&column_type, &self.type_options)?,
122                nullable: nullable.as_str() == "YES",
123            });
124        }
125
126        return Ok(TableInfo {
127            name: table.to_string(),
128            num_rows,
129            columns,
130        });
131    }
132
133    fn get_tables(&mut self) -> anyhow::Result<Vec<String>> {
134        let rows: Vec<mysql::Row> = self.connection.query(
135            "SELECT table_name FROM information_schema.tables WHERE table_schema = database()",
136        )?;
137        return rows
138            .iter()
139            .map(|row| {
140                row.get_opt(0)
141                    .context("Value expected")?
142                    .context("Couldn't parse table name")
143            })
144            .collect::<anyhow::Result<Vec<String>>>();
145    }
146}
147
148struct MysqlRowsIter<'a> {
149    target_format: TableInfo,
150    rows: mysql::QueryResult<'a, 'a, 'a, mysql::Text>,
151}
152
153impl Iterator for MysqlRowsIter<'_> {
154    type Item = anyhow::Result<Row>;
155
156    fn next(&mut self) -> Option<Self::Item> {
157        return match self.rows.next() {
158            Some(Ok(row)) => {
159                let mut result: Row = Vec::with_capacity(self.target_format.columns.len());
160                let values = row.unwrap();
161                assert_eq!(values.len(), self.target_format.columns.len());
162                for (column, value) in std::iter::zip(&self.target_format.columns, values) {
163                    match Value::try_from((column, value)) {
164                        Ok(val) => result.push(val),
165                        Err(e) => return Some(Err(e)),
166                    }
167                }
168                Some(Ok(result))
169            }
170            Some(Err(err)) => Some(Err(err).context("Error while reading data from mysql")),
171            None => None,
172        };
173    }
174}
175
176impl DBReader for MysqlDB {
177    fn read_iter(&mut self, target_format: TableInfo) -> anyhow::Result<ReaderIterator<'_>> {
178        let query = format!(
179            "SELECT {} FROM {}",
180            target_format.column_names().join(", "),
181            target_format.name
182        );
183        let rows = self
184            .connection
185            .query_iter(query)
186            .context("Failed to get data from mysql source")?;
187        return Ok(Box::new(MysqlRowsIter {
188            target_format,
189            rows,
190        }));
191    }
192}
193
194impl DBWriter for MysqlDB {
195    fn opt_clone(&self) -> anyhow::Result<Box<dyn DBWriter>> {
196        return MysqlDB::new(&self.uri, self.type_options.clone())
197            .map(|writer| Box::new(writer) as _);
198    }
199
200    fn write_batch(&mut self, batch: &[Row], table: &TableInfo) -> Result<(), WriterError> {
201        let stmt = self.get_stmt(&table.name, batch[0].len(), batch.len())?;
202        let mut values = Vec::with_capacity(batch[0].len() * batch.len());
203        for row in batch {
204            for value in row {
205                values.push(mysql::Value::from(value));
206            }
207        }
208        self.connection
209            .exec_drop(stmt, mysql::Params::Positional(values))
210            .context("Unable to insert values into mysql")?;
211
212        return Ok(());
213    }
214
215    fn recover(&mut self) -> anyhow::Result<()> {
216        debug!("Trying to reconnect to the mysql");
217        self.connection = Self::connect(&self.uri)?;
218        debug!("Successfully reconnected to the mysql");
219        return Ok(());
220    }
221}
222
223fn generate_placeholders(values_per_row: usize, rows: usize) -> String {
224    use std::fmt::Write;
225
226    let block_inner = std::iter::repeat_n("?", values_per_row).join(", ");
227    let mut result = String::with_capacity(rows * (block_inner.len() + 3));
228    for i in 0..rows {
229        if i > 0 {
230            result.push(',');
231        }
232        result.push('(');
233        result.write_str(&block_inner).unwrap();
234        result.push(')');
235    }
236    result
237}