oversync-connectors 0.6.3

Source connector implementations for oversync.
Documentation
use async_trait::async_trait;
use futures::TryStreamExt;
use sqlx::postgres::PgPoolOptions;
use sqlx::{Column, PgPool, Row, TypeInfo, ValueRef};
use tokio::sync::mpsc;
use tracing::debug;

use oversync_core::error::OversyncError;
use oversync_core::model::RawRow;
use oversync_core::traits::OriginConnector;

const SYNTHETIC_KEY_COLUMN: &str = "__oversync_key";

pub struct PostgresConnector {
	pool: PgPool,
	source_name: String,
}

impl PostgresConnector {
	pub async fn new(name: &str, dsn: &str) -> Result<Self, OversyncError> {
		let pool = PgPoolOptions::new()
			.max_connections(5)
			.connect(dsn)
			.await
			.map_err(|e| OversyncError::Connector(format!("postgres connect: {e}")))?;

		Ok(Self {
			pool,
			source_name: name.to_string(),
		})
	}

	pub fn from_pool(name: &str, pool: PgPool) -> Self {
		Self {
			pool,
			source_name: name.to_string(),
		}
	}
}

#[async_trait]
impl OriginConnector for PostgresConnector {
	fn name(&self) -> &str {
		&self.source_name
	}

	async fn fetch_all(&self, sql: &str, key_column: &str) -> Result<Vec<RawRow>, OversyncError> {
		let rows = sqlx::query(sql)
			.fetch_all(&self.pool)
			.await
			.map_err(|e| OversyncError::Connector(format!("fetch_all: {e}")))?;

		let mut result = Vec::with_capacity(rows.len());

		for row in &rows {
			let key = decode_pg_key(row, key_column)?;

			let columns = row.columns();
			let mut data = serde_json::Map::with_capacity(columns.len());

			for col in columns {
				let name = col.name();
				if name == SYNTHETIC_KEY_COLUMN {
					continue;
				}
				let raw = row.try_get_raw(name).ok();
				let val = match raw {
					Some(ref r) if !r.is_null() => {
						decode_pg_value(row, name, col.type_info().name())
					}
					_ => serde_json::Value::Null,
				};
				data.insert(name.to_string(), val);
			}

			result.push(RawRow {
				row_key: key,
				row_data: serde_json::Value::Object(data),
			});
		}

		debug!(count = result.len(), "fetched rows from postgres");
		Ok(result)
	}

	async fn fetch_into(
		&self,
		sql: &str,
		key_column: &str,
		batch_size: usize,
		tx: mpsc::Sender<Vec<RawRow>>,
	) -> Result<usize, OversyncError> {
		let mut stream = sqlx::query(sql).fetch(&self.pool);
		let mut batch = Vec::with_capacity(batch_size);
		let mut total = 0;

		while let Some(row) = stream
			.try_next()
			.await
			.map_err(|e| OversyncError::Connector(format!("fetch_into stream: {e}")))?
		{
			let key = decode_pg_key(&row, key_column)?;

			let columns = row.columns();
			let mut data = serde_json::Map::with_capacity(columns.len());
			for col in columns {
				let name = col.name();
				if name == SYNTHETIC_KEY_COLUMN {
					continue;
				}
				let raw = row.try_get_raw(name).ok();
				let val = match raw {
					Some(ref r) if !r.is_null() => {
						decode_pg_value(&row, name, col.type_info().name())
					}
					_ => serde_json::Value::Null,
				};
				data.insert(name.to_string(), val);
			}

			batch.push(RawRow {
				row_key: key,
				row_data: serde_json::Value::Object(data),
			});
			total += 1;

			if batch.len() >= batch_size {
				tx.send(std::mem::replace(
					&mut batch,
					Vec::with_capacity(batch_size),
				))
				.await
				.map_err(|_| OversyncError::Internal("channel closed".into()))?;
			}
		}

		if !batch.is_empty() {
			tx.send(batch)
				.await
				.map_err(|_| OversyncError::Internal("channel closed".into()))?;
		}

		debug!(total, "streamed rows from postgres");
		Ok(total)
	}

	async fn test_connection(&self) -> Result<(), OversyncError> {
		sqlx::query("SELECT 1")
			.fetch_one(&self.pool)
			.await
			.map_err(|e| OversyncError::Connector(format!("test_connection: {e}")))?;
		Ok(())
	}
}

fn decode_pg_key(row: &sqlx::postgres::PgRow, key_column: &str) -> Result<String, OversyncError> {
	let key_col = row
		.columns()
		.iter()
		.find(|col| col.name() == key_column)
		.ok_or_else(|| OversyncError::Connector(format!("key column '{key_column}' not found")))?;
	let raw = row
		.try_get_raw(key_column)
		.map_err(|e| OversyncError::Connector(format!("key column '{key_column}': {e}")))?;
	if raw.is_null() {
		return Err(OversyncError::Connector(format!(
			"key column '{key_column}' is NULL"
		)));
	}
	let value = decode_pg_value(row, key_column, key_col.type_info().name());
	match value {
		serde_json::Value::String(s) => Ok(s),
		serde_json::Value::Null => Err(OversyncError::Connector(format!(
			"key column '{key_column}' is NULL"
		))),
		other => Ok(other.to_string()),
	}
}

fn decode_pg_value(row: &sqlx::postgres::PgRow, name: &str, type_name: &str) -> serde_json::Value {
	match type_name {
		"BOOL" => row
			.try_get::<bool, _>(name)
			.map(serde_json::Value::Bool)
			.unwrap_or(serde_json::Value::Null),
		"INT2" | "INT4" => row
			.try_get::<i32, _>(name)
			.map(|v| serde_json::json!(v))
			.unwrap_or(serde_json::Value::Null),
		"INT8" | "OID" => row
			.try_get::<i64, _>(name)
			.map(|v| serde_json::json!(v))
			.unwrap_or(serde_json::Value::Null),
		"FLOAT4" | "FLOAT8" | "NUMERIC" => row
			.try_get::<f64, _>(name)
			.map(|v| serde_json::json!(v))
			.unwrap_or(serde_json::Value::Null),
		"JSON" | "JSONB" => row
			.try_get::<serde_json::Value, _>(name)
			.unwrap_or(serde_json::Value::Null),
		_ => row
			.try_get::<String, _>(name)
			.map(serde_json::Value::String)
			.unwrap_or(serde_json::Value::Null),
	}
}