1use std::sync::Arc;
2
3use arrow::array::{
4 Array, BinaryBuilder, BooleanBuilder, Date32Builder, Float32Builder, Float64Builder,
5 Int16Builder, Int32Builder, Int64Builder, StringBuilder, TimestampMicrosecondBuilder,
6};
7use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
8use arrow::record_batch::RecordBatch;
9use postgres::types::Type;
10use postgres::{Client, NoTls, Row};
11
12use crate::error::Result;
13use crate::tuning::SourceTuning;
14use crate::types::CursorState;
15
16pub struct PostgresSource {
17 client: Client,
18}
19
20impl PostgresSource {
21 pub fn connect(url: &str) -> Result<Self> {
22 let client = Client::connect(url, NoTls)?;
23 Ok(Self { client })
24 }
25}
26
27impl super::Source for PostgresSource {
28 fn export(
29 &mut self,
30 query: &str,
31 cursor_column: Option<&str>,
32 cursor: Option<&CursorState>,
33 tuning: &SourceTuning,
34 sink: &mut dyn super::BatchSink,
35 ) -> Result<()> {
36 let effective_query = build_query(query, cursor_column, cursor);
37 log::info!("executing query: {}", effective_query);
38
39 if tuning.statement_timeout_s > 0 {
40 self.client.batch_execute(&format!(
41 "SET statement_timeout = '{}s'",
42 tuning.statement_timeout_s
43 ))?;
44 }
45 if tuning.lock_timeout_s > 0 {
46 self.client
47 .batch_execute(&format!("SET lock_timeout = '{}s'", tuning.lock_timeout_s))?;
48 }
49
50 self.client.batch_execute("BEGIN")?;
51 self.client.batch_execute(&format!(
52 "DECLARE _rivet NO SCROLL CURSOR FOR {}",
53 effective_query
54 ))?;
55
56 let mut fetch_size = tuning.batch_size;
57 let mut fetch_sql = format!("FETCH {} FROM _rivet", fetch_size);
58 let mut schema: Option<SchemaRef> = None;
59 let mut columns_cache: Option<Vec<(String, Type)>> = None;
60 let mut total_rows: usize = 0;
61
62 loop {
63 let rows = self.client.query(&fetch_sql, &[])?;
64 if rows.is_empty() {
65 break;
66 }
67
68 if schema.is_none() {
69 let stmt_cols: Vec<(String, Type)> = rows[0]
70 .columns()
71 .iter()
72 .map(|c| (c.name().to_string(), c.type_().clone()))
73 .collect();
74 let s = Arc::new(pg_columns_to_schema(rows[0].columns()));
75 sink.on_schema(s.clone())?;
76 schema = Some(s.clone());
77 columns_cache = Some(stmt_cols);
78
79 let effective = tuning.effective_batch_size(Some(&s));
80 if effective != fetch_size {
81 fetch_size = effective;
82 fetch_sql = format!("FETCH {} FROM _rivet", fetch_size);
83 }
84 }
85
86 let row_count = rows.len();
87 total_rows += row_count;
88
89 let s = schema.as_ref().expect("schema set on first iteration");
90 let cols = columns_cache
91 .as_ref()
92 .expect("columns set on first iteration");
93 let batch = rows_to_record_batch_typed(s, cols, &rows)?;
94 drop(rows);
95 sink.on_batch(&batch)?;
96
97 log::info!("fetched {} rows so far...", total_rows);
98
99 if row_count < fetch_size {
100 break;
101 }
102
103 if tuning.throttle_ms > 0 {
104 std::thread::sleep(std::time::Duration::from_millis(tuning.throttle_ms));
105 }
106 }
107
108 self.client.batch_execute("CLOSE _rivet")?;
109 self.client.batch_execute("COMMIT")?;
110 self.client.batch_execute("RESET statement_timeout")?;
111 self.client.batch_execute("RESET lock_timeout")?;
112
113 if schema.is_none() {
114 sink.on_schema(Arc::new(Schema::empty()))?;
115 }
116
117 log::info!("total: {} rows", total_rows);
118 Ok(())
119 }
120
121 fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
122 let rows = self.client.query(sql, &[])?;
123 if rows.is_empty() {
124 return Ok(None);
125 }
126 let row = &rows[0];
127 if let Ok(Some(v)) = row.try_get::<_, Option<i64>>(0) {
128 return Ok(Some(v.to_string()));
129 }
130 if let Ok(Some(v)) = row.try_get::<_, Option<i32>>(0) {
131 return Ok(Some(v.to_string()));
132 }
133 if let Ok(Some(v)) = row.try_get::<_, Option<f64>>(0) {
134 return Ok(Some(v.to_string()));
135 }
136 if let Ok(Some(v)) = row.try_get::<_, Option<String>>(0) {
137 return Ok(Some(v));
138 }
139 Ok(None)
140 }
141}
142
143pub(crate) fn build_query(
144 base_query: &str,
145 cursor_column: Option<&str>,
146 cursor: Option<&CursorState>,
147) -> String {
148 let has_cursor_value = cursor
149 .and_then(|c| c.last_cursor_value.as_deref())
150 .is_some();
151
152 if let (Some(col), true) = (cursor_column, has_cursor_value) {
153 let cursor_val = cursor
154 .expect("cursor checked above")
155 .last_cursor_value
156 .as_deref()
157 .expect("cursor value checked above");
158 format!(
159 "SELECT * FROM ({base}) AS _rivet WHERE {col} > '{val}' ORDER BY {col}",
160 base = base_query,
161 col = col,
162 val = cursor_val,
163 )
164 } else if let Some(col) = cursor_column {
165 format!(
166 "SELECT * FROM ({base}) AS _rivet ORDER BY {col}",
167 base = base_query,
168 col = col,
169 )
170 } else {
171 base_query.to_string()
172 }
173}
174
175fn pg_type_to_arrow(pg_type: &Type) -> DataType {
176 match *pg_type {
177 Type::BOOL => DataType::Boolean,
178 Type::INT2 => DataType::Int16,
179 Type::INT4 => DataType::Int32,
180 Type::INT8 => DataType::Int64,
181 Type::FLOAT4 => DataType::Float32,
182 Type::FLOAT8 => DataType::Float64,
183 Type::TEXT | Type::VARCHAR | Type::BPCHAR | Type::NAME => DataType::Utf8,
184 Type::BYTEA => DataType::Binary,
185 Type::DATE => DataType::Date32,
186 Type::TIMESTAMP | Type::TIMESTAMPTZ => DataType::Timestamp(TimeUnit::Microsecond, None),
187 Type::NUMERIC => DataType::Utf8,
188 Type::JSON | Type::JSONB => DataType::Utf8,
189 Type::UUID => DataType::Utf8,
190 Type::OID => DataType::Int64,
191 _ => {
192 log::warn!("unmapped PG type {:?}, falling back to Utf8", pg_type);
193 DataType::Utf8
194 }
195 }
196}
197
198fn pg_columns_to_schema(columns: &[postgres::Column]) -> Schema {
199 let fields: Vec<Field> = columns
200 .iter()
201 .map(|col| {
202 let dt = pg_type_to_arrow(col.type_());
203 Field::new(col.name(), dt, true)
204 })
205 .collect();
206 Schema::new(fields)
207}
208
209fn rows_to_record_batch_typed(
210 schema: &SchemaRef,
211 columns: &[(String, Type)],
212 rows: &[Row],
213) -> Result<RecordBatch> {
214 let mut arrays: Vec<Arc<dyn Array>> = Vec::with_capacity(columns.len());
215 for (col_idx, (_name, pg_type)) in columns.iter().enumerate() {
216 let array = build_array(pg_type, col_idx, rows)?;
217 arrays.push(array);
218 }
219 let batch = RecordBatch::try_new(schema.clone(), arrays)?;
220 Ok(batch)
221}
222
223fn build_array(pg_type: &Type, col_idx: usize, rows: &[Row]) -> Result<Arc<dyn Array>> {
224 match *pg_type {
225 Type::BOOL => {
226 let mut b = BooleanBuilder::with_capacity(rows.len());
227 for row in rows {
228 b.append_option(row.get(col_idx));
229 }
230 Ok(Arc::new(b.finish()))
231 }
232 Type::INT2 => {
233 let mut b = Int16Builder::with_capacity(rows.len());
234 for row in rows {
235 b.append_option(row.get(col_idx));
236 }
237 Ok(Arc::new(b.finish()))
238 }
239 Type::INT4 => {
240 let mut b = Int32Builder::with_capacity(rows.len());
241 for row in rows {
242 b.append_option(row.get(col_idx));
243 }
244 Ok(Arc::new(b.finish()))
245 }
246 Type::INT8 => {
247 let mut b = Int64Builder::with_capacity(rows.len());
248 for row in rows {
249 b.append_option(row.get(col_idx));
250 }
251 Ok(Arc::new(b.finish()))
252 }
253 Type::FLOAT4 => {
254 let mut b = Float32Builder::with_capacity(rows.len());
255 for row in rows {
256 b.append_option(row.get(col_idx));
257 }
258 Ok(Arc::new(b.finish()))
259 }
260 Type::FLOAT8 => {
261 let mut b = Float64Builder::with_capacity(rows.len());
262 for row in rows {
263 b.append_option(row.get(col_idx));
264 }
265 Ok(Arc::new(b.finish()))
266 }
267 Type::TEXT | Type::VARCHAR | Type::BPCHAR | Type::NAME => {
268 let mut b = StringBuilder::with_capacity(rows.len(), rows.len() * 32);
269 for row in rows {
270 let val: Option<String> = row.get(col_idx);
271 b.append_option(val.as_deref());
272 }
273 Ok(Arc::new(b.finish()))
274 }
275 Type::BYTEA => {
276 let mut b = BinaryBuilder::with_capacity(rows.len(), rows.len() * 64);
277 for row in rows {
278 match row.get::<_, Option<Vec<u8>>>(col_idx) {
279 Some(v) => b.append_value(&v),
280 None => b.append_null(),
281 }
282 }
283 Ok(Arc::new(b.finish()))
284 }
285 Type::DATE => {
286 let mut b = Date32Builder::with_capacity(rows.len());
287 for row in rows {
288 match row.get::<_, Option<chrono::NaiveDate>>(col_idx) {
289 Some(d) => {
290 let epoch =
291 chrono::NaiveDate::from_ymd_opt(1970, 1, 1).expect("epoch is valid");
292 b.append_value((d - epoch).num_days() as i32);
293 }
294 None => b.append_null(),
295 }
296 }
297 Ok(Arc::new(b.finish()))
298 }
299 Type::TIMESTAMP => {
300 let mut b = TimestampMicrosecondBuilder::with_capacity(rows.len());
301 for row in rows {
302 match row.get::<_, Option<chrono::NaiveDateTime>>(col_idx) {
303 Some(ts) => b.append_value(ts.and_utc().timestamp_micros()),
304 None => b.append_null(),
305 }
306 }
307 Ok(Arc::new(b.finish()))
308 }
309 Type::TIMESTAMPTZ => {
310 let mut b = TimestampMicrosecondBuilder::with_capacity(rows.len());
311 for row in rows {
312 match row.get::<_, Option<chrono::DateTime<chrono::Utc>>>(col_idx) {
313 Some(ts) => b.append_value(ts.timestamp_micros()),
314 None => b.append_null(),
315 }
316 }
317 Ok(Arc::new(b.finish()))
318 }
319 Type::NUMERIC | Type::JSON | Type::JSONB | Type::UUID => {
320 let mut b = StringBuilder::with_capacity(rows.len(), rows.len() * 32);
321 for row in rows {
322 let val: Option<String> = row.try_get(col_idx).ok().flatten();
323 b.append_option(val.as_deref());
324 }
325 Ok(Arc::new(b.finish()))
326 }
327 Type::OID => {
328 let mut b = Int64Builder::with_capacity(rows.len());
329 for row in rows {
330 b.append_option(row.get::<_, Option<u32>>(col_idx).map(|v| v as i64));
331 }
332 Ok(Arc::new(b.finish()))
333 }
334 _ => {
335 log::warn!("unmapped PG type {:?}, extracting as text", pg_type);
336 let mut b = StringBuilder::with_capacity(rows.len(), rows.len() * 32);
337 for row in rows {
338 let val: Option<String> = row.try_get(col_idx).ok().flatten();
339 b.append_option(val.as_deref());
340 }
341 Ok(Arc::new(b.finish()))
342 }
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use crate::types::CursorState;
350
351 #[test]
352 fn test_build_query_full() {
353 let q = build_query("SELECT * FROM users", None, None);
354 assert_eq!(q, "SELECT * FROM users");
355 }
356
357 #[test]
358 fn test_build_query_incremental_first_run() {
359 let cursor = CursorState {
360 export_name: "t".into(),
361 last_cursor_value: None,
362 last_run_at: None,
363 };
364 let q = build_query("SELECT * FROM users", Some("updated_at"), Some(&cursor));
365 assert!(q.contains("ORDER BY updated_at"));
366 assert!(!q.contains("WHERE"));
367 }
368
369 #[test]
370 fn test_build_query_incremental_with_cursor() {
371 let cursor = CursorState {
372 export_name: "t".into(),
373 last_cursor_value: Some("2024-01-01T00:00:00".into()),
374 last_run_at: Some("2024-06-01".into()),
375 };
376 let q = build_query("SELECT * FROM orders", Some("updated_at"), Some(&cursor));
377 assert!(
378 q.contains("WHERE updated_at > '2024-01-01T00:00:00'"),
379 "got: {}",
380 q
381 );
382 assert!(q.contains("ORDER BY updated_at"));
383 }
384}