1use std::{borrow::Cow, collections::HashMap, fmt, mem, pin::Pin, str::FromStr, sync::Mutex};
2
3use chrono::prelude::*;
4use chrono_tz::Tz;
5use hostname::get;
6
7use lazy_static::lazy_static;
8
9use crate::{errors::ServerError, types::column::datetime64::DEFAULT_TZ};
10
11pub use self::{
12 block::{Block, RCons, RNil, Row, RowBuilder, Rows},
13 column::{Column, ColumnType, Complex, Simple},
14 decimal::Decimal,
15 enums::{Enum16, Enum8},
16 from_sql::{FromSql, FromSqlResult},
17 options::Options,
18 options::{SettingType, SettingValue},
19 query::Query,
20 query_result::QueryResult,
21 value::Value,
22 value_ref::ValueRef,
23};
24
25pub(crate) use self::{
26 cmd::Cmd,
27 date_converter::DateConverter,
28 marshal::Marshal,
29 options::{IntoOptions, OptionsSource},
30 stat_buffer::StatBuffer,
31 unmarshal::Unmarshal,
32};
33
34pub mod column;
35mod marshal;
36mod stat_buffer;
37mod unmarshal;
38
39mod from_sql;
40mod value;
41mod value_ref;
42
43pub(crate) mod block;
44mod cmd;
45
46mod date_converter;
47mod query;
48pub(crate) mod query_result;
49
50mod decimal;
51mod enums;
52mod options;
53
54#[derive(Copy, Clone, Debug, Default, PartialEq)]
55pub(crate) struct Progress {
56 pub rows: u64,
57 pub bytes: u64,
58 pub total_rows: u64,
59 pub written_rows: u64,
60 pub written_bytes: u64,
61}
62
63#[derive(Copy, Clone, Default, Debug, PartialEq)]
64pub(crate) struct ProfileInfo {
65 pub rows: u64,
66 pub bytes: u64,
67 pub blocks: u64,
68 pub applied_limit: bool,
69 pub rows_before_limit: u64,
70 pub calculated_rows_before_limit: bool,
71}
72
73#[derive(Clone, Default, Debug, PartialEq)]
74pub(crate) struct TableColumns {
75 pub table_name: String,
76 pub columns: String,
77}
78
79#[derive(Clone, PartialEq)]
80pub(crate) struct ServerInfo {
81 pub name: String,
82 pub revision: u64,
83 pub minor_version: u64,
84 pub major_version: u64,
85 pub timezone: Tz,
86 pub display_name: String,
87 pub patch_version: u64,
88}
89
90impl fmt::Debug for ServerInfo {
91 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
92 write!(
93 f,
94 "{} {}.{}.{}.{} ({:?})",
95 self.name,
96 self.major_version,
97 self.minor_version,
98 self.revision,
99 self.patch_version,
100 self.timezone
101 )
102 }
103}
104
105#[derive(Clone)]
106pub(crate) struct Context {
107 pub(crate) server_info: ServerInfo,
108 pub(crate) hostname: String,
109 pub(crate) options: OptionsSource,
110}
111
112impl Default for ServerInfo {
113 fn default() -> Self {
114 Self {
115 name: String::new(),
116 revision: 0,
117 minor_version: 0,
118 major_version: 0,
119 timezone: *DEFAULT_TZ,
120 display_name: "".into(),
121 patch_version: 0,
122 }
123 }
124}
125
126impl fmt::Debug for Context {
127 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
128 f.debug_struct("Context")
129 .field("options", &self.options)
130 .field("hostname", &self.hostname)
131 .finish()
132 }
133}
134
135impl Default for Context {
136 fn default() -> Self {
137 Self {
138 server_info: ServerInfo::default(),
139 hostname: get().unwrap().into_string().unwrap(),
140 options: OptionsSource::default(),
141 }
142 }
143}
144
145#[derive(Clone)]
146pub(crate) enum Packet<S> {
147 Hello(S, ServerInfo),
148 Pong(S),
149 Progress(Progress),
150 ProfileInfo(ProfileInfo),
151 TableColumns(TableColumns),
152 Exception(ServerError),
153 Block(Block),
154 Eof(S),
155}
156
157impl<S> fmt::Debug for Packet<S> {
158 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
159 match self {
160 Packet::Hello(_, info) => write!(f, "Hello({info:?})"),
161 Packet::Pong(_) => write!(f, "Pong"),
162 Packet::Progress(p) => write!(f, "Progress({p:?})"),
163 Packet::ProfileInfo(info) => write!(f, "ProfileInfo({info:?})"),
164 Packet::TableColumns(info) => write!(f, "TableColumns({info:?})"),
165 Packet::Exception(e) => write!(f, "Exception({e:?})"),
166 Packet::Block(b) => write!(f, "Block({b:?})"),
167 Packet::Eof(_) => write!(f, "Eof"),
168 }
169 }
170}
171
172impl<S> Packet<S> {
173 pub fn bind<N>(self, transport: &mut Option<N>) -> Packet<N> {
174 match self {
175 Packet::Hello(_, server_info) => Packet::Hello(transport.take().unwrap(), server_info),
176 Packet::Pong(_) => Packet::Pong(transport.take().unwrap()),
177 Packet::Progress(progress) => Packet::Progress(progress),
178 Packet::ProfileInfo(profile_info) => Packet::ProfileInfo(profile_info),
179 Packet::TableColumns(table_columns) => Packet::TableColumns(table_columns),
180 Packet::Exception(exception) => Packet::Exception(exception),
181 Packet::Block(block) => Packet::Block(block),
182 Packet::Eof(_) => Packet::Eof(transport.take().unwrap()),
183 }
184 }
185}
186
187pub trait HasSqlType {
188 fn get_sql_type() -> SqlType;
189}
190
191macro_rules! has_sql_type {
192 ( $( $t:ty : $k:expr ),* ) => {
193 $(
194 impl HasSqlType for $t {
195 fn get_sql_type() -> SqlType {
196 $k
197 }
198 }
199 )*
200 };
201}
202
203has_sql_type! {
204 bool: SqlType::Bool,
205 u8: SqlType::UInt8,
206 u16: SqlType::UInt16,
207 u32: SqlType::UInt32,
208 u64: SqlType::UInt64,
209 u128: SqlType::UInt128,
210 i8: SqlType::Int8,
211 i16: SqlType::Int16,
212 i32: SqlType::Int32,
213 i64: SqlType::Int64,
214 i128: SqlType::Int128,
215 &str: SqlType::String,
216 String: SqlType::String,
217 f32: SqlType::Float32,
218 f64: SqlType::Float64,
219 NaiveDate: SqlType::Date,
220 DateTime<Tz>: SqlType::DateTime(DateTimeType::DateTime32)
221}
222
223impl<K, V> HasSqlType for HashMap<K, V>
224where
225 K: HasSqlType,
226 V: HasSqlType,
227{
228 fn get_sql_type() -> SqlType {
229 SqlType::Map(K::get_sql_type().into(), V::get_sql_type().into())
230 }
231}
232
233#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
234pub enum DateTimeType {
235 DateTime32,
236 DateTime64(u32, Tz),
237 Chrono,
238}
239
240#[derive(Debug, Copy, Clone, PartialOrd, Eq, PartialEq, Hash)]
241pub enum SimpleAggFunc {
242 Any,
243 AnyLast,
244 Min,
245 Max,
246 Sum,
247 SumWithOverflow,
248 GroupBitAnd,
249 GroupBitOr,
250 GroupBitXor,
251 GroupArrayArray,
252 GroupUniqArrayArray,
253 SumMap,
254 MinMap,
255 MaxMap,
256 ArgMin,
257 ArgMax,
258}
259
260impl From<SimpleAggFunc> for &str {
261 fn from(source: SimpleAggFunc) -> &'static str {
262 match source {
263 SimpleAggFunc::Any => "any",
264 SimpleAggFunc::AnyLast => "anyLast",
265 SimpleAggFunc::Min => "min",
266 SimpleAggFunc::Max => "max",
267 SimpleAggFunc::Sum => "sum",
268 SimpleAggFunc::SumWithOverflow => "sumWithOverflow",
269 SimpleAggFunc::GroupBitAnd => "groupBitAnd",
270 SimpleAggFunc::GroupBitOr => "groupBitOr",
271 SimpleAggFunc::GroupBitXor => "groupBitXor",
272 SimpleAggFunc::GroupArrayArray => "groupArrayArray",
273 SimpleAggFunc::GroupUniqArrayArray => "groupUniqArrayArray",
274 SimpleAggFunc::SumMap => "sumMap",
275 SimpleAggFunc::MinMap => "minMap",
276 SimpleAggFunc::MaxMap => "maxMap",
277 SimpleAggFunc::ArgMin => "argMin",
278 SimpleAggFunc::ArgMax => "argMax",
279 }
280 }
281}
282
283impl FromStr for SimpleAggFunc {
284 type Err = ();
285
286 fn from_str(s: &str) -> Result<Self, Self::Err> {
287 match s {
288 "any" => Ok(SimpleAggFunc::Any),
289 "anyLast" => Ok(SimpleAggFunc::AnyLast),
290 "min" => Ok(SimpleAggFunc::Min),
291 "max" => Ok(SimpleAggFunc::Max),
292 "sum" => Ok(SimpleAggFunc::Sum),
293 "sumWithOverflow" => Ok(SimpleAggFunc::SumWithOverflow),
294 "groupBitAnd" => Ok(SimpleAggFunc::GroupBitAnd),
295 "groupBitOr" => Ok(SimpleAggFunc::GroupBitOr),
296 "groupBitXor" => Ok(SimpleAggFunc::GroupBitXor),
297 "groupArrayArray" => Ok(SimpleAggFunc::GroupArrayArray),
298 "groupUniqArrayArray" => Ok(SimpleAggFunc::GroupUniqArrayArray),
299 "sumMap" => Ok(SimpleAggFunc::SumMap),
300 "minMap" => Ok(SimpleAggFunc::MinMap),
301 "maxMap" => Ok(SimpleAggFunc::MaxMap),
302 "argMin" => Ok(SimpleAggFunc::ArgMin),
303 "argMax" => Ok(SimpleAggFunc::ArgMax),
304 _ => Err(()),
305 }
306 }
307}
308
309#[derive(Clone, Debug, Eq, PartialEq, Hash)]
310pub enum SqlType {
311 Bool,
312 UInt8,
313 UInt16,
314 UInt32,
315 UInt64,
316 UInt128,
317 Int8,
318 Int16,
319 Int32,
320 Int64,
321 Int128,
322 String,
323 FixedString(usize),
324 Float32,
325 Float64,
326 Date,
327 DateTime(DateTimeType),
328 Ipv4,
329 Ipv6,
330 Uuid,
331 Nullable(&'static SqlType),
332 Array(&'static SqlType),
333 LowCardinality(&'static SqlType),
334 Decimal(u8, u8),
335 Enum8(Vec<(String, i8)>),
336 Enum16(Vec<(String, i16)>),
337 SimpleAggregateFunction(SimpleAggFunc, &'static SqlType),
338 Map(&'static SqlType, &'static SqlType),
339}
340
341lazy_static! {
342 static ref TYPES_CACHE: Mutex<HashMap<SqlType, Pin<Box<SqlType>>>> = Mutex::new(HashMap::new());
343}
344
345impl From<SqlType> for &'static SqlType {
346 fn from(value: SqlType) -> Self {
347 match value {
348 SqlType::UInt8 => &SqlType::UInt8,
349 SqlType::UInt16 => &SqlType::UInt16,
350 SqlType::UInt32 => &SqlType::UInt32,
351 SqlType::UInt64 => &SqlType::UInt64,
352 SqlType::Int8 => &SqlType::Int8,
353 SqlType::Int16 => &SqlType::Int16,
354 SqlType::Int32 => &SqlType::Int32,
355 SqlType::Int64 => &SqlType::Int64,
356 SqlType::String => &SqlType::String,
357 SqlType::Float32 => &SqlType::Float32,
358 SqlType::Float64 => &SqlType::Float64,
359 SqlType::Date => &SqlType::Date,
360 _ => {
361 let mut guard = TYPES_CACHE.lock().unwrap();
362 loop {
363 if let Some(value_ref) = guard.get(&value.clone()) {
364 return unsafe { mem::transmute(value_ref.as_ref()) };
365 }
366 guard.insert(value.clone(), Box::pin(value.clone()));
367 }
368 }
369 }
370 }
371}
372
373impl SqlType {
374 pub(crate) fn is_datetime(&self) -> bool {
375 matches!(self, SqlType::DateTime(_))
376 }
377
378 pub(crate) fn is_inner_low_cardinality(&self) -> bool {
379 matches!(
380 self,
381 SqlType::String
382 | SqlType::FixedString(_)
383 | SqlType::Date
384 | SqlType::DateTime(_)
385 | SqlType::UInt8
386 | SqlType::UInt16
387 | SqlType::UInt32
388 | SqlType::UInt64
389 | SqlType::Int8
390 | SqlType::Int16
391 | SqlType::Int32
392 | SqlType::Int64
393 )
394 }
395
396 pub fn to_string(&self) -> Cow<'static, str> {
397 match self.clone() {
398 SqlType::Bool => "Bool".into(),
399 SqlType::UInt8 => "UInt8".into(),
400 SqlType::UInt16 => "UInt16".into(),
401 SqlType::UInt32 => "UInt32".into(),
402 SqlType::UInt64 => "UInt64".into(),
403 SqlType::UInt128 => "UInt128".into(),
404 SqlType::Int8 => "Int8".into(),
405 SqlType::Int16 => "Int16".into(),
406 SqlType::Int32 => "Int32".into(),
407 SqlType::Int64 => "Int64".into(),
408 SqlType::Int128 => "Int128".into(),
409 SqlType::String => "String".into(),
410 SqlType::FixedString(str_len) => format!("FixedString({str_len})").into(),
411 SqlType::LowCardinality(inner) => format!("LowCardinality({})", &inner).into(),
412 SqlType::Float32 => "Float32".into(),
413 SqlType::Float64 => "Float64".into(),
414 SqlType::Date => "Date".into(),
415 SqlType::DateTime(DateTimeType::DateTime64(precision, tz)) => {
416 format!("DateTime64({precision}, '{tz:?}')").into()
417 }
418 SqlType::DateTime(_) => "DateTime".into(),
419 SqlType::Ipv4 => "IPv4".into(),
420 SqlType::Ipv6 => "IPv6".into(),
421 SqlType::Uuid => "UUID".into(),
422 SqlType::Nullable(nested) => format!("Nullable({})", &nested).into(),
423 SqlType::SimpleAggregateFunction(func, nested) => {
424 let func_str: &str = func.into();
425 format!("SimpleAggregateFunction({}, {})", func_str, &nested).into()
426 }
427 SqlType::Array(nested) => format!("Array({})", &nested).into(),
428 SqlType::Decimal(precision, scale) => format!("Decimal({precision}, {scale})").into(),
429 SqlType::Enum8(values) => {
430 let a: Vec<String> = values
431 .iter()
432 .map(|(name, value)| format!("'{name}' = {value}"))
433 .collect();
434 format!("Enum8({})", a.join(",")).into()
435 }
436 SqlType::Enum16(values) => {
437 let a: Vec<String> = values
438 .iter()
439 .map(|(name, value)| format!("'{name}' = {value}"))
440 .collect();
441 format!("Enum16({})", a.join(",")).into()
442 }
443 SqlType::Map(k, v) => format!("Map({}, {})", &k, &v).into(),
444 }
445 }
446
447 pub(crate) fn level(&self) -> u8 {
448 match self {
449 SqlType::Nullable(inner) => 1 + inner.level(),
450 SqlType::Array(inner) => 1 + inner.level(),
451 SqlType::Map(_, value) => 1 + value.level(),
452 SqlType::LowCardinality(_) => 1,
453 _ => 0,
454 }
455 }
456
457 pub(crate) fn map_level(&self) -> u8 {
458 match self {
459 SqlType::Nullable(inner) => inner.level(),
460 SqlType::Array(inner) => inner.level(),
461 SqlType::Map(_, value) => 1 + value.level(),
462 _ => 0,
463 }
464 }
465}
466
467impl fmt::Display for SqlType {
468 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
469 write!(f, "{}", Self::to_string(self))
470 }
471}
472
473#[test]
474fn test_display() {
475 let expected = "UInt8".to_string();
476 let actual = format!("{}", SqlType::UInt8);
477 assert_eq!(expected, actual);
478}
479
480#[test]
481fn test_to_string() {
482 let expected: Cow<'static, str> = "Nullable(UInt8)".into();
483 let actual = SqlType::Nullable(&SqlType::UInt8).to_string();
484 assert_eq!(expected, actual)
485}