arrow_odbc/
schema.rs

1use arrow::datatypes::{DataType as ArrowDataType, Field, Schema, TimeUnit};
2use log::debug;
3use odbc_api::{ColumnDescription, DataType as OdbcDataType, ResultSetMetadata, sys::SqlDataType};
4use std::convert::TryInto;
5
6use crate::{ColumnFailure, Error};
7
8/// Query the metadata to create an arrow schema. This method is invoked automatically for you by
9/// [`crate::OdbcReaderBuilder::build`]. You may want to call this method in situation there you
10/// want to create an arrow schema without creating the reader yet.
11///
12/// # Parameters
13///
14/// * `result_set_metadata`: Used to query metadata about the columns in the result set, which is
15///   used to determine the arrow schema.
16/// * `dbms_name`: If provided, it is used to account for Database specific behavior than mapping
17///   types. Currently it is used to map `TIME` types from 'Microsoft SQL Server' to `Time32` or
18///   `Time64`
19/// * `map_value_errors_to_null`: In case falliable conversions should result in `NULL` the arrow
20///   field must be nullable, even if the source column on the database is not nullable.
21///
22/// # Example
23///
24/// ```
25/// use anyhow::Error;
26///
27/// use arrow_odbc::{arrow_schema_from, arrow::datatypes::Schema, odbc_api::Connection};
28///
29/// fn fetch_schema_for_table(
30///     table_name: &str,
31///     connection: &Connection<'_>
32/// ) -> Result<Schema, Error> {
33///     // Query column with values to get a cursor
34///     let sql = format!("SELECT * FROM {}", table_name);
35///     let mut prepared = connection.prepare(&sql)?;
36///
37///     // Now that we have prepared statement, we want to use it to query metadata.
38///     let map_errors_to_null = false;
39///     let dbms_name = None;
40///     let schema = arrow_schema_from(&mut prepared, dbms_name, map_errors_to_null)?;
41///     Ok(schema)
42/// }
43/// ```
44pub fn arrow_schema_from(
45    result_set_metadata: &mut impl ResultSetMetadata,
46    dbms_name: Option<&str>,
47    map_value_errors_to_null: bool,
48) -> Result<Schema, Error> {
49    let num_cols: u16 = result_set_metadata
50        .num_result_cols()
51        .map_err(Error::UnableToRetrieveNumCols)?
52        .try_into()
53        .unwrap();
54    let mut fields = Vec::new();
55    for index in 0..num_cols {
56        let field = arrow_field_from(
57            result_set_metadata,
58            dbms_name,
59            index,
60            map_value_errors_to_null,
61        )?;
62
63        fields.push(field)
64    }
65    Ok(Schema::new(fields))
66}
67
68fn arrow_field_from(
69    resut_set_metadata: &mut impl ResultSetMetadata,
70    dbms_name: Option<&str>,
71    index: u16,
72    map_value_errors_to_null: bool,
73) -> Result<Field, Error> {
74    let mut column_description = ColumnDescription::default();
75    resut_set_metadata
76        .describe_col(index + 1, &mut column_description)
77        .map_err(|cause| Error::ColumnFailure {
78            name: "Unknown".to_owned(),
79            index: index as usize,
80            source: ColumnFailure::FailedToDescribeColumn(cause),
81        })?;
82    let name = column_description
83        .name_to_string()
84        .map_err(|source| Error::EncodingInvalid { source })?;
85    debug!(
86        "ODBC driver reported for column {index}. Relational type: {:?}; Nullability: {:?}; \
87            Name: '{name}';",
88        column_description.data_type, column_description.nullability
89    );
90    let data_type = match column_description.data_type {
91        OdbcDataType::Numeric {
92            precision: p @ 0..=38,
93            scale,
94        }
95        | OdbcDataType::Decimal {
96            precision: p @ 0..=38,
97            scale,
98        } => ArrowDataType::Decimal128(p as u8, scale.try_into().unwrap()),
99        OdbcDataType::Integer => ArrowDataType::Int32,
100        OdbcDataType::SmallInt => ArrowDataType::Int16,
101        OdbcDataType::Real | OdbcDataType::Float { precision: 0..=24 } => ArrowDataType::Float32,
102        OdbcDataType::Float { precision: _ } | OdbcDataType::Double => ArrowDataType::Float64,
103        OdbcDataType::Date => ArrowDataType::Date32,
104        OdbcDataType::Timestamp { precision: 0 } => {
105            ArrowDataType::Timestamp(TimeUnit::Second, None)
106        }
107        OdbcDataType::Timestamp { precision: 1..=3 } => {
108            ArrowDataType::Timestamp(TimeUnit::Millisecond, None)
109        }
110        OdbcDataType::Timestamp { precision: 4..=6 } => {
111            ArrowDataType::Timestamp(TimeUnit::Microsecond, None)
112        }
113        OdbcDataType::Timestamp { precision: _ } => {
114            ArrowDataType::Timestamp(TimeUnit::Nanosecond, None)
115        }
116        OdbcDataType::BigInt => ArrowDataType::Int64,
117        OdbcDataType::TinyInt => {
118            let is_unsigned = resut_set_metadata
119                .column_is_unsigned(index + 1)
120                .map_err(|e| Error::ColumnFailure {
121                    name: name.clone(),
122                    index: index as usize,
123                    source: ColumnFailure::FailedToDescribeColumn(e),
124                })?;
125            if is_unsigned {
126                ArrowDataType::UInt8
127            } else {
128                ArrowDataType::Int8
129            }
130        }
131        OdbcDataType::Bit => ArrowDataType::Boolean,
132        OdbcDataType::Binary { length } => {
133            let length = length
134                .ok_or_else(|| Error::ColumnFailure {
135                    name: name.clone(),
136                    index: index as usize,
137                    source: ColumnFailure::ZeroSizedColumn {
138                        sql_type: OdbcDataType::Binary { length },
139                    },
140                })?
141                .get()
142                .try_into()
143                .unwrap();
144            ArrowDataType::FixedSizeBinary(length)
145        }
146        OdbcDataType::LongVarbinary { length: _ } | OdbcDataType::Varbinary { length: _ } => {
147            ArrowDataType::Binary
148        }
149        OdbcDataType::Time { precision } => precision_to_time(precision),
150        OdbcDataType::Other {
151            data_type: SqlDataType(-154),
152            column_size: _,
153            decimal_digits,
154        } => {
155            if dbms_name.is_some_and(|name| name == "Microsoft SQL Server") {
156                // SQL Server's -154 is used by Microsoft SQL Server for Timestamps without a time
157                // zone.
158                precision_to_time(decimal_digits)
159            } else {
160                // Other databases may use -154 for other purposes, so we treat it as a string.
161                ArrowDataType::Utf8
162            }
163        }
164        OdbcDataType::Other {
165            data_type: SqlDataType(-98),
166            column_size: _,
167            decimal_digits: _,
168        } => {
169            // IBM DB2 names seem platform specific. E.g.; "DB2/LINUXX8664"
170            if dbms_name.is_some_and(|name| name.starts_with("DB2/")) {
171                // IBM DB2's -98 is used for binary blob types.
172                ArrowDataType::Binary
173            } else {
174                // Other databases may use -98 for other purposes, so we treat it as a string.
175                ArrowDataType::Utf8
176            }
177        }
178        OdbcDataType::Unknown
179        | OdbcDataType::Numeric { .. }
180        | OdbcDataType::Decimal { .. }
181        | OdbcDataType::Other {
182            data_type: _,
183            column_size: _,
184            decimal_digits: _,
185        }
186        | OdbcDataType::WChar { length: _ }
187        | OdbcDataType::Char { length: _ }
188        | OdbcDataType::WVarchar { length: _ }
189        | OdbcDataType::WLongVarchar { length: _ }
190        | OdbcDataType::LongVarchar { length: _ }
191        | OdbcDataType::Varchar { length: _ } => ArrowDataType::Utf8,
192    };
193    let is_falliable = matches!(data_type, ArrowDataType::Timestamp(TimeUnit::Nanosecond, _));
194    let nullable =
195        column_description.could_be_nullable() || (is_falliable && map_value_errors_to_null);
196    let field = Field::new(name, data_type, nullable);
197    Ok(field)
198}
199
200fn precision_to_time(precision: i16) -> ArrowDataType {
201    match precision {
202        0 => ArrowDataType::Time32(TimeUnit::Second),
203        1..=3 => ArrowDataType::Time32(TimeUnit::Millisecond),
204        4..=6 => ArrowDataType::Time64(TimeUnit::Microsecond),
205        7..=9 => ArrowDataType::Time64(TimeUnit::Nanosecond),
206        _ => ArrowDataType::Utf8,
207    }
208}