datafusion_table_providers/sql/db_connection_pool/
odbcpool.rs1use crate::sql::db_connection_pool::dbconnection::odbcconn::ODBCConnection;
18use crate::sql::db_connection_pool::dbconnection::odbcconn::{ODBCDbConnection, ODBCParameter};
19use crate::sql::db_connection_pool::{DbConnectionPool, JoinPushDown};
20use async_trait::async_trait;
21use odbc_api::{sys::AttrConnectionPooling, Connection, ConnectionOptions, Environment};
22use secrecy::{ExposeSecret, SecretBox, SecretString};
23use sha2::{Digest, Sha256};
24use snafu::prelude::*;
25use std::{
26 collections::HashMap,
27 sync::{Arc, LazyLock},
28};
29
30static ENV: LazyLock<Environment> = LazyLock::new(|| unsafe {
31 if let Err(e) = Environment::set_connection_pooling(AttrConnectionPooling::DriverAware) {
36 tracing::error!("Failed to set ODBC connection pooling: {e}");
37 };
38 match Environment::new() {
39 Ok(env) => env,
40 Err(e) => {
41 panic!("Failed to create ODBC environment: {e}");
42 }
43 }
44});
45
46#[derive(Debug, Snafu)]
47pub enum Error {
48 #[snafu(display("Missing ODBC connection string parameter: odbc_connection_string"))]
49 MissingConnectionString {},
50
51 #[snafu(display("Invalid parameter: {parameter_name}"))]
52 InvalidParameterError { parameter_name: String },
53}
54
55pub struct ODBCPool {
56 pool: &'static Environment,
57 params: Arc<HashMap<String, SecretString>>,
58 connection_string: String,
59 connection_id: String,
60}
61
62fn hash_string(val: &str) -> String {
63 let mut hasher = Sha256::new();
64 hasher.update(val);
65 hasher.finalize().iter().fold(String::new(), |mut hash, b| {
66 hash.push_str(&format!("{b:02x}"));
67 hash
68 })
69}
70
71impl ODBCPool {
72 pub fn new(params: HashMap<String, SecretString>) -> Result<Self, Error> {
78 let connection_string = params
79 .get("connection_string")
80 .map(SecretBox::expose_secret)
81 .map(ToString::to_string)
82 .context(MissingConnectionStringSnafu)?;
83
84 let connection_id = hash_string(&connection_string);
87
88 Ok(Self {
89 params: params.into(),
90 connection_string,
91 connection_id,
92 pool: &ENV,
93 })
94 }
95
96 #[must_use]
97 pub fn odbc_environment(&self) -> &'static Environment {
98 self.pool
99 }
100}
101
102#[async_trait]
103impl<'a> DbConnectionPool<Connection<'a>, ODBCParameter> for ODBCPool
104where
105 'a: 'static,
106{
107 async fn connect(
108 &self,
109 ) -> Result<Box<ODBCDbConnection<'a>>, Box<dyn std::error::Error + Send + Sync>> {
110 let cxn = self.pool.connect_with_connection_string(
111 &self.connection_string,
112 ConnectionOptions::default(),
113 )?;
114
115 let odbc_cxn = ODBCConnection {
116 conn: Arc::new(cxn.into()),
117 params: Arc::clone(&self.params),
118 };
119
120 Ok(Box::new(odbc_cxn))
121 }
122
123 fn join_push_down(&self) -> JoinPushDown {
124 JoinPushDown::AllowedFor(self.connection_id.clone())
125 }
126}