use crate::{
block::Block,
io::buffer_utils,
Error,
Result,
};
use bytes::{
Buf,
BufMut,
BytesMut,
};
use std::{
collections::HashMap,
sync::Arc,
};
#[derive(Clone, Debug, Default)]
pub struct QuerySettingsField {
pub value: String,
pub flags: u64,
}
impl QuerySettingsField {
pub const IMPORTANT: u64 = 0x01;
pub const CUSTOM: u64 = 0x02;
pub const OBSOLETE: u64 = 0x04;
pub fn new(value: impl Into<String>) -> Self {
Self { value: value.into(), flags: 0 }
}
pub fn with_flags(value: impl Into<String>, flags: u64) -> Self {
Self { value: value.into(), flags }
}
pub fn important(value: impl Into<String>) -> Self {
Self::with_flags(value, Self::IMPORTANT)
}
pub fn custom(value: impl Into<String>) -> Self {
Self::with_flags(value, Self::CUSTOM)
}
pub fn is_important(&self) -> bool {
(self.flags & Self::IMPORTANT) != 0
}
pub fn is_custom(&self) -> bool {
(self.flags & Self::CUSTOM) != 0
}
pub fn is_obsolete(&self) -> bool {
(self.flags & Self::OBSOLETE) != 0
}
}
pub type QuerySettings = HashMap<String, QuerySettingsField>;
#[derive(Clone, Debug, Default)]
pub struct TracingContext {
pub trace_id: u128,
pub span_id: u64,
pub tracestate: String,
pub trace_flags: u8,
}
impl TracingContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_ids(trace_id: u128, span_id: u64) -> Self {
Self { trace_id, span_id, tracestate: String::new(), trace_flags: 0 }
}
pub fn trace_id(mut self, trace_id: u128) -> Self {
self.trace_id = trace_id;
self
}
pub fn span_id(mut self, span_id: u64) -> Self {
self.span_id = span_id;
self
}
pub fn tracestate(mut self, tracestate: impl Into<String>) -> Self {
self.tracestate = tracestate.into();
self
}
pub fn trace_flags(mut self, flags: u8) -> Self {
self.trace_flags = flags;
self
}
pub fn is_enabled(&self) -> bool {
self.trace_id != 0
}
}
#[derive(Clone)]
pub struct Query {
query_text: String,
query_id: String,
settings: QuerySettings,
parameters: HashMap<String, String>,
tracing_context: Option<TracingContext>,
on_progress: Option<ProgressCallback>,
on_profile: Option<ProfileCallback>,
on_profile_events: Option<ProfileEventsCallback>,
on_server_log: Option<ServerLogCallback>,
on_exception: Option<ExceptionCallback>,
on_data: Option<DataCallback>,
on_data_cancelable: Option<DataCancelableCallback>,
}
impl Query {
pub fn new(query_text: impl Into<String>) -> Self {
Self {
query_text: query_text.into(),
query_id: String::new(),
settings: HashMap::new(),
parameters: HashMap::new(),
tracing_context: None,
on_progress: None,
on_profile: None,
on_profile_events: None,
on_server_log: None,
on_exception: None,
on_data: None,
on_data_cancelable: None,
}
}
}
impl From<&str> for Query {
fn from(s: &str) -> Self {
Query::new(s)
}
}
impl From<String> for Query {
fn from(s: String) -> Self {
Query::new(s)
}
}
impl Query {
pub fn with_query_id(mut self, query_id: impl Into<String>) -> Self {
self.query_id = query_id.into();
self
}
pub fn with_setting(
mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> Self {
self.settings.insert(key.into(), QuerySettingsField::new(value));
self
}
pub fn with_setting_flags(
mut self,
key: impl Into<String>,
value: impl Into<String>,
flags: u64,
) -> Self {
self.settings
.insert(key.into(), QuerySettingsField::with_flags(value, flags));
self
}
pub fn with_important_setting(
mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> Self {
self.settings.insert(key.into(), QuerySettingsField::important(value));
self
}
pub fn with_parameter(
mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> Self {
self.parameters.insert(key.into(), value.into());
self
}
pub fn with_tracing_context(mut self, context: TracingContext) -> Self {
self.tracing_context = Some(context);
self
}
pub fn text(&self) -> &str {
&self.query_text
}
pub fn tracing_context(&self) -> Option<&TracingContext> {
self.tracing_context.as_ref()
}
pub fn id(&self) -> &str {
&self.query_id
}
pub fn settings(&self) -> &QuerySettings {
&self.settings
}
pub fn parameters(&self) -> &HashMap<String, String> {
&self.parameters
}
pub fn on_progress<F>(mut self, callback: F) -> Self
where
F: Fn(&Progress) + Send + Sync + 'static,
{
self.on_progress = Some(Arc::new(callback));
self
}
pub fn on_profile<F>(mut self, callback: F) -> Self
where
F: Fn(&Profile) + Send + Sync + 'static,
{
self.on_profile = Some(Arc::new(callback));
self
}
pub fn on_profile_events<F>(mut self, callback: F) -> Self
where
F: Fn(&Block) -> bool + Send + Sync + 'static,
{
self.on_profile_events = Some(Arc::new(callback));
self
}
pub fn on_server_log<F>(mut self, callback: F) -> Self
where
F: Fn(&Block) -> bool + Send + Sync + 'static,
{
self.on_server_log = Some(Arc::new(callback));
self
}
pub fn on_exception<F>(mut self, callback: F) -> Self
where
F: Fn(&Exception) + Send + Sync + 'static,
{
self.on_exception = Some(Arc::new(callback));
self
}
pub fn on_data<F>(mut self, callback: F) -> Self
where
F: Fn(&Block) + Send + Sync + 'static,
{
self.on_data = Some(Arc::new(callback));
self
}
pub fn on_data_cancelable<F>(mut self, callback: F) -> Self
where
F: Fn(&Block) -> bool + Send + Sync + 'static,
{
self.on_data_cancelable = Some(Arc::new(callback));
self
}
pub(crate) fn get_on_progress(&self) -> Option<&ProgressCallback> {
self.on_progress.as_ref()
}
pub(crate) fn get_on_profile(&self) -> Option<&ProfileCallback> {
self.on_profile.as_ref()
}
pub(crate) fn get_on_profile_events(
&self,
) -> Option<&ProfileEventsCallback> {
self.on_profile_events.as_ref()
}
pub(crate) fn get_on_server_log(&self) -> Option<&ServerLogCallback> {
self.on_server_log.as_ref()
}
pub(crate) fn get_on_exception(&self) -> Option<&ExceptionCallback> {
self.on_exception.as_ref()
}
pub(crate) fn get_on_data(&self) -> Option<&DataCallback> {
self.on_data.as_ref()
}
pub(crate) fn get_on_data_cancelable(
&self,
) -> Option<&DataCancelableCallback> {
self.on_data_cancelable.as_ref()
}
}
#[derive(Clone, Debug)]
pub struct ClientInfo {
pub interface_type: u8,
pub query_kind: u8,
pub initial_user: String,
pub initial_query_id: String,
pub quota_key: String,
pub os_user: String,
pub client_hostname: String,
pub client_name: String,
pub client_version_major: u64,
pub client_version_minor: u64,
pub client_version_patch: u64,
pub client_revision: u64,
}
impl Default for ClientInfo {
fn default() -> Self {
Self {
interface_type: 1, query_kind: 0,
initial_user: String::new(),
initial_query_id: String::new(),
quota_key: String::new(),
os_user: std::env::var("USER")
.unwrap_or_else(|_| "default".to_string()),
client_hostname: "localhost".to_string(),
client_name: "clickhouse-rust".to_string(),
client_version_major: 1,
client_version_minor: 0,
client_version_patch: 0,
client_revision: 54459,
}
}
}
impl ClientInfo {
pub fn write_to(&self, buffer: &mut BytesMut) -> Result<()> {
buffer.put_u8(self.interface_type);
buffer_utils::write_string(buffer, &self.os_user);
buffer_utils::write_string(buffer, &self.client_hostname);
buffer_utils::write_string(buffer, &self.client_name);
buffer_utils::write_varint(buffer, self.client_version_major);
buffer_utils::write_varint(buffer, self.client_version_minor);
buffer_utils::write_varint(buffer, self.client_revision);
Ok(())
}
pub fn read_from(buffer: &mut &[u8]) -> Result<Self> {
if buffer.is_empty() {
return Err(Error::Protocol(
"Not enough data to read ClientInfo".to_string(),
));
}
let interface_type = buffer[0];
buffer.advance(1);
let os_user = buffer_utils::read_string(buffer)?;
let client_hostname = buffer_utils::read_string(buffer)?;
let client_name = buffer_utils::read_string(buffer)?;
let client_version_major = buffer_utils::read_varint(buffer)?;
let client_version_minor = buffer_utils::read_varint(buffer)?;
let client_revision = buffer_utils::read_varint(buffer)?;
Ok(Self {
interface_type,
query_kind: 0,
initial_user: String::new(),
initial_query_id: String::new(),
quota_key: String::new(),
os_user,
client_hostname,
client_name,
client_version_major,
client_version_minor,
client_version_patch: 0,
client_revision,
})
}
}
#[derive(Clone, Debug, Default)]
pub struct ServerInfo {
pub name: String,
pub version_major: u64,
pub version_minor: u64,
pub version_patch: u64,
pub revision: u64,
pub timezone: String,
pub display_name: String,
}
impl ServerInfo {
pub fn write_to(&self, buffer: &mut BytesMut) -> Result<()> {
buffer_utils::write_string(buffer, &self.name);
buffer_utils::write_varint(buffer, self.version_major);
buffer_utils::write_varint(buffer, self.version_minor);
buffer_utils::write_varint(buffer, self.revision);
if self.revision >= 54058 {
buffer_utils::write_string(buffer, &self.timezone);
}
if self.revision >= 54372 {
buffer_utils::write_string(buffer, &self.display_name);
}
if self.revision >= 54401 {
buffer_utils::write_varint(buffer, self.version_patch);
}
Ok(())
}
pub fn read_from(buffer: &mut &[u8]) -> Result<Self> {
let name = buffer_utils::read_string(buffer)?;
let version_major = buffer_utils::read_varint(buffer)?;
let version_minor = buffer_utils::read_varint(buffer)?;
let revision = buffer_utils::read_varint(buffer)?;
let timezone = if revision >= 54058 {
buffer_utils::read_string(buffer)?
} else {
String::new()
};
let display_name = if revision >= 54372 {
buffer_utils::read_string(buffer)?
} else {
String::new()
};
let version_patch = if revision >= 54401 {
buffer_utils::read_varint(buffer)?
} else {
0
};
Ok(Self {
name,
version_major,
version_minor,
version_patch,
revision,
timezone,
display_name,
})
}
}
#[derive(Clone, Debug, Default)]
pub struct Progress {
pub rows: u64,
pub bytes: u64,
pub total_rows: u64,
pub written_rows: u64,
pub written_bytes: u64,
}
#[derive(Clone, Debug, Default)]
pub struct Profile {
pub rows: u64,
pub blocks: u64,
pub bytes: u64,
pub rows_before_limit: u64,
pub applied_limit: bool,
pub calculated_rows_before_limit: bool,
}
#[derive(Clone)]
pub struct ExternalTable {
pub name: String,
pub data: Block,
}
impl ExternalTable {
pub fn new(name: impl Into<String>, data: Block) -> Self {
Self { name: name.into(), data }
}
}
pub type ProgressCallback = Arc<dyn Fn(&Progress) + Send + Sync>;
pub type ProfileCallback = Arc<dyn Fn(&Profile) + Send + Sync>;
pub type ProfileEventsCallback = Arc<dyn Fn(&Block) -> bool + Send + Sync>;
pub type ServerLogCallback = Arc<dyn Fn(&Block) -> bool + Send + Sync>;
pub type ExceptionCallback = Arc<dyn Fn(&Exception) + Send + Sync>;
pub type DataCallback = Arc<dyn Fn(&Block) + Send + Sync>;
pub type DataCancelableCallback = Arc<dyn Fn(&Block) -> bool + Send + Sync>;
impl Progress {
pub fn write_to(
&self,
buffer: &mut BytesMut,
server_revision: u64,
) -> Result<()> {
buffer_utils::write_varint(buffer, self.rows);
buffer_utils::write_varint(buffer, self.bytes);
buffer_utils::write_varint(buffer, self.total_rows);
if server_revision >= 54405 {
buffer_utils::write_varint(buffer, self.written_rows);
buffer_utils::write_varint(buffer, self.written_bytes);
}
Ok(())
}
pub fn read_from(
buffer: &mut &[u8],
server_revision: u64,
) -> Result<Self> {
let rows = buffer_utils::read_varint(buffer)?;
let bytes = buffer_utils::read_varint(buffer)?;
let total_rows = buffer_utils::read_varint(buffer)?;
let (written_rows, written_bytes) = if server_revision >= 54405 {
(
buffer_utils::read_varint(buffer)?,
buffer_utils::read_varint(buffer)?,
)
} else {
(0, 0)
};
Ok(Self { rows, bytes, total_rows, written_rows, written_bytes })
}
}
impl Profile {
pub fn read_from(buffer: &mut &[u8]) -> Result<Self> {
let rows = buffer_utils::read_varint(buffer)?;
let blocks = buffer_utils::read_varint(buffer)?;
let bytes = buffer_utils::read_varint(buffer)?;
let applied_limit = if !buffer.is_empty() {
let val = buffer[0];
buffer.advance(1);
val != 0
} else {
false
};
let rows_before_limit = buffer_utils::read_varint(buffer)?;
let calculated_rows_before_limit = if !buffer.is_empty() {
let val = buffer[0];
buffer.advance(1);
val != 0
} else {
false
};
Ok(Self {
rows,
blocks,
bytes,
rows_before_limit,
applied_limit,
calculated_rows_before_limit,
})
}
}
#[derive(Clone, Debug)]
pub struct Exception {
pub code: i32,
pub name: String,
pub display_text: String,
pub stack_trace: String,
pub nested: Option<Box<Exception>>,
}
impl Exception {
pub fn write_to(&self, buffer: &mut BytesMut) -> Result<()> {
buffer.put_i32_le(self.code);
buffer_utils::write_string(buffer, &self.name);
buffer_utils::write_string(buffer, &self.display_text);
buffer_utils::write_string(buffer, &self.stack_trace);
let has_nested = self.nested.is_some();
buffer.put_u8(if has_nested { 1 } else { 0 });
if let Some(nested) = &self.nested {
nested.write_to(buffer)?;
}
Ok(())
}
pub fn read_from(buffer: &mut &[u8]) -> Result<Self> {
if buffer.len() < 4 {
return Err(Error::Protocol(
"Not enough data to read Exception".to_string(),
));
}
let code = {
let mut bytes = [0u8; 4];
bytes.copy_from_slice(&buffer[..4]);
buffer.advance(4);
i32::from_le_bytes(bytes)
};
let name = buffer_utils::read_string(buffer)?;
let display_text = buffer_utils::read_string(buffer)?;
let stack_trace = buffer_utils::read_string(buffer)?;
if buffer.is_empty() {
return Err(Error::Protocol(
"Not enough data to read nested exception flag".to_string(),
));
}
let has_nested = buffer[0] != 0;
buffer.advance(1);
let nested = if has_nested {
Some(Box::new(Exception::read_from(buffer)?))
} else {
None
};
Ok(Self { code, name, display_text, stack_trace, nested })
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
#[test]
fn test_query_creation() {
let query = Query::new("SELECT 1");
assert_eq!(query.text(), "SELECT 1");
assert_eq!(query.id(), "");
assert!(query.settings().is_empty());
}
#[test]
fn test_query_with_id() {
let query = Query::new("SELECT 1").with_query_id("test_query");
assert_eq!(query.id(), "test_query");
}
#[test]
fn test_query_with_settings() {
let query = Query::new("SELECT 1")
.with_setting("max_threads", "4")
.with_setting("max_memory_usage", "10000000");
assert_eq!(query.settings().len(), 2);
assert_eq!(
query.settings().get("max_threads").map(|f| f.value.as_str()),
Some("4")
);
assert_eq!(query.settings().get("max_threads").unwrap().flags, 0);
}
#[test]
fn test_query_with_important_settings() {
let query = Query::new("SELECT 1")
.with_important_setting("max_threads", "4")
.with_setting_flags(
"custom_setting",
"value",
QuerySettingsField::CUSTOM,
);
assert_eq!(query.settings().len(), 2);
let max_threads = query.settings().get("max_threads").unwrap();
assert_eq!(max_threads.value, "4");
assert!(max_threads.is_important());
assert!(!max_threads.is_custom());
let custom = query.settings().get("custom_setting").unwrap();
assert_eq!(custom.value, "value");
assert!(custom.is_custom());
assert!(!custom.is_important());
}
#[test]
fn test_client_info_roundtrip() {
let info = ClientInfo::default();
let mut buffer = BytesMut::new();
info.write_to(&mut buffer).unwrap();
let mut reader = &buffer[..];
let decoded = ClientInfo::read_from(&mut reader).unwrap();
assert_eq!(decoded.interface_type, 1);
assert_eq!(decoded.client_name, "clickhouse-rust");
}
#[test]
fn test_server_info_roundtrip() {
let info = ServerInfo {
name: "ClickHouse".to_string(),
version_major: 21,
version_minor: 8,
version_patch: 5,
revision: 54449,
timezone: "UTC".to_string(),
display_name: "ClickHouse server".to_string(),
};
let mut buffer = BytesMut::new();
info.write_to(&mut buffer).unwrap();
let mut reader = &buffer[..];
let decoded = ServerInfo::read_from(&mut reader).unwrap();
assert_eq!(decoded.name, "ClickHouse");
assert_eq!(decoded.version_major, 21);
assert_eq!(decoded.timezone, "UTC");
}
#[test]
fn test_progress_roundtrip() {
let progress = Progress {
rows: 100,
bytes: 1024,
total_rows: 1000,
written_rows: 50,
written_bytes: 512,
};
let mut buffer = BytesMut::new();
progress.write_to(&mut buffer, 54449).unwrap();
let mut reader = &buffer[..];
let decoded = Progress::read_from(&mut reader, 54449).unwrap();
assert_eq!(decoded.rows, 100);
assert_eq!(decoded.bytes, 1024);
assert_eq!(decoded.written_rows, 50);
}
#[test]
fn test_exception_simple() {
let exc = Exception {
code: 42,
name: "UNKNOWN_TABLE".to_string(),
display_text: "Table doesn't exist".to_string(),
stack_trace: "at query.cpp:123".to_string(),
nested: None,
};
let mut buffer = BytesMut::new();
exc.write_to(&mut buffer).unwrap();
let mut reader = &buffer[..];
let decoded = Exception::read_from(&mut reader).unwrap();
assert_eq!(decoded.code, 42);
assert_eq!(decoded.name, "UNKNOWN_TABLE");
assert!(decoded.nested.is_none());
}
#[test]
fn test_exception_nested() {
let nested_exc = Exception {
code: 1,
name: "INNER_ERROR".to_string(),
display_text: "Inner error".to_string(),
stack_trace: "inner stack".to_string(),
nested: None,
};
let exc = Exception {
code: 2,
name: "OUTER_ERROR".to_string(),
display_text: "Outer error".to_string(),
stack_trace: "outer stack".to_string(),
nested: Some(Box::new(nested_exc)),
};
let mut buffer = BytesMut::new();
exc.write_to(&mut buffer).unwrap();
let mut reader = &buffer[..];
let decoded = Exception::read_from(&mut reader).unwrap();
assert_eq!(decoded.code, 2);
assert!(decoded.nested.is_some());
assert_eq!(decoded.nested.as_ref().unwrap().code, 1);
}
}