odbc_api_helper/extension/
pg.rs

1use crate::executor::database::Options;
2use crate::executor::query::QueryResult;
3use crate::executor::statement::SqlValue;
4use crate::extension::odbc::{OdbcColumn, OdbcColumnItem, OdbcColumnType};
5use crate::{Convert, TryConvert};
6use bytes::BytesMut;
7use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
8use either::Either;
9use odbc_api::buffers::BufferKind;
10use odbc_api::parameter::InputParameter;
11use odbc_api::Bit;
12use odbc_api::IntoParameter;
13use pg_helper::table::PgTableItem;
14use postgres_protocol::types as pp_type;
15use postgres_types::{Oid, Type as PgType};
16use std::collections::BTreeMap;
17
18use crate::executor::table::TableDescResult;
19use crate::executor::SupportDatabase;
20use dameng_helper::table::DmTableDesc;
21use pg_helper::table::PgTableDesc;
22
23#[derive(Debug)]
24pub enum PgValueInput {
25    INT2(i16),
26    INT4(i32),
27    INT8(i64),
28    FLOAT4(f32),
29    FLOAT8(f64),
30    CHAR(String),
31    VARCHAR(String),
32    TEXT(String),
33    Bool(bool),
34    Binary(Vec<u8>),
35}
36
37impl SqlValue for PgValueInput {
38    fn to_value(self) -> Either<Box<dyn InputParameter>, ()> {
39        macro_rules! left_param {
40            ($($arg:tt)*) => {{
41                Either::Left(Box::new($($arg)*))
42            }};
43        }
44
45        match self {
46            Self::INT2(i) => left_param!(i.into_parameter()),
47            Self::INT4(i) => left_param!(i.into_parameter()),
48            Self::INT8(i) => left_param!(i.into_parameter()),
49            Self::FLOAT4(i) => left_param!(i.into_parameter()),
50            Self::FLOAT8(i) => left_param!(i.into_parameter()),
51            Self::CHAR(i) => left_param!(i.into_parameter()),
52            Self::VARCHAR(i) => left_param!(i.into_parameter()),
53            Self::TEXT(i) => left_param!(i.into_parameter()),
54            Self::Bool(i) => left_param!(Bit::from_bool(i).into_parameter()),
55            PgValueInput::Binary(bytes) => left_param!(bytes.into_parameter()),
56        }
57    }
58}
59
60#[derive(Debug, Default, Eq, PartialEq)]
61pub struct PgQueryResult {
62    pub columns: Vec<PgColumn>,
63    pub data: Vec<Vec<PgColumnItem>>,
64}
65
66#[derive(Debug, Eq, PartialEq)]
67pub struct PgColumn {
68    pub name: String,
69    pub pg_type: PgType,
70    pub oid: Oid,
71    pub nullable: bool,
72}
73
74#[derive(Debug, Eq, PartialEq)]
75pub struct PgColumnItem {
76    pub data: Option<BytesMut>,
77    pub pg_type: PgType,
78    pub oid: Oid,
79}
80
81impl PgColumnItem {
82    fn new(data: BytesMut, pg_type: PgType) -> Self {
83        let oid = pg_type.oid();
84        Self {
85            data: Some(data),
86            pg_type,
87            oid,
88        }
89    }
90
91    fn new_pg_type(pg_type: PgType) -> Self {
92        let oid = pg_type.oid();
93        Self {
94            data: None,
95            pg_type,
96            oid,
97        }
98    }
99}
100
101impl Convert<PgColumn> for OdbcColumn {
102    fn convert(self) -> PgColumn {
103        let buffer_kind = BufferKind::from_data_type(self.data_type).unwrap();
104        let pg_type = match buffer_kind {
105            BufferKind::Binary { .. } => PgType::BYTEA,
106            BufferKind::Text { .. } => PgType::TEXT,
107            BufferKind::WText { .. } => PgType::TEXT,
108            BufferKind::F64 => PgType::FLOAT8,
109            BufferKind::F32 => PgType::FLOAT4,
110            BufferKind::Date => PgType::DATE,
111            BufferKind::Time => PgType::TIME,
112            BufferKind::Timestamp => PgType::TIMESTAMP,
113            BufferKind::I8 => PgType::CHAR,
114            BufferKind::I16 => PgType::INT2,
115            BufferKind::I32 => PgType::INT4,
116            BufferKind::I64 => PgType::INT8,
117            BufferKind::U8 => {
118                panic!("not coverage U8");
119            }
120            BufferKind::Bit => PgType::BOOL,
121        };
122        let oid = pg_type.oid();
123        PgColumn {
124            name: self.name,
125            pg_type,
126            oid,
127            nullable: self.nullable,
128        }
129    }
130}
131
132/// referring to link:`<https://docs.rs/postgres-protocol/0.6.4/postgres_protocol/types/index.html#functions>`
133impl Convert<PgColumnItem> for OdbcColumnItem {
134    fn convert(self) -> PgColumnItem {
135        let mut buf = BytesMut::new();
136
137        let (_, t) = match self.odbc_type {
138            OdbcColumnType::Text => (
139                self.value
140                    .map(|x| pp_type::text_to_sql(&String::from_utf8_lossy(x.as_ref()), &mut buf)),
141                PgType::TEXT,
142            ),
143
144            OdbcColumnType::WText => (
145                self.value
146                    .map(|x| pp_type::text_to_sql(&String::from_utf8_lossy(x.as_ref()), &mut buf)),
147                PgType::TEXT,
148            ),
149
150            OdbcColumnType::Binary => (
151                self.value
152                    .map(|x| pp_type::bytea_to_sql(x.as_ref(), &mut buf)),
153                PgType::BYTEA,
154            ),
155            OdbcColumnType::Date => (
156                self.value.map(|x| {
157                    let val = String::from_utf8_lossy(x.as_ref()).to_string();
158                    let date = NaiveDate::parse_from_str(val.as_str(), "%Y-%m-%d").unwrap();
159                    let days = (date - NaiveDate::from_ymd(2000, 1, 1)).num_days();
160                    if days > i64::from(i32::max_value()) || days < i64::from(i32::min_value()) {
161                        panic!("value too large to transmit");
162                    }
163                    pp_type::date_to_sql(days as i32, &mut buf)
164                }),
165                PgType::DATE,
166            ),
167            OdbcColumnType::Time => (
168                self.value.map(|x| {
169                    let val = String::from_utf8_lossy(x.as_ref()).to_string();
170                    let time = NaiveTime::parse_from_str(val.as_str(), "%H:%M:%S%.f").unwrap();
171                    let delta = (time - NaiveTime::from_hms(0, 0, 0))
172                        .num_microseconds()
173                        .unwrap();
174                    pp_type::time_to_sql(delta, &mut buf)
175                }),
176                PgType::TIME,
177            ),
178            OdbcColumnType::Timestamp => (
179                self.value.map(|x| {
180                    let val = String::from_utf8_lossy(x.as_ref()).to_string();
181                    let date_time = NaiveDateTime::parse_from_str(
182                        val.as_str(),
183                        if val.contains('+') {
184                            "%Y-%m-%d %H:%M:%S%.f%#z"
185                        } else {
186                            "%Y-%m-%d %H:%M:%S%.f"
187                        },
188                    )
189                    .unwrap();
190                    let epoch = NaiveDate::from_ymd(2000, 1, 1).and_hms(0, 0, 0);
191                    let ms = (date_time - epoch).num_microseconds().unwrap();
192                    pp_type::timestamp_to_sql(ms, &mut buf)
193                }),
194                PgType::TIMESTAMP,
195            ),
196            OdbcColumnType::F64 => (
197                self.value.map(|x| {
198                    let val = &String::from_utf8_lossy(x.as_ref())
199                        .to_string()
200                        .parse::<f64>()
201                        .unwrap();
202                    pp_type::float8_to_sql(*val, &mut buf)
203                }),
204                PgType::FLOAT8,
205            ),
206            OdbcColumnType::F32 => (
207                self.value.map(|x| {
208                    let val = &String::from_utf8_lossy(x.as_ref())
209                        .to_string()
210                        .parse::<f32>()
211                        .unwrap();
212                    pp_type::float4_to_sql(*val, &mut buf)
213                }),
214                PgType::FLOAT4,
215            ),
216            OdbcColumnType::I8 => (
217                self.value.map(|x| {
218                    let val = &String::from_utf8_lossy(x.as_ref())
219                        .to_string()
220                        .parse::<i8>()
221                        .unwrap();
222                    pp_type::char_to_sql(*val, &mut buf)
223                }),
224                PgType::CHAR,
225            ),
226            OdbcColumnType::I16 => (
227                self.value.map(|x| {
228                    let val = &String::from_utf8_lossy(x.as_ref())
229                        .to_string()
230                        .parse::<i16>()
231                        .unwrap();
232                    pp_type::int2_to_sql(*val, &mut buf)
233                }),
234                PgType::INT2,
235            ),
236            OdbcColumnType::I32 => (
237                self.value.map(|x| {
238                    let val = &String::from_utf8_lossy(x.as_ref())
239                        .to_string()
240                        .parse::<i32>()
241                        .unwrap();
242                    pp_type::int4_to_sql(*val, &mut buf)
243                }),
244                PgType::INT4,
245            ),
246            OdbcColumnType::I64 => (
247                self.value.map(|x| {
248                    let val = &String::from_utf8_lossy(x.as_ref())
249                        .to_string()
250                        .parse::<i64>()
251                        .unwrap();
252                    pp_type::int8_to_sql(*val, &mut buf)
253                }),
254                PgType::INT8,
255            ),
256            OdbcColumnType::U8 => (
257                self.value.map(|x| {
258                    let val = x.as_ref().first().unwrap();
259                    pp_type::char_to_sql(*val as i8, &mut buf)
260                }),
261                PgType::CHAR,
262            ),
263            OdbcColumnType::Bit => (
264                self.value.map(|x| {
265                    let val = &String::from_utf8_lossy(x.as_ref())
266                        .to_string()
267                        .parse::<bool>()
268                        .unwrap();
269                    pp_type::bool_to_sql(*val, &mut buf)
270                }),
271                PgType::BOOL,
272            ),
273        };
274        PgColumnItem::new(buf, t)
275    }
276}
277
278impl From<QueryResult> for PgQueryResult {
279    fn from(result: QueryResult) -> Self {
280        PgQueryResult {
281            columns: result.columns.into_iter().map(|x| x.convert()).collect(),
282            data: result
283                .data
284                .into_iter()
285                .map(|x| x.into_iter().map(|x| x.convert()).collect())
286                .collect(),
287        }
288    }
289}
290
291impl Convert<PgType> for Oid {
292    fn convert(self) -> PgType {
293        PgType::from_oid(self).unwrap()
294    }
295}
296
297impl Convert<PgType> for PgType {
298    fn convert(self) -> PgType {
299        self
300    }
301}
302
303pub fn oid_typlen<C: Convert<PgType>>(c: C) -> i16 {
304    let pg_type = c.convert();
305    pg_helper::oid_typlen(pg_type)
306}
307
308impl TryConvert<PgTableDesc> for (TableDescResult, &Options) {
309    type Error = anyhow::Error;
310
311    fn try_convert(self) -> Result<PgTableDesc, Self::Error> {
312        let pg = match self.1.database {
313            SupportDatabase::Dameng => {
314                let dm = DmTableDesc::new(self.0 .0, self.0 .1)?;
315                let mut pg = BTreeMap::new();
316                for (k, v) in dm.data.into_iter() {
317                    let mut pg_item = Vec::new();
318                    for dm in v {
319                        pg_item.push(dm.try_convert()?)
320                    }
321                    pg.insert(k.to_string(), pg_item);
322                }
323                PgTableDesc { data: pg }
324            }
325            _ => PgTableDesc::default(),
326        };
327
328        Ok(pg)
329    }
330}
331
332impl TryConvert<PgColumnItem> for (&OdbcColumnItem, &PgColumn) {
333    type Error = String;
334
335    fn try_convert(self) -> Result<PgColumnItem, Self::Error> {
336        let odbc_data = self.0.value.clone();
337        let original_empty = odbc_data.is_none();
338        let pg_column = self.1;
339        let mut buf = BytesMut::new();
340
341        match pg_column.pg_type {
342            PgType::TEXT | PgType::VARCHAR => {
343                if let Some(x) = odbc_data {
344                    pp_type::text_to_sql(&String::from_utf8_lossy(x.as_ref()), &mut buf);
345                }
346            }
347
348            PgType::BYTEA => {
349                if let Some(x) = odbc_data {
350                    pp_type::bytea_to_sql(x.as_ref(), &mut buf);
351                }
352            }
353
354            PgType::DATE => {
355                if let Some(x) = odbc_data {
356                    let val = String::from_utf8_lossy(x.as_ref()).to_string();
357                    let date = NaiveDate::parse_from_str(val.as_str(), "%Y-%m-%d").unwrap();
358                    let days = (date - NaiveDate::from_ymd(2000, 1, 1)).num_days();
359                    if days > i64::from(i32::max_value()) || days < i64::from(i32::min_value()) {
360                        panic!("value too large to transmit");
361                    }
362                    pp_type::date_to_sql(days as i32, &mut buf);
363                }
364            }
365
366            PgType::TIME | PgType::TIMETZ => {
367                if let Some(x) = odbc_data {
368                    let val = String::from_utf8_lossy(x.as_ref()).to_string();
369                    let time = NaiveTime::parse_from_str(
370                        val.as_str(),
371                        if val.contains('+') {
372                            "%H:%M:%S%.f%#z"
373                        } else {
374                            "%H:%M:%S%.f"
375                        },
376                    )
377                    .unwrap();
378                    let delta = (time - NaiveTime::from_hms(0, 0, 0))
379                        .num_microseconds()
380                        .unwrap();
381                    pp_type::time_to_sql(delta, &mut buf);
382                }
383            }
384
385            PgType::TIMESTAMP | PgType::TIMESTAMPTZ => {
386                if let Some(x) = odbc_data {
387                    let val = String::from_utf8_lossy(x.as_ref()).to_string();
388                    let date_time = NaiveDateTime::parse_from_str(
389                        val.as_str(),
390                        if val.contains('+') {
391                            "%Y-%m-%d %H:%M:%S%.f%#z"
392                        } else {
393                            "%Y-%m-%d %H:%M:%S%.f"
394                        },
395                    )
396                    .unwrap();
397                    let epoch = NaiveDate::from_ymd(2000, 1, 1).and_hms(0, 0, 0);
398                    let ms = (date_time - epoch).num_microseconds().unwrap();
399                    pp_type::timestamp_to_sql(ms, &mut buf);
400                }
401            }
402            PgType::FLOAT8 => {
403                if let Some(x) = odbc_data {
404                    let val = &String::from_utf8_lossy(x.as_ref())
405                        .to_string()
406                        .parse::<f64>()
407                        .unwrap();
408                    pp_type::float8_to_sql(*val, &mut buf);
409                }
410            }
411            PgType::FLOAT4 => {
412                if let Some(x) = odbc_data {
413                    let val = &String::from_utf8_lossy(x.as_ref())
414                        .to_string()
415                        .parse::<f32>()
416                        .unwrap();
417                    pp_type::float4_to_sql(*val, &mut buf);
418                }
419            }
420            PgType::CHAR => {
421                if let Some(x) = odbc_data {
422                    let val = &String::from_utf8_lossy(x.as_ref())
423                        .to_string()
424                        .parse::<i8>()
425                        .unwrap();
426                    pp_type::char_to_sql(*val, &mut buf);
427                }
428            }
429            PgType::INT2 => {
430                if let Some(x) = odbc_data {
431                    let val = &String::from_utf8_lossy(x.as_ref())
432                        .to_string()
433                        .parse::<i16>()
434                        .unwrap();
435                    pp_type::int2_to_sql(*val, &mut buf);
436                }
437            }
438            PgType::INT4 | PgType::NUMERIC => {
439                if let Some(x) = odbc_data {
440                    let val = &String::from_utf8_lossy(x.as_ref())
441                        .to_string()
442                        .parse::<i32>()
443                        .unwrap();
444                    pp_type::int4_to_sql(*val, &mut buf);
445                }
446            }
447            PgType::INT8 => {
448                if let Some(x) = odbc_data {
449                    let val = &String::from_utf8_lossy(x.as_ref())
450                        .to_string()
451                        .parse::<i64>()
452                        .unwrap();
453                    pp_type::int8_to_sql(*val, &mut buf);
454                }
455            }
456            PgType::BOOL | PgType::BIT => {
457                if let Some(x) = odbc_data {
458                    let val = &String::from_utf8_lossy(x.as_ref())
459                        .to_string()
460                        .parse::<bool>()
461                        .unwrap();
462                    pp_type::bool_to_sql(*val, &mut buf);
463                }
464            }
465            _ => {}
466        };
467
468        if original_empty && buf.is_empty() {
469            return Ok(PgColumnItem::new_pg_type(pg_column.pg_type.clone()));
470        }
471
472        Ok(PgColumnItem::new(buf, pg_column.pg_type.clone()))
473    }
474}
475
476impl TryConvert<PgQueryResult> for (QueryResult, &Vec<PgTableItem>, &Options) {
477    type Error = String;
478
479    fn try_convert(self) -> Result<PgQueryResult, Self::Error> {
480        let res = self.0;
481        let pg_all_columns = self.1;
482        let options = self.2;
483        let mut result = PgQueryResult::default();
484        if let Ok(cols) = <(&Vec<OdbcColumn>, &Vec<PgTableItem>, &Options) as TryConvert<
485            Vec<PgColumn>,
486        >>::try_convert((&res.columns, pg_all_columns, options))
487        {
488            let cols: Vec<PgColumn> = cols;
489            result.columns = cols;
490
491            // if column name is count(*),but this name not exist Vec<PgTableItem>
492            // So,could find result.columns is empty.
493            if result.columns.is_empty() {
494                return Ok(PgQueryResult::from(res));
495            }
496
497            if let crate::executor::SupportDatabase::Dameng = options.database {
498                for v in res.data.iter() {
499                    let mut row: Vec<PgColumnItem> = vec![];
500                    for (index, odbc_item) in v.iter().enumerate() {
501                        if let Some(col) = result.columns.get(index) {
502                            row.push(
503                                <(&OdbcColumnItem, &PgColumn) as TryConvert<PgColumnItem>>::try_convert((
504                                    odbc_item, col,
505                                ))
506                                    .unwrap(),
507                            );
508                        }
509                    }
510                    result.data.push(row);
511                }
512            }
513        }
514        Ok(result)
515    }
516}
517
518impl TryConvert<Vec<PgColumn>> for (&Vec<OdbcColumn>, &Vec<PgTableItem>, &Options) {
519    type Error = String;
520
521    fn try_convert(self) -> Result<Vec<PgColumn>, Self::Error> {
522        let odbc_columns = self.0;
523        let pg_all_columns = self.1;
524        let options = self.2;
525        let mut result = vec![];
526        for v in odbc_columns.iter() {
527            let find_name = |source: &str, target: &str| -> bool {
528                if options.case_sensitive {
529                    source == target
530                } else {
531                    source.to_uppercase() == target.to_uppercase()
532                }
533            };
534
535            if let Some(pg) = pg_all_columns.iter().find(|&p| find_name(&p.name, &v.name)) {
536                result.push(PgColumn {
537                    name: pg.name.clone(),
538                    pg_type: pg.r#type.clone(),
539                    oid: pg.r#type.oid(),
540                    nullable: pg.nullable,
541                });
542            } else {
543                result.push(v.clone().convert());
544            }
545        }
546
547        Ok(result)
548    }
549}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554    use odbc_api::DataType;
555
556    #[test]
557    fn test_query_result_convert() {
558        let column = OdbcColumn {
559            name: "trace_id".to_string(),
560            data_type: DataType::Varchar { length: 255 },
561            nullable: true,
562        };
563
564        let query_result = QueryResult {
565            columns: vec![column],
566            data: vec![vec![OdbcColumnItem {
567                odbc_type: OdbcColumnType::Text,
568                value: None,
569            }]],
570        };
571
572        let pg_table_item = PgTableItem {
573            name: "trace_id".to_string(),
574            table_id: 0,
575            col_index: 0,
576            r#type: PgType::VARCHAR,
577            length: 255,
578            scale: 0,
579            nullable: true,
580            default_val: None,
581            table_name: "".to_string(),
582            create_time: "".to_string(),
583        };
584        let options = Options {
585            db_name: "test_db".to_string(),
586            database: SupportDatabase::Dameng,
587            max_batch_size: 1024,
588            max_str_len: 1024,
589            max_binary_len: 1024,
590            case_sensitive: false,
591        };
592        let result: PgQueryResult = (query_result, &vec![pg_table_item], &options)
593            .try_convert()
594            .unwrap();
595        assert_eq!(
596            result,
597            PgQueryResult {
598                columns: vec![PgColumn {
599                    name: "trace_id".to_string(),
600                    pg_type: PgType::VARCHAR,
601                    oid: 1043,
602                    nullable: true,
603                }],
604                data: vec![vec![PgColumnItem {
605                    data: None,
606                    pg_type: PgType::VARCHAR,
607                    oid: 1043,
608                }]],
609            }
610        );
611    }
612}