odbc_api_helper/executor/
database.rs

1use crate::executor::execute::ExecResult;
2use crate::executor::query::QueryResult;
3use crate::executor::statement::StatementInput;
4use crate::executor::table::TableDescResult;
5use crate::executor::SupportDatabase;
6use crate::extension::odbc::{OdbcColumn, OdbcColumnItem};
7use crate::{Convert, TryConvert};
8use dameng_helper::DmAdapter;
9use either::Either;
10use odbc_api::buffers::{AnySlice, BufferDescription, ColumnarAnyBuffer};
11use odbc_api::handles::StatementImpl;
12use odbc_api::{
13    ColumnDescription, Connection, Cursor, CursorImpl, ParameterCollectionRef, ResultSetMetadata,
14};
15use std::ops::IndexMut;
16
17pub trait ConnectionTrait {
18    /// Execute a `[Statement]`  INSETT,UPDATE,DELETE
19    fn execute<S>(&self, stmt: S) -> anyhow::Result<ExecResult>
20    where
21        S: StatementInput;
22
23    /// Execute a `[Statement]` and return a collection Vec<[QueryResult]> on success
24    fn query<S>(&self, stmt: S) -> anyhow::Result<QueryResult>
25    where
26        S: StatementInput;
27
28    fn show_table(&self, table_names: Vec<String>) -> anyhow::Result<TableDescResult>;
29
30    // begin transaction
31    fn begin(&self) -> anyhow::Result<()>;
32
33    // finish transaction
34    fn finish(&self) -> anyhow::Result<()>;
35
36    fn commit(&self) -> anyhow::Result<()>;
37
38    fn rollback(&self) -> anyhow::Result<()>;
39}
40
41#[allow(missing_debug_implementations)]
42pub struct OdbcDbConnection<'a> {
43    pub conn: Connection<'a>,
44    pub options: Options,
45}
46
47#[derive(Debug)]
48pub struct Options {
49    pub db_name: String,
50    pub database: SupportDatabase,
51    pub max_batch_size: usize,
52    pub max_str_len: usize,
53    pub max_binary_len: usize,
54    // ignore uppercase/lowercase,default is false.
55    // false:all column name convert uppercase
56    // true: ignore,keep original column name
57    pub case_sensitive: bool,
58}
59
60impl Options {
61    // Default Max Buffer Size 256
62    pub const MAX_BATCH_SIZE: usize = 1 << 7;
63    // Default Max string length 1MB
64    pub const MAX_STR_LEN: usize = 1024;
65    // Default Max binary length 1MB
66    pub const MAX_BINARY_LEN: usize = 1024 * 1024;
67
68    pub fn new(db_name: String, database: SupportDatabase) -> Self {
69        Options {
70            db_name,
71            database,
72            max_batch_size: Self::MAX_BATCH_SIZE,
73            max_str_len: Self::MAX_STR_LEN,
74            max_binary_len: Self::MAX_BINARY_LEN,
75            case_sensitive: false,
76        }
77    }
78
79    fn check(mut self) -> Self {
80        if self.max_batch_size == 0 {
81            self.max_batch_size = Self::MAX_BATCH_SIZE
82        }
83
84        if self.max_str_len == 0 {
85            // Add default size:1MB
86            self.max_str_len = Self::MAX_STR_LEN
87        }
88
89        if self.max_binary_len == 0 {
90            // Add default size:1MB
91            self.max_binary_len = Self::MAX_BINARY_LEN
92        }
93        self
94    }
95}
96
97impl<'a> ConnectionTrait for OdbcDbConnection<'a> {
98    fn execute<S>(&self, stmt: S) -> anyhow::Result<ExecResult>
99    where
100        S: StatementInput,
101    {
102        let sql = stmt.to_sql().to_string();
103        match stmt.values()? {
104            Either::Left(params) => self.exec_result(sql, &params[..]),
105            Either::Right(()) => self.exec_result(sql, ()),
106        }
107    }
108
109    fn query<S>(&self, stmt: S) -> anyhow::Result<QueryResult>
110    where
111        S: StatementInput,
112    {
113        let sql = stmt.to_sql().to_string();
114
115        match stmt.values()? {
116            Either::Left(params) => self.query_result(&sql, &params[..]),
117            Either::Right(()) => self.query_result(&sql, ()),
118        }
119    }
120
121    fn show_table(&self, table_names: Vec<String>) -> anyhow::Result<TableDescResult> {
122        self.table_desc(table_names)
123    }
124
125    fn begin(&self) -> anyhow::Result<()> {
126        Ok(self.conn.set_autocommit(false)?)
127    }
128
129    fn finish(&self) -> anyhow::Result<()> {
130        self.conn.set_autocommit(true)?;
131        Ok(())
132    }
133
134    fn commit(&self) -> anyhow::Result<()> {
135        self.conn.commit()?;
136        Ok(())
137    }
138
139    fn rollback(&self) -> anyhow::Result<()> {
140        self.conn.rollback()?;
141        Ok(())
142    }
143}
144
145impl<'a> OdbcDbConnection<'a> {
146    pub fn new(conn: Connection<'a>, options: Options) -> anyhow::Result<Self> {
147        let options = options.check();
148        let connection = Self { conn, options };
149        Ok(connection)
150    }
151
152    fn exec_result<S: Into<String>>(
153        &self,
154        sql: S,
155        params: impl ParameterCollectionRef,
156    ) -> anyhow::Result<ExecResult> {
157        let mut stmt = self.conn.preallocate()?;
158        stmt.execute(&sql.into(), params)?;
159        let row_op = stmt.row_count()?;
160        let result = row_op
161            .map(|r| ExecResult { rows_affected: r })
162            .unwrap_or_default();
163        Ok(result)
164    }
165
166    fn query_result(
167        &self,
168        sql: &str,
169        params: impl ParameterCollectionRef,
170    ) -> anyhow::Result<QueryResult> {
171        let mut cursor = self
172            .conn
173            .execute(sql, params)?
174            .ok_or_else(|| anyhow!("query error"))?;
175
176        let mut query_result = Self::get_cursor_columns(&mut cursor)?;
177        debug!("columns:{:?}", query_result.columns);
178
179        let descs = query_result.columns.iter().map(|c| {
180            <(&OdbcColumn, &Options) as TryConvert<BufferDescription>>::try_convert((
181                c,
182                &self.options,
183            ))
184            .unwrap()
185        });
186
187        let row_set_buffer =
188            ColumnarAnyBuffer::try_from_description(self.options.max_batch_size, descs).unwrap();
189
190        let mut row_set_cursor = cursor.bind_buffer(row_set_buffer).unwrap();
191
192        let mut total_row = vec![];
193        while let Some(row_set) = row_set_cursor.fetch()? {
194            for index in 0..query_result.columns.len() {
195                let column_view: AnySlice = row_set.column(index);
196                let column_types: Vec<OdbcColumnItem> = column_view.convert();
197                if index == 0 {
198                    for c in column_types.into_iter() {
199                        total_row.push(vec![c]);
200                    }
201                } else {
202                    for (col_index, c) in column_types.into_iter().enumerate() {
203                        let row = total_row.index_mut(col_index);
204                        row.push(c)
205                    }
206                }
207            }
208        }
209        query_result.data = total_row;
210        Ok(query_result)
211    }
212
213    fn get_cursor_columns(cursor: &mut CursorImpl<StatementImpl>) -> anyhow::Result<QueryResult> {
214        let mut query_result = QueryResult::default();
215        for index in 0..cursor.num_result_cols()?.try_into()? {
216            let mut column_description = ColumnDescription::default();
217            cursor.describe_col(index + 1, &mut column_description)?;
218
219            let column = OdbcColumn::new(
220                column_description.name_to_string()?,
221                column_description.data_type,
222                column_description.could_be_nullable(),
223            );
224            query_result.columns.push(column);
225        }
226        Ok(query_result)
227    }
228
229    fn table_desc(&self, table_names: Vec<String>) -> anyhow::Result<TableDescResult> {
230        let db = &self.options.database;
231        match db {
232            SupportDatabase::Dameng => {
233                let describe = CursorImpl::get_table_sql(
234                    table_names,
235                    &self.options.db_name,
236                    self.options.case_sensitive,
237                );
238                let cursor = self
239                    .conn
240                    .execute(&describe.describe_sql, ())?
241                    .ok_or_else(|| anyhow!("query error"))?;
242                cursor.get_table_desc(describe)
243            }
244            _ => {
245                bail!("current not support database:{:?}", db)
246            }
247        }
248    }
249}