Skip to main content

sqlx_odbc/
any.rs

1//! Runtime `Any` driver support for ODBC.
2
3use crate::{
4    connection::OdbcExecution, DataTypeExt, Odbc, OdbcArgumentValue, OdbcArguments, OdbcColumn,
5    OdbcConnectOptions, OdbcConnection, OdbcQueryResult, OdbcTransactionManager, OdbcTypeInfo,
6};
7use futures_core::future::BoxFuture;
8use futures_core::stream::BoxStream;
9use futures_util::{future, stream, FutureExt, StreamExt};
10use sqlx_core::any::driver::AnyDriver;
11use sqlx_core::any::{
12    AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow,
13    AnyStatement, AnyTypeInfo, AnyTypeInfoKind, AnyValueKind,
14};
15use sqlx_core::column::Column;
16use sqlx_core::connection::{ConnectOptions, Connection};
17use sqlx_core::database::Database;
18use sqlx_core::ext::ustr::UStr;
19use sqlx_core::row::Row;
20use sqlx_core::sql_str::SqlStr;
21use sqlx_core::statement::Statement;
22use sqlx_core::transaction::TransactionManager;
23use sqlx_core::{Either, HashMap};
24use std::sync::Arc;
25
26/// Installable ODBC driver for SQLx `Any` connections.
27pub const DRIVER: AnyDriver = AnyDriver::without_migrate::<Odbc>();
28
29impl AnyConnectionBackend for OdbcConnection {
30    fn name(&self) -> &str {
31        <Odbc as Database>::NAME
32    }
33
34    fn close(self: Box<Self>) -> BoxFuture<'static, sqlx_core::Result<()>> {
35        Connection::close(*self).boxed()
36    }
37
38    fn close_hard(self: Box<Self>) -> BoxFuture<'static, sqlx_core::Result<()>> {
39        Connection::close_hard(*self).boxed()
40    }
41
42    fn ping(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
43        Connection::ping(self).boxed()
44    }
45
46    fn begin(&mut self, statement: Option<SqlStr>) -> BoxFuture<'_, sqlx_core::Result<()>> {
47        OdbcTransactionManager::begin(self, statement).boxed()
48    }
49
50    fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
51        OdbcTransactionManager::commit(self).boxed()
52    }
53
54    fn rollback(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
55        OdbcTransactionManager::rollback(self).boxed()
56    }
57
58    fn start_rollback(&mut self) {
59        OdbcTransactionManager::start_rollback(self);
60    }
61
62    fn get_transaction_depth(&self) -> usize {
63        OdbcTransactionManager::get_transaction_depth(self)
64    }
65
66    fn shrink_buffers(&mut self) {
67        Connection::shrink_buffers(self);
68    }
69
70    fn flush(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
71        Connection::flush(self).boxed()
72    }
73
74    fn should_flush(&self) -> bool {
75        Connection::should_flush(self)
76    }
77
78    fn fetch_many(
79        &mut self,
80        query: SqlStr,
81        _persistent: bool,
82        arguments: Option<AnyArguments>,
83    ) -> BoxStream<'_, sqlx_core::Result<Either<AnyQueryResult, AnyRow>>> {
84        let arguments = arguments.map(map_arguments).transpose();
85
86        stream::once(async move {
87            let arguments = arguments?;
88            self.run_blocking_sql(query.as_str(), arguments.as_ref())
89        })
90        .map(|result| match result {
91            Ok(OdbcExecution::Done(result)) => {
92                stream::once(future::ready(Ok(Either::Left(map_result(result))))).boxed()
93            }
94            Ok(OdbcExecution::Rows(rows)) => {
95                if rows.is_empty() {
96                    stream::once(future::ready(Ok(Either::Left(map_result(
97                        OdbcQueryResult::new(0),
98                    )))))
99                    .boxed()
100                } else {
101                    let column_names =
102                        column_names(rows.first().expect("rows is not empty").columns());
103                    let rows = rows.into_iter().map(move |row| {
104                        AnyRow::map_from(&row, Arc::clone(&column_names)).map(Either::Right)
105                    });
106                    let done =
107                        std::iter::once(Ok(Either::Left(map_result(OdbcQueryResult::new(0)))));
108                    stream::iter(rows.chain(done)).boxed()
109                }
110            }
111            Err(error) => stream::once(future::ready(Err(error))).boxed(),
112        })
113        .flatten()
114        .boxed()
115    }
116
117    fn fetch_optional(
118        &mut self,
119        query: SqlStr,
120        _persistent: bool,
121        arguments: Option<AnyArguments>,
122    ) -> BoxFuture<'_, sqlx_core::Result<Option<AnyRow>>> {
123        let arguments = arguments.map(map_arguments).transpose();
124
125        Box::pin(async move {
126            let arguments = arguments?;
127            match self.run_blocking_sql(query.as_str(), arguments.as_ref())? {
128                OdbcExecution::Done(_) => Ok(None),
129                OdbcExecution::Rows(rows) => rows
130                    .into_iter()
131                    .next()
132                    .map(|row| {
133                        let column_names = column_names(row.columns());
134                        AnyRow::map_from(&row, column_names)
135                    })
136                    .transpose(),
137            }
138        })
139    }
140
141    fn prepare_with<'c, 'q: 'c>(
142        &'c mut self,
143        sql: SqlStr,
144        _parameters: &[AnyTypeInfo],
145    ) -> BoxFuture<'c, sqlx_core::Result<AnyStatement>> {
146        Box::pin(async move {
147            let statement = self.prepare_blocking(sql)?;
148            let column_names = column_names(statement.columns());
149            AnyStatement::try_from_statement(statement, column_names)
150        })
151    }
152}
153
154impl<'a> TryFrom<&'a AnyConnectOptions> for OdbcConnectOptions {
155    type Error = sqlx_core::Error;
156
157    fn try_from(options: &'a AnyConnectOptions) -> Result<Self, Self::Error> {
158        let mut options_out = OdbcConnectOptions::from_url(&options.database_url)?;
159        options_out.log_statements = options.log_settings.statements_level;
160        options_out.log_slow_statements = options.log_settings.slow_statements_level;
161        options_out.log_slow_statement_duration = options.log_settings.slow_statements_duration;
162        Ok(options_out)
163    }
164}
165
166impl<'a> TryFrom<&'a OdbcTypeInfo> for AnyTypeInfo {
167    type Error = sqlx_core::Error;
168
169    fn try_from(type_info: &'a OdbcTypeInfo) -> Result<Self, Self::Error> {
170        let kind = match type_info.data_type() {
171            odbc_api::DataType::Unknown => AnyTypeInfoKind::Null,
172            odbc_api::DataType::Bit => AnyTypeInfoKind::Bool,
173            odbc_api::DataType::TinyInt | odbc_api::DataType::SmallInt => AnyTypeInfoKind::SmallInt,
174            odbc_api::DataType::Integer => AnyTypeInfoKind::Integer,
175            odbc_api::DataType::BigInt => AnyTypeInfoKind::BigInt,
176            odbc_api::DataType::Real => AnyTypeInfoKind::Real,
177            odbc_api::DataType::Float { .. } | odbc_api::DataType::Double => {
178                AnyTypeInfoKind::Double
179            }
180            data_type if data_type.accepts_character_data() => AnyTypeInfoKind::Text,
181            data_type if data_type.accepts_binary_data() => AnyTypeInfoKind::Blob,
182            data_type => {
183                return Err(sqlx_core::Error::AnyDriverError(
184                    format!(
185                        "ODBC Any conversion does not support result column type {data_type:?}"
186                    )
187                    .into(),
188                ));
189            }
190        };
191
192        Ok(AnyTypeInfo { kind })
193    }
194}
195
196impl<'a> TryFrom<&'a OdbcColumn> for AnyColumn {
197    type Error = sqlx_core::Error;
198
199    fn try_from(column: &'a OdbcColumn) -> Result<Self, Self::Error> {
200        let type_info = AnyTypeInfo::try_from(column.type_info()).map_err(|error| {
201            sqlx_core::Error::ColumnDecode {
202                index: column.name().to_owned(),
203                source: error.into(),
204            }
205        })?;
206
207        Ok(Self {
208            ordinal: column.ordinal(),
209            name: UStr::new(column.name()),
210            type_info,
211        })
212    }
213}
214
215fn map_arguments(arguments: AnyArguments) -> sqlx_core::Result<OdbcArguments> {
216    let mut out = OdbcArguments::default();
217
218    for value in arguments.values.0 {
219        out.add_value(match value {
220            AnyValueKind::Null(kind) => OdbcArgumentValue::Null(any_type_to_odbc(kind)),
221            AnyValueKind::Bool(value) => OdbcArgumentValue::Bit(value),
222            AnyValueKind::SmallInt(value) => OdbcArgumentValue::Int(i64::from(value)),
223            AnyValueKind::Integer(value) => OdbcArgumentValue::Int(i64::from(value)),
224            AnyValueKind::BigInt(value) => OdbcArgumentValue::Int(value),
225            AnyValueKind::Real(value) => OdbcArgumentValue::Float(f64::from(value)),
226            AnyValueKind::Double(value) => OdbcArgumentValue::Float(value),
227            AnyValueKind::Text(value) => OdbcArgumentValue::Text(value.to_string()),
228            AnyValueKind::TextSlice(value) => OdbcArgumentValue::Text(value.to_string()),
229            AnyValueKind::Blob(value) => OdbcArgumentValue::Bytes(value.to_vec()),
230            other => {
231                return Err(sqlx_core::Error::AnyDriverError(
232                    format!("ODBC Any arguments do not support value kind {other:?}").into(),
233                ))
234            }
235        });
236    }
237
238    Ok(out)
239}
240
241fn any_type_to_odbc(kind: AnyTypeInfoKind) -> OdbcTypeInfo {
242    OdbcTypeInfo::new(match kind {
243        AnyTypeInfoKind::Null => odbc_api::DataType::Unknown,
244        AnyTypeInfoKind::Bool => odbc_api::DataType::Bit,
245        AnyTypeInfoKind::SmallInt => odbc_api::DataType::SmallInt,
246        AnyTypeInfoKind::Integer => odbc_api::DataType::Integer,
247        AnyTypeInfoKind::BigInt => odbc_api::DataType::BigInt,
248        AnyTypeInfoKind::Real => odbc_api::DataType::Real,
249        AnyTypeInfoKind::Double => odbc_api::DataType::Double,
250        AnyTypeInfoKind::Text => odbc_api::DataType::WVarchar { length: None },
251        AnyTypeInfoKind::Blob => odbc_api::DataType::Varbinary { length: None },
252    })
253}
254
255fn map_result(result: OdbcQueryResult) -> AnyQueryResult {
256    AnyQueryResult {
257        rows_affected: result.rows_affected(),
258        last_insert_id: None,
259    }
260}
261
262fn column_names(columns: &[OdbcColumn]) -> Arc<HashMap<UStr, usize>> {
263    Arc::new(
264        columns
265            .iter()
266            .map(|column| (UStr::new(column.name()), column.ordinal()))
267            .collect(),
268    )
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn maps_stable_odbc_types_to_any_types() {
277        assert_eq!(
278            AnyTypeInfo::try_from(&OdbcTypeInfo::new(odbc_api::DataType::Bit))
279                .unwrap()
280                .kind(),
281            AnyTypeInfoKind::Bool
282        );
283        assert_eq!(
284            AnyTypeInfo::try_from(&OdbcTypeInfo::new(odbc_api::DataType::Integer))
285                .unwrap()
286                .kind(),
287            AnyTypeInfoKind::Integer
288        );
289        assert_eq!(
290            AnyTypeInfo::try_from(&OdbcTypeInfo::new(odbc_api::DataType::WVarchar {
291                length: None
292            }))
293            .unwrap()
294            .kind(),
295            AnyTypeInfoKind::Text
296        );
297        assert_eq!(
298            AnyTypeInfo::try_from(&OdbcTypeInfo::new(odbc_api::DataType::Varbinary {
299                length: None
300            }))
301            .unwrap()
302            .kind(),
303            AnyTypeInfoKind::Blob
304        );
305    }
306
307    #[test]
308    fn rejects_unstable_odbc_types_for_any_mapping() {
309        assert!(matches!(
310            AnyTypeInfo::try_from(&OdbcTypeInfo::new(odbc_api::DataType::Timestamp {
311                precision: 6
312            })),
313            Err(sqlx_core::Error::AnyDriverError(_))
314        ));
315    }
316}