1use std::sync::Arc;
2
3use arrow::array::{
4 Array, BinaryBuilder, BooleanBuilder, Date32Builder, Float32Builder, Float64Builder,
5 Int16Builder, Int32Builder, Int64Builder, StringBuilder, TimestampMicrosecondBuilder,
6};
7use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
8use arrow::record_batch::RecordBatch;
9use mysql::prelude::*;
10use mysql::{Opts, Pool, Value};
11
12use crate::error::Result;
13use crate::tuning::SourceTuning;
14use crate::types::CursorState;
15
16pub struct MysqlSource {
17 pool: Pool,
18}
19
20impl MysqlSource {
21 pub fn connect(url: &str) -> Result<Self> {
22 let opts = Opts::from_url(url)?;
23 let pool = Pool::new(opts)?;
24 Ok(Self { pool })
25 }
26}
27
28impl super::Source for MysqlSource {
29 fn export(
30 &mut self,
31 query: &str,
32 cursor_column: Option<&str>,
33 cursor: Option<&CursorState>,
34 tuning: &SourceTuning,
35 sink: &mut dyn super::BatchSink,
36 ) -> Result<()> {
37 let effective_query = build_query(query, cursor_column, cursor);
38 log::info!("executing query: {}", effective_query);
39
40 let mut conn = self.pool.get_conn()?;
41
42 if tuning.statement_timeout_s > 0 {
43 conn.query_drop(format!(
44 "SET SESSION max_execution_time = {}",
45 tuning.statement_timeout_s * 1000
46 ))?;
47 }
48
49 let mut result = conn.query_iter(&effective_query)?;
50 let columns = result.columns().as_ref().to_vec();
51 let schema = Arc::new(mysql_columns_to_schema(&columns));
52 let arrow_types: Vec<DataType> = columns.iter().map(mysql_type_to_arrow).collect();
53
54 sink.on_schema(schema.clone())?;
55
56 let effective_bs = tuning.effective_batch_size(Some(&schema));
57 let row_set = result
58 .iter()
59 .ok_or_else(|| anyhow::anyhow!("no result set"))?;
60 let mut row_buf: Vec<mysql::Row> = Vec::with_capacity(effective_bs);
61 let mut total_rows: usize = 0;
62
63 for row_result in row_set {
64 let row = row_result?;
65 row_buf.push(row);
66
67 if row_buf.len() >= effective_bs {
68 total_rows += row_buf.len();
69 let batch = rows_to_record_batch_typed(&schema, &arrow_types, &row_buf)?;
70 sink.on_batch(&batch)?;
71 row_buf.clear();
72
73 log::info!("fetched {} rows so far...", total_rows);
74
75 if tuning.throttle_ms > 0 {
76 std::thread::sleep(std::time::Duration::from_millis(tuning.throttle_ms));
77 }
78 }
79 }
80
81 if !row_buf.is_empty() {
82 total_rows += row_buf.len();
83 let batch = rows_to_record_batch_typed(&schema, &arrow_types, &row_buf)?;
84 sink.on_batch(&batch)?;
85 }
86
87 drop(result);
88
89 if tuning.statement_timeout_s > 0 {
90 conn.query_drop("SET SESSION max_execution_time = 0")?;
91 }
92
93 log::info!("total: {} rows", total_rows);
94 Ok(())
95 }
96
97 fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
98 use mysql::prelude::*;
99 let mut conn = self.pool.get_conn()?;
100 let row: Option<mysql::Row> = conn.query_first(sql)?;
101 match row {
102 Some(r) => {
103 let val: Option<mysql::Value> = r.get(0);
104 match val {
105 Some(mysql::Value::Bytes(b)) => {
106 Ok(Some(String::from_utf8_lossy(&b).into_owned()))
107 }
108 Some(mysql::Value::Int(v)) => Ok(Some(v.to_string())),
109 Some(mysql::Value::UInt(v)) => Ok(Some(v.to_string())),
110 Some(mysql::Value::Float(v)) => Ok(Some(v.to_string())),
111 Some(mysql::Value::Double(v)) => Ok(Some(v.to_string())),
112 _ => Ok(None),
113 }
114 }
115 None => Ok(None),
116 }
117 }
118}
119
120pub(crate) fn build_query(
121 base_query: &str,
122 cursor_column: Option<&str>,
123 cursor: Option<&CursorState>,
124) -> String {
125 let has_cursor_value = cursor
126 .and_then(|c| c.last_cursor_value.as_deref())
127 .is_some();
128
129 if let (Some(col), true) = (cursor_column, has_cursor_value) {
130 let cursor_val = cursor
131 .expect("cursor checked above")
132 .last_cursor_value
133 .as_deref()
134 .expect("cursor value checked above");
135 format!(
136 "SELECT * FROM ({base}) AS _rivet WHERE {col} > '{val}' ORDER BY {col}",
137 base = base_query,
138 col = col,
139 val = cursor_val,
140 )
141 } else if let Some(col) = cursor_column {
142 format!(
143 "SELECT * FROM ({base}) AS _rivet ORDER BY {col}",
144 base = base_query,
145 col = col,
146 )
147 } else {
148 base_query.to_string()
149 }
150}
151
152fn mysql_type_to_arrow(col: &mysql::Column) -> DataType {
153 use mysql::consts::ColumnType::*;
154 match col.column_type() {
155 MYSQL_TYPE_TINY | MYSQL_TYPE_SHORT => DataType::Int16,
156 MYSQL_TYPE_INT24 | MYSQL_TYPE_LONG => DataType::Int32,
157 MYSQL_TYPE_LONGLONG => DataType::Int64,
158 MYSQL_TYPE_FLOAT => DataType::Float32,
159 MYSQL_TYPE_DOUBLE => DataType::Float64,
160 MYSQL_TYPE_DECIMAL | MYSQL_TYPE_NEWDECIMAL => DataType::Utf8,
161 MYSQL_TYPE_VARCHAR
162 | MYSQL_TYPE_VAR_STRING
163 | MYSQL_TYPE_STRING
164 | MYSQL_TYPE_ENUM
165 | MYSQL_TYPE_SET => DataType::Utf8,
166 MYSQL_TYPE_JSON => DataType::Utf8,
167 MYSQL_TYPE_TINY_BLOB | MYSQL_TYPE_MEDIUM_BLOB | MYSQL_TYPE_LONG_BLOB | MYSQL_TYPE_BLOB => {
168 if col.character_set() == 63 {
169 DataType::Binary
170 } else {
171 DataType::Utf8
172 }
173 }
174 MYSQL_TYPE_DATE | MYSQL_TYPE_NEWDATE => DataType::Date32,
175 MYSQL_TYPE_DATETIME
176 | MYSQL_TYPE_DATETIME2
177 | MYSQL_TYPE_TIMESTAMP
178 | MYSQL_TYPE_TIMESTAMP2 => DataType::Timestamp(TimeUnit::Microsecond, None),
179 MYSQL_TYPE_BIT => DataType::Boolean,
180 MYSQL_TYPE_YEAR => DataType::Int16,
181 _ => {
182 log::warn!(
183 "unmapped MySQL type {:?}, falling back to Utf8",
184 col.column_type()
185 );
186 DataType::Utf8
187 }
188 }
189}
190
191fn mysql_columns_to_schema(columns: &[mysql::Column]) -> Schema {
192 let fields: Vec<Field> = columns
193 .iter()
194 .map(|col| Field::new(col.name_str().to_string(), mysql_type_to_arrow(col), true))
195 .collect();
196 Schema::new(fields)
197}
198
199fn rows_to_record_batch_typed(
200 schema: &SchemaRef,
201 arrow_types: &[DataType],
202 rows: &[mysql::Row],
203) -> Result<RecordBatch> {
204 let mut arrays: Vec<Arc<dyn Array>> = Vec::with_capacity(arrow_types.len());
205 for (col_idx, arrow_type) in arrow_types.iter().enumerate() {
206 arrays.push(build_array(arrow_type, col_idx, rows)?);
207 }
208 Ok(RecordBatch::try_new(schema.clone(), arrays)?)
209}
210
211fn bytes_to_str(b: &[u8]) -> Option<&str> {
212 std::str::from_utf8(b).ok()
213}
214
215fn build_array(
216 arrow_type: &DataType,
217 col_idx: usize,
218 rows: &[mysql::Row],
219) -> Result<Arc<dyn Array>> {
220 match arrow_type {
221 DataType::Boolean => {
222 let mut b = BooleanBuilder::with_capacity(rows.len());
223 for row in rows {
224 match row.as_ref(col_idx) {
225 Some(Value::Int(v)) => b.append_value(*v != 0),
226 Some(Value::UInt(v)) => b.append_value(*v != 0),
227 Some(Value::Bytes(bv)) => {
228 let v = bytes_to_str(bv)
229 .and_then(|s| s.parse::<i64>().ok())
230 .unwrap_or(0);
231 b.append_value(v != 0);
232 }
233 _ => b.append_null(),
234 }
235 }
236 Ok(Arc::new(b.finish()))
237 }
238 DataType::Int16 => {
239 let mut b = Int16Builder::with_capacity(rows.len());
240 for row in rows {
241 match row.as_ref(col_idx) {
242 Some(Value::Int(v)) => b.append_value(*v as i16),
243 Some(Value::UInt(v)) => b.append_value(*v as i16),
244 Some(Value::Bytes(bv)) => match bytes_to_str(bv).and_then(|s| s.parse().ok()) {
245 Some(v) => b.append_value(v),
246 None => b.append_null(),
247 },
248 _ => b.append_null(),
249 }
250 }
251 Ok(Arc::new(b.finish()))
252 }
253 DataType::Int32 => {
254 let mut b = Int32Builder::with_capacity(rows.len());
255 for row in rows {
256 match row.as_ref(col_idx) {
257 Some(Value::Int(v)) => b.append_value(*v as i32),
258 Some(Value::UInt(v)) => b.append_value(*v as i32),
259 Some(Value::Bytes(bv)) => match bytes_to_str(bv).and_then(|s| s.parse().ok()) {
260 Some(v) => b.append_value(v),
261 None => b.append_null(),
262 },
263 _ => b.append_null(),
264 }
265 }
266 Ok(Arc::new(b.finish()))
267 }
268 DataType::Int64 => {
269 let mut b = Int64Builder::with_capacity(rows.len());
270 for row in rows {
271 match row.as_ref(col_idx) {
272 Some(Value::Int(v)) => b.append_value(*v),
273 Some(Value::UInt(v)) => b.append_value(*v as i64),
274 Some(Value::Bytes(bv)) => match bytes_to_str(bv).and_then(|s| s.parse().ok()) {
275 Some(v) => b.append_value(v),
276 None => b.append_null(),
277 },
278 _ => b.append_null(),
279 }
280 }
281 Ok(Arc::new(b.finish()))
282 }
283 DataType::Float32 => {
284 let mut b = Float32Builder::with_capacity(rows.len());
285 for row in rows {
286 match row.as_ref(col_idx) {
287 Some(Value::Float(v)) => b.append_value(*v),
288 Some(Value::Double(v)) => b.append_value(*v as f32),
289 Some(Value::Bytes(bv)) => match bytes_to_str(bv).and_then(|s| s.parse().ok()) {
290 Some(v) => b.append_value(v),
291 None => b.append_null(),
292 },
293 _ => b.append_null(),
294 }
295 }
296 Ok(Arc::new(b.finish()))
297 }
298 DataType::Float64 => {
299 let mut b = Float64Builder::with_capacity(rows.len());
300 for row in rows {
301 match row.as_ref(col_idx) {
302 Some(Value::Float(v)) => b.append_value(*v as f64),
303 Some(Value::Double(v)) => b.append_value(*v),
304 Some(Value::Bytes(bv)) => match bytes_to_str(bv).and_then(|s| s.parse().ok()) {
305 Some(v) => b.append_value(v),
306 None => b.append_null(),
307 },
308 _ => b.append_null(),
309 }
310 }
311 Ok(Arc::new(b.finish()))
312 }
313 DataType::Utf8 => {
314 let mut b = StringBuilder::with_capacity(rows.len(), rows.len() * 32);
315 for row in rows {
316 match row.as_ref(col_idx) {
317 Some(Value::Bytes(bv)) => b.append_value(String::from_utf8_lossy(bv).as_ref()),
318 Some(Value::Int(v)) => b.append_value(v.to_string()),
319 Some(Value::UInt(v)) => b.append_value(v.to_string()),
320 Some(Value::Float(v)) => b.append_value(v.to_string()),
321 Some(Value::Double(v)) => b.append_value(v.to_string()),
322 Some(Value::Date(y, m, d, h, mi, s, us)) => {
323 b.append_value(format!(
324 "{y:04}-{m:02}-{d:02} {h:02}:{mi:02}:{s:02}.{us:06}"
325 ));
326 }
327 _ => b.append_null(),
328 }
329 }
330 Ok(Arc::new(b.finish()))
331 }
332 DataType::Binary => {
333 let mut b = BinaryBuilder::with_capacity(rows.len(), rows.len() * 64);
334 for row in rows {
335 match row.as_ref(col_idx) {
336 Some(Value::Bytes(bv)) => b.append_value(bv),
337 _ => b.append_null(),
338 }
339 }
340 Ok(Arc::new(b.finish()))
341 }
342 DataType::Date32 => {
343 let mut b = Date32Builder::with_capacity(rows.len());
344 for row in rows {
345 let d = match row.as_ref(col_idx) {
346 Some(Value::Date(y, m, d, _, _, _, _)) => {
347 chrono::NaiveDate::from_ymd_opt(*y as i32, *m as u32, *d as u32)
348 }
349 Some(Value::Bytes(bv)) => bytes_to_str(bv).and_then(|s| {
350 chrono::NaiveDate::parse_from_str(
351 s.split(' ').next().unwrap_or(s),
352 "%Y-%m-%d",
353 )
354 .ok()
355 }),
356 _ => None,
357 };
358 match d {
359 Some(date) => {
360 let epoch =
361 chrono::NaiveDate::from_ymd_opt(1970, 1, 1).expect("epoch is valid");
362 b.append_value((date - epoch).num_days() as i32);
363 }
364 None => b.append_null(),
365 }
366 }
367 Ok(Arc::new(b.finish()))
368 }
369 DataType::Timestamp(TimeUnit::Microsecond, _) => {
370 let mut b = TimestampMicrosecondBuilder::with_capacity(rows.len());
371 for row in rows {
372 let dt = match row.as_ref(col_idx) {
373 Some(Value::Date(y, mo, d, h, mi, s, us)) => chrono::NaiveDate::from_ymd_opt(
374 *y as i32, *mo as u32, *d as u32,
375 )
376 .and_then(|d| d.and_hms_micro_opt(*h as u32, *mi as u32, *s as u32, *us)),
377 Some(Value::Bytes(bv)) => bytes_to_str(bv).and_then(|s| {
378 chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S").ok()
379 }),
380 _ => None,
381 };
382 match dt {
383 Some(dt) => b.append_value(dt.and_utc().timestamp_micros()),
384 None => b.append_null(),
385 }
386 }
387 Ok(Arc::new(b.finish()))
388 }
389 _ => {
390 log::warn!(
391 "unhandled Arrow type {:?} for MySQL, writing nulls",
392 arrow_type
393 );
394 let mut b = StringBuilder::with_capacity(rows.len(), 0);
395 for _ in rows {
396 b.append_null();
397 }
398 Ok(Arc::new(b.finish()))
399 }
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406 use crate::types::CursorState;
407
408 #[test]
409 fn test_build_query_full() {
410 assert_eq!(
411 build_query("SELECT * FROM users", None, None),
412 "SELECT * FROM users"
413 );
414 }
415
416 #[test]
417 fn test_build_query_incremental_first_run() {
418 let c = CursorState {
419 export_name: "t".into(),
420 last_cursor_value: None,
421 last_run_at: None,
422 };
423 let q = build_query("SELECT * FROM users", Some("id"), Some(&c));
424 assert!(q.contains("ORDER BY id"));
425 assert!(!q.contains("WHERE"));
426 }
427
428 #[test]
429 fn test_build_query_incremental_with_cursor() {
430 let c = CursorState {
431 export_name: "t".into(),
432 last_cursor_value: Some("42".into()),
433 last_run_at: None,
434 };
435 let q = build_query("SELECT * FROM events", Some("id"), Some(&c));
436 assert!(q.contains("WHERE id > '42'"), "got: {}", q);
437 assert!(q.contains("ORDER BY id"));
438 }
439}