db_mover/databases/mysql/
mod.rs1use std::collections::HashMap;
2
3use anyhow::Context;
4use itertools::Itertools;
5use mysql::prelude::Queryable;
6use mysql::{Conn, Opts, params};
7use tracing::debug;
8pub use value::MysqlTypeOptions;
9
10use crate::databases::table::{Row, Value};
11use crate::databases::traits::{DBInfoProvider, DBReader};
12
13use super::table::{Column, ColumnType, TableInfo};
14use super::traits::{DBWriter, ReaderIterator, WriterError};
15
16mod value;
17
18pub struct MysqlDB {
19 uri: String,
20 connection: Conn,
21 is_mariadb: bool,
22 type_options: MysqlTypeOptions,
23 stmt_cache: HashMap<(String, usize, usize), mysql::Statement>,
24}
25
26impl MysqlDB {
27 pub fn new(uri: &str, type_options: MysqlTypeOptions) -> anyhow::Result<Self> {
28 let mut connection = Self::connect(uri)?;
29 debug!("Connected to mysql {uri}");
30 let version: String = connection
31 .query_first("SELECT VERSION()")
32 .context("Unable to fetch database version")?
33 .unwrap();
34 return Ok(Self {
35 uri: uri.to_string(),
36 connection,
37 is_mariadb: version.contains("MariaDB"),
38 type_options,
39 stmt_cache: HashMap::new(),
40 });
41 }
42
43 fn connect(uri: &str) -> Result<Conn, anyhow::Error> {
44 let opts = Opts::from_url(uri)?;
45 let mut conn = Conn::new(opts)?;
46 conn.query_drop("SET time_zone = 'UTC'")
47 .context("Failed to set UTC timezone")?;
48 return Ok(conn);
49 }
50
51 fn get_num_rows(&mut self, table: &str) -> anyhow::Result<u64> {
52 let count_query = format!("SELECT count(1) FROM {table}");
53 return self
54 .connection
55 .query_first(count_query)?
56 .context("Unable to get count of rows for table");
57 }
58
59 fn get_stmt(
60 &mut self,
61 table_name: &str,
62 values_per_row: usize,
63 rows: usize,
64 ) -> anyhow::Result<mysql::Statement> {
65 let key = (table_name.to_owned(), values_per_row, rows);
66 return match self.stmt_cache.get(&key) {
67 Some(stmt) => Ok(stmt.to_owned()),
68 None => {
69 let placeholder = generate_placeholders(values_per_row, rows);
70 let stmt = self
71 .connection
72 .prep(format!("INSERT INTO {table_name} VALUES {placeholder}"))
73 .context("Unable to prepare insert query")?;
74 self.stmt_cache.insert(key, stmt.clone());
75 Ok(stmt)
76 }
77 };
78 }
79}
80
81impl DBInfoProvider for MysqlDB {
82 fn get_table_info(&mut self, table: &str, no_count: bool) -> anyhow::Result<TableInfo> {
83 let mut num_rows = None;
84 if !no_count {
85 num_rows = Some(
86 self.get_num_rows(table)
87 .context("Failed to get number of rows in the table")?,
88 );
89 }
90
91 let info_rows: Vec<mysql::Row> = self.connection.exec(r"SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE
92 FROM INFORMATION_SCHEMA.COLUMNS
93 WHERE table_name = :table AND TABLE_SCHEMA = database()
94 ORDER BY ORDINAL_POSITION", params! {table})?;
95 let mut columns = Vec::with_capacity(info_rows.len());
96 for row in info_rows {
97 let name = row
98 .get_opt(0)
99 .context("Value expected")?
100 .context("Couldn't parse column name")?;
101 let mut column_type: String = row
102 .get_opt(1)
103 .context("Value expected")?
104 .context("Couldn't parse column type")?;
105 let nullable: String = row
106 .get_opt(2)
107 .context("Value expected")?
108 .context("Couldn't parse column nullable")?;
109 if column_type == "longtext" && self.is_mariadb {
110 let num_json_constraints: usize = self.connection.exec_first(
111 r"SELECT count(1) FROM INFORMATION_SCHEMA.check_constraints
112 WHERE CONSTRAINT_SCHEMA = database() AND TABLE_NAME = :table AND CHECK_CLAUSE = :clause",
113 params! {table, "clause" => format!("json_valid(`{name}`)")},
114 ).context("Failed to check json constraint")?.unwrap();
115 if num_json_constraints > 0 {
116 column_type = String::from("json");
117 }
118 }
119 columns.push(Column {
120 name,
121 column_type: ColumnType::try_from_mysql_type(&column_type, &self.type_options)?,
122 nullable: nullable.as_str() == "YES",
123 });
124 }
125
126 return Ok(TableInfo {
127 name: table.to_string(),
128 num_rows,
129 columns,
130 });
131 }
132
133 fn get_tables(&mut self) -> anyhow::Result<Vec<String>> {
134 let rows: Vec<mysql::Row> = self.connection.query(
135 "SELECT table_name FROM information_schema.tables WHERE table_schema = database()",
136 )?;
137 return rows
138 .iter()
139 .map(|row| {
140 row.get_opt(0)
141 .context("Value expected")?
142 .context("Couldn't parse table name")
143 })
144 .collect::<anyhow::Result<Vec<String>>>();
145 }
146}
147
148struct MysqlRowsIter<'a> {
149 target_format: TableInfo,
150 rows: mysql::QueryResult<'a, 'a, 'a, mysql::Text>,
151}
152
153impl Iterator for MysqlRowsIter<'_> {
154 type Item = anyhow::Result<Row>;
155
156 fn next(&mut self) -> Option<Self::Item> {
157 return match self.rows.next() {
158 Some(Ok(row)) => {
159 let mut result: Row = Vec::with_capacity(self.target_format.columns.len());
160 let values = row.unwrap();
161 assert_eq!(values.len(), self.target_format.columns.len());
162 for (column, value) in std::iter::zip(&self.target_format.columns, values) {
163 match Value::try_from((column, value)) {
164 Ok(val) => result.push(val),
165 Err(e) => return Some(Err(e)),
166 }
167 }
168 Some(Ok(result))
169 }
170 Some(Err(err)) => Some(Err(err).context("Error while reading data from mysql")),
171 None => None,
172 };
173 }
174}
175
176impl DBReader for MysqlDB {
177 fn read_iter(&mut self, target_format: TableInfo) -> anyhow::Result<ReaderIterator<'_>> {
178 let query = format!(
179 "SELECT {} FROM {}",
180 target_format.column_names().join(", "),
181 target_format.name
182 );
183 let rows = self
184 .connection
185 .query_iter(query)
186 .context("Failed to get data from mysql source")?;
187 return Ok(Box::new(MysqlRowsIter {
188 target_format,
189 rows,
190 }));
191 }
192}
193
194impl DBWriter for MysqlDB {
195 fn opt_clone(&self) -> anyhow::Result<Box<dyn DBWriter>> {
196 return MysqlDB::new(&self.uri, self.type_options.clone())
197 .map(|writer| Box::new(writer) as _);
198 }
199
200 fn write_batch(&mut self, batch: &[Row], table: &TableInfo) -> Result<(), WriterError> {
201 let stmt = self.get_stmt(&table.name, batch[0].len(), batch.len())?;
202 let mut values = Vec::with_capacity(batch[0].len() * batch.len());
203 for row in batch {
204 for value in row {
205 values.push(mysql::Value::from(value));
206 }
207 }
208 self.connection
209 .exec_drop(stmt, mysql::Params::Positional(values))
210 .context("Unable to insert values into mysql")?;
211
212 return Ok(());
213 }
214
215 fn recover(&mut self) -> anyhow::Result<()> {
216 debug!("Trying to reconnect to the mysql");
217 self.connection = Self::connect(&self.uri)?;
218 debug!("Successfully reconnected to the mysql");
219 return Ok(());
220 }
221}
222
223fn generate_placeholders(values_per_row: usize, rows: usize) -> String {
224 use std::fmt::Write;
225
226 let block_inner = std::iter::repeat_n("?", values_per_row).join(", ");
227 let mut result = String::with_capacity(rows * (block_inner.len() + 3));
228 for i in 0..rows {
229 if i > 0 {
230 result.push(',');
231 }
232 result.push('(');
233 result.write_str(&block_inner).unwrap();
234 result.push(')');
235 }
236 result
237}