datafusion_table_providers/sql/arrow_sql_gen/postgres/
composite.rs

1use byteorder::{BigEndian, ByteOrder, ReadBytesExt};
2use fallible_iterator::FallibleIterator;
3use snafu::prelude::*;
4use std::{fmt, ops::Range};
5use tokio_postgres::{
6    row::RowIndex,
7    types::{Field, FromSql, Kind, Type, WrongType},
8};
9
10#[derive(Debug, Snafu)]
11pub enum Error {
12    #[snafu(display("Unable to parse composite type ranges: {source}"))]
13    UnableToParseCompositeTypeRanges { source: std::io::Error },
14
15    #[snafu(display("Unable to find column {column} in the fields {}", fields.join(", ")))]
16    UnableToFindColumnInFields { column: String, fields: Vec<String> },
17
18    #[snafu(display("{source}"))]
19    UnableToConvertType { source: WrongType },
20
21    #[snafu(display("Unable to conver raw bytes into expected type: {source}"))]
22    UnableToConvertBytesToType {
23        source: Box<dyn std::error::Error + Sync + Send>,
24    },
25}
26
27pub type Result<T, E = Error> = std::result::Result<T, E>;
28
29/// A `PostgreSQL` composite type.
30/// Fields of a type can be accessed using `CompositeType::get` and `CompositeType::try_get` methods.
31///
32/// Adapted from <https://github.com/sfackler/rust-postgres/pull/565>
33pub struct CompositeType<'a> {
34    type_: Type,
35    body: &'a [u8],
36    ranges: Vec<Option<Range<usize>>>,
37}
38
39#[allow(clippy::cast_sign_loss)]
40#[allow(clippy::cast_possible_truncation)]
41impl<'a> FromSql<'a> for CompositeType<'a> {
42    fn from_sql(
43        type_: &Type,
44        body: &'a [u8],
45    ) -> Result<CompositeType<'a>, Box<dyn std::error::Error + Sync + Send>> {
46        match *type_.kind() {
47            Kind::Composite(_) => {
48                let fields: &[Field] = composite_type_fields(type_);
49                if body.len() < 4 {
50                    let message = format!("invalid composite type body length: {}", body.len());
51                    return Err(message.into());
52                }
53                let num_fields: i32 = BigEndian::read_i32(&body[0..4]);
54                if num_fields as usize != fields.len() {
55                    let message =
56                        format!("invalid field count: {} vs {}", num_fields, fields.len());
57                    return Err(message.into());
58                }
59                let ranges = CompositeTypeRanges::new(&body[4..], body.len(), num_fields as u16)
60                    .collect()
61                    .context(UnableToParseCompositeTypeRangesSnafu)?;
62                Ok(CompositeType {
63                    type_: type_.clone(),
64                    body,
65                    ranges,
66                })
67            }
68            _ => Err(format!("expected composite type, got {type_}").into()),
69        }
70    }
71    fn accepts(ty: &Type) -> bool {
72        matches!(*ty.kind(), Kind::Composite(_))
73    }
74}
75
76fn composite_type_fields(type_: &Type) -> &[Field] {
77    match type_.kind() {
78        Kind::Composite(ref fields) => fields,
79        _ => unreachable!(),
80    }
81}
82
83impl CompositeType<'_> {
84    /// Returns information about the fields of the composite type.
85    #[must_use]
86    pub fn fields(&self) -> &[Field] {
87        composite_type_fields(&self.type_)
88    }
89
90    /// Determines if the composite contains no values.
91    #[must_use]
92    pub fn is_empty(&self) -> bool {
93        self.len() == 0
94    }
95
96    /// Returns the number of fields of the composite type.
97    #[must_use]
98    pub fn len(&self) -> usize {
99        self.fields().len()
100    }
101
102    /// Deserializes a value from the composite type.
103    ///
104    /// The value can be specified either by its numeric index, or by its field name.
105    ///
106    /// # Panics
107    ///
108    /// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
109    pub fn get<'b, I, T>(&'b self, idx: I) -> T
110    where
111        I: RowIndex + fmt::Display,
112        T: FromSql<'b>,
113    {
114        match self.get_inner(&idx) {
115            Ok(ok) => ok,
116            Err(err) => panic!("error retrieving column {idx}: {err}"),
117        }
118    }
119
120    /// Like `CompositeType::get`, but returns a `Result` rather than panicking.
121    ///
122    /// # Errors
123    ///
124    /// Returns an error if the index is out of bounds or if the value cannot be converted to the specified type.
125    pub fn try_get<'b, I, T>(&'b self, idx: I) -> Result<T, Error>
126    where
127        I: RowIndex + fmt::Display,
128        T: FromSql<'b>,
129    {
130        self.get_inner(&idx)
131    }
132
133    fn get_inner<'b, I, T>(&'b self, idx: &I) -> Result<T, Error>
134    where
135        I: RowIndex + fmt::Display,
136        T: FromSql<'b>,
137    {
138        let fields_vec = self
139            .fields()
140            .iter()
141            .map(|f| f.name().to_string())
142            .collect::<Vec<_>>();
143        let idx = match idx.__idx(&fields_vec) {
144            Some(idx) => idx,
145            None => UnableToFindColumnInFieldsSnafu {
146                column: idx.to_string(),
147                fields: fields_vec,
148            }
149            .fail()?,
150        };
151
152        let ty = self.fields()[idx].type_();
153        if !T::accepts(ty) {
154            return Err(WrongType::new::<T>(ty.clone())).context(UnableToConvertTypeSnafu);
155        }
156
157        let buf = self.ranges[idx].clone().map(|r| &self.body[r]);
158        FromSql::from_sql_nullable(ty, buf).context(UnableToConvertBytesToTypeSnafu)
159    }
160}
161
162/// A fallible iterator over the fields of a composite type.
163pub struct CompositeTypeRanges<'a> {
164    buf: &'a [u8],
165    len: usize,
166    remaining: u16,
167}
168
169impl<'a> CompositeTypeRanges<'a> {
170    /// Returns a fallible iterator over the fields of the composite type.
171    #[inline]
172    #[must_use]
173    pub fn new(buf: &'a [u8], len: usize, remaining: u16) -> CompositeTypeRanges<'a> {
174        CompositeTypeRanges {
175            buf,
176            len,
177            remaining,
178        }
179    }
180}
181
182#[allow(clippy::cast_sign_loss)]
183impl FallibleIterator for CompositeTypeRanges<'_> {
184    type Item = Option<std::ops::Range<usize>>;
185    type Error = std::io::Error;
186
187    #[inline]
188    fn next(&mut self) -> std::io::Result<Option<Option<std::ops::Range<usize>>>> {
189        if self.remaining == 0 {
190            if self.buf.is_empty() {
191                return Ok(None);
192            }
193            return Err(std::io::Error::new(
194                std::io::ErrorKind::InvalidInput,
195                "invalid buffer length: compositetyperanges is not empty",
196            ));
197        }
198
199        self.remaining -= 1;
200
201        // Binary format of a composite type:
202        // [for each field]
203        //     <OID of field's type: 4 bytes>
204        //     [if value is NULL]
205        //         <-1: 4 bytes>
206        //     [else]
207        //         <length of value: 4 bytes>
208        //         <value: <length> bytes>
209        //     [end if]
210        // [end for]
211        // https://www.postgresql.org/message-id/16CCB2D3-197E-4D9F-BC6F-9B123EA0D40D%40phlo.org
212        // https://github.com/postgres/postgres/blob/29e321cdd63ea48fd0223447d58f4742ad729eb0/src/backend/utils/adt/rowtypes.c#L736
213
214        let _oid = self.buf.read_i32::<BigEndian>()?;
215        let len = self.buf.read_i32::<BigEndian>()?;
216        if len < 0 {
217            Ok(Some(None))
218        } else {
219            let len = len as usize;
220            if self.buf.len() < len {
221                return Err(std::io::Error::new(
222                    std::io::ErrorKind::UnexpectedEof,
223                    "unexpected EOF",
224                ));
225            }
226            let base = self.len - self.buf.len();
227            self.buf = &self.buf[len..];
228            Ok(Some(Some(base..base + len)))
229        }
230    }
231
232    #[inline]
233    fn size_hint(&self) -> (usize, Option<usize>) {
234        let len = self.remaining as usize;
235        (len, Some(len))
236    }
237}