odbc_api_helper/executor/
database.rs1use crate::executor::execute::ExecResult;
2use crate::executor::query::QueryResult;
3use crate::executor::statement::StatementInput;
4use crate::executor::table::TableDescResult;
5use crate::executor::SupportDatabase;
6use crate::extension::odbc::{OdbcColumn, OdbcColumnItem};
7use crate::{Convert, TryConvert};
8use dameng_helper::DmAdapter;
9use either::Either;
10use odbc_api::buffers::{AnySlice, BufferDescription, ColumnarAnyBuffer};
11use odbc_api::handles::StatementImpl;
12use odbc_api::{
13 ColumnDescription, Connection, Cursor, CursorImpl, ParameterCollectionRef, ResultSetMetadata,
14};
15use std::ops::IndexMut;
16
17pub trait ConnectionTrait {
18 fn execute<S>(&self, stmt: S) -> anyhow::Result<ExecResult>
20 where
21 S: StatementInput;
22
23 fn query<S>(&self, stmt: S) -> anyhow::Result<QueryResult>
25 where
26 S: StatementInput;
27
28 fn show_table(&self, table_names: Vec<String>) -> anyhow::Result<TableDescResult>;
29
30 fn begin(&self) -> anyhow::Result<()>;
32
33 fn finish(&self) -> anyhow::Result<()>;
35
36 fn commit(&self) -> anyhow::Result<()>;
37
38 fn rollback(&self) -> anyhow::Result<()>;
39}
40
41#[allow(missing_debug_implementations)]
42pub struct OdbcDbConnection<'a> {
43 pub conn: Connection<'a>,
44 pub options: Options,
45}
46
47#[derive(Debug)]
48pub struct Options {
49 pub db_name: String,
50 pub database: SupportDatabase,
51 pub max_batch_size: usize,
52 pub max_str_len: usize,
53 pub max_binary_len: usize,
54 pub case_sensitive: bool,
58}
59
60impl Options {
61 pub const MAX_BATCH_SIZE: usize = 1 << 7;
63 pub const MAX_STR_LEN: usize = 1024;
65 pub const MAX_BINARY_LEN: usize = 1024 * 1024;
67
68 pub fn new(db_name: String, database: SupportDatabase) -> Self {
69 Options {
70 db_name,
71 database,
72 max_batch_size: Self::MAX_BATCH_SIZE,
73 max_str_len: Self::MAX_STR_LEN,
74 max_binary_len: Self::MAX_BINARY_LEN,
75 case_sensitive: false,
76 }
77 }
78
79 fn check(mut self) -> Self {
80 if self.max_batch_size == 0 {
81 self.max_batch_size = Self::MAX_BATCH_SIZE
82 }
83
84 if self.max_str_len == 0 {
85 self.max_str_len = Self::MAX_STR_LEN
87 }
88
89 if self.max_binary_len == 0 {
90 self.max_binary_len = Self::MAX_BINARY_LEN
92 }
93 self
94 }
95}
96
97impl<'a> ConnectionTrait for OdbcDbConnection<'a> {
98 fn execute<S>(&self, stmt: S) -> anyhow::Result<ExecResult>
99 where
100 S: StatementInput,
101 {
102 let sql = stmt.to_sql().to_string();
103 match stmt.values()? {
104 Either::Left(params) => self.exec_result(sql, ¶ms[..]),
105 Either::Right(()) => self.exec_result(sql, ()),
106 }
107 }
108
109 fn query<S>(&self, stmt: S) -> anyhow::Result<QueryResult>
110 where
111 S: StatementInput,
112 {
113 let sql = stmt.to_sql().to_string();
114
115 match stmt.values()? {
116 Either::Left(params) => self.query_result(&sql, ¶ms[..]),
117 Either::Right(()) => self.query_result(&sql, ()),
118 }
119 }
120
121 fn show_table(&self, table_names: Vec<String>) -> anyhow::Result<TableDescResult> {
122 self.table_desc(table_names)
123 }
124
125 fn begin(&self) -> anyhow::Result<()> {
126 Ok(self.conn.set_autocommit(false)?)
127 }
128
129 fn finish(&self) -> anyhow::Result<()> {
130 self.conn.set_autocommit(true)?;
131 Ok(())
132 }
133
134 fn commit(&self) -> anyhow::Result<()> {
135 self.conn.commit()?;
136 Ok(())
137 }
138
139 fn rollback(&self) -> anyhow::Result<()> {
140 self.conn.rollback()?;
141 Ok(())
142 }
143}
144
145impl<'a> OdbcDbConnection<'a> {
146 pub fn new(conn: Connection<'a>, options: Options) -> anyhow::Result<Self> {
147 let options = options.check();
148 let connection = Self { conn, options };
149 Ok(connection)
150 }
151
152 fn exec_result<S: Into<String>>(
153 &self,
154 sql: S,
155 params: impl ParameterCollectionRef,
156 ) -> anyhow::Result<ExecResult> {
157 let mut stmt = self.conn.preallocate()?;
158 stmt.execute(&sql.into(), params)?;
159 let row_op = stmt.row_count()?;
160 let result = row_op
161 .map(|r| ExecResult { rows_affected: r })
162 .unwrap_or_default();
163 Ok(result)
164 }
165
166 fn query_result(
167 &self,
168 sql: &str,
169 params: impl ParameterCollectionRef,
170 ) -> anyhow::Result<QueryResult> {
171 let mut cursor = self
172 .conn
173 .execute(sql, params)?
174 .ok_or_else(|| anyhow!("query error"))?;
175
176 let mut query_result = Self::get_cursor_columns(&mut cursor)?;
177 debug!("columns:{:?}", query_result.columns);
178
179 let descs = query_result.columns.iter().map(|c| {
180 <(&OdbcColumn, &Options) as TryConvert<BufferDescription>>::try_convert((
181 c,
182 &self.options,
183 ))
184 .unwrap()
185 });
186
187 let row_set_buffer =
188 ColumnarAnyBuffer::try_from_description(self.options.max_batch_size, descs).unwrap();
189
190 let mut row_set_cursor = cursor.bind_buffer(row_set_buffer).unwrap();
191
192 let mut total_row = vec![];
193 while let Some(row_set) = row_set_cursor.fetch()? {
194 for index in 0..query_result.columns.len() {
195 let column_view: AnySlice = row_set.column(index);
196 let column_types: Vec<OdbcColumnItem> = column_view.convert();
197 if index == 0 {
198 for c in column_types.into_iter() {
199 total_row.push(vec![c]);
200 }
201 } else {
202 for (col_index, c) in column_types.into_iter().enumerate() {
203 let row = total_row.index_mut(col_index);
204 row.push(c)
205 }
206 }
207 }
208 }
209 query_result.data = total_row;
210 Ok(query_result)
211 }
212
213 fn get_cursor_columns(cursor: &mut CursorImpl<StatementImpl>) -> anyhow::Result<QueryResult> {
214 let mut query_result = QueryResult::default();
215 for index in 0..cursor.num_result_cols()?.try_into()? {
216 let mut column_description = ColumnDescription::default();
217 cursor.describe_col(index + 1, &mut column_description)?;
218
219 let column = OdbcColumn::new(
220 column_description.name_to_string()?,
221 column_description.data_type,
222 column_description.could_be_nullable(),
223 );
224 query_result.columns.push(column);
225 }
226 Ok(query_result)
227 }
228
229 fn table_desc(&self, table_names: Vec<String>) -> anyhow::Result<TableDescResult> {
230 let db = &self.options.database;
231 match db {
232 SupportDatabase::Dameng => {
233 let describe = CursorImpl::get_table_sql(
234 table_names,
235 &self.options.db_name,
236 self.options.case_sensitive,
237 );
238 let cursor = self
239 .conn
240 .execute(&describe.describe_sql, ())?
241 .ok_or_else(|| anyhow!("query error"))?;
242 cursor.get_table_desc(describe)
243 }
244 _ => {
245 bail!("current not support database:{:?}", db)
246 }
247 }
248 }
249}