Skip to main content

oversync_connectors/
postgres.rs

1use async_trait::async_trait;
2use futures::TryStreamExt;
3use sqlx::postgres::PgPoolOptions;
4use sqlx::{Column, PgPool, Row, TypeInfo, ValueRef};
5use tokio::sync::mpsc;
6use tracing::debug;
7
8use oversync_core::error::OversyncError;
9use oversync_core::model::RawRow;
10use oversync_core::traits::OriginConnector;
11
12pub struct PostgresConnector {
13	pool: PgPool,
14	source_name: String,
15}
16
17impl PostgresConnector {
18	pub async fn new(name: &str, dsn: &str) -> Result<Self, OversyncError> {
19		let pool = PgPoolOptions::new()
20			.max_connections(5)
21			.connect(dsn)
22			.await
23			.map_err(|e| OversyncError::Connector(format!("postgres connect: {e}")))?;
24
25		Ok(Self {
26			pool,
27			source_name: name.to_string(),
28		})
29	}
30
31	pub fn from_pool(name: &str, pool: PgPool) -> Self {
32		Self {
33			pool,
34			source_name: name.to_string(),
35		}
36	}
37}
38
39#[async_trait]
40impl OriginConnector for PostgresConnector {
41	fn name(&self) -> &str {
42		&self.source_name
43	}
44
45	async fn fetch_all(&self, sql: &str, key_column: &str) -> Result<Vec<RawRow>, OversyncError> {
46		let rows = sqlx::query(sql)
47			.fetch_all(&self.pool)
48			.await
49			.map_err(|e| OversyncError::Connector(format!("fetch_all: {e}")))?;
50
51		let mut result = Vec::with_capacity(rows.len());
52
53		for row in &rows {
54			let key: String = row
55				.try_get(key_column)
56				.map_err(|e| OversyncError::Connector(format!("key column '{key_column}': {e}")))?;
57
58			let columns = row.columns();
59			let mut data = serde_json::Map::with_capacity(columns.len());
60
61			for col in columns {
62				let name = col.name();
63				let raw = row.try_get_raw(name).ok();
64				let val = match raw {
65					Some(ref r) if !r.is_null() => {
66						decode_pg_value(row, name, col.type_info().name())
67					}
68					_ => serde_json::Value::Null,
69				};
70				data.insert(name.to_string(), val);
71			}
72
73			result.push(RawRow {
74				row_key: key,
75				row_data: serde_json::Value::Object(data),
76			});
77		}
78
79		debug!(count = result.len(), "fetched rows from postgres");
80		Ok(result)
81	}
82
83	async fn fetch_into(
84		&self,
85		sql: &str,
86		key_column: &str,
87		batch_size: usize,
88		tx: mpsc::Sender<Vec<RawRow>>,
89	) -> Result<usize, OversyncError> {
90		let mut stream = sqlx::query(sql).fetch(&self.pool);
91		let mut batch = Vec::with_capacity(batch_size);
92		let mut total = 0;
93
94		while let Some(row) = stream
95			.try_next()
96			.await
97			.map_err(|e| OversyncError::Connector(format!("fetch_into stream: {e}")))?
98		{
99			let key: String = row
100				.try_get(key_column)
101				.map_err(|e| OversyncError::Connector(format!("key column '{key_column}': {e}")))?;
102
103			let columns = row.columns();
104			let mut data = serde_json::Map::with_capacity(columns.len());
105			for col in columns {
106				let name = col.name();
107				let raw = row.try_get_raw(name).ok();
108				let val = match raw {
109					Some(ref r) if !r.is_null() => {
110						decode_pg_value(&row, name, col.type_info().name())
111					}
112					_ => serde_json::Value::Null,
113				};
114				data.insert(name.to_string(), val);
115			}
116
117			batch.push(RawRow {
118				row_key: key,
119				row_data: serde_json::Value::Object(data),
120			});
121			total += 1;
122
123			if batch.len() >= batch_size {
124				tx.send(std::mem::replace(
125					&mut batch,
126					Vec::with_capacity(batch_size),
127				))
128				.await
129				.map_err(|_| OversyncError::Internal("channel closed".into()))?;
130			}
131		}
132
133		if !batch.is_empty() {
134			tx.send(batch)
135				.await
136				.map_err(|_| OversyncError::Internal("channel closed".into()))?;
137		}
138
139		debug!(total, "streamed rows from postgres");
140		Ok(total)
141	}
142
143	async fn test_connection(&self) -> Result<(), OversyncError> {
144		sqlx::query("SELECT 1")
145			.fetch_one(&self.pool)
146			.await
147			.map_err(|e| OversyncError::Connector(format!("test_connection: {e}")))?;
148		Ok(())
149	}
150}
151
152fn decode_pg_value(row: &sqlx::postgres::PgRow, name: &str, type_name: &str) -> serde_json::Value {
153	match type_name {
154		"BOOL" => row
155			.try_get::<bool, _>(name)
156			.map(serde_json::Value::Bool)
157			.unwrap_or(serde_json::Value::Null),
158		"INT2" | "INT4" => row
159			.try_get::<i32, _>(name)
160			.map(|v| serde_json::json!(v))
161			.unwrap_or(serde_json::Value::Null),
162		"INT8" | "OID" => row
163			.try_get::<i64, _>(name)
164			.map(|v| serde_json::json!(v))
165			.unwrap_or(serde_json::Value::Null),
166		"FLOAT4" | "FLOAT8" | "NUMERIC" => row
167			.try_get::<f64, _>(name)
168			.map(|v| serde_json::json!(v))
169			.unwrap_or(serde_json::Value::Null),
170		"JSON" | "JSONB" => row
171			.try_get::<serde_json::Value, _>(name)
172			.unwrap_or(serde_json::Value::Null),
173		_ => row
174			.try_get::<String, _>(name)
175			.map(serde_json::Value::String)
176			.unwrap_or(serde_json::Value::Null),
177	}
178}