1use crate::{
2 block::Block,
3 io::buffer_utils,
4 Error,
5 Result,
6};
7use bytes::{
8 Buf,
9 BufMut,
10 BytesMut,
11};
12use std::{
13 collections::HashMap,
14 sync::Arc,
15};
16
17#[derive(Clone, Debug, Default)]
24pub struct QuerySettingsField {
25 pub value: String,
27 pub flags: u64,
29}
30
31impl QuerySettingsField {
32 pub const IMPORTANT: u64 = 0x01;
34 pub const CUSTOM: u64 = 0x02;
36 pub const OBSOLETE: u64 = 0x04;
38
39 pub fn new(value: impl Into<String>) -> Self {
41 Self { value: value.into(), flags: 0 }
42 }
43
44 pub fn with_flags(value: impl Into<String>, flags: u64) -> Self {
46 Self { value: value.into(), flags }
47 }
48
49 pub fn important(value: impl Into<String>) -> Self {
51 Self::with_flags(value, Self::IMPORTANT)
52 }
53
54 pub fn custom(value: impl Into<String>) -> Self {
56 Self::with_flags(value, Self::CUSTOM)
57 }
58
59 pub fn is_important(&self) -> bool {
61 (self.flags & Self::IMPORTANT) != 0
62 }
63
64 pub fn is_custom(&self) -> bool {
66 (self.flags & Self::CUSTOM) != 0
67 }
68
69 pub fn is_obsolete(&self) -> bool {
71 (self.flags & Self::OBSOLETE) != 0
72 }
73}
74
75pub type QuerySettings = HashMap<String, QuerySettingsField>;
77
78#[derive(Clone, Debug, Default)]
81pub struct TracingContext {
82 pub trace_id: u128,
84 pub span_id: u64,
86 pub tracestate: String,
88 pub trace_flags: u8,
90}
91
92impl TracingContext {
93 pub fn new() -> Self {
95 Self::default()
96 }
97
98 pub fn with_ids(trace_id: u128, span_id: u64) -> Self {
100 Self { trace_id, span_id, tracestate: String::new(), trace_flags: 0 }
101 }
102
103 pub fn trace_id(mut self, trace_id: u128) -> Self {
105 self.trace_id = trace_id;
106 self
107 }
108
109 pub fn span_id(mut self, span_id: u64) -> Self {
111 self.span_id = span_id;
112 self
113 }
114
115 pub fn tracestate(mut self, tracestate: impl Into<String>) -> Self {
117 self.tracestate = tracestate.into();
118 self
119 }
120
121 pub fn trace_flags(mut self, flags: u8) -> Self {
123 self.trace_flags = flags;
124 self
125 }
126
127 pub fn is_enabled(&self) -> bool {
129 self.trace_id != 0
130 }
131}
132
133#[derive(Clone)]
135pub struct Query {
136 query_text: String,
138 query_id: String,
140 settings: QuerySettings,
142 parameters: HashMap<String, String>,
144 tracing_context: Option<TracingContext>,
146 on_progress: Option<ProgressCallback>,
148 on_profile: Option<ProfileCallback>,
150 on_profile_events: Option<ProfileEventsCallback>,
152 on_server_log: Option<ServerLogCallback>,
154 on_exception: Option<ExceptionCallback>,
156 on_data: Option<DataCallback>,
158 on_data_cancelable: Option<DataCancelableCallback>,
160}
161
162impl Query {
163 pub fn new(query_text: impl Into<String>) -> Self {
165 Self {
166 query_text: query_text.into(),
167 query_id: String::new(),
168 settings: HashMap::new(),
169 parameters: HashMap::new(),
170 tracing_context: None,
171 on_progress: None,
172 on_profile: None,
173 on_profile_events: None,
174 on_server_log: None,
175 on_exception: None,
176 on_data: None,
177 on_data_cancelable: None,
178 }
179 }
180}
181
182impl From<&str> for Query {
183 fn from(s: &str) -> Self {
184 Query::new(s)
185 }
186}
187
188impl From<String> for Query {
189 fn from(s: String) -> Self {
190 Query::new(s)
191 }
192}
193
194impl Query {
195 pub fn with_query_id(mut self, query_id: impl Into<String>) -> Self {
197 self.query_id = query_id.into();
198 self
199 }
200
201 pub fn with_setting(
203 mut self,
204 key: impl Into<String>,
205 value: impl Into<String>,
206 ) -> Self {
207 self.settings.insert(key.into(), QuerySettingsField::new(value));
208 self
209 }
210
211 pub fn with_setting_flags(
213 mut self,
214 key: impl Into<String>,
215 value: impl Into<String>,
216 flags: u64,
217 ) -> Self {
218 self.settings
219 .insert(key.into(), QuerySettingsField::with_flags(value, flags));
220 self
221 }
222
223 pub fn with_important_setting(
225 mut self,
226 key: impl Into<String>,
227 value: impl Into<String>,
228 ) -> Self {
229 self.settings.insert(key.into(), QuerySettingsField::important(value));
230 self
231 }
232
233 pub fn with_parameter(
235 mut self,
236 key: impl Into<String>,
237 value: impl Into<String>,
238 ) -> Self {
239 self.parameters.insert(key.into(), value.into());
240 self
241 }
242
243 pub fn with_tracing_context(mut self, context: TracingContext) -> Self {
245 self.tracing_context = Some(context);
246 self
247 }
248
249 pub fn text(&self) -> &str {
251 &self.query_text
252 }
253
254 pub fn tracing_context(&self) -> Option<&TracingContext> {
256 self.tracing_context.as_ref()
257 }
258
259 pub fn id(&self) -> &str {
261 &self.query_id
262 }
263
264 pub fn settings(&self) -> &QuerySettings {
266 &self.settings
267 }
268
269 pub fn parameters(&self) -> &HashMap<String, String> {
271 &self.parameters
272 }
273
274 pub fn on_progress<F>(mut self, callback: F) -> Self
276 where
277 F: Fn(&Progress) + Send + Sync + 'static,
278 {
279 self.on_progress = Some(Arc::new(callback));
280 self
281 }
282
283 pub fn on_profile<F>(mut self, callback: F) -> Self
285 where
286 F: Fn(&Profile) + Send + Sync + 'static,
287 {
288 self.on_profile = Some(Arc::new(callback));
289 self
290 }
291
292 pub fn on_profile_events<F>(mut self, callback: F) -> Self
294 where
295 F: Fn(&Block) -> bool + Send + Sync + 'static,
296 {
297 self.on_profile_events = Some(Arc::new(callback));
298 self
299 }
300
301 pub fn on_server_log<F>(mut self, callback: F) -> Self
303 where
304 F: Fn(&Block) -> bool + Send + Sync + 'static,
305 {
306 self.on_server_log = Some(Arc::new(callback));
307 self
308 }
309
310 pub fn on_exception<F>(mut self, callback: F) -> Self
312 where
313 F: Fn(&Exception) + Send + Sync + 'static,
314 {
315 self.on_exception = Some(Arc::new(callback));
316 self
317 }
318
319 pub fn on_data<F>(mut self, callback: F) -> Self
321 where
322 F: Fn(&Block) + Send + Sync + 'static,
323 {
324 self.on_data = Some(Arc::new(callback));
325 self
326 }
327
328 pub fn on_data_cancelable<F>(mut self, callback: F) -> Self
330 where
331 F: Fn(&Block) -> bool + Send + Sync + 'static,
332 {
333 self.on_data_cancelable = Some(Arc::new(callback));
334 self
335 }
336
337 pub(crate) fn get_on_progress(&self) -> Option<&ProgressCallback> {
340 self.on_progress.as_ref()
341 }
342
343 pub(crate) fn get_on_profile(&self) -> Option<&ProfileCallback> {
344 self.on_profile.as_ref()
345 }
346
347 pub(crate) fn get_on_profile_events(
348 &self,
349 ) -> Option<&ProfileEventsCallback> {
350 self.on_profile_events.as_ref()
351 }
352
353 pub(crate) fn get_on_server_log(&self) -> Option<&ServerLogCallback> {
354 self.on_server_log.as_ref()
355 }
356
357 pub(crate) fn get_on_exception(&self) -> Option<&ExceptionCallback> {
358 self.on_exception.as_ref()
359 }
360
361 pub(crate) fn get_on_data(&self) -> Option<&DataCallback> {
362 self.on_data.as_ref()
363 }
364
365 pub(crate) fn get_on_data_cancelable(
366 &self,
367 ) -> Option<&DataCancelableCallback> {
368 self.on_data_cancelable.as_ref()
369 }
370}
371
372#[derive(Clone, Debug)]
374pub struct ClientInfo {
375 pub interface_type: u8,
377 pub query_kind: u8,
379 pub initial_user: String,
381 pub initial_query_id: String,
383 pub quota_key: String,
385 pub os_user: String,
387 pub client_hostname: String,
389 pub client_name: String,
391 pub client_version_major: u64,
393 pub client_version_minor: u64,
395 pub client_version_patch: u64,
397 pub client_revision: u64,
399}
400
401impl Default for ClientInfo {
402 fn default() -> Self {
403 Self {
404 interface_type: 1, query_kind: 0,
406 initial_user: String::new(),
407 initial_query_id: String::new(),
408 quota_key: String::new(),
409 os_user: std::env::var("USER")
410 .unwrap_or_else(|_| "default".to_string()),
411 client_hostname: "localhost".to_string(),
412 client_name: "clickhouse-rust".to_string(),
413 client_version_major: 1,
414 client_version_minor: 0,
415 client_version_patch: 0,
416 client_revision: 54459, }
418 }
419}
420
421impl ClientInfo {
422 pub fn write_to(&self, buffer: &mut BytesMut) -> Result<()> {
424 buffer.put_u8(self.interface_type);
425
426 buffer_utils::write_string(buffer, &self.os_user);
427 buffer_utils::write_string(buffer, &self.client_hostname);
428 buffer_utils::write_string(buffer, &self.client_name);
429
430 buffer_utils::write_varint(buffer, self.client_version_major);
431 buffer_utils::write_varint(buffer, self.client_version_minor);
432 buffer_utils::write_varint(buffer, self.client_revision);
433
434 Ok(())
435 }
436
437 pub fn read_from(buffer: &mut &[u8]) -> Result<Self> {
439 if buffer.is_empty() {
440 return Err(Error::Protocol(
441 "Not enough data to read ClientInfo".to_string(),
442 ));
443 }
444
445 let interface_type = buffer[0];
446 buffer.advance(1);
447
448 let os_user = buffer_utils::read_string(buffer)?;
449 let client_hostname = buffer_utils::read_string(buffer)?;
450 let client_name = buffer_utils::read_string(buffer)?;
451
452 let client_version_major = buffer_utils::read_varint(buffer)?;
453 let client_version_minor = buffer_utils::read_varint(buffer)?;
454 let client_revision = buffer_utils::read_varint(buffer)?;
455
456 Ok(Self {
457 interface_type,
458 query_kind: 0,
459 initial_user: String::new(),
460 initial_query_id: String::new(),
461 quota_key: String::new(),
462 os_user,
463 client_hostname,
464 client_name,
465 client_version_major,
466 client_version_minor,
467 client_version_patch: 0,
468 client_revision,
469 })
470 }
471}
472
473#[derive(Clone, Debug, Default)]
475pub struct ServerInfo {
476 pub name: String,
478 pub version_major: u64,
480 pub version_minor: u64,
482 pub version_patch: u64,
484 pub revision: u64,
486 pub timezone: String,
488 pub display_name: String,
490}
491
492impl ServerInfo {
493 pub fn write_to(&self, buffer: &mut BytesMut) -> Result<()> {
495 buffer_utils::write_string(buffer, &self.name);
496 buffer_utils::write_varint(buffer, self.version_major);
497 buffer_utils::write_varint(buffer, self.version_minor);
498 buffer_utils::write_varint(buffer, self.revision);
499
500 if self.revision >= 54058 {
501 buffer_utils::write_string(buffer, &self.timezone);
502 }
503
504 if self.revision >= 54372 {
505 buffer_utils::write_string(buffer, &self.display_name);
506 }
507
508 if self.revision >= 54401 {
509 buffer_utils::write_varint(buffer, self.version_patch);
510 }
511
512 Ok(())
513 }
514
515 pub fn read_from(buffer: &mut &[u8]) -> Result<Self> {
517 let name = buffer_utils::read_string(buffer)?;
518 let version_major = buffer_utils::read_varint(buffer)?;
519 let version_minor = buffer_utils::read_varint(buffer)?;
520 let revision = buffer_utils::read_varint(buffer)?;
521
522 let timezone = if revision >= 54058 {
523 buffer_utils::read_string(buffer)?
524 } else {
525 String::new()
526 };
527
528 let display_name = if revision >= 54372 {
529 buffer_utils::read_string(buffer)?
530 } else {
531 String::new()
532 };
533
534 let version_patch = if revision >= 54401 {
535 buffer_utils::read_varint(buffer)?
536 } else {
537 0
538 };
539
540 Ok(Self {
541 name,
542 version_major,
543 version_minor,
544 version_patch,
545 revision,
546 timezone,
547 display_name,
548 })
549 }
550}
551
552#[derive(Clone, Debug, Default)]
554pub struct Progress {
555 pub rows: u64,
557 pub bytes: u64,
559 pub total_rows: u64,
561 pub written_rows: u64,
563 pub written_bytes: u64,
565}
566
567#[derive(Clone, Debug, Default)]
569pub struct Profile {
570 pub rows: u64,
572 pub blocks: u64,
574 pub bytes: u64,
576 pub rows_before_limit: u64,
578 pub applied_limit: bool,
580 pub calculated_rows_before_limit: bool,
582}
583
584#[derive(Clone)]
611pub struct ExternalTable {
612 pub name: String,
614 pub data: Block,
616}
617
618impl ExternalTable {
619 pub fn new(name: impl Into<String>, data: Block) -> Self {
621 Self { name: name.into(), data }
622 }
623}
624
625pub type ProgressCallback = Arc<dyn Fn(&Progress) + Send + Sync>;
627pub type ProfileCallback = Arc<dyn Fn(&Profile) + Send + Sync>;
629pub type ProfileEventsCallback = Arc<dyn Fn(&Block) -> bool + Send + Sync>;
631pub type ServerLogCallback = Arc<dyn Fn(&Block) -> bool + Send + Sync>;
633pub type ExceptionCallback = Arc<dyn Fn(&Exception) + Send + Sync>;
635pub type DataCallback = Arc<dyn Fn(&Block) + Send + Sync>;
637pub type DataCancelableCallback = Arc<dyn Fn(&Block) -> bool + Send + Sync>;
639
640impl Progress {
641 pub fn write_to(
643 &self,
644 buffer: &mut BytesMut,
645 server_revision: u64,
646 ) -> Result<()> {
647 buffer_utils::write_varint(buffer, self.rows);
648 buffer_utils::write_varint(buffer, self.bytes);
649 buffer_utils::write_varint(buffer, self.total_rows);
650
651 if server_revision >= 54405 {
652 buffer_utils::write_varint(buffer, self.written_rows);
653 buffer_utils::write_varint(buffer, self.written_bytes);
654 }
655
656 Ok(())
657 }
658
659 pub fn read_from(
661 buffer: &mut &[u8],
662 server_revision: u64,
663 ) -> Result<Self> {
664 let rows = buffer_utils::read_varint(buffer)?;
665 let bytes = buffer_utils::read_varint(buffer)?;
666 let total_rows = buffer_utils::read_varint(buffer)?;
667
668 let (written_rows, written_bytes) = if server_revision >= 54405 {
669 (
670 buffer_utils::read_varint(buffer)?,
671 buffer_utils::read_varint(buffer)?,
672 )
673 } else {
674 (0, 0)
675 };
676
677 Ok(Self { rows, bytes, total_rows, written_rows, written_bytes })
678 }
679}
680
681impl Profile {
682 pub fn read_from(buffer: &mut &[u8]) -> Result<Self> {
684 let rows = buffer_utils::read_varint(buffer)?;
685 let blocks = buffer_utils::read_varint(buffer)?;
686 let bytes = buffer_utils::read_varint(buffer)?;
687
688 let applied_limit = if !buffer.is_empty() {
689 let val = buffer[0];
690 buffer.advance(1);
691 val != 0
692 } else {
693 false
694 };
695
696 let rows_before_limit = buffer_utils::read_varint(buffer)?;
697
698 let calculated_rows_before_limit = if !buffer.is_empty() {
699 let val = buffer[0];
700 buffer.advance(1);
701 val != 0
702 } else {
703 false
704 };
705
706 Ok(Self {
707 rows,
708 blocks,
709 bytes,
710 rows_before_limit,
711 applied_limit,
712 calculated_rows_before_limit,
713 })
714 }
715}
716
717#[derive(Clone, Debug)]
719pub struct Exception {
720 pub code: i32,
722 pub name: String,
724 pub display_text: String,
726 pub stack_trace: String,
728 pub nested: Option<Box<Exception>>,
730}
731
732impl Exception {
733 pub fn write_to(&self, buffer: &mut BytesMut) -> Result<()> {
735 buffer.put_i32_le(self.code);
736 buffer_utils::write_string(buffer, &self.name);
737 buffer_utils::write_string(buffer, &self.display_text);
738 buffer_utils::write_string(buffer, &self.stack_trace);
739
740 let has_nested = self.nested.is_some();
742 buffer.put_u8(if has_nested { 1 } else { 0 });
743
744 if let Some(nested) = &self.nested {
745 nested.write_to(buffer)?;
746 }
747
748 Ok(())
749 }
750
751 pub fn read_from(buffer: &mut &[u8]) -> Result<Self> {
753 if buffer.len() < 4 {
754 return Err(Error::Protocol(
755 "Not enough data to read Exception".to_string(),
756 ));
757 }
758
759 let code = {
760 let mut bytes = [0u8; 4];
761 bytes.copy_from_slice(&buffer[..4]);
762 buffer.advance(4);
763 i32::from_le_bytes(bytes)
764 };
765
766 let name = buffer_utils::read_string(buffer)?;
767 let display_text = buffer_utils::read_string(buffer)?;
768 let stack_trace = buffer_utils::read_string(buffer)?;
769
770 if buffer.is_empty() {
771 return Err(Error::Protocol(
772 "Not enough data to read nested exception flag".to_string(),
773 ));
774 }
775
776 let has_nested = buffer[0] != 0;
777 buffer.advance(1);
778
779 let nested = if has_nested {
780 Some(Box::new(Exception::read_from(buffer)?))
781 } else {
782 None
783 };
784
785 Ok(Self { code, name, display_text, stack_trace, nested })
786 }
787}
788
789#[cfg(test)]
793#[cfg_attr(coverage_nightly, coverage(off))]
794mod tests {
795 use super::*;
796
797 #[test]
798 fn test_query_creation() {
799 let query = Query::new("SELECT 1");
800 assert_eq!(query.text(), "SELECT 1");
801 assert_eq!(query.id(), "");
802 assert!(query.settings().is_empty());
803 }
804
805 #[test]
806 fn test_query_with_id() {
807 let query = Query::new("SELECT 1").with_query_id("test_query");
808 assert_eq!(query.id(), "test_query");
809 }
810
811 #[test]
812 fn test_query_with_settings() {
813 let query = Query::new("SELECT 1")
814 .with_setting("max_threads", "4")
815 .with_setting("max_memory_usage", "10000000");
816
817 assert_eq!(query.settings().len(), 2);
818 assert_eq!(
819 query.settings().get("max_threads").map(|f| f.value.as_str()),
820 Some("4")
821 );
822 assert_eq!(query.settings().get("max_threads").unwrap().flags, 0);
823 }
824
825 #[test]
826 fn test_query_with_important_settings() {
827 let query = Query::new("SELECT 1")
828 .with_important_setting("max_threads", "4")
829 .with_setting_flags(
830 "custom_setting",
831 "value",
832 QuerySettingsField::CUSTOM,
833 );
834
835 assert_eq!(query.settings().len(), 2);
836
837 let max_threads = query.settings().get("max_threads").unwrap();
838 assert_eq!(max_threads.value, "4");
839 assert!(max_threads.is_important());
840 assert!(!max_threads.is_custom());
841
842 let custom = query.settings().get("custom_setting").unwrap();
843 assert_eq!(custom.value, "value");
844 assert!(custom.is_custom());
845 assert!(!custom.is_important());
846 }
847
848 #[test]
849 fn test_client_info_roundtrip() {
850 let info = ClientInfo::default();
851 let mut buffer = BytesMut::new();
852 info.write_to(&mut buffer).unwrap();
853
854 let mut reader = &buffer[..];
855 let decoded = ClientInfo::read_from(&mut reader).unwrap();
856
857 assert_eq!(decoded.interface_type, 1);
858 assert_eq!(decoded.client_name, "clickhouse-rust");
859 }
860
861 #[test]
862 fn test_server_info_roundtrip() {
863 let info = ServerInfo {
864 name: "ClickHouse".to_string(),
865 version_major: 21,
866 version_minor: 8,
867 version_patch: 5,
868 revision: 54449,
869 timezone: "UTC".to_string(),
870 display_name: "ClickHouse server".to_string(),
871 };
872
873 let mut buffer = BytesMut::new();
874 info.write_to(&mut buffer).unwrap();
875
876 let mut reader = &buffer[..];
877 let decoded = ServerInfo::read_from(&mut reader).unwrap();
878
879 assert_eq!(decoded.name, "ClickHouse");
880 assert_eq!(decoded.version_major, 21);
881 assert_eq!(decoded.timezone, "UTC");
882 }
883
884 #[test]
885 fn test_progress_roundtrip() {
886 let progress = Progress {
887 rows: 100,
888 bytes: 1024,
889 total_rows: 1000,
890 written_rows: 50,
891 written_bytes: 512,
892 };
893
894 let mut buffer = BytesMut::new();
895 progress.write_to(&mut buffer, 54449).unwrap();
896
897 let mut reader = &buffer[..];
898 let decoded = Progress::read_from(&mut reader, 54449).unwrap();
899
900 assert_eq!(decoded.rows, 100);
901 assert_eq!(decoded.bytes, 1024);
902 assert_eq!(decoded.written_rows, 50);
903 }
904
905 #[test]
906 fn test_exception_simple() {
907 let exc = Exception {
908 code: 42,
909 name: "UNKNOWN_TABLE".to_string(),
910 display_text: "Table doesn't exist".to_string(),
911 stack_trace: "at query.cpp:123".to_string(),
912 nested: None,
913 };
914
915 let mut buffer = BytesMut::new();
916 exc.write_to(&mut buffer).unwrap();
917
918 let mut reader = &buffer[..];
919 let decoded = Exception::read_from(&mut reader).unwrap();
920
921 assert_eq!(decoded.code, 42);
922 assert_eq!(decoded.name, "UNKNOWN_TABLE");
923 assert!(decoded.nested.is_none());
924 }
925
926 #[test]
927 fn test_exception_nested() {
928 let nested_exc = Exception {
929 code: 1,
930 name: "INNER_ERROR".to_string(),
931 display_text: "Inner error".to_string(),
932 stack_trace: "inner stack".to_string(),
933 nested: None,
934 };
935
936 let exc = Exception {
937 code: 2,
938 name: "OUTER_ERROR".to_string(),
939 display_text: "Outer error".to_string(),
940 stack_trace: "outer stack".to_string(),
941 nested: Some(Box::new(nested_exc)),
942 };
943
944 let mut buffer = BytesMut::new();
945 exc.write_to(&mut buffer).unwrap();
946
947 let mut reader = &buffer[..];
948 let decoded = Exception::read_from(&mut reader).unwrap();
949
950 assert_eq!(decoded.code, 2);
951 assert!(decoded.nested.is_some());
952 assert_eq!(decoded.nested.as_ref().unwrap().code, 1);
953 }
954}