Skip to main content

connector_arrow/postgres/
query.rs

1use arrow::datatypes::*;
2use arrow::record_batch::RecordBatch;
3
4use postgres::fallible_iterator::FallibleIterator;
5use postgres::types::{FromSql, Type};
6use postgres::{Client, Row, RowIter};
7
8use crate::api::{ResultReader, Statement};
9use crate::types::{ArrowType, FixedSizeBinaryType};
10use crate::util::CellReader;
11use crate::util::{transport, ArrayCellRef};
12use crate::{errors::ConnectorError, util::RowsReader};
13
14use super::{types, PostgresError};
15
16pub struct PostgresStatement<'conn> {
17    pub(super) client: &'conn mut Client,
18    pub(super) stmt: postgres::Statement,
19}
20
21impl<'conn> Statement<'conn> for PostgresStatement<'conn> {
22    type Reader<'stmt>
23        = PostgresBatchStream<'stmt>
24    where
25        Self: 'stmt;
26
27    fn start_batch<'p>(
28        &mut self,
29        args: (&RecordBatch, usize),
30    ) -> Result<Self::Reader<'_>, ConnectorError> {
31        let stmt = &self.stmt;
32        let schema = types::pg_stmt_to_arrow(stmt)?;
33
34        let arg_row = ArrayCellRef::vec_from_batch(args.0, args.1);
35
36        // query
37        let rows = self
38            .client
39            .query_raw::<_, _, _>(stmt, &arg_row)
40            .map_err(PostgresError::from)?;
41
42        // create the row reader
43        let row_reader = PostgresRowStream::new(rows);
44        Ok(PostgresBatchStream { schema, row_reader })
45    }
46}
47
48pub struct PostgresBatchStream<'a> {
49    schema: SchemaRef,
50    row_reader: PostgresRowStream<'a>,
51}
52
53impl<'a> ResultReader<'a> for PostgresBatchStream<'a> {
54    fn get_schema(&mut self) -> Result<std::sync::Arc<arrow::datatypes::Schema>, ConnectorError> {
55        Ok(self.schema.clone())
56    }
57}
58
59impl Iterator for PostgresBatchStream<'_> {
60    type Item = Result<RecordBatch, ConnectorError>;
61
62    fn next(&mut self) -> Option<Self::Item> {
63        crate::util::next_batch_from_rows(&self.schema, &mut self.row_reader, 1024).transpose()
64    }
65}
66
67struct PostgresRowStream<'a> {
68    iter: postgres_fallible_iterator::Fuse<postgres::RowIter<'a>>,
69}
70
71impl<'a> PostgresRowStream<'a> {
72    pub fn new(iter: RowIter<'a>) -> Self {
73        Self { iter: iter.fuse() }
74    }
75}
76
77impl<'stmt> RowsReader<'stmt> for PostgresRowStream<'stmt> {
78    type CellReader<'row>
79        = PostgresCellReader
80    where
81        Self: 'row;
82
83    fn next_row(&mut self) -> Result<Option<Self::CellReader<'_>>, ConnectorError> {
84        let row = self.iter.next().map_err(PostgresError::from)?;
85
86        Ok(row.map(|row| PostgresCellReader { row, next_col: 0 }))
87    }
88}
89
90struct PostgresCellReader {
91    row: Row,
92    next_col: usize,
93}
94
95impl CellReader<'_> for PostgresCellReader {
96    type CellRef<'cell>
97        = CellRef<'cell>
98    where
99        Self: 'cell;
100
101    fn next_cell(&mut self) -> Option<Self::CellRef<'_>> {
102        if self.next_col >= self.row.columns().len() {
103            return None;
104        }
105        let col = self.next_col;
106        self.next_col += 1;
107        Some((&self.row, col))
108    }
109}
110
111type CellRef<'a> = (&'a Row, usize);
112
113impl<'c> transport::Produce<'c> for CellRef<'c> {}
114
115macro_rules! impl_produce {
116    ($t: ty, $native: ty, $conversion_fn: expr) => {
117        impl<'c> transport::ProduceTy<'c, $t> for CellRef<'c> {
118            fn produce(self) -> Result<<$t as ArrowType>::Native, ConnectorError> {
119                let value = self.0.get::<_, $native>(self.1);
120                $conversion_fn(value)
121            }
122
123            fn produce_opt(self) -> Result<Option<<$t as ArrowType>::Native>, ConnectorError> {
124                let value = self.0.get::<_, Option<$native>>(self.1);
125                value.map($conversion_fn).transpose()
126            }
127        }
128    };
129}
130
131impl_produce!(BooleanType, bool, Result::Ok);
132impl_produce!(Int8Type, i8, Result::Ok);
133impl_produce!(Int16Type, i16, Result::Ok);
134impl_produce!(Int32Type, i32, Result::Ok);
135impl_produce!(Int64Type, i64, Result::Ok);
136impl_produce!(Float32Type, f32, Result::Ok);
137impl_produce!(Float64Type, f64, Result::Ok);
138impl_produce!(BinaryType, Binary, Binary::into_arrow);
139impl_produce!(LargeBinaryType, Binary, Binary::into_arrow);
140impl_produce!(Utf8Type, StrOrNum, StrOrNum::into_arrow);
141impl_produce!(LargeUtf8Type, String, Result::Ok);
142impl_produce!(
143    TimestampMicrosecondType,
144    TimestampY2000,
145    TimestampY2000::into_microsecond
146);
147impl_produce!(Time64MicrosecondType, Time64, Time64::into_microsecond);
148impl_produce!(Date32Type, DaysSinceY2000, DaysSinceY2000::into_date32);
149impl_produce!(
150    IntervalMonthDayNanoType,
151    IntervalMonthDayMicros,
152    IntervalMonthDayMicros::into_arrow
153);
154
155crate::impl_produce_unsupported!(
156    CellRef<'r>,
157    (
158        UInt8Type,
159        UInt16Type,
160        UInt32Type,
161        UInt64Type,
162        Float16Type,
163        TimestampSecondType,
164        TimestampMillisecondType,
165        TimestampNanosecondType,
166        Date64Type,
167        Time32SecondType,
168        Time32MillisecondType,
169        Time64NanosecondType,
170        IntervalYearMonthType,
171        IntervalDayTimeType,
172        DurationSecondType,
173        DurationMillisecondType,
174        DurationMicrosecondType,
175        DurationNanosecondType,
176        FixedSizeBinaryType,
177        Decimal128Type,
178        Decimal256Type,
179    )
180);
181
182struct StrOrNum(String);
183
184impl StrOrNum {
185    fn into_arrow(self) -> Result<String, ConnectorError> {
186        Ok(self.0)
187    }
188}
189
190impl<'a> FromSql<'a> for StrOrNum {
191    fn from_sql(
192        ty: &Type,
193        raw: &'a [u8],
194    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
195        if matches!(ty, &Type::NUMERIC) {
196            Ok(super::decimal::from_sql(raw).map(StrOrNum)?)
197        } else {
198            let slice = postgres_protocol::types::text_from_sql(raw)?;
199            Ok(StrOrNum(slice.to_string()))
200        }
201    }
202
203    fn accepts(_ty: &Type) -> bool {
204        true
205    }
206}
207
208const DUR_1970_TO_2000_DAYS: i32 = 10957;
209const DUR_1970_TO_2000_SEC: i64 = DUR_1970_TO_2000_DAYS as i64 * 24 * 60 * 60;
210
211struct TimestampY2000(i64);
212
213impl<'a> FromSql<'a> for TimestampY2000 {
214    fn from_sql(
215        _ty: &Type,
216        raw: &'a [u8],
217    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
218        postgres_protocol::types::timestamp_from_sql(raw).map(TimestampY2000)
219    }
220
221    fn accepts(_ty: &Type) -> bool {
222        true
223    }
224}
225
226impl TimestampY2000 {
227    fn into_microsecond(self) -> Result<i64, ConnectorError> {
228        self.0
229            .checked_add(DUR_1970_TO_2000_SEC * 1000 * 1000)
230            .ok_or(ConnectorError::DataOutOfRange)
231    }
232}
233
234struct DaysSinceY2000(i32);
235
236impl<'a> FromSql<'a> for DaysSinceY2000 {
237    fn from_sql(
238        _ty: &Type,
239        raw: &'a [u8],
240    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
241        postgres_protocol::types::date_from_sql(raw).map(DaysSinceY2000)
242    }
243
244    fn accepts(_ty: &Type) -> bool {
245        true
246    }
247}
248
249impl DaysSinceY2000 {
250    fn into_date32(self) -> Result<i32, ConnectorError> {
251        self.0
252            .checked_add(DUR_1970_TO_2000_DAYS)
253            .ok_or(ConnectorError::DataOutOfRange)
254    }
255}
256
257struct Time64(i64);
258
259impl<'a> FromSql<'a> for Time64 {
260    fn from_sql(
261        _ty: &Type,
262        raw: &'a [u8],
263    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
264        postgres_protocol::types::time_from_sql(raw).map(Time64)
265    }
266    fn accepts(_ty: &Type) -> bool {
267        true
268    }
269}
270
271impl Time64 {
272    fn into_microsecond(self) -> Result<i64, ConnectorError> {
273        Ok(self.0)
274    }
275}
276
277struct IntervalMonthDayMicros {
278    months: i32,
279    days: i32,
280    micros: i64,
281}
282
283impl<'a> FromSql<'a> for IntervalMonthDayMicros {
284    fn from_sql(
285        _ty: &Type,
286        raw: &'a [u8],
287    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
288        let micros = postgres_protocol::types::time_from_sql(&raw[0..8])?;
289        let days = postgres_protocol::types::int4_from_sql(&raw[8..12])?;
290        let months = postgres_protocol::types::int4_from_sql(&raw[12..16])?;
291        Ok(IntervalMonthDayMicros {
292            months,
293            days,
294            micros,
295        })
296    }
297    fn accepts(_ty: &Type) -> bool {
298        true
299    }
300}
301
302impl IntervalMonthDayMicros {
303    fn into_arrow(self) -> Result<IntervalMonthDayNano, ConnectorError> {
304        let nanoseconds = (self.micros.checked_mul(1000)).ok_or(ConnectorError::DataOutOfRange)?;
305        Ok(IntervalMonthDayNano {
306            months: self.months,
307            days: self.days,
308            nanoseconds,
309        })
310    }
311}
312
313struct Binary<'a>(&'a [u8]);
314
315impl<'a> FromSql<'a> for Binary<'a> {
316    fn from_sql(
317        ty: &Type,
318        raw: &'a [u8],
319    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
320        Ok(if matches!(ty, &Type::VARBIT | &Type::BIT) {
321            let varbit = postgres_protocol::types::varbit_from_sql(raw)?;
322            Binary(varbit.bytes())
323        } else {
324            Binary(postgres_protocol::types::bytea_from_sql(raw))
325        })
326    }
327    fn accepts(_ty: &Type) -> bool {
328        true
329    }
330}
331
332impl Binary<'_> {
333    fn into_arrow(self) -> Result<Vec<u8>, ConnectorError> {
334        // this is a clone, that is needed because Produce requires Vec<u8>
335        Ok(self.0.to_vec())
336    }
337}