Skip to main content

qail_pg/protocol/
encoder.rs

1//! PostgreSQL Encoder (Visitor Pattern)
2//!
3//! Compiles Qail AST into PostgreSQL wire protocol bytes.
4//! This is pure, synchronous computation - no I/O, no async.
5//!
6//! # Architecture
7//!
8//! Layer 2 of the QAIL architecture:
9//! - Input: Qail (AST)
10//! - Output: BytesMut (ready to send over the wire)
11//!
12//! The async I/O layer (Layer 3) consumes these bytes.
13
14use bytes::BytesMut;
15use super::EncodeError;
16
17/// Takes a Qail and produces wire protocol bytes.
18/// This is the "Visitor" in the visitor pattern.
19pub struct PgEncoder;
20
21impl PgEncoder {
22    /// Encode a raw SQL string as a Simple Query message.
23    /// Wire format:
24    /// - 'Q' (1 byte) - message type
25    /// - length (4 bytes, big-endian, includes self)
26    /// - query string (null-terminated)
27    pub fn encode_query_string(sql: &str) -> BytesMut {
28        let mut buf = BytesMut::new();
29
30        // Bounds check: SQL + null terminator + 4 bytes length must fit in i32
31        let content_len = sql.len() + 1; // +1 for null terminator
32        if content_len > (i32::MAX as usize) - 4 {
33            // Return empty buffer — write will fail safely rather than
34            // producing a malformed message with overflowed length.
35            return buf;
36        }
37
38        // Message type 'Q' for Query
39        buf.extend_from_slice(b"Q");
40
41        let total_len = (content_len + 4) as i32; // +4 for length field itself
42
43        // Length (4 bytes, big-endian)
44        buf.extend_from_slice(&total_len.to_be_bytes());
45
46        // Query string
47        buf.extend_from_slice(sql.as_bytes());
48
49        // Null terminator
50        buf.extend_from_slice(&[0]);
51
52        buf
53    }
54
55    /// Encode a Terminate message to close the connection.
56    pub fn encode_terminate() -> BytesMut {
57        let mut buf = BytesMut::new();
58        buf.extend_from_slice(&[b'X', 0, 0, 0, 4]);
59        buf
60    }
61
62    /// Encode a Sync message (end of pipeline in extended query protocol).
63    pub fn encode_sync() -> BytesMut {
64        let mut buf = BytesMut::new();
65        buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
66        buf
67    }
68
69    // ==================== Extended Query Protocol ====================
70
71    /// Encode a Parse message (prepare a statement).
72    /// Wire format:
73    /// - 'P' (1 byte) - message type
74    /// - length (4 bytes)
75    /// - statement name (null-terminated, "" for unnamed)
76    /// - query string (null-terminated)
77    /// - parameter count (2 bytes)
78    /// - parameter OIDs (4 bytes each, 0 = infer type)
79    pub fn encode_parse(name: &str, sql: &str, param_types: &[u32]) -> BytesMut {
80        let mut buf = BytesMut::new();
81
82        // Message type 'P'
83        buf.extend_from_slice(b"P");
84
85        let mut content = Vec::new();
86
87        // Statement name (null-terminated)
88        content.extend_from_slice(name.as_bytes());
89        content.push(0);
90
91        // Query string (null-terminated)
92        content.extend_from_slice(sql.as_bytes());
93        content.push(0);
94
95        // Parameter count
96        content.extend_from_slice(&(param_types.len() as i16).to_be_bytes());
97
98        // Parameter OIDs
99        for &oid in param_types {
100            content.extend_from_slice(&oid.to_be_bytes());
101        }
102
103        // Length (includes length field itself)
104        let len = (content.len() + 4) as i32;
105        buf.extend_from_slice(&len.to_be_bytes());
106        buf.extend_from_slice(&content);
107
108        buf
109    }
110
111    /// Encode a Bind message (bind parameters to a prepared statement).
112    /// Wire format:
113    /// - 'B' (1 byte) - message type
114    /// - length (4 bytes)
115    /// - portal name (null-terminated)
116    /// - statement name (null-terminated)
117    /// - format code count (2 bytes) - we use 0 (all text)
118    /// - parameter count (2 bytes)
119    /// - for each parameter: length (4 bytes, -1 for NULL), data
120    /// - result format count (2 bytes) - we use 0 (all text)
121    pub fn encode_bind(portal: &str, statement: &str, params: &[Option<Vec<u8>>]) -> Result<BytesMut, EncodeError> {
122        if params.len() > i16::MAX as usize {
123            return Err(EncodeError::TooManyParameters(params.len()));
124        }
125
126        let mut buf = BytesMut::new();
127
128        // Message type 'B'
129        buf.extend_from_slice(b"B");
130
131        let mut content = Vec::new();
132
133        // Portal name (null-terminated)
134        content.extend_from_slice(portal.as_bytes());
135        content.push(0);
136
137        // Statement name (null-terminated)
138        content.extend_from_slice(statement.as_bytes());
139        content.push(0);
140
141        // Format codes count (0 = use default text format)
142        content.extend_from_slice(&0i16.to_be_bytes());
143
144        // Parameter count
145        content.extend_from_slice(&(params.len() as i16).to_be_bytes());
146
147        // Parameters
148        for param in params {
149            match param {
150                None => {
151                    // NULL: length = -1
152                    content.extend_from_slice(&(-1i32).to_be_bytes());
153                }
154                Some(data) => {
155                    if data.len() > i32::MAX as usize {
156                        return Err(EncodeError::MessageTooLarge(data.len()));
157                    }
158                    content.extend_from_slice(&(data.len() as i32).to_be_bytes());
159                    content.extend_from_slice(data);
160                }
161            }
162        }
163
164        // Result format codes count (0 = use default text format)
165        content.extend_from_slice(&0i16.to_be_bytes());
166
167        // Length
168        let len = (content.len() + 4) as i32;
169        buf.extend_from_slice(&len.to_be_bytes());
170        buf.extend_from_slice(&content);
171
172        Ok(buf)
173    }
174
175    /// Encode an Execute message (execute a bound portal).
176    /// Wire format:
177    /// - 'E' (1 byte) - message type
178    /// - length (4 bytes)
179    /// - portal name (null-terminated)
180    /// - max rows (4 bytes, 0 = unlimited)
181    pub fn encode_execute(portal: &str, max_rows: i32) -> BytesMut {
182        let mut buf = BytesMut::new();
183
184        // Message type 'E'
185        buf.extend_from_slice(b"E");
186
187        let mut content = Vec::new();
188
189        // Portal name (null-terminated)
190        content.extend_from_slice(portal.as_bytes());
191        content.push(0);
192
193        // Max rows
194        content.extend_from_slice(&max_rows.to_be_bytes());
195
196        // Length
197        let len = (content.len() + 4) as i32;
198        buf.extend_from_slice(&len.to_be_bytes());
199        buf.extend_from_slice(&content);
200
201        buf
202    }
203
204    /// Encode a Describe message (get statement/portal metadata).
205    /// Wire format:
206    /// - 'D' (1 byte) - message type
207    /// - length (4 bytes)
208    /// - 'S' for statement or 'P' for portal
209    /// - name (null-terminated)
210    pub fn encode_describe(is_portal: bool, name: &str) -> BytesMut {
211        let mut buf = BytesMut::new();
212
213        // Message type 'D'
214        buf.extend_from_slice(b"D");
215
216        let mut content = Vec::new();
217
218        // Type: 'S' for statement, 'P' for portal
219        content.push(if is_portal { b'P' } else { b'S' });
220
221        // Name (null-terminated)
222        content.extend_from_slice(name.as_bytes());
223        content.push(0);
224
225        // Length
226        let len = (content.len() + 4) as i32;
227        buf.extend_from_slice(&len.to_be_bytes());
228        buf.extend_from_slice(&content);
229
230        buf
231    }
232
233    /// Encode a complete extended query pipeline (OPTIMIZED).
234    /// This combines Parse + Bind + Execute + Sync in a single buffer.
235    /// Zero intermediate allocations - writes directly to pre-sized BytesMut.
236    pub fn encode_extended_query(sql: &str, params: &[Option<Vec<u8>>]) -> Result<BytesMut, EncodeError> {
237        if params.len() > i16::MAX as usize {
238            return Err(EncodeError::TooManyParameters(params.len()));
239        }
240
241        // Calculate total size upfront to avoid reallocations
242        // Bind: 1 + 4 + 1 + 1 + 2 + 2 + params_data + 2 = 13 + params_data
243        // Execute: 1 + 4 + 1 + 4 = 10
244        // Sync: 5
245        let params_size: usize = params
246            .iter()
247            .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
248            .sum();
249        let total_size = 9 + sql.len() + 13 + params_size + 10 + 5;
250
251        let mut buf = BytesMut::with_capacity(total_size);
252
253        // ===== PARSE =====
254        buf.extend_from_slice(b"P");
255        let parse_len = (1 + sql.len() + 1 + 2 + 4) as i32; // name + sql + param_count
256        buf.extend_from_slice(&parse_len.to_be_bytes());
257        buf.extend_from_slice(&[0]); // Unnamed statement
258        buf.extend_from_slice(sql.as_bytes());
259        buf.extend_from_slice(&[0]); // Null terminator
260        buf.extend_from_slice(&0i16.to_be_bytes()); // No param types (infer)
261
262        // ===== BIND =====
263        buf.extend_from_slice(b"B");
264        let bind_len = (1 + 1 + 2 + 2 + params_size + 2 + 4) as i32;
265        buf.extend_from_slice(&bind_len.to_be_bytes());
266        buf.extend_from_slice(&[0]); // Unnamed portal
267        buf.extend_from_slice(&[0]); // Unnamed statement
268        buf.extend_from_slice(&0i16.to_be_bytes()); // Format codes (default text)
269        buf.extend_from_slice(&(params.len() as i16).to_be_bytes());
270        for param in params {
271            match param {
272                None => buf.extend_from_slice(&(-1i32).to_be_bytes()),
273                Some(data) => {
274                    if data.len() > i32::MAX as usize {
275                        return Err(EncodeError::MessageTooLarge(data.len()));
276                    }
277                    buf.extend_from_slice(&(data.len() as i32).to_be_bytes());
278                    buf.extend_from_slice(data);
279                }
280            }
281        }
282        buf.extend_from_slice(&0i16.to_be_bytes()); // Result format (default text)
283
284        // ===== EXECUTE =====
285        buf.extend_from_slice(b"E");
286        buf.extend_from_slice(&9i32.to_be_bytes()); // len = 4 + 1 + 4
287        buf.extend_from_slice(&[0]); // Unnamed portal
288        buf.extend_from_slice(&0i32.to_be_bytes()); // Unlimited rows
289
290        // ===== SYNC =====
291        buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
292
293        Ok(buf)
294    }
295
296    /// Encode a CopyFail message to abort a COPY IN with an error.
297    /// Wire format:
298    /// - 'f' (1 byte) - message type
299    /// - length (4 bytes)
300    /// - error message (null-terminated)
301    pub fn encode_copy_fail(reason: &str) -> BytesMut {
302        let mut buf = BytesMut::new();
303        buf.extend_from_slice(b"f");
304        let content_len = reason.len() + 1; // +1 for null terminator
305        let len = (content_len + 4) as i32;
306        buf.extend_from_slice(&len.to_be_bytes());
307        buf.extend_from_slice(reason.as_bytes());
308        buf.extend_from_slice(&[0]);
309        buf
310    }
311
312    /// Encode a Close message to release a prepared statement or portal.
313    /// Wire format:
314    /// - 'C' (1 byte) - message type
315    /// - length (4 bytes)
316    /// - 'S' for statement or 'P' for portal
317    /// - name (null-terminated)
318    pub fn encode_close(is_portal: bool, name: &str) -> BytesMut {
319        let mut buf = BytesMut::new();
320        buf.extend_from_slice(b"C");
321        let content_len = 1 + name.len() + 1; // type + name + null
322        let len = (content_len + 4) as i32;
323        buf.extend_from_slice(&len.to_be_bytes());
324        buf.extend_from_slice(&[if is_portal { b'P' } else { b'S' }]);
325        buf.extend_from_slice(name.as_bytes());
326        buf.extend_from_slice(&[0]);
327        buf
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    // NOTE: test_encode_simple_query removed - use AstEncoder instead
336    #[test]
337    fn test_encode_query_string() {
338        let sql = "SELECT 1";
339        let bytes = PgEncoder::encode_query_string(sql);
340
341        // Message type
342        assert_eq!(bytes[0], b'Q');
343
344        // Length: 4 (length field) + 8 (query) + 1 (null) = 13
345        let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
346        assert_eq!(len, 13);
347
348        // Query content
349        assert_eq!(&bytes[5..13], b"SELECT 1");
350
351        // Null terminator
352        assert_eq!(bytes[13], 0);
353    }
354
355    #[test]
356    fn test_encode_terminate() {
357        let bytes = PgEncoder::encode_terminate();
358        assert_eq!(bytes.as_ref(), &[b'X', 0, 0, 0, 4]);
359    }
360
361    #[test]
362    fn test_encode_sync() {
363        let bytes = PgEncoder::encode_sync();
364        assert_eq!(bytes.as_ref(), &[b'S', 0, 0, 0, 4]);
365    }
366
367    #[test]
368    fn test_encode_parse() {
369        let bytes = PgEncoder::encode_parse("", "SELECT $1", &[]);
370
371        // Message type 'P'
372        assert_eq!(bytes[0], b'P');
373
374        // Content should include query
375        let content = String::from_utf8_lossy(&bytes[5..]);
376        assert!(content.contains("SELECT $1"));
377    }
378
379    #[test]
380    fn test_encode_bind() {
381        let params = vec![
382            Some(b"42".to_vec()),
383            None, // NULL
384        ];
385        let bytes = PgEncoder::encode_bind("", "", &params).unwrap();
386
387        // Message type 'B'
388        assert_eq!(bytes[0], b'B');
389
390        // Should have proper length
391        let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
392        assert!(len > 4); // At least header
393    }
394
395    #[test]
396    fn test_encode_execute() {
397        let bytes = PgEncoder::encode_execute("", 0);
398
399        // Message type 'E'
400        assert_eq!(bytes[0], b'E');
401
402        // Length: 4 + 1 (null) + 4 (max_rows) = 9
403        let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
404        assert_eq!(len, 9);
405    }
406
407    #[test]
408    fn test_encode_extended_query() {
409        let params = vec![Some(b"hello".to_vec())];
410        let bytes = PgEncoder::encode_extended_query("SELECT $1", &params).unwrap();
411
412        // Should contain all 4 message types: P, B, E, S
413        assert!(bytes.windows(1).any(|w| w == [b'P']));
414        assert!(bytes.windows(1).any(|w| w == [b'B']));
415        assert!(bytes.windows(1).any(|w| w == [b'E']));
416        assert!(bytes.windows(1).any(|w| w == [b'S']));
417    }
418
419    #[test]
420    fn test_encode_copy_fail() {
421        let bytes = PgEncoder::encode_copy_fail("bad data");
422        assert_eq!(bytes[0], b'f');
423        let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
424        assert_eq!(len as usize, 4 + "bad data".len() + 1);
425        assert_eq!(&bytes[5..13], b"bad data");
426        assert_eq!(bytes[13], 0);
427    }
428
429    #[test]
430    fn test_encode_close_statement() {
431        let bytes = PgEncoder::encode_close(false, "my_stmt");
432        assert_eq!(bytes[0], b'C');
433        assert_eq!(bytes[5], b'S'); // Statement type
434        assert_eq!(&bytes[6..13], b"my_stmt");
435        assert_eq!(bytes[13], 0);
436    }
437
438    #[test]
439    fn test_encode_close_portal() {
440        let bytes = PgEncoder::encode_close(true, "");
441        assert_eq!(bytes[0], b'C');
442        assert_eq!(bytes[5], b'P'); // Portal type
443        assert_eq!(bytes[6], 0); // Empty name null terminator
444    }
445}
446
447// ==================== ULTRA-OPTIMIZED Hot Path Encoders ====================
448//
449// These encoders are designed to beat C:
450// - Direct integer writes (no temp arrays, no bounds checks)
451// - Borrowed slice params (zero-copy)
452// - Single store instructions via BufMut
453//
454
455use bytes::BufMut;
456
457/// Zero-copy parameter for ultra-fast encoding.
458/// Uses borrowed slices to avoid any allocation or copy.
459pub enum Param<'a> {
460    Null,
461    Bytes(&'a [u8]),
462}
463
464impl PgEncoder {
465    /// Direct i32 write - no temp array, no bounds check.
466    /// LLVM emits a single store instruction.
467    #[inline(always)]
468    fn put_i32_be(buf: &mut BytesMut, v: i32) {
469        buf.put_i32(v);
470    }
471
472    #[inline(always)]
473    fn put_i16_be(buf: &mut BytesMut, v: i16) {
474        buf.put_i16(v);
475    }
476
477    /// Encode Bind message - ULTRA OPTIMIZED.
478    /// - Direct integer writes (no temp arrays)
479    /// - Borrowed params (zero-copy)
480    /// - Single allocation check
481    #[inline]
482    pub fn encode_bind_ultra<'a>(buf: &mut BytesMut, statement: &str, params: &[Param<'a>]) -> Result<(), EncodeError> {
483        if params.len() > i16::MAX as usize {
484            return Err(EncodeError::TooManyParameters(params.len()));
485        }
486
487        // Calculate content length upfront
488        let params_size: usize = params
489            .iter()
490            .map(|p| match p {
491                Param::Null => 4,
492                Param::Bytes(b) => 4 + b.len(),
493            })
494            .sum();
495        let content_len = 1 + statement.len() + 1 + 2 + 2 + params_size + 2;
496
497        // Single reserve - no more allocations
498        buf.reserve(1 + 4 + content_len);
499
500        // Message type 'B'
501        buf.put_u8(b'B');
502
503        // Length (includes itself) - DIRECT WRITE
504        Self::put_i32_be(buf, (content_len + 4) as i32);
505
506        // Portal name (empty, null-terminated)
507        buf.put_u8(0);
508
509        // Statement name (null-terminated)
510        buf.extend_from_slice(statement.as_bytes());
511        buf.put_u8(0);
512
513        // Format codes count (0 = default text)
514        Self::put_i16_be(buf, 0);
515
516        // Parameter count
517        Self::put_i16_be(buf, params.len() as i16);
518
519        // Parameters - ZERO COPY from borrowed slices
520        for param in params {
521            match param {
522                Param::Null => Self::put_i32_be(buf, -1),
523                Param::Bytes(data) => {
524                    if data.len() > i32::MAX as usize {
525                        return Err(EncodeError::MessageTooLarge(data.len()));
526                    }
527                    Self::put_i32_be(buf, data.len() as i32);
528                    buf.extend_from_slice(data);
529                }
530            }
531        }
532
533        // Result format codes count (0 = default text)
534        Self::put_i16_be(buf, 0);
535        Ok(())
536    }
537
538    /// Encode Execute message - ULTRA OPTIMIZED.
539    #[inline(always)]
540    pub fn encode_execute_ultra(buf: &mut BytesMut) {
541        // Execute: 'E' + len(9) + portal("") + max_rows(0)
542        // = 'E' 00 00 00 09 00 00 00 00 00
543        buf.extend_from_slice(&[b'E', 0, 0, 0, 9, 0, 0, 0, 0, 0]);
544    }
545
546    /// Encode Sync message - ULTRA OPTIMIZED.
547    #[inline(always)]
548    pub fn encode_sync_ultra(buf: &mut BytesMut) {
549        buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
550    }
551
552    // Keep the original methods for compatibility
553
554    /// Encode Bind message directly into existing buffer (ZERO ALLOCATION).
555    /// This is the hot path optimization - no intermediate Vec allocation.
556    #[inline]
557    pub fn encode_bind_to(buf: &mut BytesMut, statement: &str, params: &[Option<Vec<u8>>]) -> Result<(), EncodeError> {
558        if params.len() > i16::MAX as usize {
559            return Err(EncodeError::TooManyParameters(params.len()));
560        }
561
562        // Calculate content length upfront
563        // portal(1) + statement(len+1) + format_codes(2) + param_count(2) + params_data + result_format(2)
564        let params_size: usize = params
565            .iter()
566            .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
567            .sum();
568        let content_len = 1 + statement.len() + 1 + 2 + 2 + params_size + 2;
569
570        buf.reserve(1 + 4 + content_len);
571
572        // Message type 'B'
573        buf.put_u8(b'B');
574
575        // Length (includes itself) - DIRECT WRITE
576        Self::put_i32_be(buf, (content_len + 4) as i32);
577
578        // Portal name (empty, null-terminated)
579        buf.put_u8(0);
580
581        // Statement name (null-terminated)
582        buf.extend_from_slice(statement.as_bytes());
583        buf.put_u8(0);
584
585        // Format codes count (0 = default text)
586        Self::put_i16_be(buf, 0);
587
588        // Parameter count
589        Self::put_i16_be(buf, params.len() as i16);
590
591        // Parameters
592        for param in params {
593            match param {
594                None => Self::put_i32_be(buf, -1),
595                Some(data) => {
596                    if data.len() > i32::MAX as usize {
597                        return Err(EncodeError::MessageTooLarge(data.len()));
598                    }
599                    Self::put_i32_be(buf, data.len() as i32);
600                    buf.extend_from_slice(data);
601                }
602            }
603        }
604
605        // Result format codes count (0 = default text)
606        Self::put_i16_be(buf, 0);
607        Ok(())
608    }
609
610    /// Encode Execute message directly into existing buffer (ZERO ALLOCATION).
611    #[inline]
612    pub fn encode_execute_to(buf: &mut BytesMut) {
613        // Content: portal(1) + max_rows(4) = 5 bytes
614        buf.extend_from_slice(&[b'E', 0, 0, 0, 9, 0, 0, 0, 0, 0]);
615    }
616
617    /// Encode Sync message directly into existing buffer (ZERO ALLOCATION).
618    #[inline]
619    pub fn encode_sync_to(buf: &mut BytesMut) {
620        buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
621    }
622}