clickhouse_rs_async/types/
mod.rs

1use std::{borrow::Cow, collections::HashMap, fmt, mem, pin::Pin, str::FromStr, sync::Mutex};
2
3use chrono::prelude::*;
4use chrono_tz::Tz;
5use hostname::get;
6
7use lazy_static::lazy_static;
8
9use crate::{errors::ServerError, types::column::datetime64::DEFAULT_TZ};
10
11pub use self::{
12    block::{Block, RCons, RNil, Row, RowBuilder, Rows},
13    column::{Column, ColumnType, Complex, Simple},
14    decimal::Decimal,
15    enums::{Enum16, Enum8},
16    from_sql::{FromSql, FromSqlResult},
17    options::Options,
18    options::{SettingType, SettingValue},
19    query::Query,
20    query_result::QueryResult,
21    value::Value,
22    value_ref::ValueRef,
23};
24
25pub(crate) use self::{
26    cmd::Cmd,
27    date_converter::DateConverter,
28    marshal::Marshal,
29    options::{IntoOptions, OptionsSource},
30    stat_buffer::StatBuffer,
31    unmarshal::Unmarshal,
32};
33
34pub mod column;
35mod marshal;
36mod stat_buffer;
37mod unmarshal;
38
39mod from_sql;
40mod value;
41mod value_ref;
42
43pub(crate) mod block;
44mod cmd;
45
46mod date_converter;
47mod query;
48pub(crate) mod query_result;
49
50mod decimal;
51mod enums;
52mod options;
53
54#[derive(Copy, Clone, Debug, Default, PartialEq)]
55pub(crate) struct Progress {
56    pub rows: u64,
57    pub bytes: u64,
58    pub total_rows: u64,
59    pub written_rows: u64,
60    pub written_bytes: u64,
61}
62
63#[derive(Copy, Clone, Default, Debug, PartialEq)]
64pub(crate) struct ProfileInfo {
65    pub rows: u64,
66    pub bytes: u64,
67    pub blocks: u64,
68    pub applied_limit: bool,
69    pub rows_before_limit: u64,
70    pub calculated_rows_before_limit: bool,
71}
72
73#[derive(Clone, Default, Debug, PartialEq)]
74pub(crate) struct TableColumns {
75    pub table_name: String,
76    pub columns: String,
77}
78
79#[derive(Clone, PartialEq)]
80pub(crate) struct ServerInfo {
81    pub name: String,
82    pub revision: u64,
83    pub minor_version: u64,
84    pub major_version: u64,
85    pub timezone: Tz,
86    pub display_name: String,
87    pub patch_version: u64,
88}
89
90impl fmt::Debug for ServerInfo {
91    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
92        write!(
93            f,
94            "{} {}.{}.{}.{} ({:?})",
95            self.name,
96            self.major_version,
97            self.minor_version,
98            self.revision,
99            self.patch_version,
100            self.timezone
101        )
102    }
103}
104
105#[derive(Clone)]
106pub(crate) struct Context {
107    pub(crate) server_info: ServerInfo,
108    pub(crate) hostname: String,
109    pub(crate) options: OptionsSource,
110}
111
112impl Default for ServerInfo {
113    fn default() -> Self {
114        Self {
115            name: String::new(),
116            revision: 0,
117            minor_version: 0,
118            major_version: 0,
119            timezone: *DEFAULT_TZ,
120            display_name: "".into(),
121            patch_version: 0,
122        }
123    }
124}
125
126impl fmt::Debug for Context {
127    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
128        f.debug_struct("Context")
129            .field("options", &self.options)
130            .field("hostname", &self.hostname)
131            .finish()
132    }
133}
134
135impl Default for Context {
136    fn default() -> Self {
137        Self {
138            server_info: ServerInfo::default(),
139            hostname: get().unwrap().into_string().unwrap(),
140            options: OptionsSource::default(),
141        }
142    }
143}
144
145#[derive(Clone)]
146pub(crate) enum Packet<S> {
147    Hello(S, ServerInfo),
148    Pong(S),
149    Progress(Progress),
150    ProfileInfo(ProfileInfo),
151    TableColumns(TableColumns),
152    Exception(ServerError),
153    Block(Block),
154    Eof(S),
155}
156
157impl<S> fmt::Debug for Packet<S> {
158    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
159        match self {
160            Packet::Hello(_, info) => write!(f, "Hello({info:?})"),
161            Packet::Pong(_) => write!(f, "Pong"),
162            Packet::Progress(p) => write!(f, "Progress({p:?})"),
163            Packet::ProfileInfo(info) => write!(f, "ProfileInfo({info:?})"),
164            Packet::TableColumns(info) => write!(f, "TableColumns({info:?})"),
165            Packet::Exception(e) => write!(f, "Exception({e:?})"),
166            Packet::Block(b) => write!(f, "Block({b:?})"),
167            Packet::Eof(_) => write!(f, "Eof"),
168        }
169    }
170}
171
172impl<S> Packet<S> {
173    pub fn bind<N>(self, transport: &mut Option<N>) -> Packet<N> {
174        match self {
175            Packet::Hello(_, server_info) => Packet::Hello(transport.take().unwrap(), server_info),
176            Packet::Pong(_) => Packet::Pong(transport.take().unwrap()),
177            Packet::Progress(progress) => Packet::Progress(progress),
178            Packet::ProfileInfo(profile_info) => Packet::ProfileInfo(profile_info),
179            Packet::TableColumns(table_columns) => Packet::TableColumns(table_columns),
180            Packet::Exception(exception) => Packet::Exception(exception),
181            Packet::Block(block) => Packet::Block(block),
182            Packet::Eof(_) => Packet::Eof(transport.take().unwrap()),
183        }
184    }
185}
186
187pub trait HasSqlType {
188    fn get_sql_type() -> SqlType;
189}
190
191macro_rules! has_sql_type {
192    ( $( $t:ty : $k:expr ),* ) => {
193        $(
194            impl HasSqlType for $t {
195                fn get_sql_type() -> SqlType {
196                    $k
197                }
198            }
199        )*
200    };
201}
202
203has_sql_type! {
204    bool: SqlType::Bool,
205    u8: SqlType::UInt8,
206    u16: SqlType::UInt16,
207    u32: SqlType::UInt32,
208    u64: SqlType::UInt64,
209    u128: SqlType::UInt128,
210    i8: SqlType::Int8,
211    i16: SqlType::Int16,
212    i32: SqlType::Int32,
213    i64: SqlType::Int64,
214    i128: SqlType::Int128,
215    &str: SqlType::String,
216    String: SqlType::String,
217    f32: SqlType::Float32,
218    f64: SqlType::Float64,
219    NaiveDate: SqlType::Date,
220    DateTime<Tz>: SqlType::DateTime(DateTimeType::DateTime32)
221}
222
223impl<K, V> HasSqlType for HashMap<K, V>
224where
225    K: HasSqlType,
226    V: HasSqlType,
227{
228    fn get_sql_type() -> SqlType {
229        SqlType::Map(K::get_sql_type().into(), V::get_sql_type().into())
230    }
231}
232
233#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
234pub enum DateTimeType {
235    DateTime32,
236    DateTime64(u32, Tz),
237    Chrono,
238}
239
240#[derive(Debug, Copy, Clone, PartialOrd, Eq, PartialEq, Hash)]
241pub enum SimpleAggFunc {
242    Any,
243    AnyLast,
244    Min,
245    Max,
246    Sum,
247    SumWithOverflow,
248    GroupBitAnd,
249    GroupBitOr,
250    GroupBitXor,
251    GroupArrayArray,
252    GroupUniqArrayArray,
253    SumMap,
254    MinMap,
255    MaxMap,
256    ArgMin,
257    ArgMax,
258}
259
260impl From<SimpleAggFunc> for &str {
261    fn from(source: SimpleAggFunc) -> &'static str {
262        match source {
263            SimpleAggFunc::Any => "any",
264            SimpleAggFunc::AnyLast => "anyLast",
265            SimpleAggFunc::Min => "min",
266            SimpleAggFunc::Max => "max",
267            SimpleAggFunc::Sum => "sum",
268            SimpleAggFunc::SumWithOverflow => "sumWithOverflow",
269            SimpleAggFunc::GroupBitAnd => "groupBitAnd",
270            SimpleAggFunc::GroupBitOr => "groupBitOr",
271            SimpleAggFunc::GroupBitXor => "groupBitXor",
272            SimpleAggFunc::GroupArrayArray => "groupArrayArray",
273            SimpleAggFunc::GroupUniqArrayArray => "groupUniqArrayArray",
274            SimpleAggFunc::SumMap => "sumMap",
275            SimpleAggFunc::MinMap => "minMap",
276            SimpleAggFunc::MaxMap => "maxMap",
277            SimpleAggFunc::ArgMin => "argMin",
278            SimpleAggFunc::ArgMax => "argMax",
279        }
280    }
281}
282
283impl FromStr for SimpleAggFunc {
284    type Err = ();
285
286    fn from_str(s: &str) -> Result<Self, Self::Err> {
287        match s {
288            "any" => Ok(SimpleAggFunc::Any),
289            "anyLast" => Ok(SimpleAggFunc::AnyLast),
290            "min" => Ok(SimpleAggFunc::Min),
291            "max" => Ok(SimpleAggFunc::Max),
292            "sum" => Ok(SimpleAggFunc::Sum),
293            "sumWithOverflow" => Ok(SimpleAggFunc::SumWithOverflow),
294            "groupBitAnd" => Ok(SimpleAggFunc::GroupBitAnd),
295            "groupBitOr" => Ok(SimpleAggFunc::GroupBitOr),
296            "groupBitXor" => Ok(SimpleAggFunc::GroupBitXor),
297            "groupArrayArray" => Ok(SimpleAggFunc::GroupArrayArray),
298            "groupUniqArrayArray" => Ok(SimpleAggFunc::GroupUniqArrayArray),
299            "sumMap" => Ok(SimpleAggFunc::SumMap),
300            "minMap" => Ok(SimpleAggFunc::MinMap),
301            "maxMap" => Ok(SimpleAggFunc::MaxMap),
302            "argMin" => Ok(SimpleAggFunc::ArgMin),
303            "argMax" => Ok(SimpleAggFunc::ArgMax),
304            _ => Err(()),
305        }
306    }
307}
308
309#[derive(Clone, Debug, Eq, PartialEq, Hash)]
310pub enum SqlType {
311    Bool,
312    UInt8,
313    UInt16,
314    UInt32,
315    UInt64,
316    UInt128,
317    Int8,
318    Int16,
319    Int32,
320    Int64,
321    Int128,
322    String,
323    FixedString(usize),
324    Float32,
325    Float64,
326    Date,
327    DateTime(DateTimeType),
328    Ipv4,
329    Ipv6,
330    Uuid,
331    Nullable(&'static SqlType),
332    Array(&'static SqlType),
333    LowCardinality(&'static SqlType),
334    Decimal(u8, u8),
335    Enum8(Vec<(String, i8)>),
336    Enum16(Vec<(String, i16)>),
337    SimpleAggregateFunction(SimpleAggFunc, &'static SqlType),
338    Map(&'static SqlType, &'static SqlType),
339}
340
341lazy_static! {
342    static ref TYPES_CACHE: Mutex<HashMap<SqlType, Pin<Box<SqlType>>>> = Mutex::new(HashMap::new());
343}
344
345impl From<SqlType> for &'static SqlType {
346    fn from(value: SqlType) -> Self {
347        match value {
348            SqlType::UInt8 => &SqlType::UInt8,
349            SqlType::UInt16 => &SqlType::UInt16,
350            SqlType::UInt32 => &SqlType::UInt32,
351            SqlType::UInt64 => &SqlType::UInt64,
352            SqlType::Int8 => &SqlType::Int8,
353            SqlType::Int16 => &SqlType::Int16,
354            SqlType::Int32 => &SqlType::Int32,
355            SqlType::Int64 => &SqlType::Int64,
356            SqlType::String => &SqlType::String,
357            SqlType::Float32 => &SqlType::Float32,
358            SqlType::Float64 => &SqlType::Float64,
359            SqlType::Date => &SqlType::Date,
360            _ => {
361                let mut guard = TYPES_CACHE.lock().unwrap();
362                loop {
363                    if let Some(value_ref) = guard.get(&value.clone()) {
364                        return unsafe { mem::transmute(value_ref.as_ref()) };
365                    }
366                    guard.insert(value.clone(), Box::pin(value.clone()));
367                }
368            }
369        }
370    }
371}
372
373impl SqlType {
374    pub(crate) fn is_datetime(&self) -> bool {
375        matches!(self, SqlType::DateTime(_))
376    }
377
378    pub(crate) fn is_inner_low_cardinality(&self) -> bool {
379        matches!(
380            self,
381            SqlType::String
382                | SqlType::FixedString(_)
383                | SqlType::Date
384                | SqlType::DateTime(_)
385                | SqlType::UInt8
386                | SqlType::UInt16
387                | SqlType::UInt32
388                | SqlType::UInt64
389                | SqlType::Int8
390                | SqlType::Int16
391                | SqlType::Int32
392                | SqlType::Int64
393        )
394    }
395
396    pub fn to_string(&self) -> Cow<'static, str> {
397        match self.clone() {
398            SqlType::Bool => "Bool".into(),
399            SqlType::UInt8 => "UInt8".into(),
400            SqlType::UInt16 => "UInt16".into(),
401            SqlType::UInt32 => "UInt32".into(),
402            SqlType::UInt64 => "UInt64".into(),
403            SqlType::UInt128 => "UInt128".into(),
404            SqlType::Int8 => "Int8".into(),
405            SqlType::Int16 => "Int16".into(),
406            SqlType::Int32 => "Int32".into(),
407            SqlType::Int64 => "Int64".into(),
408            SqlType::Int128 => "Int128".into(),
409            SqlType::String => "String".into(),
410            SqlType::FixedString(str_len) => format!("FixedString({str_len})").into(),
411            SqlType::LowCardinality(inner) => format!("LowCardinality({})", &inner).into(),
412            SqlType::Float32 => "Float32".into(),
413            SqlType::Float64 => "Float64".into(),
414            SqlType::Date => "Date".into(),
415            SqlType::DateTime(DateTimeType::DateTime64(precision, tz)) => {
416                format!("DateTime64({precision}, '{tz:?}')").into()
417            }
418            SqlType::DateTime(_) => "DateTime".into(),
419            SqlType::Ipv4 => "IPv4".into(),
420            SqlType::Ipv6 => "IPv6".into(),
421            SqlType::Uuid => "UUID".into(),
422            SqlType::Nullable(nested) => format!("Nullable({})", &nested).into(),
423            SqlType::SimpleAggregateFunction(func, nested) => {
424                let func_str: &str = func.into();
425                format!("SimpleAggregateFunction({}, {})", func_str, &nested).into()
426            }
427            SqlType::Array(nested) => format!("Array({})", &nested).into(),
428            SqlType::Decimal(precision, scale) => format!("Decimal({precision}, {scale})").into(),
429            SqlType::Enum8(values) => {
430                let a: Vec<String> = values
431                    .iter()
432                    .map(|(name, value)| format!("'{name}' = {value}"))
433                    .collect();
434                format!("Enum8({})", a.join(",")).into()
435            }
436            SqlType::Enum16(values) => {
437                let a: Vec<String> = values
438                    .iter()
439                    .map(|(name, value)| format!("'{name}' = {value}"))
440                    .collect();
441                format!("Enum16({})", a.join(",")).into()
442            }
443            SqlType::Map(k, v) => format!("Map({}, {})", &k, &v).into(),
444        }
445    }
446
447    pub(crate) fn level(&self) -> u8 {
448        match self {
449            SqlType::Nullable(inner) => 1 + inner.level(),
450            SqlType::Array(inner) => 1 + inner.level(),
451            SqlType::Map(_, value) => 1 + value.level(),
452            SqlType::LowCardinality(_) => 1,
453            _ => 0,
454        }
455    }
456
457    pub(crate) fn map_level(&self) -> u8 {
458        match self {
459            SqlType::Nullable(inner) => inner.level(),
460            SqlType::Array(inner) => inner.level(),
461            SqlType::Map(_, value) => 1 + value.level(),
462            _ => 0,
463        }
464    }
465}
466
467impl fmt::Display for SqlType {
468    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
469        write!(f, "{}", Self::to_string(self))
470    }
471}
472
473#[test]
474fn test_display() {
475    let expected = "UInt8".to_string();
476    let actual = format!("{}", SqlType::UInt8);
477    assert_eq!(expected, actual);
478}
479
480#[test]
481fn test_to_string() {
482    let expected: Cow<'static, str> = "Nullable(UInt8)".into();
483    let actual = SqlType::Nullable(&SqlType::UInt8).to_string();
484    assert_eq!(expected, actual)
485}