Skip to main content

rivet_cli/source/
mysql.rs

1use std::sync::Arc;
2
3use arrow::array::{
4    Array, BinaryBuilder, BooleanBuilder, Date32Builder, Float32Builder, Float64Builder,
5    Int16Builder, Int32Builder, Int64Builder, StringBuilder, TimestampMicrosecondBuilder,
6};
7use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
8use arrow::record_batch::RecordBatch;
9use mysql::prelude::*;
10use mysql::{Opts, Pool, Value};
11
12use crate::error::Result;
13use crate::tuning::SourceTuning;
14use crate::types::CursorState;
15
16pub struct MysqlSource {
17    pool: Pool,
18}
19
20impl MysqlSource {
21    pub fn connect(url: &str) -> Result<Self> {
22        let opts = Opts::from_url(url)?;
23        let pool = Pool::new(opts)?;
24        Ok(Self { pool })
25    }
26}
27
28impl super::Source for MysqlSource {
29    fn export(
30        &mut self,
31        query: &str,
32        cursor_column: Option<&str>,
33        cursor: Option<&CursorState>,
34        tuning: &SourceTuning,
35        sink: &mut dyn super::BatchSink,
36    ) -> Result<()> {
37        let effective_query = build_query(query, cursor_column, cursor);
38        log::info!("executing query: {}", effective_query);
39
40        let mut conn = self.pool.get_conn()?;
41
42        if tuning.statement_timeout_s > 0 {
43            conn.query_drop(format!(
44                "SET SESSION max_execution_time = {}",
45                tuning.statement_timeout_s * 1000
46            ))?;
47        }
48
49        let mut result = conn.query_iter(&effective_query)?;
50        let columns = result.columns().as_ref().to_vec();
51        let schema = Arc::new(mysql_columns_to_schema(&columns));
52        let arrow_types: Vec<DataType> = columns.iter().map(mysql_type_to_arrow).collect();
53
54        sink.on_schema(schema.clone())?;
55
56        let effective_bs = tuning.effective_batch_size(Some(&schema));
57        let row_set = result
58            .iter()
59            .ok_or_else(|| anyhow::anyhow!("no result set"))?;
60        let mut row_buf: Vec<mysql::Row> = Vec::with_capacity(effective_bs);
61        let mut total_rows: usize = 0;
62
63        for row_result in row_set {
64            let row = row_result?;
65            row_buf.push(row);
66
67            if row_buf.len() >= effective_bs {
68                total_rows += row_buf.len();
69                let batch = rows_to_record_batch_typed(&schema, &arrow_types, &row_buf)?;
70                sink.on_batch(&batch)?;
71                row_buf.clear();
72
73                log::info!("fetched {} rows so far...", total_rows);
74
75                if tuning.throttle_ms > 0 {
76                    std::thread::sleep(std::time::Duration::from_millis(tuning.throttle_ms));
77                }
78            }
79        }
80
81        if !row_buf.is_empty() {
82            total_rows += row_buf.len();
83            let batch = rows_to_record_batch_typed(&schema, &arrow_types, &row_buf)?;
84            sink.on_batch(&batch)?;
85        }
86
87        drop(result);
88
89        if tuning.statement_timeout_s > 0 {
90            conn.query_drop("SET SESSION max_execution_time = 0")?;
91        }
92
93        log::info!("total: {} rows", total_rows);
94        Ok(())
95    }
96
97    fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
98        use mysql::prelude::*;
99        let mut conn = self.pool.get_conn()?;
100        let row: Option<mysql::Row> = conn.query_first(sql)?;
101        match row {
102            Some(r) => {
103                let val: Option<mysql::Value> = r.get(0);
104                match val {
105                    Some(mysql::Value::Bytes(b)) => {
106                        Ok(Some(String::from_utf8_lossy(&b).into_owned()))
107                    }
108                    Some(mysql::Value::Int(v)) => Ok(Some(v.to_string())),
109                    Some(mysql::Value::UInt(v)) => Ok(Some(v.to_string())),
110                    Some(mysql::Value::Float(v)) => Ok(Some(v.to_string())),
111                    Some(mysql::Value::Double(v)) => Ok(Some(v.to_string())),
112                    _ => Ok(None),
113                }
114            }
115            None => Ok(None),
116        }
117    }
118}
119
120pub(crate) fn build_query(
121    base_query: &str,
122    cursor_column: Option<&str>,
123    cursor: Option<&CursorState>,
124) -> String {
125    let has_cursor_value = cursor
126        .and_then(|c| c.last_cursor_value.as_deref())
127        .is_some();
128
129    if let (Some(col), true) = (cursor_column, has_cursor_value) {
130        let cursor_val = cursor
131            .expect("cursor checked above")
132            .last_cursor_value
133            .as_deref()
134            .expect("cursor value checked above");
135        format!(
136            "SELECT * FROM ({base}) AS _rivet WHERE {col} > '{val}' ORDER BY {col}",
137            base = base_query,
138            col = col,
139            val = cursor_val,
140        )
141    } else if let Some(col) = cursor_column {
142        format!(
143            "SELECT * FROM ({base}) AS _rivet ORDER BY {col}",
144            base = base_query,
145            col = col,
146        )
147    } else {
148        base_query.to_string()
149    }
150}
151
152fn mysql_type_to_arrow(col: &mysql::Column) -> DataType {
153    use mysql::consts::ColumnType::*;
154    match col.column_type() {
155        MYSQL_TYPE_TINY | MYSQL_TYPE_SHORT => DataType::Int16,
156        MYSQL_TYPE_INT24 | MYSQL_TYPE_LONG => DataType::Int32,
157        MYSQL_TYPE_LONGLONG => DataType::Int64,
158        MYSQL_TYPE_FLOAT => DataType::Float32,
159        MYSQL_TYPE_DOUBLE => DataType::Float64,
160        MYSQL_TYPE_DECIMAL | MYSQL_TYPE_NEWDECIMAL => DataType::Utf8,
161        MYSQL_TYPE_VARCHAR
162        | MYSQL_TYPE_VAR_STRING
163        | MYSQL_TYPE_STRING
164        | MYSQL_TYPE_ENUM
165        | MYSQL_TYPE_SET => DataType::Utf8,
166        MYSQL_TYPE_JSON => DataType::Utf8,
167        MYSQL_TYPE_TINY_BLOB | MYSQL_TYPE_MEDIUM_BLOB | MYSQL_TYPE_LONG_BLOB | MYSQL_TYPE_BLOB => {
168            if col.character_set() == 63 {
169                DataType::Binary
170            } else {
171                DataType::Utf8
172            }
173        }
174        MYSQL_TYPE_DATE | MYSQL_TYPE_NEWDATE => DataType::Date32,
175        MYSQL_TYPE_DATETIME
176        | MYSQL_TYPE_DATETIME2
177        | MYSQL_TYPE_TIMESTAMP
178        | MYSQL_TYPE_TIMESTAMP2 => DataType::Timestamp(TimeUnit::Microsecond, None),
179        MYSQL_TYPE_BIT => DataType::Boolean,
180        MYSQL_TYPE_YEAR => DataType::Int16,
181        _ => {
182            log::warn!(
183                "unmapped MySQL type {:?}, falling back to Utf8",
184                col.column_type()
185            );
186            DataType::Utf8
187        }
188    }
189}
190
191fn mysql_columns_to_schema(columns: &[mysql::Column]) -> Schema {
192    let fields: Vec<Field> = columns
193        .iter()
194        .map(|col| Field::new(col.name_str().to_string(), mysql_type_to_arrow(col), true))
195        .collect();
196    Schema::new(fields)
197}
198
199fn rows_to_record_batch_typed(
200    schema: &SchemaRef,
201    arrow_types: &[DataType],
202    rows: &[mysql::Row],
203) -> Result<RecordBatch> {
204    let mut arrays: Vec<Arc<dyn Array>> = Vec::with_capacity(arrow_types.len());
205    for (col_idx, arrow_type) in arrow_types.iter().enumerate() {
206        arrays.push(build_array(arrow_type, col_idx, rows)?);
207    }
208    Ok(RecordBatch::try_new(schema.clone(), arrays)?)
209}
210
211fn bytes_to_str(b: &[u8]) -> Option<&str> {
212    std::str::from_utf8(b).ok()
213}
214
215fn build_array(
216    arrow_type: &DataType,
217    col_idx: usize,
218    rows: &[mysql::Row],
219) -> Result<Arc<dyn Array>> {
220    match arrow_type {
221        DataType::Boolean => {
222            let mut b = BooleanBuilder::with_capacity(rows.len());
223            for row in rows {
224                match row.as_ref(col_idx) {
225                    Some(Value::Int(v)) => b.append_value(*v != 0),
226                    Some(Value::UInt(v)) => b.append_value(*v != 0),
227                    Some(Value::Bytes(bv)) => {
228                        let v = bytes_to_str(bv)
229                            .and_then(|s| s.parse::<i64>().ok())
230                            .unwrap_or(0);
231                        b.append_value(v != 0);
232                    }
233                    _ => b.append_null(),
234                }
235            }
236            Ok(Arc::new(b.finish()))
237        }
238        DataType::Int16 => {
239            let mut b = Int16Builder::with_capacity(rows.len());
240            for row in rows {
241                match row.as_ref(col_idx) {
242                    Some(Value::Int(v)) => b.append_value(*v as i16),
243                    Some(Value::UInt(v)) => b.append_value(*v as i16),
244                    Some(Value::Bytes(bv)) => match bytes_to_str(bv).and_then(|s| s.parse().ok()) {
245                        Some(v) => b.append_value(v),
246                        None => b.append_null(),
247                    },
248                    _ => b.append_null(),
249                }
250            }
251            Ok(Arc::new(b.finish()))
252        }
253        DataType::Int32 => {
254            let mut b = Int32Builder::with_capacity(rows.len());
255            for row in rows {
256                match row.as_ref(col_idx) {
257                    Some(Value::Int(v)) => b.append_value(*v as i32),
258                    Some(Value::UInt(v)) => b.append_value(*v as i32),
259                    Some(Value::Bytes(bv)) => match bytes_to_str(bv).and_then(|s| s.parse().ok()) {
260                        Some(v) => b.append_value(v),
261                        None => b.append_null(),
262                    },
263                    _ => b.append_null(),
264                }
265            }
266            Ok(Arc::new(b.finish()))
267        }
268        DataType::Int64 => {
269            let mut b = Int64Builder::with_capacity(rows.len());
270            for row in rows {
271                match row.as_ref(col_idx) {
272                    Some(Value::Int(v)) => b.append_value(*v),
273                    Some(Value::UInt(v)) => b.append_value(*v as i64),
274                    Some(Value::Bytes(bv)) => match bytes_to_str(bv).and_then(|s| s.parse().ok()) {
275                        Some(v) => b.append_value(v),
276                        None => b.append_null(),
277                    },
278                    _ => b.append_null(),
279                }
280            }
281            Ok(Arc::new(b.finish()))
282        }
283        DataType::Float32 => {
284            let mut b = Float32Builder::with_capacity(rows.len());
285            for row in rows {
286                match row.as_ref(col_idx) {
287                    Some(Value::Float(v)) => b.append_value(*v),
288                    Some(Value::Double(v)) => b.append_value(*v as f32),
289                    Some(Value::Bytes(bv)) => match bytes_to_str(bv).and_then(|s| s.parse().ok()) {
290                        Some(v) => b.append_value(v),
291                        None => b.append_null(),
292                    },
293                    _ => b.append_null(),
294                }
295            }
296            Ok(Arc::new(b.finish()))
297        }
298        DataType::Float64 => {
299            let mut b = Float64Builder::with_capacity(rows.len());
300            for row in rows {
301                match row.as_ref(col_idx) {
302                    Some(Value::Float(v)) => b.append_value(*v as f64),
303                    Some(Value::Double(v)) => b.append_value(*v),
304                    Some(Value::Bytes(bv)) => match bytes_to_str(bv).and_then(|s| s.parse().ok()) {
305                        Some(v) => b.append_value(v),
306                        None => b.append_null(),
307                    },
308                    _ => b.append_null(),
309                }
310            }
311            Ok(Arc::new(b.finish()))
312        }
313        DataType::Utf8 => {
314            let mut b = StringBuilder::with_capacity(rows.len(), rows.len() * 32);
315            for row in rows {
316                match row.as_ref(col_idx) {
317                    Some(Value::Bytes(bv)) => b.append_value(String::from_utf8_lossy(bv).as_ref()),
318                    Some(Value::Int(v)) => b.append_value(v.to_string()),
319                    Some(Value::UInt(v)) => b.append_value(v.to_string()),
320                    Some(Value::Float(v)) => b.append_value(v.to_string()),
321                    Some(Value::Double(v)) => b.append_value(v.to_string()),
322                    Some(Value::Date(y, m, d, h, mi, s, us)) => {
323                        b.append_value(format!(
324                            "{y:04}-{m:02}-{d:02} {h:02}:{mi:02}:{s:02}.{us:06}"
325                        ));
326                    }
327                    _ => b.append_null(),
328                }
329            }
330            Ok(Arc::new(b.finish()))
331        }
332        DataType::Binary => {
333            let mut b = BinaryBuilder::with_capacity(rows.len(), rows.len() * 64);
334            for row in rows {
335                match row.as_ref(col_idx) {
336                    Some(Value::Bytes(bv)) => b.append_value(bv),
337                    _ => b.append_null(),
338                }
339            }
340            Ok(Arc::new(b.finish()))
341        }
342        DataType::Date32 => {
343            let mut b = Date32Builder::with_capacity(rows.len());
344            for row in rows {
345                let d = match row.as_ref(col_idx) {
346                    Some(Value::Date(y, m, d, _, _, _, _)) => {
347                        chrono::NaiveDate::from_ymd_opt(*y as i32, *m as u32, *d as u32)
348                    }
349                    Some(Value::Bytes(bv)) => bytes_to_str(bv).and_then(|s| {
350                        chrono::NaiveDate::parse_from_str(
351                            s.split(' ').next().unwrap_or(s),
352                            "%Y-%m-%d",
353                        )
354                        .ok()
355                    }),
356                    _ => None,
357                };
358                match d {
359                    Some(date) => {
360                        let epoch =
361                            chrono::NaiveDate::from_ymd_opt(1970, 1, 1).expect("epoch is valid");
362                        b.append_value((date - epoch).num_days() as i32);
363                    }
364                    None => b.append_null(),
365                }
366            }
367            Ok(Arc::new(b.finish()))
368        }
369        DataType::Timestamp(TimeUnit::Microsecond, _) => {
370            let mut b = TimestampMicrosecondBuilder::with_capacity(rows.len());
371            for row in rows {
372                let dt = match row.as_ref(col_idx) {
373                    Some(Value::Date(y, mo, d, h, mi, s, us)) => chrono::NaiveDate::from_ymd_opt(
374                        *y as i32, *mo as u32, *d as u32,
375                    )
376                    .and_then(|d| d.and_hms_micro_opt(*h as u32, *mi as u32, *s as u32, *us)),
377                    Some(Value::Bytes(bv)) => bytes_to_str(bv).and_then(|s| {
378                        chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S").ok()
379                    }),
380                    _ => None,
381                };
382                match dt {
383                    Some(dt) => b.append_value(dt.and_utc().timestamp_micros()),
384                    None => b.append_null(),
385                }
386            }
387            Ok(Arc::new(b.finish()))
388        }
389        _ => {
390            log::warn!(
391                "unhandled Arrow type {:?} for MySQL, writing nulls",
392                arrow_type
393            );
394            let mut b = StringBuilder::with_capacity(rows.len(), 0);
395            for _ in rows {
396                b.append_null();
397            }
398            Ok(Arc::new(b.finish()))
399        }
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use crate::types::CursorState;
407
408    #[test]
409    fn test_build_query_full() {
410        assert_eq!(
411            build_query("SELECT * FROM users", None, None),
412            "SELECT * FROM users"
413        );
414    }
415
416    #[test]
417    fn test_build_query_incremental_first_run() {
418        let c = CursorState {
419            export_name: "t".into(),
420            last_cursor_value: None,
421            last_run_at: None,
422        };
423        let q = build_query("SELECT * FROM users", Some("id"), Some(&c));
424        assert!(q.contains("ORDER BY id"));
425        assert!(!q.contains("WHERE"));
426    }
427
428    #[test]
429    fn test_build_query_incremental_with_cursor() {
430        let c = CursorState {
431            export_name: "t".into(),
432            last_cursor_value: Some("42".into()),
433            last_run_at: None,
434        };
435        let q = build_query("SELECT * FROM events", Some("id"), Some(&c));
436        assert!(q.contains("WHERE id > '42'"), "got: {}", q);
437        assert!(q.contains("ORDER BY id"));
438    }
439}