Skip to main content

connector_arrow/tiberius/
schema.rs

1use std::borrow::Cow;
2use std::sync::Arc;
3
4use arrow::datatypes::{DataType, Schema, SchemaRef};
5use futures::{AsyncRead, AsyncWrite};
6use itertools::Itertools;
7
8use crate::api::{Connector, SchemaEdit, SchemaGet};
9use crate::util::escape::escaped_ident;
10use crate::{ConnectorError, TableCreateError, TableDropError};
11
12impl<S: AsyncRead + AsyncWrite + Unpin + Send> SchemaGet for super::TiberiusConnection<S> {
13    fn table_list(&mut self) -> Result<Vec<String>, ConnectorError> {
14        let query = "
15            SELECT TABLE_NAME
16            FROM INFORMATION_SCHEMA.TABLES
17            WHERE
18                TABLE_CATALOG = DB_NAME() AND
19                TABLE_SCHEMA = SCHEMA_NAME() AND
20                TABLE_TYPE='BASE TABLE'
21            ORDER BY TABLE_NAME
22        ";
23        let res = self.client.query(query, &[]);
24        let res = self.rt.block_on(res)?;
25
26        let res = res.into_first_result();
27        let res = self.rt.block_on(res)?;
28
29        let table_names = res
30            .into_iter()
31            .map(|r| r.get::<&str, _>(0).unwrap().to_string())
32            .collect_vec();
33
34        Ok(table_names)
35    }
36
37    fn table_get(
38        &mut self,
39        table_name: &str,
40    ) -> Result<arrow::datatypes::SchemaRef, ConnectorError> {
41        let query = "
42            SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, NUMERIC_PRECISION
43            FROM INFORMATION_SCHEMA.COLUMNS
44            WHERE
45                TABLE_CATALOG = DB_NAME() AND
46                TABLE_SCHEMA = SCHEMA_NAME() AND
47                TABLE_NAME = @P1
48            ORDER BY ORDINAL_POSITION;
49        ";
50        let params: [&dyn tiberius::ToSql; 1] = [&table_name.to_string()];
51        let res = self.client.query(query, &params);
52        let res = self.rt.block_on(res)?;
53
54        let res = res.into_first_result();
55        let res = self.rt.block_on(res)?;
56
57        let fields: Vec<_> = res
58            .into_iter()
59            .map(|row| -> Result<_, ConnectorError> {
60                let name: &str = row.get(0).unwrap();
61                let data_type: &str = row.get(1).unwrap();
62                let is_nullable: bool = row.get::<&str, _>(2).unwrap() != "NO";
63                let numeric_precision: Option<u8> = row.get(3);
64
65                let db_type_name = if let Some(numeric_precision) = numeric_precision {
66                    Cow::from(format!("{data_type}({numeric_precision})"))
67                } else {
68                    Cow::from(data_type)
69                };
70
71                Ok(super::types::create_field(name, &db_type_name, is_nullable))
72            })
73            .try_collect()?;
74
75        Ok(Arc::new(Schema::new(fields)))
76    }
77}
78
79impl<S: AsyncRead + AsyncWrite + Unpin + Send> SchemaEdit for super::TiberiusConnection<S> {
80    fn table_create(&mut self, name: &str, schema: SchemaRef) -> Result<(), TableCreateError> {
81        let column_defs = schema
82            .fields()
83            .iter()
84            .map(|field| {
85                let ty_name = Self::type_arrow_into_db(field.data_type()).unwrap_or_else(|| {
86                    unimplemented!("cannot store type {} in MS SQL Server", field.data_type());
87                });
88
89                let is_nullable =
90                    field.is_nullable() || matches!(field.data_type(), DataType::Null);
91                let not_null = if is_nullable { "" } else { " NOT NULL" };
92
93                let name = escaped_ident(field.name());
94                format!("{name} {ty_name}{not_null}",)
95            })
96            .join(",");
97
98        let ddl = format!("CREATE TABLE {} ({column_defs});", escaped_ident(name));
99
100        let res = self.client.execute(&ddl, &[]);
101        let res = self.rt.block_on(res);
102
103        match res {
104            Ok(_) => Ok(()),
105            Err(tiberius::error::Error::Server(e)) if e.code() == 2714 => {
106                Err(TableCreateError::TableExists)
107            }
108            Err(e) => Err(TableCreateError::Connector(e.into())),
109        }
110    }
111
112    fn table_drop(&mut self, name: &str) -> Result<(), TableDropError> {
113        let ddl = format!("DROP TABLE {}", escaped_ident(name));
114        let res = self.client.execute(&ddl, &[]);
115        let res = self.rt.block_on(res);
116
117        match res {
118            Ok(_) => Ok(()),
119            Err(tiberius::error::Error::Server(e)) if e.code() == 3701 => {
120                Err(TableDropError::TableNonexistent)
121            }
122            Err(e) => Err(TableDropError::Connector(e.into())),
123        }
124    }
125}