Skip to main content

arrow_pg/
row_encoder.rs

1use std::sync::Arc;
2
3#[cfg(not(feature = "datafusion"))]
4use arrow::array::RecordBatch;
5#[cfg(feature = "datafusion")]
6use datafusion::arrow::array::RecordBatch;
7
8use pgwire::{
9    api::results::{DataRowEncoder, FieldInfo},
10    error::PgWireResult,
11    messages::data::DataRow,
12};
13
14use crate::encoder::encode_value;
15
16pub struct RowEncoder {
17    rb: RecordBatch,
18    curr_idx: usize,
19    fields: Arc<Vec<FieldInfo>>,
20    row_encoder: DataRowEncoder,
21}
22
23impl RowEncoder {
24    pub fn new(rb: RecordBatch, fields: Arc<Vec<FieldInfo>>) -> Self {
25        assert_eq!(rb.num_columns(), fields.len());
26        Self {
27            rb,
28            fields: fields.clone(),
29            curr_idx: 0,
30            row_encoder: DataRowEncoder::new(fields),
31        }
32    }
33
34    pub fn next_row(&mut self) -> Option<PgWireResult<DataRow>> {
35        if self.curr_idx == self.rb.num_rows() {
36            return None;
37        }
38
39        let arrow_schema = self.rb.schema_ref();
40        for col in 0..self.rb.num_columns() {
41            let array = self.rb.column(col);
42            let arrow_field = arrow_schema.field(col);
43            let pg_field = &self.fields[col];
44
45            if let Err(e) = encode_value(
46                &mut self.row_encoder,
47                array,
48                self.curr_idx,
49                arrow_field,
50                pg_field,
51            ) {
52                return Some(Err(e));
53            };
54        }
55        self.curr_idx += 1;
56        Some(Ok(self.row_encoder.take_row()))
57    }
58}