Skip to main content

clickhouse_native_client/
wire_format.rs

1use crate::{
2    Error,
3    Result,
4};
5use tokio::io::{
6    AsyncRead,
7    AsyncReadExt,
8    AsyncWrite,
9    AsyncWriteExt,
10};
11
12/// Wire format utilities for ClickHouse protocol
13pub struct WireFormat;
14
15impl WireFormat {
16    /// Read a varint-encoded u64
17    pub async fn read_varint64<R: AsyncRead + Unpin>(
18        reader: &mut R,
19    ) -> Result<u64> {
20        let mut result: u64 = 0;
21        let mut shift = 0;
22
23        loop {
24            let byte = reader.read_u8().await?;
25            result |= ((byte & 0x7F) as u64) << shift;
26
27            if byte & 0x80 == 0 {
28                break;
29            }
30
31            shift += 7;
32            if shift >= 64 {
33                return Err(Error::Protocol("Varint overflow".to_string()));
34            }
35        }
36
37        Ok(result)
38    }
39
40    /// Write a varint-encoded u64
41    pub async fn write_varint64<W: AsyncWrite + Unpin>(
42        writer: &mut W,
43        mut value: u64,
44    ) -> Result<()> {
45        loop {
46            let mut byte = (value & 0x7F) as u8;
47            value >>= 7;
48
49            if value != 0 {
50                byte |= 0x80;
51            }
52
53            writer.write_u8(byte).await?;
54
55            if value == 0 {
56                break;
57            }
58        }
59
60        Ok(())
61    }
62
63    /// Read a fixed-size value (little-endian)
64    pub async fn read_fixed<R: AsyncRead + Unpin + Send, T: FixedSize>(
65        reader: &mut R,
66    ) -> Result<T> {
67        T::read_from(reader).await
68    }
69
70    /// Write a fixed-size value (little-endian)
71    pub async fn write_fixed<W: AsyncWrite + Unpin + Send, T: FixedSize>(
72        writer: &mut W,
73        value: T,
74    ) -> Result<()> {
75        value.write_to(writer).await
76    }
77
78    /// Read a length-prefixed string
79    pub async fn read_string<R: AsyncRead + Unpin>(
80        reader: &mut R,
81    ) -> Result<String> {
82        let len = Self::read_varint64(reader).await? as usize;
83
84        if len > 0x00FFFFFF {
85            return Err(Error::Protocol(format!(
86                "String length too large: {}",
87                len
88            )));
89        }
90
91        let mut buf = vec![0u8; len];
92        reader.read_exact(&mut buf).await?;
93
94        String::from_utf8(buf)
95            .map_err(|e| Error::Protocol(format!("Invalid UTF-8: {}", e)))
96    }
97
98    /// Write a length-prefixed string
99    pub async fn write_string<W: AsyncWrite + Unpin>(
100        writer: &mut W,
101        value: &str,
102    ) -> Result<()> {
103        Self::write_varint64(writer, value.len() as u64).await?;
104        writer.write_all(value.as_bytes()).await?;
105        Ok(())
106    }
107
108    /// Read raw bytes of specified length
109    pub async fn read_bytes<R: AsyncRead + Unpin>(
110        reader: &mut R,
111        len: usize,
112    ) -> Result<Vec<u8>> {
113        let mut buf = vec![0u8; len];
114        reader.read_exact(&mut buf).await?;
115        Ok(buf)
116    }
117
118    /// Write raw bytes
119    pub async fn write_bytes<W: AsyncWrite + Unpin>(
120        writer: &mut W,
121        bytes: &[u8],
122    ) -> Result<()> {
123        writer.write_all(bytes).await?;
124        Ok(())
125    }
126
127    /// Skip a string without reading it into memory
128    pub async fn skip_string<R: AsyncRead + Unpin>(
129        reader: &mut R,
130    ) -> Result<()> {
131        let len = Self::read_varint64(reader).await? as usize;
132
133        if len > 0x00FFFFFF {
134            return Err(Error::Protocol(format!(
135                "String length too large: {}",
136                len
137            )));
138        }
139
140        // Skip bytes
141        let mut remaining = len;
142        let mut buf = [0u8; 8192];
143        while remaining > 0 {
144            let to_read = remaining.min(buf.len());
145            reader.read_exact(&mut buf[..to_read]).await?;
146            remaining -= to_read;
147        }
148
149        Ok(())
150    }
151
152    /// Write a quoted string for query parameters (1:1 port of C++
153    /// WriteQuotedString)
154    ///
155    /// Format: varint(length) + quoted_string
156    /// Special chars escaped: \0, \b, \t, \n, ', \
157    ///
158    /// Escaping rules:
159    /// - \0 → \x00
160    /// - \b → \x08
161    /// - \t → \\t
162    /// - \n → \\n
163    /// - '  → \x27
164    /// - \  → \\\
165    pub async fn write_quoted_string<W: AsyncWrite + Unpin>(
166        writer: &mut W,
167        value: &str,
168    ) -> Result<()> {
169        const QUOTED_CHARS: &[u8] = b"\0\x08\t\n'\\";
170
171        // Check if we need escaping (fast path)
172        let bytes = value.as_bytes();
173        let first_special =
174            bytes.iter().position(|&b| QUOTED_CHARS.contains(&b));
175
176        if first_special.is_none() {
177            // Fast path: no special characters
178            Self::write_varint64(writer, (value.len() + 2) as u64).await?;
179            writer.write_all(b"'").await?;
180            writer.write_all(bytes).await?;
181            writer.write_all(b"'").await?;
182            return Ok(());
183        }
184
185        // Count special characters for length calculation
186        let quoted_count =
187            bytes.iter().filter(|&&b| QUOTED_CHARS.contains(&b)).count();
188
189        // Write length: original + 2 quotes + 3 bytes per special char
190        let total_len = value.len() + 2 + 3 * quoted_count;
191        Self::write_varint64(writer, total_len as u64).await?;
192
193        // Write opening quote
194        writer.write_all(b"'").await?;
195
196        // Write string with escaping
197        let mut start = 0;
198        for (i, &byte) in bytes.iter().enumerate() {
199            if QUOTED_CHARS.contains(&byte) {
200                // Write chunk before special char
201                if i > start {
202                    writer.write_all(&bytes[start..i]).await?;
203                }
204
205                // Write escape sequence
206                writer.write_all(b"\\").await?;
207                match byte {
208                    b'\0' => writer.write_all(b"x00").await?,
209                    b'\x08' => writer.write_all(b"x08").await?,
210                    b'\t' => writer.write_all(b"\\t").await?,
211                    b'\n' => writer.write_all(b"\\n").await?,
212                    b'\'' => writer.write_all(b"x27").await?,
213                    b'\\' => writer.write_all(b"\\\\").await?,
214                    _ => unreachable!(),
215                }
216
217                start = i + 1;
218            }
219        }
220
221        // Write final chunk
222        if start < bytes.len() {
223            writer.write_all(&bytes[start..]).await?;
224        }
225
226        // Write closing quote
227        writer.write_all(b"'").await?;
228
229        Ok(())
230    }
231}
232
233/// Trait for types that can be read/written as fixed-size values
234#[async_trait::async_trait]
235pub trait FixedSize: Sized + Send {
236    /// Reads a fixed-size value from the given async reader.
237    async fn read_from<R: AsyncRead + Unpin + Send>(
238        reader: &mut R,
239    ) -> Result<Self>;
240    /// Writes this fixed-size value to the given async writer.
241    async fn write_to<W: AsyncWrite + Unpin + Send>(
242        self,
243        writer: &mut W,
244    ) -> Result<()>;
245}
246
247// Implement FixedSize for primitive types
248macro_rules! impl_fixed_size {
249    ($type:ty, $read:ident, $write:ident) => {
250        #[async_trait::async_trait]
251        impl FixedSize for $type {
252            async fn read_from<R: AsyncRead + Unpin + Send>(
253                reader: &mut R,
254            ) -> Result<Self> {
255                Ok(reader.$read().await?)
256            }
257
258            async fn write_to<W: AsyncWrite + Unpin + Send>(
259                self,
260                writer: &mut W,
261            ) -> Result<()> {
262                Ok(writer.$write(self).await?)
263            }
264        }
265    };
266}
267
268impl_fixed_size!(u8, read_u8, write_u8);
269impl_fixed_size!(u16, read_u16_le, write_u16_le);
270impl_fixed_size!(u32, read_u32_le, write_u32_le);
271impl_fixed_size!(u64, read_u64_le, write_u64_le);
272impl_fixed_size!(i8, read_i8, write_i8);
273impl_fixed_size!(i16, read_i16_le, write_i16_le);
274impl_fixed_size!(i32, read_i32_le, write_i32_le);
275impl_fixed_size!(i64, read_i64_le, write_i64_le);
276impl_fixed_size!(f32, read_f32_le, write_f32_le);
277impl_fixed_size!(f64, read_f64_le, write_f64_le);
278
279// i128/u128 implementation
280#[async_trait::async_trait]
281impl FixedSize for i128 {
282    async fn read_from<R: AsyncRead + Unpin + Send>(
283        reader: &mut R,
284    ) -> Result<Self> {
285        Ok(reader.read_i128_le().await?)
286    }
287
288    async fn write_to<W: AsyncWrite + Unpin + Send>(
289        self,
290        writer: &mut W,
291    ) -> Result<()> {
292        Ok(writer.write_i128_le(self).await?)
293    }
294}
295
296#[async_trait::async_trait]
297impl FixedSize for u128 {
298    async fn read_from<R: AsyncRead + Unpin + Send>(
299        reader: &mut R,
300    ) -> Result<Self> {
301        Ok(reader.read_u128_le().await?)
302    }
303
304    async fn write_to<W: AsyncWrite + Unpin + Send>(
305        self,
306        writer: &mut W,
307    ) -> Result<()> {
308        Ok(writer.write_u128_le(self).await?)
309    }
310}
311
312#[cfg(test)]
313#[cfg_attr(coverage_nightly, coverage(off))]
314mod tests {
315    use super::*;
316
317    #[tokio::test]
318    async fn test_varint64_encoding() {
319        let test_cases =
320            vec![0u64, 1, 127, 128, 255, 256, 65535, 65536, u64::MAX];
321
322        for value in test_cases {
323            let mut buf = Vec::new();
324            WireFormat::write_varint64(&mut buf, value).await.unwrap();
325
326            let mut reader = &buf[..];
327            let decoded =
328                WireFormat::read_varint64(&mut reader).await.unwrap();
329
330            assert_eq!(value, decoded, "Varint encoding failed for {}", value);
331        }
332    }
333
334    #[tokio::test]
335    async fn test_string_encoding() {
336        let test_strings = vec!["", "hello", "мир", "🦀"];
337
338        for s in test_strings {
339            let mut buf = Vec::new();
340            WireFormat::write_string(&mut buf, s).await.unwrap();
341
342            let mut reader = &buf[..];
343            let decoded = WireFormat::read_string(&mut reader).await.unwrap();
344
345            assert_eq!(s, decoded, "String encoding failed for '{}'", s);
346        }
347    }
348
349    #[tokio::test]
350    async fn test_fixed_u32() {
351        let value = 0x12345678u32;
352        let mut buf = Vec::new();
353        WireFormat::write_fixed(&mut buf, value).await.unwrap();
354
355        assert_eq!(buf, vec![0x78, 0x56, 0x34, 0x12]); // Little-endian
356
357        let mut reader = &buf[..];
358        let decoded: u32 = WireFormat::read_fixed(&mut reader).await.unwrap();
359
360        assert_eq!(value, decoded);
361    }
362
363    #[tokio::test]
364    async fn test_fixed_i64() {
365        let value = -12345i64;
366        let mut buf = Vec::new();
367        WireFormat::write_fixed(&mut buf, value).await.unwrap();
368
369        let mut reader = &buf[..];
370        let decoded: i64 = WireFormat::read_fixed(&mut reader).await.unwrap();
371
372        assert_eq!(value, decoded);
373    }
374
375    #[tokio::test]
376    async fn test_fixed_float() {
377        let value = std::f32::consts::PI;
378        let mut buf = Vec::new();
379        WireFormat::write_fixed(&mut buf, value).await.unwrap();
380
381        let mut reader = &buf[..];
382        let decoded: f32 = WireFormat::read_fixed(&mut reader).await.unwrap();
383
384        assert!((value - decoded).abs() < 1e-6);
385    }
386
387    #[tokio::test]
388    async fn test_bytes() {
389        let data = vec![1u8, 2, 3, 4, 5];
390        let mut buf = Vec::new();
391        WireFormat::write_bytes(&mut buf, &data).await.unwrap();
392
393        let mut reader = &buf[..];
394        let decoded =
395            WireFormat::read_bytes(&mut reader, data.len()).await.unwrap();
396
397        assert_eq!(data, decoded);
398    }
399
400    #[tokio::test]
401    async fn test_write_quoted_string_no_escaping() {
402        let mut buf = Vec::new();
403        WireFormat::write_quoted_string(&mut buf, "hello").await.unwrap();
404
405        // Length: 7 (5 + 2 quotes)
406        // Content: 'hello'
407        let mut expected = Vec::new();
408        WireFormat::write_varint64(&mut expected, 7).await.unwrap();
409        expected.extend_from_slice(b"'hello'");
410
411        assert_eq!(buf, expected);
412    }
413
414    #[tokio::test]
415    async fn test_write_quoted_string_with_tab() {
416        let mut buf = Vec::new();
417        WireFormat::write_quoted_string(&mut buf, "a\tb").await.unwrap();
418
419        // Length: original(3) + 2(quotes) + 3(one special char) = 8
420        // Content: 'a\\tb'
421        let mut expected = Vec::new();
422        WireFormat::write_varint64(&mut expected, 8).await.unwrap();
423        expected.extend_from_slice(b"'a\\\\tb'");
424
425        assert_eq!(buf, expected);
426    }
427
428    #[tokio::test]
429    async fn test_write_quoted_string_with_null() {
430        let mut buf = Vec::new();
431        WireFormat::write_quoted_string(&mut buf, "a\0b").await.unwrap();
432
433        // Length: 3 + 2 + 3 = 8
434        // Content: 'a\x00b'
435        let mut expected = Vec::new();
436        WireFormat::write_varint64(&mut expected, 8).await.unwrap();
437        expected.extend_from_slice(b"'a\\x00b'");
438
439        assert_eq!(buf, expected);
440    }
441
442    #[tokio::test]
443    async fn test_write_quoted_string_all_special_chars() {
444        let test_str = "\0\x08\t\n'\\";
445        let mut buf = Vec::new();
446        WireFormat::write_quoted_string(&mut buf, test_str).await.unwrap();
447
448        // 6 chars, each becomes 4 bytes: 6 + 2 + 3*6 = 26
449        let mut expected = Vec::new();
450        WireFormat::write_varint64(&mut expected, 26).await.unwrap();
451        // \0 → \x00, \b → \x08, \t → \\t, \n → \\n, ' → \x27, \ → \\\
452        expected.extend_from_slice(b"'\\x00\\x08\\\\t\\\\n\\x27\\\\\\'");
453
454        assert_eq!(buf, expected);
455    }
456
457    #[tokio::test]
458    async fn test_write_quoted_string_single_quote() {
459        let mut buf = Vec::new();
460        WireFormat::write_quoted_string(&mut buf, "a'b").await.unwrap();
461
462        // Length: 3 + 2 + 3 = 8
463        // Content: 'a\x27b'
464        let mut expected = Vec::new();
465        WireFormat::write_varint64(&mut expected, 8).await.unwrap();
466        expected.extend_from_slice(b"'a\\x27b'");
467
468        assert_eq!(buf, expected);
469    }
470
471    #[tokio::test]
472    async fn test_write_quoted_string_backslash() {
473        let mut buf = Vec::new();
474        WireFormat::write_quoted_string(&mut buf, "a\\b").await.unwrap();
475
476        // Length: 3 + 2 + 3 = 8
477        // Content: 'a\\\b' (backslash becomes \\\ which is 3 backslashes)
478        let mut expected = Vec::new();
479        WireFormat::write_varint64(&mut expected, 8).await.unwrap();
480        expected.extend_from_slice(b"'a\\\\\\b'");
481
482        assert_eq!(buf, expected);
483    }
484
485    #[tokio::test]
486    async fn test_write_quoted_string_utf8() {
487        let mut buf = Vec::new();
488        WireFormat::write_quoted_string(&mut buf, "utf8Русский")
489            .await
490            .unwrap();
491
492        // UTF-8 doesn't need escaping unless it contains special chars
493        let content = "utf8Русский";
494        let expected_len = content.len() + 2;
495        let mut expected = Vec::new();
496        WireFormat::write_varint64(&mut expected, expected_len as u64)
497            .await
498            .unwrap();
499        expected.push(b'\'');
500        expected.extend_from_slice(content.as_bytes());
501        expected.push(b'\'');
502
503        assert_eq!(buf, expected);
504    }
505}