Skip to main content

sqlx_odbc/
connection.rs

1use crate::{
2    OdbcArguments, OdbcBufferSettings, OdbcColumn, OdbcConnectOptions, OdbcParameterCollection,
3    OdbcQueryResult, OdbcRow, OdbcStatement, OdbcTypeInfo, OdbcValue, OdbcValueKind, Result,
4};
5use futures_core::future::BoxFuture;
6use futures_core::stream::BoxStream;
7use futures_util::{future, stream, StreamExt};
8use odbc_api::buffers::{AnyColumnBufferSlice, BufferDesc, ColumnarDynBuffer, NullableSlice};
9use odbc_api::{Cursor, DataType, Nullable, ResultSetMetadata};
10use sqlx_core::column::Column;
11use sqlx_core::executor::{Execute, Executor};
12use sqlx_core::transaction::Transaction;
13use sqlx_core::Either;
14use std::future::Future;
15
16/// Blocking ODBC connection wrapper.
17///
18/// This is the minimal smoke-test surface. The SQLx async `Connection` and `Executor` traits will
19/// be implemented as the port progresses.
20pub struct OdbcConnection {
21    conn: odbc_api::Connection<'static>,
22    buffer_settings: OdbcBufferSettings,
23    transaction_depth: usize,
24}
25
26impl std::fmt::Debug for OdbcConnection {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("OdbcConnection").finish_non_exhaustive()
29    }
30}
31
32impl OdbcConnection {
33    /// Opens a blocking ODBC connection with the provided options.
34    pub fn connect_blocking(options: &OdbcConnectOptions) -> Result<Self> {
35        let env = odbc_api::environment()
36            .map_err(|error| crate::OdbcError::Configuration(error.to_string()))?;
37
38        let conn =
39            env.connect_with_connection_string(options.connection_string(), Default::default())?;
40
41        Ok(Self {
42            conn,
43            buffer_settings: options.buffer_settings,
44            transaction_depth: 0,
45        })
46    }
47
48    /// Executes a minimal connectivity query.
49    pub fn ping_blocking(&mut self) -> Result<()> {
50        let query = self
51            .conn
52            .database_management_system_name()
53            .map(|name| ping_query_for_dbms_name(&name))
54            .unwrap_or("SELECT 1");
55        self.conn.execute(query, (), None)?;
56        Ok(())
57    }
58
59    /// Returns the DBMS name reported by the ODBC driver.
60    pub fn dbms_name(&self) -> Result<String> {
61        Ok(self.conn.database_management_system_name()?)
62    }
63
64    pub(crate) fn begin_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
65        if self.transaction_depth > 0 {
66            return Err(sqlx_core::Error::InvalidSavePointStatement);
67        }
68
69        self.conn
70            .set_autocommit(false)
71            .map_err(crate::OdbcError::from)?;
72        self.transaction_depth = 1;
73        Ok(())
74    }
75
76    pub(crate) fn commit_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
77        if self.transaction_depth == 0 {
78            return Ok(());
79        }
80
81        self.conn.commit().map_err(crate::OdbcError::from)?;
82        self.conn
83            .set_autocommit(true)
84            .map_err(crate::OdbcError::from)?;
85        self.transaction_depth = 0;
86        Ok(())
87    }
88
89    pub(crate) fn rollback_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
90        if self.transaction_depth == 0 {
91            return Ok(());
92        }
93
94        self.conn.rollback().map_err(crate::OdbcError::from)?;
95        self.conn
96            .set_autocommit(true)
97            .map_err(crate::OdbcError::from)?;
98        self.transaction_depth = 0;
99        Ok(())
100    }
101
102    pub(crate) fn start_rollback(&mut self) {
103        if self.transaction_depth == 0 {
104            return;
105        }
106
107        if self.conn.rollback().is_ok() {
108            let _ = self.conn.set_autocommit(true);
109            self.transaction_depth = 0;
110        }
111    }
112
113    pub(crate) const fn transaction_depth(&self) -> usize {
114        self.transaction_depth
115    }
116
117    /// Prepares a statement and returns the metadata reported by the ODBC driver.
118    pub fn prepare_blocking(
119        &mut self,
120        sql: sqlx_core::sql_str::SqlStr,
121    ) -> std::result::Result<OdbcStatement, sqlx_core::Error> {
122        let mut prepared = self
123            .conn
124            .prepare(sql.as_str())
125            .map_err(crate::OdbcError::from)?;
126        let parameters = prepared.num_params().map_err(crate::OdbcError::from)?;
127        let columns = collect_prepared_columns(&mut prepared, parameters)?;
128
129        Ok(OdbcStatement::new(sql, columns, usize::from(parameters)))
130    }
131
132    pub(crate) fn run_blocking_sql(
133        &mut self,
134        sql: &str,
135        arguments: Option<&OdbcArguments>,
136    ) -> std::result::Result<OdbcExecution, sqlx_core::Error> {
137        let mut statement = self.conn.preallocate().map_err(crate::OdbcError::from)?;
138        let parameters = odbc_parameters(arguments);
139
140        if let Some(cursor) = statement
141            .execute(sql, parameters.as_slice())
142            .map_err(crate::OdbcError::from)?
143        {
144            return collect_rows(cursor, self.buffer_settings).map(OdbcExecution::Rows);
145        }
146
147        let rows_affected = statement
148            .row_count()
149            .map_err(crate::OdbcError::from)?
150            .unwrap_or(0);
151
152        let rows_affected = rows_affected.try_into().map_err(|_| {
153            sqlx_core::Error::Protocol("ODBC row count does not fit in u64".to_owned())
154        })?;
155
156        Ok(OdbcExecution::Done(OdbcQueryResult::new(rows_affected)))
157    }
158}
159
160impl sqlx_core::connection::Connection for OdbcConnection {
161    type Database = crate::Odbc;
162    type Options = OdbcConnectOptions;
163
164    async fn close(self) -> std::result::Result<(), sqlx_core::Error> {
165        drop(self);
166        Ok(())
167    }
168
169    async fn close_hard(self) -> std::result::Result<(), sqlx_core::Error> {
170        drop(self);
171        Ok(())
172    }
173
174    async fn ping(&mut self) -> std::result::Result<(), sqlx_core::Error> {
175        self.ping_blocking().map_err(Into::into)
176    }
177
178    fn begin(
179        &mut self,
180    ) -> impl Future<Output = std::result::Result<Transaction<'_, Self::Database>, sqlx_core::Error>>
181           + Send
182           + '_ {
183        Transaction::begin(self, None)
184    }
185
186    fn shrink_buffers(&mut self) {}
187
188    async fn flush(&mut self) -> std::result::Result<(), sqlx_core::Error> {
189        Ok(())
190    }
191
192    fn should_flush(&self) -> bool {
193        false
194    }
195}
196
197impl<'c> Executor<'c> for &'c mut OdbcConnection {
198    type Database = crate::Odbc;
199
200    fn fetch_many<'e, 'q, E>(
201        self,
202        mut query: E,
203    ) -> BoxStream<'e, std::result::Result<Either<OdbcQueryResult, OdbcRow>, sqlx_core::Error>>
204    where
205        'c: 'e,
206        E: Execute<'q, Self::Database>,
207        'q: 'e,
208        E: 'q,
209    {
210        let arguments = query.take_arguments().map_err(sqlx_core::Error::Encode);
211        let sql = query.sql();
212
213        stream::once(async move {
214            let arguments = arguments?;
215            self.run_blocking_sql(sql.as_str(), arguments.as_ref())
216        })
217        .map(|result| match result {
218            Ok(OdbcExecution::Done(result)) => {
219                stream::once(future::ready(Ok(Either::Left(result)))).boxed()
220            }
221            Ok(OdbcExecution::Rows(rows)) => {
222                stream::iter(rows.into_iter().map(|row| Ok(Either::Right(row)))).boxed()
223            }
224            Err(error) => stream::once(future::ready(Err(error))).boxed(),
225        })
226        .flatten()
227        .boxed()
228    }
229
230    fn fetch_optional<'e, 'q, E>(
231        self,
232        mut query: E,
233    ) -> BoxFuture<'e, std::result::Result<Option<OdbcRow>, sqlx_core::Error>>
234    where
235        'c: 'e,
236        E: Execute<'q, Self::Database>,
237        'q: 'e,
238        E: 'q,
239    {
240        let arguments = query.take_arguments().map_err(sqlx_core::Error::Encode);
241        let sql = query.sql();
242
243        Box::pin(async move {
244            let arguments = arguments?;
245
246            match self.run_blocking_sql(sql.as_str(), arguments.as_ref())? {
247                OdbcExecution::Rows(rows) => Ok(rows.into_iter().next()),
248                OdbcExecution::Done(_) => Ok(None),
249            }
250        })
251    }
252
253    fn prepare_with<'e>(
254        self,
255        sql: sqlx_core::sql_str::SqlStr,
256        _parameters: &[crate::OdbcTypeInfo],
257    ) -> BoxFuture<'e, std::result::Result<OdbcStatement, sqlx_core::Error>>
258    where
259        'c: 'e,
260    {
261        Box::pin(async move { self.prepare_blocking(sql) })
262    }
263}
264
265pub(crate) enum OdbcExecution {
266    Done(OdbcQueryResult),
267    Rows(Vec<OdbcRow>),
268}
269
270fn odbc_parameters(arguments: Option<&OdbcArguments>) -> OdbcParameterCollection {
271    arguments
272        .map(OdbcArguments::to_odbc_parameter_collection)
273        .unwrap_or_default()
274}
275
276fn ping_query_for_dbms_name(dbms_name: &str) -> &'static str {
277    let dbms_name = dbms_name.to_ascii_uppercase();
278
279    if dbms_name.contains("DB2")
280        || dbms_name.contains("DB/2")
281        || dbms_name.contains("ISERIES")
282        || dbms_name.contains("AS/400")
283        || dbms_name.contains("IBM I")
284    {
285        "SELECT 1 FROM SYSIBM.SYSDUMMY1"
286    } else {
287        "SELECT 1"
288    }
289}
290
291fn collect_columns(
292    cursor: &mut impl ResultSetMetadata,
293) -> std::result::Result<Vec<OdbcColumn>, sqlx_core::Error> {
294    let count = cursor.num_result_cols().map_err(crate::OdbcError::from)?;
295    let count = usize::try_from(count).map_err(|_| {
296        sqlx_core::Error::Protocol(format!("ODBC returned a negative column count: {count}"))
297    })?;
298
299    let mut columns = Vec::with_capacity(count);
300    for ordinal in 0..count {
301        let column_number = u16::try_from(ordinal + 1).map_err(|_| {
302            sqlx_core::Error::Protocol(format!("ODBC column index exceeds u16: {}", ordinal + 1))
303        })?;
304
305        let mut description = odbc_api::ColumnDescription::default();
306        cursor
307            .describe_col(column_number, &mut description)
308            .map_err(crate::OdbcError::from)?;
309        let name = description
310            .name_to_string()
311            .unwrap_or_else(|_| format!("col{ordinal}"));
312
313        columns.push(OdbcColumn::new(
314            ordinal,
315            name,
316            OdbcTypeInfo::new(description.data_type),
317        ));
318    }
319
320    Ok(columns)
321}
322
323fn collect_prepared_columns(
324    prepared: &mut impl PreparedStatementMetadata,
325    parameter_count: u16,
326) -> std::result::Result<Vec<OdbcColumn>, sqlx_core::Error> {
327    match collect_columns(prepared) {
328        Ok(columns) => Ok(columns),
329        Err(error) if parameter_count > 0 => {
330            validate_parameter_metadata(prepared, parameter_count)?;
331            log::debug!("ODBC driver deferred result-column metadata until execution: {error}");
332            Ok(Vec::new())
333        }
334        Err(error) => Err(error),
335    }
336}
337
338trait PreparedStatementMetadata: ResultSetMetadata {
339    fn describe_prepared_parameter(
340        &mut self,
341        index: u16,
342    ) -> std::result::Result<(), odbc_api::Error>;
343}
344
345impl<S> PreparedStatementMetadata for odbc_api::Prepared<S>
346where
347    S: odbc_api::handles::AsStatementRef,
348{
349    fn describe_prepared_parameter(
350        &mut self,
351        index: u16,
352    ) -> std::result::Result<(), odbc_api::Error> {
353        self.describe_param(index).map(|_| ())
354    }
355}
356
357fn validate_parameter_metadata(
358    prepared: &mut impl PreparedStatementMetadata,
359    parameter_count: u16,
360) -> std::result::Result<(), sqlx_core::Error> {
361    for index in 1..=parameter_count {
362        prepared
363            .describe_prepared_parameter(index)
364            .map_err(crate::OdbcError::from)?;
365    }
366
367    Ok(())
368}
369
370fn collect_rows<C>(
371    cursor: C,
372    settings: OdbcBufferSettings,
373) -> std::result::Result<Vec<OdbcRow>, sqlx_core::Error>
374where
375    C: Cursor + ResultSetMetadata,
376{
377    if let Some(max_column_size) = settings.max_column_size {
378        collect_rows_buffered(cursor, settings.batch_size, max_column_size)
379    } else {
380        collect_rows_unbuffered(cursor)
381    }
382}
383
384#[derive(Debug)]
385struct ColumnBinding {
386    column: OdbcColumn,
387    buffer_desc: BufferDesc,
388}
389
390fn collect_rows_buffered<C>(
391    cursor: C,
392    batch_size: usize,
393    max_column_size: usize,
394) -> std::result::Result<Vec<OdbcRow>, sqlx_core::Error>
395where
396    C: Cursor + ResultSetMetadata,
397{
398    let mut cursor = cursor;
399    let bindings = build_buffer_bindings(&mut cursor, max_column_size)?;
400    let buffer_descriptions = bindings
401        .iter()
402        .map(|binding| binding.buffer_desc)
403        .collect::<Vec<_>>();
404    let mut row_set_cursor = cursor
405        .bind_buffer(ColumnarDynBuffer::from_descs(
406            batch_size,
407            buffer_descriptions,
408        ))
409        .map_err(crate::OdbcError::from)?;
410    let columns = bindings
411        .iter()
412        .map(|binding| binding.column.clone())
413        .collect::<Vec<_>>();
414    let mut rows = Vec::new();
415
416    while let Some(batch) = row_set_cursor.fetch().map_err(crate::OdbcError::from)? {
417        let column_values = bindings
418            .iter()
419            .enumerate()
420            .map(|(index, binding)| {
421                buffered_column_values(batch.column(index), binding.buffer_desc)
422            })
423            .collect::<std::result::Result<Vec<_>, _>>()?;
424
425        for row_index in 0..batch.num_rows() {
426            let values = column_values
427                .iter()
428                .map(|values| OdbcValue::new(values[row_index].clone()))
429                .collect::<Vec<_>>();
430            rows.push(OdbcRow::new(columns.clone(), values));
431        }
432    }
433
434    Ok(rows)
435}
436
437fn build_buffer_bindings(
438    cursor: &mut impl ResultSetMetadata,
439    max_column_size: usize,
440) -> std::result::Result<Vec<ColumnBinding>, sqlx_core::Error> {
441    collect_columns(cursor).map(|columns| {
442        columns
443            .into_iter()
444            .map(|column| ColumnBinding {
445                buffer_desc: map_buffer_desc(column.type_info().data_type(), max_column_size),
446                column,
447            })
448            .collect()
449    })
450}
451
452fn map_buffer_desc(data_type: DataType, max_column_size: usize) -> BufferDesc {
453    match data_type {
454        DataType::TinyInt | DataType::SmallInt | DataType::Integer | DataType::BigInt => {
455            BufferDesc::I64 { nullable: true }
456        }
457        DataType::Real => BufferDesc::F32 { nullable: true },
458        DataType::Float { .. } | DataType::Double => BufferDesc::F64 { nullable: true },
459        DataType::Bit => BufferDesc::Bit { nullable: true },
460        DataType::Date => BufferDesc::Date { nullable: true },
461        DataType::Time { .. } => BufferDesc::Time { nullable: true },
462        DataType::Timestamp { .. } => BufferDesc::Timestamp { nullable: true },
463        DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => {
464            BufferDesc::Binary {
465                max_bytes: max_column_size,
466            }
467        }
468        DataType::Char { .. }
469        | DataType::WChar { .. }
470        | DataType::Varchar { .. }
471        | DataType::WVarchar { .. }
472        | DataType::LongVarchar { .. }
473        | DataType::WLongVarchar { .. }
474        | DataType::Other { .. }
475        | DataType::Unknown
476        | DataType::Decimal { .. }
477        | DataType::Numeric { .. } => BufferDesc::Text {
478            max_str_len: max_column_size,
479        },
480    }
481}
482
483fn buffered_column_values(
484    slice: AnyColumnBufferSlice<'_>,
485    desc: BufferDesc,
486) -> std::result::Result<Vec<OdbcValueKind>, sqlx_core::Error> {
487    Ok(match desc {
488        BufferDesc::I8 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
489            OdbcValueKind::TinyInt(value)
490        })?,
491        BufferDesc::I16 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
492            OdbcValueKind::SmallInt(value)
493        })?,
494        BufferDesc::I32 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
495            OdbcValueKind::Integer(value)
496        })?,
497        BufferDesc::I64 { nullable } => {
498            buffered_numeric(&slice, desc, nullable, OdbcValueKind::BigInt)?
499        }
500        BufferDesc::U8 { nullable } => buffered_numeric(&slice, desc, nullable, |value: u8| {
501            OdbcValueKind::BigInt(i64::from(value))
502        })?,
503        BufferDesc::F32 { nullable } => {
504            buffered_numeric(&slice, desc, nullable, OdbcValueKind::Real)?
505        }
506        BufferDesc::F64 { nullable } => {
507            buffered_numeric(&slice, desc, nullable, OdbcValueKind::Double)?
508        }
509        BufferDesc::Bit { nullable } => {
510            buffered_numeric(&slice, desc, nullable, |value: odbc_api::Bit| {
511                OdbcValueKind::Bit(value.as_bool())
512            })?
513        }
514        BufferDesc::Date { nullable } => {
515            buffered_numeric(&slice, desc, nullable, OdbcValueKind::Date)?
516        }
517        BufferDesc::Time { nullable } => {
518            buffered_numeric(&slice, desc, nullable, OdbcValueKind::Time)?
519        }
520        BufferDesc::Timestamp { nullable } => {
521            buffered_numeric(&slice, desc, nullable, OdbcValueKind::Timestamp)?
522        }
523        BufferDesc::Text { .. } => {
524            let text = expect_buffer_slice(slice.as_text(), desc)?;
525            text.iter()
526                .map(|value| {
527                    value
528                        .map(|bytes| {
529                            OdbcValueKind::Text(String::from_utf8_lossy(bytes).into_owned())
530                        })
531                        .unwrap_or(OdbcValueKind::Null)
532                })
533                .collect()
534        }
535        BufferDesc::WText { .. } => {
536            let text = expect_buffer_slice(slice.as_wide_text(), desc)?;
537            text.iter()
538                .map(|value| {
539                    value
540                        .map(|chars| OdbcValueKind::Text(String::from_utf16_lossy(chars.into())))
541                        .unwrap_or(OdbcValueKind::Null)
542                })
543                .collect()
544        }
545        BufferDesc::Binary { .. } => {
546            let binary = expect_buffer_slice(slice.as_binary(), desc)?;
547            binary
548                .iter()
549                .map(|value| {
550                    value
551                        .map(|bytes| OdbcValueKind::Binary(bytes.to_vec()))
552                        .unwrap_or(OdbcValueKind::Null)
553                })
554                .collect()
555        }
556        BufferDesc::Numeric => {
557            return Err(sqlx_core::Error::Protocol(format!(
558                "unsupported ODBC buffer descriptor: {desc:?}"
559            )))
560        }
561    })
562}
563
564fn buffered_numeric<T, F>(
565    slice: &AnyColumnBufferSlice<'_>,
566    desc: BufferDesc,
567    nullable: bool,
568    map: F,
569) -> std::result::Result<Vec<OdbcValueKind>, sqlx_core::Error>
570where
571    T: Copy + odbc_api::Pod,
572    F: FnMut(T) -> OdbcValueKind,
573{
574    if nullable {
575        Ok(buffered_nullable_numeric(
576            expect_buffer_slice(slice.as_nullable_slice::<T>(), desc)?,
577            map,
578        ))
579    } else {
580        Ok(expect_buffer_slice(slice.as_slice::<T>(), desc)?
581            .iter()
582            .copied()
583            .map(map)
584            .collect())
585    }
586}
587
588fn buffered_nullable_numeric<T, F>(slice: NullableSlice<'_, T>, mut map: F) -> Vec<OdbcValueKind>
589where
590    T: Copy,
591    F: FnMut(T) -> OdbcValueKind,
592{
593    slice
594        .map(|value| value.copied().map(&mut map).unwrap_or(OdbcValueKind::Null))
595        .collect()
596}
597
598fn expect_buffer_slice<T>(
599    slice: Option<T>,
600    desc: BufferDesc,
601) -> std::result::Result<T, sqlx_core::Error> {
602    slice.ok_or_else(|| {
603        sqlx_core::Error::Protocol(format!(
604            "ODBC column buffer {desc:?} did not match fetched slice"
605        ))
606    })
607}
608
609fn collect_rows_unbuffered<C>(mut cursor: C) -> std::result::Result<Vec<OdbcRow>, sqlx_core::Error>
610where
611    C: Cursor + ResultSetMetadata,
612{
613    let columns = collect_columns(&mut cursor)?;
614    let mut rows = Vec::new();
615
616    while let Some(mut cursor_row) = cursor.next_row().map_err(crate::OdbcError::from)? {
617        let mut values = Vec::with_capacity(columns.len());
618
619        for column in &columns {
620            let column_number = u16::try_from(sqlx_core::column::Column::ordinal(column) + 1)
621                .map_err(|_| {
622                    sqlx_core::Error::Protocol("ODBC column index exceeds u16".to_owned())
623                })?;
624            values.push(fetch_value(
625                &mut cursor_row,
626                column_number,
627                column.type_info().data_type(),
628            )?);
629        }
630
631        rows.push(OdbcRow::new(columns.clone(), values));
632    }
633
634    Ok(rows)
635}
636
637fn fetch_value(
638    row: &mut odbc_api::CursorRow<'_>,
639    column_number: u16,
640    data_type: DataType,
641) -> std::result::Result<OdbcValue, sqlx_core::Error> {
642    let kind = match data_type {
643        DataType::Bit => {
644            let mut value = Nullable::<odbc_api::Bit>::null();
645            row.get_data(column_number, &mut value)
646                .map_err(crate::OdbcError::from)?;
647            value
648                .into_opt()
649                .map(|value| OdbcValueKind::Bit(value.as_bool()))
650                .unwrap_or(OdbcValueKind::Null)
651        }
652        DataType::TinyInt => fetch_nullable(row, column_number, OdbcValueKind::TinyInt)?,
653        DataType::SmallInt => fetch_nullable(row, column_number, OdbcValueKind::SmallInt)?,
654        DataType::Integer => fetch_nullable(row, column_number, OdbcValueKind::Integer)?,
655        DataType::BigInt => fetch_nullable(row, column_number, OdbcValueKind::BigInt)?,
656        DataType::Real => fetch_nullable(row, column_number, OdbcValueKind::Real)?,
657        DataType::Float { .. } | DataType::Double => {
658            fetch_nullable(row, column_number, OdbcValueKind::Double)?
659        }
660        DataType::Date => fetch_nullable(row, column_number, OdbcValueKind::Date)?,
661        DataType::Time { .. } => fetch_nullable(row, column_number, OdbcValueKind::Time)?,
662        DataType::Timestamp { .. } => fetch_nullable(row, column_number, OdbcValueKind::Timestamp)?,
663        DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => {
664            let mut value = Vec::new();
665            if row
666                .get_binary(column_number, &mut value)
667                .map_err(crate::OdbcError::from)?
668            {
669                OdbcValueKind::Binary(value)
670            } else {
671                OdbcValueKind::Null
672            }
673        }
674        _ => {
675            let mut value = Vec::new();
676            if row
677                .get_wide_text(column_number, &mut value)
678                .map_err(crate::OdbcError::from)?
679            {
680                OdbcValueKind::Text(String::from_utf16_lossy(&value))
681            } else {
682                OdbcValueKind::Null
683            }
684        }
685    };
686
687    Ok(OdbcValue::new(kind))
688}
689
690fn fetch_nullable<T, F>(
691    row: &mut odbc_api::CursorRow<'_>,
692    column_number: u16,
693    map: F,
694) -> std::result::Result<OdbcValueKind, sqlx_core::Error>
695where
696    T: Default + Copy + odbc_api::parameter::CElement + odbc_api::handles::CDataMut,
697    Nullable<T>: odbc_api::parameter::CElement + odbc_api::handles::CDataMut,
698    F: FnOnce(T) -> OdbcValueKind,
699{
700    let mut value = Nullable::<T>::null();
701    row.get_data(column_number, &mut value)
702        .map_err(crate::OdbcError::from)?;
703    Ok(value.into_opt().map(map).unwrap_or(OdbcValueKind::Null))
704}
705
706#[cfg(test)]
707mod tests {
708    use super::*;
709
710    #[test]
711    fn buffered_fetch_maps_numeric_types_to_nullable_64_bit_buffers() {
712        assert!(matches!(
713            map_buffer_desc(DataType::TinyInt, 64),
714            BufferDesc::I64 { nullable: true }
715        ));
716        assert!(matches!(
717            map_buffer_desc(DataType::Integer, 64),
718            BufferDesc::I64 { nullable: true }
719        ));
720        assert!(matches!(
721            map_buffer_desc(DataType::BigInt, 64),
722            BufferDesc::I64 { nullable: true }
723        ));
724    }
725
726    #[test]
727    fn buffered_fetch_uses_configured_limits_for_variable_sized_data() {
728        assert_eq!(
729            map_buffer_desc(DataType::Varchar { length: None }, 32),
730            BufferDesc::Text { max_str_len: 32 }
731        );
732        assert_eq!(
733            map_buffer_desc(DataType::Varbinary { length: None }, 16),
734            BufferDesc::Binary { max_bytes: 16 }
735        );
736    }
737
738    #[test]
739    fn ping_query_uses_db2_dummy_table_for_db2_drivers() {
740        assert_eq!(
741            "SELECT 1 FROM SYSIBM.SYSDUMMY1",
742            ping_query_for_dbms_name("DB2")
743        );
744        assert_eq!(
745            "SELECT 1 FROM SYSIBM.SYSDUMMY1",
746            ping_query_for_dbms_name("DB2 UDB for AS/400")
747        );
748        assert_eq!(
749            "SELECT 1 FROM SYSIBM.SYSDUMMY1",
750            ping_query_for_dbms_name("IBM DB2 for i")
751        );
752        assert_eq!(
753            "SELECT 1 FROM SYSIBM.SYSDUMMY1",
754            ping_query_for_dbms_name("iSeries")
755        );
756        assert_eq!(
757            "SELECT 1 FROM SYSIBM.SYSDUMMY1",
758            ping_query_for_dbms_name("IBM i")
759        );
760    }
761
762    #[test]
763    fn ping_query_keeps_select_one_for_non_db2_drivers() {
764        assert_eq!("SELECT 1", ping_query_for_dbms_name("DuckDB"));
765        assert_eq!("SELECT 1", ping_query_for_dbms_name("Microsoft SQL Server"));
766        assert_eq!("SELECT 1", ping_query_for_dbms_name("PostgreSQL"));
767    }
768}