sqlx-core-oldapi 0.6.49-beta

Core of SQLx, the rust SQL toolkit. Not intended to be used directly.
Documentation
use crate::error::Error;
use crate::odbc::{
    connection::MaybePrepared, OdbcArgumentValue, OdbcArguments, OdbcColumn, OdbcQueryResult,
    OdbcRow, OdbcTypeInfo,
};
use either::Either;
use flume::{SendError, Sender};
use odbc_api::handles::{AsStatementRef, Statement};
use odbc_api::{Cursor, CursorRow, IntoParameter, Nullable, ResultSetMetadata};

pub type ExecuteResult = Result<Either<OdbcQueryResult, OdbcRow>, Error>;
pub type ExecuteSender = Sender<ExecuteResult>;

pub fn establish_connection(
    options: &crate::odbc::OdbcConnectOptions,
) -> Result<odbc_api::Connection<'static>, Error> {
    let env = odbc_api::environment().map_err(|e| Error::Configuration(e.to_string().into()))?;
    let conn = env
        .connect_with_connection_string(options.connection_string(), Default::default())
        .map_err(|e| Error::Configuration(e.to_string().into()))?;
    Ok(conn)
}

pub fn execute_sql(
    conn: &mut odbc_api::Connection<'static>,
    maybe_prepared: MaybePrepared,
    args: Option<OdbcArguments>,
    tx: &ExecuteSender,
) -> Result<(), Error> {
    let params = prepare_parameters(args);

    let affected = match maybe_prepared {
        MaybePrepared::Prepared(prepared) => {
            let mut prepared = prepared.lock().expect("prepared statement lock");
            if let Some(mut cursor) = prepared.execute(&params[..])? {
                handle_cursor(&mut cursor, tx);
            }
            extract_rows_affected(&mut *prepared)
        }
        MaybePrepared::NotPrepared(sql) => {
            let mut preallocated = conn.preallocate().map_err(Error::from)?;
            if let Some(mut cursor) = preallocated.execute(&sql, &params[..])? {
                handle_cursor(&mut cursor, tx);
            }
            extract_rows_affected(&mut preallocated)
        }
    };

    let _ = send_done(tx, affected);
    Ok(())
}

fn extract_rows_affected<S: AsStatementRef>(stmt: &mut S) -> u64 {
    let mut stmt_ref = stmt.as_stmt_ref();
    let count = match stmt_ref.row_count().into_result(&stmt_ref) {
        Ok(count) => count,
        Err(e) => {
            log::warn!("Failed to get row count: {}", e);
            return 0;
        }
    };

    match u64::try_from(count) {
        Ok(count) => count,
        Err(e) => {
            log::warn!("Failed to get row count: {}", e);
            0
        }
    }
}

fn prepare_parameters(
    args: Option<OdbcArguments>,
) -> Vec<Box<dyn odbc_api::parameter::InputParameter>> {
    let args = args.map(|a| a.values).unwrap_or_default();
    args.into_iter().map(to_param).collect()
}

fn to_param(arg: OdbcArgumentValue) -> Box<dyn odbc_api::parameter::InputParameter + 'static> {
    match arg {
        OdbcArgumentValue::Int(i) => Box::new(i.into_parameter()),
        OdbcArgumentValue::Float(f) => Box::new(f.into_parameter()),
        OdbcArgumentValue::Text(s) => Box::new(s.into_parameter()),
        OdbcArgumentValue::Bytes(b) => Box::new(b.into_parameter()),
        OdbcArgumentValue::Null => Box::new(Option::<String>::None.into_parameter()),
    }
}

fn handle_cursor<C>(cursor: &mut C, tx: &ExecuteSender)
where
    C: Cursor + ResultSetMetadata,
{
    let columns = collect_columns(cursor);

    match stream_rows(cursor, &columns, tx) {
        Ok(true) => {
            let _ = send_done(tx, 0);
        }
        Ok(false) => {}
        Err(e) => {
            send_error(tx, e);
        }
    }
}

fn send_done(tx: &ExecuteSender, rows_affected: u64) -> Result<(), SendError<ExecuteResult>> {
    tx.send(Ok(Either::Left(OdbcQueryResult { rows_affected })))
}

fn send_error(tx: &ExecuteSender, error: Error) {
    let _ = tx.send(Err(error));
}

fn send_row(tx: &ExecuteSender, row: OdbcRow) -> Result<(), SendError<ExecuteResult>> {
    tx.send(Ok(Either::Right(row)))
}

fn collect_columns<C>(cursor: &mut C) -> Vec<OdbcColumn>
where
    C: ResultSetMetadata,
{
    let count = cursor.num_result_cols().unwrap_or(0);
    (1..=count)
        .map(|i| create_column(cursor, i as u16))
        .collect()
}

fn create_column<C>(cursor: &mut C, index: u16) -> OdbcColumn
where
    C: ResultSetMetadata,
{
    let mut cd = odbc_api::ColumnDescription::default();
    let _ = cursor.describe_col(index, &mut cd);

    OdbcColumn {
        name: decode_column_name(cd.name, index),
        type_info: OdbcTypeInfo::new(cd.data_type),
        ordinal: usize::from(index.checked_sub(1).unwrap()),
    }
}

fn decode_column_name(name_bytes: Vec<u8>, index: u16) -> String {
    String::from_utf8(name_bytes).unwrap_or_else(|_| format!("col{}", index - 1))
}

fn stream_rows<C>(cursor: &mut C, columns: &[OdbcColumn], tx: &ExecuteSender) -> Result<bool, Error>
where
    C: Cursor,
{
    let mut receiver_open = true;

    while let Some(mut row) = cursor.next_row()? {
        let values = collect_row_values(&mut row, columns)?;
        let row_data = OdbcRow {
            columns: columns.to_vec(),
            values: values.into_iter().map(|(_, value)| value).collect(),
        };

        if send_row(tx, row_data).is_err() {
            receiver_open = false;
            break;
        }
    }
    Ok(receiver_open)
}

fn collect_row_values(
    row: &mut CursorRow<'_>,
    columns: &[OdbcColumn],
) -> Result<Vec<(OdbcTypeInfo, crate::odbc::OdbcValue)>, Error> {
    columns
        .iter()
        .enumerate()
        .map(|(i, column)| collect_column_value(row, i, column))
        .collect()
}

fn collect_column_value(
    row: &mut CursorRow<'_>,
    index: usize,
    column: &OdbcColumn,
) -> Result<(OdbcTypeInfo, crate::odbc::OdbcValue), Error> {
    use odbc_api::DataType;

    let col_idx = (index + 1) as u16;
    let type_info = column.type_info.clone();
    let data_type = type_info.data_type();

    let value = match data_type {
        DataType::TinyInt
        | DataType::SmallInt
        | DataType::Integer
        | DataType::BigInt
        | DataType::Bit => extract_int(row, col_idx, &type_info)?,

        DataType::Real => extract_float::<f32>(row, col_idx, &type_info)?,
        DataType::Float { .. } | DataType::Double => {
            extract_float::<f64>(row, col_idx, &type_info)?
        }

        DataType::Char { .. }
        | DataType::Varchar { .. }
        | DataType::LongVarchar { .. }
        | DataType::WChar { .. }
        | DataType::WVarchar { .. }
        | DataType::WLongVarchar { .. }
        | DataType::Date
        | DataType::Time { .. }
        | DataType::Timestamp { .. }
        | DataType::Decimal { .. }
        | DataType::Numeric { .. } => extract_text(row, col_idx, &type_info)?,

        DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => {
            extract_binary(row, col_idx, &type_info)?
        }

        DataType::Unknown | DataType::Other { .. } => {
            match extract_text(row, col_idx, &type_info) {
                Ok(v) => v,
                Err(_) => extract_binary(row, col_idx, &type_info)?,
            }
        }
    };

    Ok((type_info, value))
}

fn extract_int(
    row: &mut CursorRow<'_>,
    col_idx: u16,
    type_info: &OdbcTypeInfo,
) -> Result<crate::odbc::OdbcValue, Error> {
    let mut nullable = Nullable::<i64>::null();
    row.get_data(col_idx, &mut nullable)?;

    let (is_null, int) = match nullable.into_opt() {
        None => (true, None),
        Some(v) => (false, Some(v)),
    };

    Ok(crate::odbc::OdbcValue {
        type_info: type_info.clone(),
        is_null,
        text: None,
        blob: None,
        int,
        float: None,
    })
}

fn extract_float<T>(
    row: &mut CursorRow<'_>,
    col_idx: u16,
    type_info: &OdbcTypeInfo,
) -> Result<crate::odbc::OdbcValue, Error>
where
    T: Into<f64> + Default,
    odbc_api::Nullable<T>: odbc_api::parameter::CElement + odbc_api::handles::CDataMut,
{
    let mut nullable = Nullable::<T>::null();
    row.get_data(col_idx, &mut nullable)?;

    let (is_null, float) = match nullable.into_opt() {
        None => (true, None),
        Some(v) => (false, Some(v.into())),
    };

    Ok(crate::odbc::OdbcValue {
        type_info: type_info.clone(),
        is_null,
        text: None,
        blob: None,
        int: None,
        float,
    })
}

fn extract_text(
    row: &mut CursorRow<'_>,
    col_idx: u16,
    type_info: &OdbcTypeInfo,
) -> Result<crate::odbc::OdbcValue, Error> {
    let mut buf = Vec::new();
    let is_some = row.get_text(col_idx, &mut buf)?;

    let (is_null, text) = if !is_some {
        (true, None)
    } else {
        match String::from_utf8(buf) {
            Ok(s) => (false, Some(s)),
            Err(e) => return Err(Error::Decode(e.into())),
        }
    };

    Ok(crate::odbc::OdbcValue {
        type_info: type_info.clone(),
        is_null,
        text,
        blob: None,
        int: None,
        float: None,
    })
}

fn extract_binary(
    row: &mut CursorRow<'_>,
    col_idx: u16,
    type_info: &OdbcTypeInfo,
) -> Result<crate::odbc::OdbcValue, Error> {
    let mut buf = Vec::new();
    let is_some = row.get_binary(col_idx, &mut buf)?;

    let (is_null, blob) = if !is_some {
        (true, None)
    } else {
        (false, Some(buf))
    };

    Ok(crate::odbc::OdbcValue {
        type_info: type_info.clone(),
        is_null,
        text: None,
        blob,
        int: None,
        float: None,
    })
}