Skip to main content

rivet_cli/source/
postgres.rs

1use std::sync::Arc;
2
3use arrow::array::{
4    Array, BinaryBuilder, BooleanBuilder, Date32Builder, Float32Builder, Float64Builder,
5    Int16Builder, Int32Builder, Int64Builder, StringBuilder, TimestampMicrosecondBuilder,
6};
7use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
8use arrow::record_batch::RecordBatch;
9use postgres::types::Type;
10use postgres::{Client, NoTls, Row};
11
12use crate::error::Result;
13use crate::tuning::SourceTuning;
14use crate::types::CursorState;
15
16pub struct PostgresSource {
17    client: Client,
18}
19
20impl PostgresSource {
21    pub fn connect(url: &str) -> Result<Self> {
22        let client = Client::connect(url, NoTls)?;
23        Ok(Self { client })
24    }
25}
26
27impl super::Source for PostgresSource {
28    fn export(
29        &mut self,
30        query: &str,
31        cursor_column: Option<&str>,
32        cursor: Option<&CursorState>,
33        tuning: &SourceTuning,
34        sink: &mut dyn super::BatchSink,
35    ) -> Result<()> {
36        let effective_query = build_query(query, cursor_column, cursor);
37        log::info!("executing query: {}", effective_query);
38
39        if tuning.statement_timeout_s > 0 {
40            self.client.batch_execute(&format!(
41                "SET statement_timeout = '{}s'",
42                tuning.statement_timeout_s
43            ))?;
44        }
45        if tuning.lock_timeout_s > 0 {
46            self.client
47                .batch_execute(&format!("SET lock_timeout = '{}s'", tuning.lock_timeout_s))?;
48        }
49
50        self.client.batch_execute("BEGIN")?;
51        self.client.batch_execute(&format!(
52            "DECLARE _rivet NO SCROLL CURSOR FOR {}",
53            effective_query
54        ))?;
55
56        let mut fetch_size = tuning.batch_size;
57        let mut fetch_sql = format!("FETCH {} FROM _rivet", fetch_size);
58        let mut schema: Option<SchemaRef> = None;
59        let mut columns_cache: Option<Vec<(String, Type)>> = None;
60        let mut total_rows: usize = 0;
61
62        loop {
63            let rows = self.client.query(&fetch_sql, &[])?;
64            if rows.is_empty() {
65                break;
66            }
67
68            if schema.is_none() {
69                let stmt_cols: Vec<(String, Type)> = rows[0]
70                    .columns()
71                    .iter()
72                    .map(|c| (c.name().to_string(), c.type_().clone()))
73                    .collect();
74                let s = Arc::new(pg_columns_to_schema(rows[0].columns()));
75                sink.on_schema(s.clone())?;
76                schema = Some(s.clone());
77                columns_cache = Some(stmt_cols);
78
79                let effective = tuning.effective_batch_size(Some(&s));
80                if effective != fetch_size {
81                    fetch_size = effective;
82                    fetch_sql = format!("FETCH {} FROM _rivet", fetch_size);
83                }
84            }
85
86            let row_count = rows.len();
87            total_rows += row_count;
88
89            let s = schema.as_ref().expect("schema set on first iteration");
90            let cols = columns_cache
91                .as_ref()
92                .expect("columns set on first iteration");
93            let batch = rows_to_record_batch_typed(s, cols, &rows)?;
94            drop(rows);
95            sink.on_batch(&batch)?;
96
97            log::info!("fetched {} rows so far...", total_rows);
98
99            if row_count < fetch_size {
100                break;
101            }
102
103            if tuning.throttle_ms > 0 {
104                std::thread::sleep(std::time::Duration::from_millis(tuning.throttle_ms));
105            }
106        }
107
108        self.client.batch_execute("CLOSE _rivet")?;
109        self.client.batch_execute("COMMIT")?;
110        self.client.batch_execute("RESET statement_timeout")?;
111        self.client.batch_execute("RESET lock_timeout")?;
112
113        if schema.is_none() {
114            sink.on_schema(Arc::new(Schema::empty()))?;
115        }
116
117        log::info!("total: {} rows", total_rows);
118        Ok(())
119    }
120
121    fn query_scalar(&mut self, sql: &str) -> Result<Option<String>> {
122        let rows = self.client.query(sql, &[])?;
123        if rows.is_empty() {
124            return Ok(None);
125        }
126        let row = &rows[0];
127        if let Ok(Some(v)) = row.try_get::<_, Option<i64>>(0) {
128            return Ok(Some(v.to_string()));
129        }
130        if let Ok(Some(v)) = row.try_get::<_, Option<i32>>(0) {
131            return Ok(Some(v.to_string()));
132        }
133        if let Ok(Some(v)) = row.try_get::<_, Option<f64>>(0) {
134            return Ok(Some(v.to_string()));
135        }
136        if let Ok(Some(v)) = row.try_get::<_, Option<String>>(0) {
137            return Ok(Some(v));
138        }
139        Ok(None)
140    }
141}
142
143pub(crate) fn build_query(
144    base_query: &str,
145    cursor_column: Option<&str>,
146    cursor: Option<&CursorState>,
147) -> String {
148    let has_cursor_value = cursor
149        .and_then(|c| c.last_cursor_value.as_deref())
150        .is_some();
151
152    if let (Some(col), true) = (cursor_column, has_cursor_value) {
153        let cursor_val = cursor
154            .expect("cursor checked above")
155            .last_cursor_value
156            .as_deref()
157            .expect("cursor value checked above");
158        format!(
159            "SELECT * FROM ({base}) AS _rivet WHERE {col} > '{val}' ORDER BY {col}",
160            base = base_query,
161            col = col,
162            val = cursor_val,
163        )
164    } else if let Some(col) = cursor_column {
165        format!(
166            "SELECT * FROM ({base}) AS _rivet ORDER BY {col}",
167            base = base_query,
168            col = col,
169        )
170    } else {
171        base_query.to_string()
172    }
173}
174
175fn pg_type_to_arrow(pg_type: &Type) -> DataType {
176    match *pg_type {
177        Type::BOOL => DataType::Boolean,
178        Type::INT2 => DataType::Int16,
179        Type::INT4 => DataType::Int32,
180        Type::INT8 => DataType::Int64,
181        Type::FLOAT4 => DataType::Float32,
182        Type::FLOAT8 => DataType::Float64,
183        Type::TEXT | Type::VARCHAR | Type::BPCHAR | Type::NAME => DataType::Utf8,
184        Type::BYTEA => DataType::Binary,
185        Type::DATE => DataType::Date32,
186        Type::TIMESTAMP | Type::TIMESTAMPTZ => DataType::Timestamp(TimeUnit::Microsecond, None),
187        Type::NUMERIC => DataType::Utf8,
188        Type::JSON | Type::JSONB => DataType::Utf8,
189        Type::UUID => DataType::Utf8,
190        Type::OID => DataType::Int64,
191        _ => {
192            log::warn!("unmapped PG type {:?}, falling back to Utf8", pg_type);
193            DataType::Utf8
194        }
195    }
196}
197
198fn pg_columns_to_schema(columns: &[postgres::Column]) -> Schema {
199    let fields: Vec<Field> = columns
200        .iter()
201        .map(|col| {
202            let dt = pg_type_to_arrow(col.type_());
203            Field::new(col.name(), dt, true)
204        })
205        .collect();
206    Schema::new(fields)
207}
208
209fn rows_to_record_batch_typed(
210    schema: &SchemaRef,
211    columns: &[(String, Type)],
212    rows: &[Row],
213) -> Result<RecordBatch> {
214    let mut arrays: Vec<Arc<dyn Array>> = Vec::with_capacity(columns.len());
215    for (col_idx, (_name, pg_type)) in columns.iter().enumerate() {
216        let array = build_array(pg_type, col_idx, rows)?;
217        arrays.push(array);
218    }
219    let batch = RecordBatch::try_new(schema.clone(), arrays)?;
220    Ok(batch)
221}
222
223fn build_array(pg_type: &Type, col_idx: usize, rows: &[Row]) -> Result<Arc<dyn Array>> {
224    match *pg_type {
225        Type::BOOL => {
226            let mut b = BooleanBuilder::with_capacity(rows.len());
227            for row in rows {
228                b.append_option(row.get(col_idx));
229            }
230            Ok(Arc::new(b.finish()))
231        }
232        Type::INT2 => {
233            let mut b = Int16Builder::with_capacity(rows.len());
234            for row in rows {
235                b.append_option(row.get(col_idx));
236            }
237            Ok(Arc::new(b.finish()))
238        }
239        Type::INT4 => {
240            let mut b = Int32Builder::with_capacity(rows.len());
241            for row in rows {
242                b.append_option(row.get(col_idx));
243            }
244            Ok(Arc::new(b.finish()))
245        }
246        Type::INT8 => {
247            let mut b = Int64Builder::with_capacity(rows.len());
248            for row in rows {
249                b.append_option(row.get(col_idx));
250            }
251            Ok(Arc::new(b.finish()))
252        }
253        Type::FLOAT4 => {
254            let mut b = Float32Builder::with_capacity(rows.len());
255            for row in rows {
256                b.append_option(row.get(col_idx));
257            }
258            Ok(Arc::new(b.finish()))
259        }
260        Type::FLOAT8 => {
261            let mut b = Float64Builder::with_capacity(rows.len());
262            for row in rows {
263                b.append_option(row.get(col_idx));
264            }
265            Ok(Arc::new(b.finish()))
266        }
267        Type::TEXT | Type::VARCHAR | Type::BPCHAR | Type::NAME => {
268            let mut b = StringBuilder::with_capacity(rows.len(), rows.len() * 32);
269            for row in rows {
270                let val: Option<String> = row.get(col_idx);
271                b.append_option(val.as_deref());
272            }
273            Ok(Arc::new(b.finish()))
274        }
275        Type::BYTEA => {
276            let mut b = BinaryBuilder::with_capacity(rows.len(), rows.len() * 64);
277            for row in rows {
278                match row.get::<_, Option<Vec<u8>>>(col_idx) {
279                    Some(v) => b.append_value(&v),
280                    None => b.append_null(),
281                }
282            }
283            Ok(Arc::new(b.finish()))
284        }
285        Type::DATE => {
286            let mut b = Date32Builder::with_capacity(rows.len());
287            for row in rows {
288                match row.get::<_, Option<chrono::NaiveDate>>(col_idx) {
289                    Some(d) => {
290                        let epoch =
291                            chrono::NaiveDate::from_ymd_opt(1970, 1, 1).expect("epoch is valid");
292                        b.append_value((d - epoch).num_days() as i32);
293                    }
294                    None => b.append_null(),
295                }
296            }
297            Ok(Arc::new(b.finish()))
298        }
299        Type::TIMESTAMP => {
300            let mut b = TimestampMicrosecondBuilder::with_capacity(rows.len());
301            for row in rows {
302                match row.get::<_, Option<chrono::NaiveDateTime>>(col_idx) {
303                    Some(ts) => b.append_value(ts.and_utc().timestamp_micros()),
304                    None => b.append_null(),
305                }
306            }
307            Ok(Arc::new(b.finish()))
308        }
309        Type::TIMESTAMPTZ => {
310            let mut b = TimestampMicrosecondBuilder::with_capacity(rows.len());
311            for row in rows {
312                match row.get::<_, Option<chrono::DateTime<chrono::Utc>>>(col_idx) {
313                    Some(ts) => b.append_value(ts.timestamp_micros()),
314                    None => b.append_null(),
315                }
316            }
317            Ok(Arc::new(b.finish()))
318        }
319        Type::NUMERIC | Type::JSON | Type::JSONB | Type::UUID => {
320            let mut b = StringBuilder::with_capacity(rows.len(), rows.len() * 32);
321            for row in rows {
322                let val: Option<String> = row.try_get(col_idx).ok().flatten();
323                b.append_option(val.as_deref());
324            }
325            Ok(Arc::new(b.finish()))
326        }
327        Type::OID => {
328            let mut b = Int64Builder::with_capacity(rows.len());
329            for row in rows {
330                b.append_option(row.get::<_, Option<u32>>(col_idx).map(|v| v as i64));
331            }
332            Ok(Arc::new(b.finish()))
333        }
334        _ => {
335            log::warn!("unmapped PG type {:?}, extracting as text", pg_type);
336            let mut b = StringBuilder::with_capacity(rows.len(), rows.len() * 32);
337            for row in rows {
338                let val: Option<String> = row.try_get(col_idx).ok().flatten();
339                b.append_option(val.as_deref());
340            }
341            Ok(Arc::new(b.finish()))
342        }
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use crate::types::CursorState;
350
351    #[test]
352    fn test_build_query_full() {
353        let q = build_query("SELECT * FROM users", None, None);
354        assert_eq!(q, "SELECT * FROM users");
355    }
356
357    #[test]
358    fn test_build_query_incremental_first_run() {
359        let cursor = CursorState {
360            export_name: "t".into(),
361            last_cursor_value: None,
362            last_run_at: None,
363        };
364        let q = build_query("SELECT * FROM users", Some("updated_at"), Some(&cursor));
365        assert!(q.contains("ORDER BY updated_at"));
366        assert!(!q.contains("WHERE"));
367    }
368
369    #[test]
370    fn test_build_query_incremental_with_cursor() {
371        let cursor = CursorState {
372            export_name: "t".into(),
373            last_cursor_value: Some("2024-01-01T00:00:00".into()),
374            last_run_at: Some("2024-06-01".into()),
375        };
376        let q = build_query("SELECT * FROM orders", Some("updated_at"), Some(&cursor));
377        assert!(
378            q.contains("WHERE updated_at > '2024-01-01T00:00:00'"),
379            "got: {}",
380            q
381        );
382        assert!(q.contains("ORDER BY updated_at"));
383    }
384}