connector_arrow/postgres/
query.rs1use arrow::datatypes::*;
2use arrow::record_batch::RecordBatch;
3
4use postgres::fallible_iterator::FallibleIterator;
5use postgres::types::{FromSql, Type};
6use postgres::{Client, Row, RowIter};
7
8use crate::api::{ResultReader, Statement};
9use crate::types::{ArrowType, FixedSizeBinaryType};
10use crate::util::CellReader;
11use crate::util::{transport, ArrayCellRef};
12use crate::{errors::ConnectorError, util::RowsReader};
13
14use super::{types, PostgresError};
15
16pub struct PostgresStatement<'conn> {
17 pub(super) client: &'conn mut Client,
18 pub(super) stmt: postgres::Statement,
19}
20
21impl<'conn> Statement<'conn> for PostgresStatement<'conn> {
22 type Reader<'stmt>
23 = PostgresBatchStream<'stmt>
24 where
25 Self: 'stmt;
26
27 fn start_batch<'p>(
28 &mut self,
29 args: (&RecordBatch, usize),
30 ) -> Result<Self::Reader<'_>, ConnectorError> {
31 let stmt = &self.stmt;
32 let schema = types::pg_stmt_to_arrow(stmt)?;
33
34 let arg_row = ArrayCellRef::vec_from_batch(args.0, args.1);
35
36 let rows = self
38 .client
39 .query_raw::<_, _, _>(stmt, &arg_row)
40 .map_err(PostgresError::from)?;
41
42 let row_reader = PostgresRowStream::new(rows);
44 Ok(PostgresBatchStream { schema, row_reader })
45 }
46}
47
48pub struct PostgresBatchStream<'a> {
49 schema: SchemaRef,
50 row_reader: PostgresRowStream<'a>,
51}
52
53impl<'a> ResultReader<'a> for PostgresBatchStream<'a> {
54 fn get_schema(&mut self) -> Result<std::sync::Arc<arrow::datatypes::Schema>, ConnectorError> {
55 Ok(self.schema.clone())
56 }
57}
58
59impl Iterator for PostgresBatchStream<'_> {
60 type Item = Result<RecordBatch, ConnectorError>;
61
62 fn next(&mut self) -> Option<Self::Item> {
63 crate::util::next_batch_from_rows(&self.schema, &mut self.row_reader, 1024).transpose()
64 }
65}
66
67struct PostgresRowStream<'a> {
68 iter: postgres_fallible_iterator::Fuse<postgres::RowIter<'a>>,
69}
70
71impl<'a> PostgresRowStream<'a> {
72 pub fn new(iter: RowIter<'a>) -> Self {
73 Self { iter: iter.fuse() }
74 }
75}
76
77impl<'stmt> RowsReader<'stmt> for PostgresRowStream<'stmt> {
78 type CellReader<'row>
79 = PostgresCellReader
80 where
81 Self: 'row;
82
83 fn next_row(&mut self) -> Result<Option<Self::CellReader<'_>>, ConnectorError> {
84 let row = self.iter.next().map_err(PostgresError::from)?;
85
86 Ok(row.map(|row| PostgresCellReader { row, next_col: 0 }))
87 }
88}
89
90struct PostgresCellReader {
91 row: Row,
92 next_col: usize,
93}
94
95impl CellReader<'_> for PostgresCellReader {
96 type CellRef<'cell>
97 = CellRef<'cell>
98 where
99 Self: 'cell;
100
101 fn next_cell(&mut self) -> Option<Self::CellRef<'_>> {
102 if self.next_col >= self.row.columns().len() {
103 return None;
104 }
105 let col = self.next_col;
106 self.next_col += 1;
107 Some((&self.row, col))
108 }
109}
110
111type CellRef<'a> = (&'a Row, usize);
112
113impl<'c> transport::Produce<'c> for CellRef<'c> {}
114
115macro_rules! impl_produce {
116 ($t: ty, $native: ty, $conversion_fn: expr) => {
117 impl<'c> transport::ProduceTy<'c, $t> for CellRef<'c> {
118 fn produce(self) -> Result<<$t as ArrowType>::Native, ConnectorError> {
119 let value = self.0.get::<_, $native>(self.1);
120 $conversion_fn(value)
121 }
122
123 fn produce_opt(self) -> Result<Option<<$t as ArrowType>::Native>, ConnectorError> {
124 let value = self.0.get::<_, Option<$native>>(self.1);
125 value.map($conversion_fn).transpose()
126 }
127 }
128 };
129}
130
131impl_produce!(BooleanType, bool, Result::Ok);
132impl_produce!(Int8Type, i8, Result::Ok);
133impl_produce!(Int16Type, i16, Result::Ok);
134impl_produce!(Int32Type, i32, Result::Ok);
135impl_produce!(Int64Type, i64, Result::Ok);
136impl_produce!(Float32Type, f32, Result::Ok);
137impl_produce!(Float64Type, f64, Result::Ok);
138impl_produce!(BinaryType, Binary, Binary::into_arrow);
139impl_produce!(LargeBinaryType, Binary, Binary::into_arrow);
140impl_produce!(Utf8Type, StrOrNum, StrOrNum::into_arrow);
141impl_produce!(LargeUtf8Type, String, Result::Ok);
142impl_produce!(
143 TimestampMicrosecondType,
144 TimestampY2000,
145 TimestampY2000::into_microsecond
146);
147impl_produce!(Time64MicrosecondType, Time64, Time64::into_microsecond);
148impl_produce!(Date32Type, DaysSinceY2000, DaysSinceY2000::into_date32);
149impl_produce!(
150 IntervalMonthDayNanoType,
151 IntervalMonthDayMicros,
152 IntervalMonthDayMicros::into_arrow
153);
154
155crate::impl_produce_unsupported!(
156 CellRef<'r>,
157 (
158 UInt8Type,
159 UInt16Type,
160 UInt32Type,
161 UInt64Type,
162 Float16Type,
163 TimestampSecondType,
164 TimestampMillisecondType,
165 TimestampNanosecondType,
166 Date64Type,
167 Time32SecondType,
168 Time32MillisecondType,
169 Time64NanosecondType,
170 IntervalYearMonthType,
171 IntervalDayTimeType,
172 DurationSecondType,
173 DurationMillisecondType,
174 DurationMicrosecondType,
175 DurationNanosecondType,
176 FixedSizeBinaryType,
177 Decimal128Type,
178 Decimal256Type,
179 )
180);
181
182struct StrOrNum(String);
183
184impl StrOrNum {
185 fn into_arrow(self) -> Result<String, ConnectorError> {
186 Ok(self.0)
187 }
188}
189
190impl<'a> FromSql<'a> for StrOrNum {
191 fn from_sql(
192 ty: &Type,
193 raw: &'a [u8],
194 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
195 if matches!(ty, &Type::NUMERIC) {
196 Ok(super::decimal::from_sql(raw).map(StrOrNum)?)
197 } else {
198 let slice = postgres_protocol::types::text_from_sql(raw)?;
199 Ok(StrOrNum(slice.to_string()))
200 }
201 }
202
203 fn accepts(_ty: &Type) -> bool {
204 true
205 }
206}
207
208const DUR_1970_TO_2000_DAYS: i32 = 10957;
209const DUR_1970_TO_2000_SEC: i64 = DUR_1970_TO_2000_DAYS as i64 * 24 * 60 * 60;
210
211struct TimestampY2000(i64);
212
213impl<'a> FromSql<'a> for TimestampY2000 {
214 fn from_sql(
215 _ty: &Type,
216 raw: &'a [u8],
217 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
218 postgres_protocol::types::timestamp_from_sql(raw).map(TimestampY2000)
219 }
220
221 fn accepts(_ty: &Type) -> bool {
222 true
223 }
224}
225
226impl TimestampY2000 {
227 fn into_microsecond(self) -> Result<i64, ConnectorError> {
228 self.0
229 .checked_add(DUR_1970_TO_2000_SEC * 1000 * 1000)
230 .ok_or(ConnectorError::DataOutOfRange)
231 }
232}
233
234struct DaysSinceY2000(i32);
235
236impl<'a> FromSql<'a> for DaysSinceY2000 {
237 fn from_sql(
238 _ty: &Type,
239 raw: &'a [u8],
240 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
241 postgres_protocol::types::date_from_sql(raw).map(DaysSinceY2000)
242 }
243
244 fn accepts(_ty: &Type) -> bool {
245 true
246 }
247}
248
249impl DaysSinceY2000 {
250 fn into_date32(self) -> Result<i32, ConnectorError> {
251 self.0
252 .checked_add(DUR_1970_TO_2000_DAYS)
253 .ok_or(ConnectorError::DataOutOfRange)
254 }
255}
256
257struct Time64(i64);
258
259impl<'a> FromSql<'a> for Time64 {
260 fn from_sql(
261 _ty: &Type,
262 raw: &'a [u8],
263 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
264 postgres_protocol::types::time_from_sql(raw).map(Time64)
265 }
266 fn accepts(_ty: &Type) -> bool {
267 true
268 }
269}
270
271impl Time64 {
272 fn into_microsecond(self) -> Result<i64, ConnectorError> {
273 Ok(self.0)
274 }
275}
276
277struct IntervalMonthDayMicros {
278 months: i32,
279 days: i32,
280 micros: i64,
281}
282
283impl<'a> FromSql<'a> for IntervalMonthDayMicros {
284 fn from_sql(
285 _ty: &Type,
286 raw: &'a [u8],
287 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
288 let micros = postgres_protocol::types::time_from_sql(&raw[0..8])?;
289 let days = postgres_protocol::types::int4_from_sql(&raw[8..12])?;
290 let months = postgres_protocol::types::int4_from_sql(&raw[12..16])?;
291 Ok(IntervalMonthDayMicros {
292 months,
293 days,
294 micros,
295 })
296 }
297 fn accepts(_ty: &Type) -> bool {
298 true
299 }
300}
301
302impl IntervalMonthDayMicros {
303 fn into_arrow(self) -> Result<IntervalMonthDayNano, ConnectorError> {
304 let nanoseconds = (self.micros.checked_mul(1000)).ok_or(ConnectorError::DataOutOfRange)?;
305 Ok(IntervalMonthDayNano {
306 months: self.months,
307 days: self.days,
308 nanoseconds,
309 })
310 }
311}
312
313struct Binary<'a>(&'a [u8]);
314
315impl<'a> FromSql<'a> for Binary<'a> {
316 fn from_sql(
317 ty: &Type,
318 raw: &'a [u8],
319 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
320 Ok(if matches!(ty, &Type::VARBIT | &Type::BIT) {
321 let varbit = postgres_protocol::types::varbit_from_sql(raw)?;
322 Binary(varbit.bytes())
323 } else {
324 Binary(postgres_protocol::types::bytea_from_sql(raw))
325 })
326 }
327 fn accepts(_ty: &Type) -> bool {
328 true
329 }
330}
331
332impl Binary<'_> {
333 fn into_arrow(self) -> Result<Vec<u8>, ConnectorError> {
334 Ok(self.0.to_vec())
336 }
337}