1mod arrow_convert;
22mod proxy;
23
24use std::sync::Arc;
25
26use arrow::datatypes::Schema;
27use mysql::prelude::*;
28use mysql::{Opts, OptsBuilder, Pool, PoolConstraints, PoolOpts, SslOpts};
29
30use crate::config::{SourceType, TlsConfig, TlsMode};
31use crate::error::Result;
32use crate::source::query::build_incremental_query;
33use crate::tuning::{ADAPTIVE_SAMPLE_INTERVAL, SourceTuning, next_adaptive_batch_size};
34use crate::types::ColumnOverrides;
35
36use arrow_convert::{
37 mysql_native_type_name, mysql_schema_and_arrow_types, mysql_type_to_rivet,
38 rows_to_record_batch_typed,
39};
40#[cfg(test)]
43use arrow_convert::bit_bytes_to_u64;
44use proxy::{detect_mysql_proxy_kind, warn_proxy_kind};
45
46pub use proxy::MysqlProxyKind;
50
51pub struct MysqlSource {
52 pool: Pool,
53 proxy_kind: MysqlProxyKind,
54}
55
56fn lean_pool_opts() -> PoolOpts {
60 PoolOpts::default()
61 .with_constraints(PoolConstraints::new(1, 100).expect("valid pool constraints"))
62}
63
64fn mysql_sample_innodb_log_waits(pool: &Pool) -> Option<u64> {
67 let mut conn = pool.get_conn().ok()?;
68 conn.query_first::<(String, u64), _>("SHOW GLOBAL STATUS LIKE 'Innodb_log_waits'")
69 .ok()
70 .flatten()
71 .map(|(_, v)| v)
72}
73
74impl MysqlSource {
75 #[allow(dead_code)]
78 pub fn from_pool(pool: Pool) -> Self {
79 let proxy_kind = detect_mysql_proxy_kind(&pool);
80 warn_proxy_kind(proxy_kind);
81 Self { pool, proxy_kind }
82 }
83
84 pub fn connect(url: &str) -> Result<Self> {
86 let opts =
87 Opts::from(OptsBuilder::from_opts(Opts::from_url(url)?).pool_opts(lean_pool_opts()));
88 let pool = Pool::new(opts)?;
89 let proxy_kind = detect_mysql_proxy_kind(&pool);
90 warn_proxy_kind(proxy_kind);
91 Ok(Self { pool, proxy_kind })
92 }
93
94 pub fn connect_with_tls(url: &str, tls: Option<&TlsConfig>) -> Result<Self> {
96 match tls {
97 Some(cfg) if cfg.mode.is_enforced() => {
98 let base = Opts::from_url(url)?;
99 let ssl = build_mysql_ssl_opts(cfg);
100 let opts = Opts::from(
101 OptsBuilder::from_opts(base)
102 .ssl_opts(Some(ssl))
103 .pool_opts(lean_pool_opts()),
104 );
105 let pool = Pool::new(opts)?;
106 let proxy_kind = detect_mysql_proxy_kind(&pool);
107 warn_proxy_kind(proxy_kind);
108 Ok(Self { pool, proxy_kind })
109 }
110 _ => Self::connect(url),
111 }
112 }
113
114 #[allow(dead_code)]
122 pub fn proxy_kind(&self) -> MysqlProxyKind {
123 self.proxy_kind
124 }
125}
126
127pub(crate) fn connect_pool(url: &str, tls: Option<&TlsConfig>) -> Result<Pool> {
132 match tls {
133 Some(cfg) if cfg.mode.is_enforced() => {
134 let base = Opts::from_url(url)?;
135 let ssl = build_mysql_ssl_opts(cfg);
136 let opts = Opts::from(
137 OptsBuilder::from_opts(base)
138 .ssl_opts(Some(ssl))
139 .pool_opts(lean_pool_opts()),
140 );
141 Ok(Pool::new(opts)?)
142 }
143 _ => {
144 let opts = Opts::from(
145 OptsBuilder::from_opts(Opts::from_url(url)?).pool_opts(lean_pool_opts()),
146 );
147 Ok(Pool::new(opts)?)
148 }
149 }
150}
151
152const INNODB_BLOB_OVERFLOW_THRESHOLD_BYTES: i64 = 8 * 1024;
156
157const INNODB_BLOB_OVERFLOW_DIVISOR: i64 = 3;
161
162fn correct_innodb_avg_row_length(raw_bytes: i64) -> i64 {
169 if raw_bytes > INNODB_BLOB_OVERFLOW_THRESHOLD_BYTES {
170 (raw_bytes / INNODB_BLOB_OVERFLOW_DIVISOR).max(INNODB_BLOB_OVERFLOW_THRESHOLD_BYTES / 2)
171 } else {
172 raw_bytes
173 }
174}
175
176pub(crate) fn introspect_mysql_table_for_chunking(
195 url: &str,
196 tls: Option<&TlsConfig>,
197 qualified_table: &str,
198) -> Result<crate::source::TableIntrospection> {
199 let pool = connect_pool(url, tls)?;
200 let mut conn = pool.get_conn()?;
201 let default_db: Option<String> = conn.query_first("SELECT DATABASE()")?;
202 let default_db = default_db.unwrap_or_default();
203
204 let (schema, table) = match qualified_table.split_once('.') {
205 Some((s, t)) => (s.to_string(), t.to_string()),
206 None => (default_db, qualified_table.to_string()),
207 };
208
209 let row_stats: Option<(i64, i64, i64)> = conn.exec_first(
214 "SELECT CAST(IFNULL(TABLE_ROWS, 0) AS SIGNED), \
215 CAST(IFNULL(AVG_ROW_LENGTH, 0) AS SIGNED), \
216 CAST(IFNULL(DATA_LENGTH, 0) AS SIGNED) \
217 FROM information_schema.TABLES \
218 WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?",
219 (&schema, &table),
220 )?;
221 let (row_estimate, avg_row_bytes) = match row_stats {
222 Some((rows, avg, data_len)) => {
223 let row_count = rows.max(0);
224 let raw_per_row = if avg > 0 {
225 Some(avg)
226 } else if row_count > 0 {
227 Some(data_len / row_count)
228 } else {
229 None
230 };
231 let per_row = raw_per_row.map(correct_innodb_avg_row_length);
246 (row_count, per_row.filter(|b| *b > 0))
247 }
248 None => (0, None),
249 };
250
251 let pk_first: Option<(String,)> = conn.exec_first(
255 "SELECT COLUMN_NAME \
256 FROM information_schema.STATISTICS \
257 WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND INDEX_NAME = 'PRIMARY' AND SEQ_IN_INDEX = 1",
258 (&schema, &table),
259 )?;
260 let single_int_pk = if let Some((col,)) = pk_first {
261 let composite: Option<(String,)> = conn.exec_first(
262 "SELECT COLUMN_NAME FROM information_schema.STATISTICS \
263 WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND INDEX_NAME = 'PRIMARY' AND SEQ_IN_INDEX = 2 \
264 LIMIT 1",
265 (&schema, &table),
266 )?;
267 if composite.is_some() {
268 log::debug!(
269 "introspect_mysql_table: composite PK on {schema}.{table} — skipping auto-resolve"
270 );
271 None
272 } else {
273 let type_row: Option<(String,)> = conn.exec_first(
275 "SELECT DATA_TYPE FROM information_schema.COLUMNS \
276 WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND COLUMN_NAME = ?",
277 (&schema, &table, &col),
278 )?;
279 match type_row.map(|(t,)| t.to_ascii_lowercase()) {
280 Some(t)
281 if matches!(
282 t.as_str(),
283 "tinyint" | "smallint" | "mediumint" | "int" | "bigint"
284 ) =>
285 {
286 Some(col)
287 }
288 Some(t) => {
289 log::debug!(
290 "introspect_mysql_table: PK '{col}' on {schema}.{table} has non-int type '{t}' — skipping auto-resolve"
291 );
292 None
293 }
294 None => None,
295 }
296 }
297 } else {
298 None
299 };
300
301 Ok(crate::source::TableIntrospection {
302 single_int_pk,
303 row_estimate,
304 avg_row_bytes,
305 })
306}
307
308fn build_mysql_ssl_opts(cfg: &TlsConfig) -> SslOpts {
309 let mut ssl = SslOpts::default();
310 if let Some(path) = &cfg.ca_file {
311 ssl = ssl.with_root_cert_path(Some(std::path::PathBuf::from(path)));
312 }
313 match cfg.mode {
314 TlsMode::Require => {
315 ssl = ssl
316 .with_danger_accept_invalid_certs(true)
317 .with_danger_skip_domain_validation(true);
318 }
319 TlsMode::VerifyCa => {
320 ssl = ssl.with_danger_skip_domain_validation(true);
321 }
322 TlsMode::VerifyFull => {
323 }
325 TlsMode::Disable => {
326 }
328 }
329 if cfg.accept_invalid_certs {
330 ssl = ssl.with_danger_accept_invalid_certs(true);
331 }
332 if cfg.accept_invalid_hostnames {
333 ssl = ssl.with_danger_skip_domain_validation(true);
334 }
335 ssl
336}
337
338fn mysql_run_export(
347 conn: &mut mysql::PooledConn,
348 sample_pool: Option<Pool>,
349 sql: &str,
350 cursor_param: Option<&str>,
351 tuning: &SourceTuning,
352 column_overrides: &ColumnOverrides,
353 sink: &mut dyn super::BatchSink,
354) -> Result<usize> {
355 let mut result = match cursor_param {
359 Some(val) => conn.exec_iter(sql, (val,))?,
360 None => conn.exec_iter(sql, ())?,
361 };
362 let columns = result.columns().as_ref().to_vec();
363
364 let (schema, arrow_types) = mysql_schema_and_arrow_types(&columns, column_overrides)?;
367 let schema = Arc::new(schema);
368
369 sink.on_schema(schema.clone())?;
370
371 const PROBE_BATCH_SIZE: usize = 500;
386 const MYSQL_BATCH_TARGET_MB_DEFAULT: usize = 64;
387
388 let configured_batch_size = tuning.effective_batch_size(Some(&schema));
389 let mut effective_bs = configured_batch_size.min(PROBE_BATCH_SIZE);
390 let mut base_fetch_size = effective_bs;
391 let mut adaptive_last_waits: Option<u64> = if tuning.adaptive {
392 sample_pool.as_ref().and_then(mysql_sample_innodb_log_waits)
393 } else {
394 None
395 };
396 let mut batch_count: usize = 0;
397 let row_set = result
398 .iter()
399 .ok_or_else(|| anyhow::anyhow!("no result set"))?;
400 let mut row_buf: Vec<mysql::Row> = Vec::with_capacity(effective_bs);
401 let mut total_rows: usize = 0;
402 let mut memory_cap_applied = false;
403
404 for row_result in row_set {
405 let row = row_result?;
406 row_buf.push(row);
407
408 if row_buf.len() >= effective_bs {
409 total_rows += row_buf.len();
410 batch_count += 1;
411 let batch = rows_to_record_batch_typed(&schema, &arrow_types, &row_buf)?;
412 let batch_rows = row_buf.len();
413 row_buf.clear();
414
415 if !memory_cap_applied && batch_rows > 0 {
419 let arrow_bytes = crate::tuning::SourceTuning::batch_memory_bytes(&batch);
420 let arrow_per_row = (arrow_bytes / batch_rows).max(64);
421 let target_mb = tuning
422 .batch_size_memory_mb
423 .unwrap_or(MYSQL_BATCH_TARGET_MB_DEFAULT);
424 let safe = ((target_mb * 1024 * 1024) / arrow_per_row).max(PROBE_BATCH_SIZE);
425 let target = safe.min(configured_batch_size);
426 if target != effective_bs {
427 log::info!(
428 "MySQL row_buf cap: arrow≈{} B/row, target={} MB → batch_size {} → {} (configured={})",
429 arrow_per_row,
430 target_mb,
431 effective_bs,
432 target,
433 configured_batch_size
434 );
435 effective_bs = target;
436 base_fetch_size = effective_bs;
437 row_buf.reserve(effective_bs.saturating_sub(row_buf.capacity()));
438 }
439 memory_cap_applied = true;
440 }
441
442 sink.on_batch(&batch)?;
443
444 if tuning.adaptive
445 && batch_count.is_multiple_of(ADAPTIVE_SAMPLE_INTERVAL)
446 && let Some(ref pool) = sample_pool
447 && let Some(cur) = mysql_sample_innodb_log_waits(pool)
448 {
449 let under_pressure = adaptive_last_waits.is_some_and(|prev| cur > prev);
450 adaptive_last_waits = Some(cur);
451 let next = next_adaptive_batch_size(effective_bs, base_fetch_size, under_pressure);
452 if next != effective_bs {
453 effective_bs = next;
454 log::info!(
455 "adaptive batch size → {} ({})",
456 effective_bs,
457 if under_pressure {
458 "pressure"
459 } else {
460 "recovery"
461 }
462 );
463 }
464 }
465
466 log::info!("fetched {} rows so far...", total_rows);
467
468 if tuning.throttle_ms > 0 {
469 std::thread::sleep(std::time::Duration::from_millis(tuning.throttle_ms));
470 }
471 }
472 }
473
474 if !row_buf.is_empty() {
475 total_rows += row_buf.len();
476 let batch = rows_to_record_batch_typed(&schema, &arrow_types, &row_buf)?;
477 sink.on_batch(&batch)?;
478 }
479
480 drop(result);
481 Ok(total_rows)
482}
483
484impl super::Source for MysqlSource {
485 fn export(
486 &mut self,
487 request: &super::ExportRequest<'_>,
488 sink: &mut dyn super::BatchSink,
489 ) -> Result<()> {
490 let built = build_incremental_query(
491 request.query,
492 request.incremental,
493 request.cursor,
494 SourceType::Mysql,
495 );
496 log::debug!(
497 "executing query (connection={}): {}",
498 self.proxy_kind.log_label(),
499 built.sql
500 );
501
502 let mut conn = self.pool.get_conn()?;
503
504 conn.query_drop("SET time_zone = '+00:00'")?;
507
508 if request.tuning.statement_timeout_s > 0 {
509 conn.query_drop(format!(
510 "SET SESSION max_execution_time = {}",
511 request.tuning.statement_timeout_s * 1000
512 ))?;
513 }
514
515 let sample_pool = if request.tuning.adaptive {
516 Some(self.pool.clone())
517 } else {
518 None
519 };
520 let result = mysql_run_export(
521 &mut conn,
522 sample_pool,
523 &built.sql,
524 built.cursor_param.as_deref(),
525 request.tuning,
526 request.column_overrides,
527 sink,
528 );
529
530 let _ = conn.query_drop("SET time_zone = @@global.time_zone");
533 if request.tuning.statement_timeout_s > 0 {
534 let _ = conn.query_drop("SET SESSION max_execution_time = 0");
535 }
536
537 let total_rows = result?;
542 if total_rows == 0 {
543 sink.on_schema(Arc::new(Schema::empty()))?;
544 }
545 log::info!("total: {} rows", total_rows);
546 Ok(())
547 }
548
549 fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
550 let mut conn = self.pool.get_conn()?;
551 let row: Option<mysql::Row> = conn.query_first(sql)?;
552 match row {
553 Some(r) => {
554 let val: Option<mysql::Value> = r.get(0);
555 match val {
556 Some(mysql::Value::Bytes(b)) => {
557 Ok(Some(String::from_utf8_lossy(&b).into_owned()))
558 }
559 Some(mysql::Value::Int(v)) => Ok(Some(v.to_string())),
560 Some(mysql::Value::UInt(v)) => Ok(Some(v.to_string())),
561 Some(mysql::Value::Float(v)) => Ok(Some(v.to_string())),
562 Some(mysql::Value::Double(v)) => Ok(Some(v.to_string())),
563 _ => Ok(None),
564 }
565 }
566 None => Ok(None),
567 }
568 }
569
570 fn type_mappings(
571 &mut self,
572 query: &str,
573 column_overrides: &ColumnOverrides,
574 ) -> Result<Vec<crate::types::TypeMapping>> {
575 let wrapped = format!("SELECT * FROM ({}) AS _rivet_type_probe LIMIT 0", query);
576 let mut conn = self.pool.get_conn()?;
577 let result = conn.exec_iter(&wrapped, ())?;
578 let columns = result.columns().as_ref().to_vec();
579 drop(result);
580 let mappings = columns
581 .iter()
582 .map(|col| {
583 let rivet = column_overrides
584 .get(col.name_str().as_ref())
585 .cloned()
586 .unwrap_or_else(|| mysql_type_to_rivet(col));
587 let source = crate::types::SourceColumn::simple(
588 col.name_str().as_ref(),
589 mysql_native_type_name(col),
590 true,
591 );
592 crate::types::TypeMapping::from_source(&source, rivet)
593 })
594 .collect();
595 Ok(mappings)
596 }
597}
598
599#[cfg(test)]
600mod tests {
601 use super::{bit_bytes_to_u64, correct_innodb_avg_row_length};
602
603 #[test]
608 fn bit_bytes_single_byte() {
609 assert_eq!(bit_bytes_to_u64(&[0x00]), 0);
610 assert_eq!(bit_bytes_to_u64(&[0x01]), 1);
611 assert_eq!(bit_bytes_to_u64(&[0xFF]), 255);
612 }
613
614 #[test]
615 fn bit_bytes_multi_byte() {
616 assert_eq!(bit_bytes_to_u64(&[0x01, 0x02]), 258);
617 assert_eq!(bit_bytes_to_u64(&[0xFF; 8]), u64::MAX);
618 }
619
620 #[test]
621 fn bit_bytes_empty() {
622 assert_eq!(bit_bytes_to_u64(&[]), 0);
623 }
624
625 #[test]
628 fn innodb_correction_below_threshold_is_identity() {
629 assert_eq!(correct_innodb_avg_row_length(82), 82);
630 assert_eq!(correct_innodb_avg_row_length(314), 314);
631 assert_eq!(correct_innodb_avg_row_length(2_048), 2_048);
632 assert_eq!(correct_innodb_avg_row_length(8 * 1024), 8 * 1024);
633 }
634
635 #[test]
636 fn innodb_correction_above_threshold_divides_by_three() {
637 assert_eq!(correct_innodb_avg_row_length(40_978), 40_978 / 3);
638 assert_eq!(correct_innodb_avg_row_length(120_000), 40_000);
639 }
640
641 #[test]
642 fn innodb_correction_does_not_undershoot_floor() {
643 let just_above = 8 * 1024 + 1;
644 let divided = correct_innodb_avg_row_length(just_above);
645 assert!(divided >= 4 * 1024, "got {divided}");
646 }
647}