oversync_connectors/
postgres.rs1use async_trait::async_trait;
2use futures::TryStreamExt;
3use sqlx::postgres::PgPoolOptions;
4use sqlx::{Column, PgPool, Row, TypeInfo, ValueRef};
5use tokio::sync::mpsc;
6use tracing::debug;
7
8use oversync_core::error::OversyncError;
9use oversync_core::model::RawRow;
10use oversync_core::traits::OriginConnector;
11
12pub struct PostgresConnector {
13 pool: PgPool,
14 source_name: String,
15}
16
17impl PostgresConnector {
18 pub async fn new(name: &str, dsn: &str) -> Result<Self, OversyncError> {
19 let pool = PgPoolOptions::new()
20 .max_connections(5)
21 .connect(dsn)
22 .await
23 .map_err(|e| OversyncError::Connector(format!("postgres connect: {e}")))?;
24
25 Ok(Self {
26 pool,
27 source_name: name.to_string(),
28 })
29 }
30
31 pub fn from_pool(name: &str, pool: PgPool) -> Self {
32 Self {
33 pool,
34 source_name: name.to_string(),
35 }
36 }
37}
38
39#[async_trait]
40impl OriginConnector for PostgresConnector {
41 fn name(&self) -> &str {
42 &self.source_name
43 }
44
45 async fn fetch_all(&self, sql: &str, key_column: &str) -> Result<Vec<RawRow>, OversyncError> {
46 let rows = sqlx::query(sql)
47 .fetch_all(&self.pool)
48 .await
49 .map_err(|e| OversyncError::Connector(format!("fetch_all: {e}")))?;
50
51 let mut result = Vec::with_capacity(rows.len());
52
53 for row in &rows {
54 let key: String = row
55 .try_get(key_column)
56 .map_err(|e| OversyncError::Connector(format!("key column '{key_column}': {e}")))?;
57
58 let columns = row.columns();
59 let mut data = serde_json::Map::with_capacity(columns.len());
60
61 for col in columns {
62 let name = col.name();
63 let raw = row.try_get_raw(name).ok();
64 let val = match raw {
65 Some(ref r) if !r.is_null() => {
66 decode_pg_value(row, name, col.type_info().name())
67 }
68 _ => serde_json::Value::Null,
69 };
70 data.insert(name.to_string(), val);
71 }
72
73 result.push(RawRow {
74 row_key: key,
75 row_data: serde_json::Value::Object(data),
76 });
77 }
78
79 debug!(count = result.len(), "fetched rows from postgres");
80 Ok(result)
81 }
82
83 async fn fetch_into(
84 &self,
85 sql: &str,
86 key_column: &str,
87 batch_size: usize,
88 tx: mpsc::Sender<Vec<RawRow>>,
89 ) -> Result<usize, OversyncError> {
90 let mut stream = sqlx::query(sql).fetch(&self.pool);
91 let mut batch = Vec::with_capacity(batch_size);
92 let mut total = 0;
93
94 while let Some(row) = stream
95 .try_next()
96 .await
97 .map_err(|e| OversyncError::Connector(format!("fetch_into stream: {e}")))?
98 {
99 let key: String = row
100 .try_get(key_column)
101 .map_err(|e| OversyncError::Connector(format!("key column '{key_column}': {e}")))?;
102
103 let columns = row.columns();
104 let mut data = serde_json::Map::with_capacity(columns.len());
105 for col in columns {
106 let name = col.name();
107 let raw = row.try_get_raw(name).ok();
108 let val = match raw {
109 Some(ref r) if !r.is_null() => {
110 decode_pg_value(&row, name, col.type_info().name())
111 }
112 _ => serde_json::Value::Null,
113 };
114 data.insert(name.to_string(), val);
115 }
116
117 batch.push(RawRow {
118 row_key: key,
119 row_data: serde_json::Value::Object(data),
120 });
121 total += 1;
122
123 if batch.len() >= batch_size {
124 tx.send(std::mem::replace(
125 &mut batch,
126 Vec::with_capacity(batch_size),
127 ))
128 .await
129 .map_err(|_| OversyncError::Internal("channel closed".into()))?;
130 }
131 }
132
133 if !batch.is_empty() {
134 tx.send(batch)
135 .await
136 .map_err(|_| OversyncError::Internal("channel closed".into()))?;
137 }
138
139 debug!(total, "streamed rows from postgres");
140 Ok(total)
141 }
142
143 async fn test_connection(&self) -> Result<(), OversyncError> {
144 sqlx::query("SELECT 1")
145 .fetch_one(&self.pool)
146 .await
147 .map_err(|e| OversyncError::Connector(format!("test_connection: {e}")))?;
148 Ok(())
149 }
150}
151
152fn decode_pg_value(row: &sqlx::postgres::PgRow, name: &str, type_name: &str) -> serde_json::Value {
153 match type_name {
154 "BOOL" => row
155 .try_get::<bool, _>(name)
156 .map(serde_json::Value::Bool)
157 .unwrap_or(serde_json::Value::Null),
158 "INT2" | "INT4" => row
159 .try_get::<i32, _>(name)
160 .map(|v| serde_json::json!(v))
161 .unwrap_or(serde_json::Value::Null),
162 "INT8" | "OID" => row
163 .try_get::<i64, _>(name)
164 .map(|v| serde_json::json!(v))
165 .unwrap_or(serde_json::Value::Null),
166 "FLOAT4" | "FLOAT8" | "NUMERIC" => row
167 .try_get::<f64, _>(name)
168 .map(|v| serde_json::json!(v))
169 .unwrap_or(serde_json::Value::Null),
170 "JSON" | "JSONB" => row
171 .try_get::<serde_json::Value, _>(name)
172 .unwrap_or(serde_json::Value::Null),
173 _ => row
174 .try_get::<String, _>(name)
175 .map(serde_json::Value::String)
176 .unwrap_or(serde_json::Value::Null),
177 }
178}