1use arrow::datatypes::{DataType as ArrowDataType, Field, Schema, TimeUnit};
2use log::debug;
3use odbc_api::{ColumnDescription, DataType as OdbcDataType, ResultSetMetadata, sys::SqlDataType};
4use std::convert::TryInto;
5
6use crate::{ColumnFailure, Error};
7
8pub fn arrow_schema_from(
45 result_set_metadata: &mut impl ResultSetMetadata,
46 dbms_name: Option<&str>,
47 map_value_errors_to_null: bool,
48) -> Result<Schema, Error> {
49 let num_cols: u16 = result_set_metadata
50 .num_result_cols()
51 .map_err(Error::UnableToRetrieveNumCols)?
52 .try_into()
53 .unwrap();
54 let mut fields = Vec::new();
55 for index in 0..num_cols {
56 let field = arrow_field_from(
57 result_set_metadata,
58 dbms_name,
59 index,
60 map_value_errors_to_null,
61 )?;
62
63 fields.push(field)
64 }
65 Ok(Schema::new(fields))
66}
67
68fn arrow_field_from(
69 resut_set_metadata: &mut impl ResultSetMetadata,
70 dbms_name: Option<&str>,
71 index: u16,
72 map_value_errors_to_null: bool,
73) -> Result<Field, Error> {
74 let mut column_description = ColumnDescription::default();
75 resut_set_metadata
76 .describe_col(index + 1, &mut column_description)
77 .map_err(|cause| Error::ColumnFailure {
78 name: "Unknown".to_owned(),
79 index: index as usize,
80 source: ColumnFailure::FailedToDescribeColumn(cause),
81 })?;
82 let name = column_description
83 .name_to_string()
84 .map_err(|source| Error::EncodingInvalid { source })?;
85 debug!(
86 "ODBC driver reported for column {index}. Relational type: {:?}; Nullability: {:?}; \
87 Name: '{name}';",
88 column_description.data_type, column_description.nullability
89 );
90 let data_type = match column_description.data_type {
91 OdbcDataType::Numeric {
92 precision: p @ 0..=38,
93 scale,
94 }
95 | OdbcDataType::Decimal {
96 precision: p @ 0..=38,
97 scale,
98 } => ArrowDataType::Decimal128(p as u8, scale.try_into().unwrap()),
99 OdbcDataType::Integer => ArrowDataType::Int32,
100 OdbcDataType::SmallInt => ArrowDataType::Int16,
101 OdbcDataType::Real | OdbcDataType::Float { precision: 0..=24 } => ArrowDataType::Float32,
102 OdbcDataType::Float { precision: _ } | OdbcDataType::Double => ArrowDataType::Float64,
103 OdbcDataType::Date => ArrowDataType::Date32,
104 OdbcDataType::Timestamp { precision: 0 } => {
105 ArrowDataType::Timestamp(TimeUnit::Second, None)
106 }
107 OdbcDataType::Timestamp { precision: 1..=3 } => {
108 ArrowDataType::Timestamp(TimeUnit::Millisecond, None)
109 }
110 OdbcDataType::Timestamp { precision: 4..=6 } => {
111 ArrowDataType::Timestamp(TimeUnit::Microsecond, None)
112 }
113 OdbcDataType::Timestamp { precision: _ } => {
114 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None)
115 }
116 OdbcDataType::BigInt => ArrowDataType::Int64,
117 OdbcDataType::TinyInt => {
118 let is_unsigned = resut_set_metadata
119 .column_is_unsigned(index + 1)
120 .map_err(|e| Error::ColumnFailure {
121 name: name.clone(),
122 index: index as usize,
123 source: ColumnFailure::FailedToDescribeColumn(e),
124 })?;
125 if is_unsigned {
126 ArrowDataType::UInt8
127 } else {
128 ArrowDataType::Int8
129 }
130 }
131 OdbcDataType::Bit => ArrowDataType::Boolean,
132 OdbcDataType::Binary { length } => {
133 let length = length
134 .ok_or_else(|| Error::ColumnFailure {
135 name: name.clone(),
136 index: index as usize,
137 source: ColumnFailure::ZeroSizedColumn {
138 sql_type: OdbcDataType::Binary { length },
139 },
140 })?
141 .get()
142 .try_into()
143 .unwrap();
144 ArrowDataType::FixedSizeBinary(length)
145 }
146 OdbcDataType::LongVarbinary { length: _ } | OdbcDataType::Varbinary { length: _ } => {
147 ArrowDataType::Binary
148 }
149 OdbcDataType::Time { precision } => precision_to_time(precision),
150 OdbcDataType::Other {
151 data_type: SqlDataType(-154),
152 column_size: _,
153 decimal_digits,
154 } => {
155 if dbms_name.is_some_and(|name| name == "Microsoft SQL Server") {
156 precision_to_time(decimal_digits)
159 } else {
160 ArrowDataType::Utf8
162 }
163 }
164 OdbcDataType::Other {
165 data_type: SqlDataType(-98),
166 column_size: _,
167 decimal_digits: _,
168 } => {
169 if dbms_name.is_some_and(|name| name.starts_with("DB2/")) {
171 ArrowDataType::Binary
173 } else {
174 ArrowDataType::Utf8
176 }
177 }
178 OdbcDataType::Unknown
179 | OdbcDataType::Numeric { .. }
180 | OdbcDataType::Decimal { .. }
181 | OdbcDataType::Other {
182 data_type: _,
183 column_size: _,
184 decimal_digits: _,
185 }
186 | OdbcDataType::WChar { length: _ }
187 | OdbcDataType::Char { length: _ }
188 | OdbcDataType::WVarchar { length: _ }
189 | OdbcDataType::WLongVarchar { length: _ }
190 | OdbcDataType::LongVarchar { length: _ }
191 | OdbcDataType::Varchar { length: _ } => ArrowDataType::Utf8,
192 };
193 let is_falliable = matches!(data_type, ArrowDataType::Timestamp(TimeUnit::Nanosecond, _));
194 let nullable =
195 column_description.could_be_nullable() || (is_falliable && map_value_errors_to_null);
196 let field = Field::new(name, data_type, nullable);
197 Ok(field)
198}
199
200fn precision_to_time(precision: i16) -> ArrowDataType {
201 match precision {
202 0 => ArrowDataType::Time32(TimeUnit::Second),
203 1..=3 => ArrowDataType::Time32(TimeUnit::Millisecond),
204 4..=6 => ArrowDataType::Time64(TimeUnit::Microsecond),
205 7..=9 => ArrowDataType::Time64(TimeUnit::Nanosecond),
206 _ => ArrowDataType::Utf8,
207 }
208}