Skip to main content

qail_pg/protocol/
encoder.rs

1//! PostgreSQL Encoder (Visitor Pattern)
2//!
3//! Compiles Qail AST into PostgreSQL wire protocol bytes.
4//! This is pure, synchronous computation - no I/O, no async.
5//!
6//! # Architecture
7//!
8//! Layer 2 of the QAIL architecture:
9//! - Input: Qail (AST)
10//! - Output: BytesMut (ready to send over the wire)
11//!
12//! The async I/O layer (Layer 3) consumes these bytes.
13
14use bytes::BytesMut;
15use super::EncodeError;
16
17/// Takes a Qail and produces wire protocol bytes.
18/// This is the "Visitor" in the visitor pattern.
19pub struct PgEncoder;
20
21impl PgEncoder {
22    /// Encode a raw SQL string as a Simple Query message.
23    /// Wire format:
24    /// - 'Q' (1 byte) - message type
25    /// - length (4 bytes, big-endian, includes self)
26    /// - query string (null-terminated)
27    pub fn encode_query_string(sql: &str) -> BytesMut {
28        let mut buf = BytesMut::new();
29
30        // Bounds check: SQL + null terminator + 4 bytes length must fit in i32
31        let content_len = sql.len() + 1; // +1 for null terminator
32        if content_len > (i32::MAX as usize) - 4 {
33            // Return empty buffer — write will fail safely rather than
34            // producing a malformed message with overflowed length.
35            return buf;
36        }
37
38        // Message type 'Q' for Query
39        buf.extend_from_slice(b"Q");
40
41        let total_len = (content_len + 4) as i32; // +4 for length field itself
42
43        // Length (4 bytes, big-endian)
44        buf.extend_from_slice(&total_len.to_be_bytes());
45
46        // Query string
47        buf.extend_from_slice(sql.as_bytes());
48
49        // Null terminator
50        buf.extend_from_slice(&[0]);
51
52        buf
53    }
54
55    /// Encode a Terminate message to close the connection.
56    pub fn encode_terminate() -> BytesMut {
57        let mut buf = BytesMut::new();
58        buf.extend_from_slice(&[b'X', 0, 0, 0, 4]);
59        buf
60    }
61
62    /// Encode a Sync message (end of pipeline in extended query protocol).
63    pub fn encode_sync() -> BytesMut {
64        let mut buf = BytesMut::new();
65        buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
66        buf
67    }
68
69    // ==================== Extended Query Protocol ====================
70
71    /// Encode a Parse message (prepare a statement).
72    /// Wire format:
73    /// - 'P' (1 byte) - message type
74    /// - length (4 bytes)
75    /// - statement name (null-terminated, "" for unnamed)
76    /// - query string (null-terminated)
77    /// - parameter count (2 bytes)
78    /// - parameter OIDs (4 bytes each, 0 = infer type)
79    pub fn encode_parse(name: &str, sql: &str, param_types: &[u32]) -> BytesMut {
80        let mut buf = BytesMut::new();
81
82        // Message type 'P'
83        buf.extend_from_slice(b"P");
84
85        let mut content = Vec::new();
86
87        // Statement name (null-terminated)
88        content.extend_from_slice(name.as_bytes());
89        content.push(0);
90
91        // Query string (null-terminated)
92        content.extend_from_slice(sql.as_bytes());
93        content.push(0);
94
95        // Parameter count
96        content.extend_from_slice(&(param_types.len() as i16).to_be_bytes());
97
98        // Parameter OIDs
99        for &oid in param_types {
100            content.extend_from_slice(&oid.to_be_bytes());
101        }
102
103        // Length (includes length field itself)
104        let len = (content.len() + 4) as i32;
105        buf.extend_from_slice(&len.to_be_bytes());
106        buf.extend_from_slice(&content);
107
108        buf
109    }
110
111    /// Encode a Bind message (bind parameters to a prepared statement).
112    /// Wire format:
113    /// - 'B' (1 byte) - message type
114    /// - length (4 bytes)
115    /// - portal name (null-terminated)
116    /// - statement name (null-terminated)
117    /// - format code count (2 bytes) - we use 0 (all text)
118    /// - parameter count (2 bytes)
119    /// - for each parameter: length (4 bytes, -1 for NULL), data
120    /// - result format count (2 bytes) - we use 0 (all text)
121    ///
122    /// # Arguments
123    ///
124    /// * `portal` — Destination portal name (empty string for unnamed).
125    /// * `statement` — Source prepared statement name (empty string for unnamed).
126    /// * `params` — Parameter values; `None` entries encode as SQL NULL.
127    pub fn encode_bind(portal: &str, statement: &str, params: &[Option<Vec<u8>>]) -> Result<BytesMut, EncodeError> {
128        if params.len() > i16::MAX as usize {
129            return Err(EncodeError::TooManyParameters(params.len()));
130        }
131
132        let mut buf = BytesMut::new();
133
134        // Message type 'B'
135        buf.extend_from_slice(b"B");
136
137        let mut content = Vec::new();
138
139        // Portal name (null-terminated)
140        content.extend_from_slice(portal.as_bytes());
141        content.push(0);
142
143        // Statement name (null-terminated)
144        content.extend_from_slice(statement.as_bytes());
145        content.push(0);
146
147        // Format codes count (0 = use default text format)
148        content.extend_from_slice(&0i16.to_be_bytes());
149
150        // Parameter count
151        content.extend_from_slice(&(params.len() as i16).to_be_bytes());
152
153        // Parameters
154        for param in params {
155            match param {
156                None => {
157                    // NULL: length = -1
158                    content.extend_from_slice(&(-1i32).to_be_bytes());
159                }
160                Some(data) => {
161                    if data.len() > i32::MAX as usize {
162                        return Err(EncodeError::MessageTooLarge(data.len()));
163                    }
164                    content.extend_from_slice(&(data.len() as i32).to_be_bytes());
165                    content.extend_from_slice(data);
166                }
167            }
168        }
169
170        // Result format codes count (0 = use default text format)
171        content.extend_from_slice(&0i16.to_be_bytes());
172
173        // Length
174        let len = (content.len() + 4) as i32;
175        buf.extend_from_slice(&len.to_be_bytes());
176        buf.extend_from_slice(&content);
177
178        Ok(buf)
179    }
180
181    /// Encode an Execute message (execute a bound portal).
182    /// Wire format:
183    /// - 'E' (1 byte) - message type
184    /// - length (4 bytes)
185    /// - portal name (null-terminated)
186    /// - max rows (4 bytes, 0 = unlimited)
187    pub fn encode_execute(portal: &str, max_rows: i32) -> BytesMut {
188        let mut buf = BytesMut::new();
189
190        // Message type 'E'
191        buf.extend_from_slice(b"E");
192
193        let mut content = Vec::new();
194
195        // Portal name (null-terminated)
196        content.extend_from_slice(portal.as_bytes());
197        content.push(0);
198
199        // Max rows
200        content.extend_from_slice(&max_rows.to_be_bytes());
201
202        // Length
203        let len = (content.len() + 4) as i32;
204        buf.extend_from_slice(&len.to_be_bytes());
205        buf.extend_from_slice(&content);
206
207        buf
208    }
209
210    /// Encode a Describe message (get statement/portal metadata).
211    /// Wire format:
212    /// - 'D' (1 byte) - message type
213    /// - length (4 bytes)
214    /// - 'S' for statement or 'P' for portal
215    /// - name (null-terminated)
216    pub fn encode_describe(is_portal: bool, name: &str) -> BytesMut {
217        let mut buf = BytesMut::new();
218
219        // Message type 'D'
220        buf.extend_from_slice(b"D");
221
222        let mut content = Vec::new();
223
224        // Type: 'S' for statement, 'P' for portal
225        content.push(if is_portal { b'P' } else { b'S' });
226
227        // Name (null-terminated)
228        content.extend_from_slice(name.as_bytes());
229        content.push(0);
230
231        // Length
232        let len = (content.len() + 4) as i32;
233        buf.extend_from_slice(&len.to_be_bytes());
234        buf.extend_from_slice(&content);
235
236        buf
237    }
238
239    /// Encode a complete extended query pipeline (OPTIMIZED).
240    /// This combines Parse + Bind + Execute + Sync in a single buffer.
241    /// Zero intermediate allocations - writes directly to pre-sized BytesMut.
242    pub fn encode_extended_query(sql: &str, params: &[Option<Vec<u8>>]) -> Result<BytesMut, EncodeError> {
243        if params.len() > i16::MAX as usize {
244            return Err(EncodeError::TooManyParameters(params.len()));
245        }
246
247        // Calculate total size upfront to avoid reallocations
248        // Bind: 1 + 4 + 1 + 1 + 2 + 2 + params_data + 2 = 13 + params_data
249        // Execute: 1 + 4 + 1 + 4 = 10
250        // Sync: 5
251        let params_size: usize = params
252            .iter()
253            .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
254            .sum();
255        let total_size = 9 + sql.len() + 13 + params_size + 10 + 5;
256
257        let mut buf = BytesMut::with_capacity(total_size);
258
259        // ===== PARSE =====
260        buf.extend_from_slice(b"P");
261        let parse_len = (1 + sql.len() + 1 + 2 + 4) as i32; // name + sql + param_count
262        buf.extend_from_slice(&parse_len.to_be_bytes());
263        buf.extend_from_slice(&[0]); // Unnamed statement
264        buf.extend_from_slice(sql.as_bytes());
265        buf.extend_from_slice(&[0]); // Null terminator
266        buf.extend_from_slice(&0i16.to_be_bytes()); // No param types (infer)
267
268        // ===== BIND =====
269        buf.extend_from_slice(b"B");
270        let bind_len = (1 + 1 + 2 + 2 + params_size + 2 + 4) as i32;
271        buf.extend_from_slice(&bind_len.to_be_bytes());
272        buf.extend_from_slice(&[0]); // Unnamed portal
273        buf.extend_from_slice(&[0]); // Unnamed statement
274        buf.extend_from_slice(&0i16.to_be_bytes()); // Format codes (default text)
275        buf.extend_from_slice(&(params.len() as i16).to_be_bytes());
276        for param in params {
277            match param {
278                None => buf.extend_from_slice(&(-1i32).to_be_bytes()),
279                Some(data) => {
280                    if data.len() > i32::MAX as usize {
281                        return Err(EncodeError::MessageTooLarge(data.len()));
282                    }
283                    buf.extend_from_slice(&(data.len() as i32).to_be_bytes());
284                    buf.extend_from_slice(data);
285                }
286            }
287        }
288        buf.extend_from_slice(&0i16.to_be_bytes()); // Result format (default text)
289
290        // ===== EXECUTE =====
291        buf.extend_from_slice(b"E");
292        buf.extend_from_slice(&9i32.to_be_bytes()); // len = 4 + 1 + 4
293        buf.extend_from_slice(&[0]); // Unnamed portal
294        buf.extend_from_slice(&0i32.to_be_bytes()); // Unlimited rows
295
296        // ===== SYNC =====
297        buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
298
299        Ok(buf)
300    }
301
302    /// Encode a CopyFail message to abort a COPY IN with an error.
303    /// Wire format:
304    /// - 'f' (1 byte) - message type
305    /// - length (4 bytes)
306    /// - error message (null-terminated)
307    pub fn encode_copy_fail(reason: &str) -> BytesMut {
308        let mut buf = BytesMut::new();
309        buf.extend_from_slice(b"f");
310        let content_len = reason.len() + 1; // +1 for null terminator
311        let len = (content_len + 4) as i32;
312        buf.extend_from_slice(&len.to_be_bytes());
313        buf.extend_from_slice(reason.as_bytes());
314        buf.extend_from_slice(&[0]);
315        buf
316    }
317
318    /// Encode a Close message to release a prepared statement or portal.
319    /// Wire format:
320    /// - 'C' (1 byte) - message type
321    /// - length (4 bytes)
322    /// - 'S' for statement or 'P' for portal
323    /// - name (null-terminated)
324    pub fn encode_close(is_portal: bool, name: &str) -> BytesMut {
325        let mut buf = BytesMut::new();
326        buf.extend_from_slice(b"C");
327        let content_len = 1 + name.len() + 1; // type + name + null
328        let len = (content_len + 4) as i32;
329        buf.extend_from_slice(&len.to_be_bytes());
330        buf.extend_from_slice(&[if is_portal { b'P' } else { b'S' }]);
331        buf.extend_from_slice(name.as_bytes());
332        buf.extend_from_slice(&[0]);
333        buf
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    // NOTE: test_encode_simple_query removed - use AstEncoder instead
342    #[test]
343    fn test_encode_query_string() {
344        let sql = "SELECT 1";
345        let bytes = PgEncoder::encode_query_string(sql);
346
347        // Message type
348        assert_eq!(bytes[0], b'Q');
349
350        // Length: 4 (length field) + 8 (query) + 1 (null) = 13
351        let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
352        assert_eq!(len, 13);
353
354        // Query content
355        assert_eq!(&bytes[5..13], b"SELECT 1");
356
357        // Null terminator
358        assert_eq!(bytes[13], 0);
359    }
360
361    #[test]
362    fn test_encode_terminate() {
363        let bytes = PgEncoder::encode_terminate();
364        assert_eq!(bytes.as_ref(), &[b'X', 0, 0, 0, 4]);
365    }
366
367    #[test]
368    fn test_encode_sync() {
369        let bytes = PgEncoder::encode_sync();
370        assert_eq!(bytes.as_ref(), &[b'S', 0, 0, 0, 4]);
371    }
372
373    #[test]
374    fn test_encode_parse() {
375        let bytes = PgEncoder::encode_parse("", "SELECT $1", &[]);
376
377        // Message type 'P'
378        assert_eq!(bytes[0], b'P');
379
380        // Content should include query
381        let content = String::from_utf8_lossy(&bytes[5..]);
382        assert!(content.contains("SELECT $1"));
383    }
384
385    #[test]
386    fn test_encode_bind() {
387        let params = vec![
388            Some(b"42".to_vec()),
389            None, // NULL
390        ];
391        let bytes = PgEncoder::encode_bind("", "", &params).unwrap();
392
393        // Message type 'B'
394        assert_eq!(bytes[0], b'B');
395
396        // Should have proper length
397        let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
398        assert!(len > 4); // At least header
399    }
400
401    #[test]
402    fn test_encode_execute() {
403        let bytes = PgEncoder::encode_execute("", 0);
404
405        // Message type 'E'
406        assert_eq!(bytes[0], b'E');
407
408        // Length: 4 + 1 (null) + 4 (max_rows) = 9
409        let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
410        assert_eq!(len, 9);
411    }
412
413    #[test]
414    fn test_encode_extended_query() {
415        let params = vec![Some(b"hello".to_vec())];
416        let bytes = PgEncoder::encode_extended_query("SELECT $1", &params).unwrap();
417
418        // Should contain all 4 message types: P, B, E, S
419        assert!(bytes.windows(1).any(|w| w == [b'P']));
420        assert!(bytes.windows(1).any(|w| w == [b'B']));
421        assert!(bytes.windows(1).any(|w| w == [b'E']));
422        assert!(bytes.windows(1).any(|w| w == [b'S']));
423    }
424
425    #[test]
426    fn test_encode_copy_fail() {
427        let bytes = PgEncoder::encode_copy_fail("bad data");
428        assert_eq!(bytes[0], b'f');
429        let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
430        assert_eq!(len as usize, 4 + "bad data".len() + 1);
431        assert_eq!(&bytes[5..13], b"bad data");
432        assert_eq!(bytes[13], 0);
433    }
434
435    #[test]
436    fn test_encode_close_statement() {
437        let bytes = PgEncoder::encode_close(false, "my_stmt");
438        assert_eq!(bytes[0], b'C');
439        assert_eq!(bytes[5], b'S'); // Statement type
440        assert_eq!(&bytes[6..13], b"my_stmt");
441        assert_eq!(bytes[13], 0);
442    }
443
444    #[test]
445    fn test_encode_close_portal() {
446        let bytes = PgEncoder::encode_close(true, "");
447        assert_eq!(bytes[0], b'C');
448        assert_eq!(bytes[5], b'P'); // Portal type
449        assert_eq!(bytes[6], 0); // Empty name null terminator
450    }
451}
452
453// ==================== ULTRA-OPTIMIZED Hot Path Encoders ====================
454//
455// These encoders are designed to beat C:
456// - Direct integer writes (no temp arrays, no bounds checks)
457// - Borrowed slice params (zero-copy)
458// - Single store instructions via BufMut
459//
460
461use bytes::BufMut;
462
463/// Zero-copy parameter for ultra-fast encoding.
464/// Uses borrowed slices to avoid any allocation or copy.
465pub enum Param<'a> {
466    /// SQL NULL value.
467    Null,
468    /// Non-null parameter as a borrowed byte slice.
469    Bytes(&'a [u8]),
470}
471
472impl PgEncoder {
473    /// Direct i32 write - no temp array, no bounds check.
474    /// LLVM emits a single store instruction.
475    #[inline(always)]
476    fn put_i32_be(buf: &mut BytesMut, v: i32) {
477        buf.put_i32(v);
478    }
479
480    #[inline(always)]
481    fn put_i16_be(buf: &mut BytesMut, v: i16) {
482        buf.put_i16(v);
483    }
484
485    /// Encode Bind message - ULTRA OPTIMIZED.
486    /// - Direct integer writes (no temp arrays)
487    /// - Borrowed params (zero-copy)
488    /// - Single allocation check
489    #[inline]
490    pub fn encode_bind_ultra<'a>(buf: &mut BytesMut, statement: &str, params: &[Param<'a>]) -> Result<(), EncodeError> {
491        if params.len() > i16::MAX as usize {
492            return Err(EncodeError::TooManyParameters(params.len()));
493        }
494
495        // Calculate content length upfront
496        let params_size: usize = params
497            .iter()
498            .map(|p| match p {
499                Param::Null => 4,
500                Param::Bytes(b) => 4 + b.len(),
501            })
502            .sum();
503        let content_len = 1 + statement.len() + 1 + 2 + 2 + params_size + 2;
504
505        // Single reserve - no more allocations
506        buf.reserve(1 + 4 + content_len);
507
508        // Message type 'B'
509        buf.put_u8(b'B');
510
511        // Length (includes itself) - DIRECT WRITE
512        Self::put_i32_be(buf, (content_len + 4) as i32);
513
514        // Portal name (empty, null-terminated)
515        buf.put_u8(0);
516
517        // Statement name (null-terminated)
518        buf.extend_from_slice(statement.as_bytes());
519        buf.put_u8(0);
520
521        // Format codes count (0 = default text)
522        Self::put_i16_be(buf, 0);
523
524        // Parameter count
525        Self::put_i16_be(buf, params.len() as i16);
526
527        // Parameters - ZERO COPY from borrowed slices
528        for param in params {
529            match param {
530                Param::Null => Self::put_i32_be(buf, -1),
531                Param::Bytes(data) => {
532                    if data.len() > i32::MAX as usize {
533                        return Err(EncodeError::MessageTooLarge(data.len()));
534                    }
535                    Self::put_i32_be(buf, data.len() as i32);
536                    buf.extend_from_slice(data);
537                }
538            }
539        }
540
541        // Result format codes count (0 = default text)
542        Self::put_i16_be(buf, 0);
543        Ok(())
544    }
545
546    /// Encode Execute message - ULTRA OPTIMIZED.
547    #[inline(always)]
548    pub fn encode_execute_ultra(buf: &mut BytesMut) {
549        // Execute: 'E' + len(9) + portal("") + max_rows(0)
550        // = 'E' 00 00 00 09 00 00 00 00 00
551        buf.extend_from_slice(&[b'E', 0, 0, 0, 9, 0, 0, 0, 0, 0]);
552    }
553
554    /// Encode Sync message - ULTRA OPTIMIZED.
555    #[inline(always)]
556    pub fn encode_sync_ultra(buf: &mut BytesMut) {
557        buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
558    }
559
560    // Keep the original methods for compatibility
561
562    /// Encode Bind message directly into existing buffer (ZERO ALLOCATION).
563    /// This is the hot path optimization - no intermediate Vec allocation.
564    #[inline]
565    pub fn encode_bind_to(buf: &mut BytesMut, statement: &str, params: &[Option<Vec<u8>>]) -> Result<(), EncodeError> {
566        if params.len() > i16::MAX as usize {
567            return Err(EncodeError::TooManyParameters(params.len()));
568        }
569
570        // Calculate content length upfront
571        // portal(1) + statement(len+1) + format_codes(2) + param_count(2) + params_data + result_format(2)
572        let params_size: usize = params
573            .iter()
574            .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
575            .sum();
576        let content_len = 1 + statement.len() + 1 + 2 + 2 + params_size + 2;
577
578        buf.reserve(1 + 4 + content_len);
579
580        // Message type 'B'
581        buf.put_u8(b'B');
582
583        // Length (includes itself) - DIRECT WRITE
584        Self::put_i32_be(buf, (content_len + 4) as i32);
585
586        // Portal name (empty, null-terminated)
587        buf.put_u8(0);
588
589        // Statement name (null-terminated)
590        buf.extend_from_slice(statement.as_bytes());
591        buf.put_u8(0);
592
593        // Format codes count (0 = default text)
594        Self::put_i16_be(buf, 0);
595
596        // Parameter count
597        Self::put_i16_be(buf, params.len() as i16);
598
599        // Parameters
600        for param in params {
601            match param {
602                None => Self::put_i32_be(buf, -1),
603                Some(data) => {
604                    if data.len() > i32::MAX as usize {
605                        return Err(EncodeError::MessageTooLarge(data.len()));
606                    }
607                    Self::put_i32_be(buf, data.len() as i32);
608                    buf.extend_from_slice(data);
609                }
610            }
611        }
612
613        // Result format codes count (0 = default text)
614        Self::put_i16_be(buf, 0);
615        Ok(())
616    }
617
618    /// Encode Execute message directly into existing buffer (ZERO ALLOCATION).
619    #[inline]
620    pub fn encode_execute_to(buf: &mut BytesMut) {
621        // Content: portal(1) + max_rows(4) = 5 bytes
622        buf.extend_from_slice(&[b'E', 0, 0, 0, 9, 0, 0, 0, 0, 0]);
623    }
624
625    /// Encode Sync message directly into existing buffer (ZERO ALLOCATION).
626    #[inline]
627    pub fn encode_sync_to(buf: &mut BytesMut) {
628        buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
629    }
630}