qail_pg/protocol/
encoder.rs

1//! PostgreSQL Encoder (Visitor Pattern)
2//!
3//! Compiles Qail AST into PostgreSQL wire protocol bytes.
4//! This is pure, synchronous computation - no I/O, no async.
5//!
6//! # Architecture
7//!
8//! Layer 2 of the QAIL architecture:
9//! - Input: Qail (AST)
10//! - Output: BytesMut (ready to send over the wire)
11//!
12//! The async I/O layer (Layer 3) consumes these bytes.
13
14use bytes::BytesMut;
15use super::EncodeError;
16
17/// Takes a Qail and produces wire protocol bytes.
18/// This is the "Visitor" in the visitor pattern.
19pub struct PgEncoder;
20
21impl PgEncoder {
22    /// Encode a raw SQL string as a Simple Query message.
23    /// Wire format:
24    /// - 'Q' (1 byte) - message type
25    /// - length (4 bytes, big-endian, includes self)
26    /// - query string (null-terminated)
27    pub fn encode_query_string(sql: &str) -> BytesMut {
28        let mut buf = BytesMut::new();
29
30        // Message type 'Q' for Query
31        buf.extend_from_slice(b"Q");
32
33        // Content: query string + null terminator
34        let content_len = sql.len() + 1; // +1 for null terminator
35        let total_len = (content_len + 4) as i32; // +4 for length field itself
36
37        // Length (4 bytes, big-endian)
38        buf.extend_from_slice(&total_len.to_be_bytes());
39
40        // Query string
41        buf.extend_from_slice(sql.as_bytes());
42
43        // Null terminator
44        buf.extend_from_slice(&[0]);
45
46        buf
47    }
48
49    /// Encode a Terminate message to close the connection.
50    pub fn encode_terminate() -> BytesMut {
51        let mut buf = BytesMut::new();
52        buf.extend_from_slice(&[b'X', 0, 0, 0, 4]);
53        buf
54    }
55
56    /// Encode a Sync message (end of pipeline in extended query protocol).
57    pub fn encode_sync() -> BytesMut {
58        let mut buf = BytesMut::new();
59        buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
60        buf
61    }
62
63    // ==================== Extended Query Protocol ====================
64
65    /// Encode a Parse message (prepare a statement).
66    /// Wire format:
67    /// - 'P' (1 byte) - message type
68    /// - length (4 bytes)
69    /// - statement name (null-terminated, "" for unnamed)
70    /// - query string (null-terminated)
71    /// - parameter count (2 bytes)
72    /// - parameter OIDs (4 bytes each, 0 = infer type)
73    pub fn encode_parse(name: &str, sql: &str, param_types: &[u32]) -> BytesMut {
74        let mut buf = BytesMut::new();
75
76        // Message type 'P'
77        buf.extend_from_slice(b"P");
78
79        let mut content = Vec::new();
80
81        // Statement name (null-terminated)
82        content.extend_from_slice(name.as_bytes());
83        content.push(0);
84
85        // Query string (null-terminated)
86        content.extend_from_slice(sql.as_bytes());
87        content.push(0);
88
89        // Parameter count
90        content.extend_from_slice(&(param_types.len() as i16).to_be_bytes());
91
92        // Parameter OIDs
93        for &oid in param_types {
94            content.extend_from_slice(&oid.to_be_bytes());
95        }
96
97        // Length (includes length field itself)
98        let len = (content.len() + 4) as i32;
99        buf.extend_from_slice(&len.to_be_bytes());
100        buf.extend_from_slice(&content);
101
102        buf
103    }
104
105    /// Encode a Bind message (bind parameters to a prepared statement).
106    /// Wire format:
107    /// - 'B' (1 byte) - message type
108    /// - length (4 bytes)
109    /// - portal name (null-terminated)
110    /// - statement name (null-terminated)
111    /// - format code count (2 bytes) - we use 0 (all text)
112    /// - parameter count (2 bytes)
113    /// - for each parameter: length (4 bytes, -1 for NULL), data
114    /// - result format count (2 bytes) - we use 0 (all text)
115    pub fn encode_bind(portal: &str, statement: &str, params: &[Option<Vec<u8>>]) -> Result<BytesMut, EncodeError> {
116        if params.len() > i16::MAX as usize {
117            return Err(EncodeError::TooManyParameters(params.len()));
118        }
119
120        let mut buf = BytesMut::new();
121
122        // Message type 'B'
123        buf.extend_from_slice(b"B");
124
125        let mut content = Vec::new();
126
127        // Portal name (null-terminated)
128        content.extend_from_slice(portal.as_bytes());
129        content.push(0);
130
131        // Statement name (null-terminated)
132        content.extend_from_slice(statement.as_bytes());
133        content.push(0);
134
135        // Format codes count (0 = use default text format)
136        content.extend_from_slice(&0i16.to_be_bytes());
137
138        // Parameter count
139        content.extend_from_slice(&(params.len() as i16).to_be_bytes());
140
141        // Parameters
142        for param in params {
143            match param {
144                None => {
145                    // NULL: length = -1
146                    content.extend_from_slice(&(-1i32).to_be_bytes());
147                }
148                Some(data) => {
149                    content.extend_from_slice(&(data.len() as i32).to_be_bytes());
150                    content.extend_from_slice(data);
151                }
152            }
153        }
154
155        // Result format codes count (0 = use default text format)
156        content.extend_from_slice(&0i16.to_be_bytes());
157
158        // Length
159        let len = (content.len() + 4) as i32;
160        buf.extend_from_slice(&len.to_be_bytes());
161        buf.extend_from_slice(&content);
162
163        Ok(buf)
164    }
165
166    /// Encode an Execute message (execute a bound portal).
167    /// Wire format:
168    /// - 'E' (1 byte) - message type
169    /// - length (4 bytes)
170    /// - portal name (null-terminated)
171    /// - max rows (4 bytes, 0 = unlimited)
172    pub fn encode_execute(portal: &str, max_rows: i32) -> BytesMut {
173        let mut buf = BytesMut::new();
174
175        // Message type 'E'
176        buf.extend_from_slice(b"E");
177
178        let mut content = Vec::new();
179
180        // Portal name (null-terminated)
181        content.extend_from_slice(portal.as_bytes());
182        content.push(0);
183
184        // Max rows
185        content.extend_from_slice(&max_rows.to_be_bytes());
186
187        // Length
188        let len = (content.len() + 4) as i32;
189        buf.extend_from_slice(&len.to_be_bytes());
190        buf.extend_from_slice(&content);
191
192        buf
193    }
194
195    /// Encode a Describe message (get statement/portal metadata).
196    /// Wire format:
197    /// - 'D' (1 byte) - message type
198    /// - length (4 bytes)
199    /// - 'S' for statement or 'P' for portal
200    /// - name (null-terminated)
201    pub fn encode_describe(is_portal: bool, name: &str) -> BytesMut {
202        let mut buf = BytesMut::new();
203
204        // Message type 'D'
205        buf.extend_from_slice(b"D");
206
207        let mut content = Vec::new();
208
209        // Type: 'S' for statement, 'P' for portal
210        content.push(if is_portal { b'P' } else { b'S' });
211
212        // Name (null-terminated)
213        content.extend_from_slice(name.as_bytes());
214        content.push(0);
215
216        // Length
217        let len = (content.len() + 4) as i32;
218        buf.extend_from_slice(&len.to_be_bytes());
219        buf.extend_from_slice(&content);
220
221        buf
222    }
223
224    /// Encode a complete extended query pipeline (OPTIMIZED).
225    /// This combines Parse + Bind + Execute + Sync in a single buffer.
226    /// Zero intermediate allocations - writes directly to pre-sized BytesMut.
227    pub fn encode_extended_query(sql: &str, params: &[Option<Vec<u8>>]) -> Result<BytesMut, EncodeError> {
228        if params.len() > i16::MAX as usize {
229            return Err(EncodeError::TooManyParameters(params.len()));
230        }
231
232        // Calculate total size upfront to avoid reallocations
233        // Bind: 1 + 4 + 1 + 1 + 2 + 2 + params_data + 2 = 13 + params_data
234        // Execute: 1 + 4 + 1 + 4 = 10
235        // Sync: 5
236        let params_size: usize = params
237            .iter()
238            .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
239            .sum();
240        let total_size = 9 + sql.len() + 13 + params_size + 10 + 5;
241
242        let mut buf = BytesMut::with_capacity(total_size);
243
244        // ===== PARSE =====
245        buf.extend_from_slice(b"P");
246        let parse_len = (1 + sql.len() + 1 + 2 + 4) as i32; // name + sql + param_count
247        buf.extend_from_slice(&parse_len.to_be_bytes());
248        buf.extend_from_slice(&[0]); // Unnamed statement
249        buf.extend_from_slice(sql.as_bytes());
250        buf.extend_from_slice(&[0]); // Null terminator
251        buf.extend_from_slice(&0i16.to_be_bytes()); // No param types (infer)
252
253        // ===== BIND =====
254        buf.extend_from_slice(b"B");
255        let bind_len = (1 + 1 + 2 + 2 + params_size + 2 + 4) as i32;
256        buf.extend_from_slice(&bind_len.to_be_bytes());
257        buf.extend_from_slice(&[0]); // Unnamed portal
258        buf.extend_from_slice(&[0]); // Unnamed statement
259        buf.extend_from_slice(&0i16.to_be_bytes()); // Format codes (default text)
260        buf.extend_from_slice(&(params.len() as i16).to_be_bytes());
261        for param in params {
262            match param {
263                None => buf.extend_from_slice(&(-1i32).to_be_bytes()),
264                Some(data) => {
265                    buf.extend_from_slice(&(data.len() as i32).to_be_bytes());
266                    buf.extend_from_slice(data);
267                }
268            }
269        }
270        buf.extend_from_slice(&0i16.to_be_bytes()); // Result format (default text)
271
272        // ===== EXECUTE =====
273        buf.extend_from_slice(b"E");
274        buf.extend_from_slice(&9i32.to_be_bytes()); // len = 4 + 1 + 4
275        buf.extend_from_slice(&[0]); // Unnamed portal
276        buf.extend_from_slice(&0i32.to_be_bytes()); // Unlimited rows
277
278        // ===== SYNC =====
279        buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
280
281        Ok(buf)
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    // NOTE: test_encode_simple_query removed - use AstEncoder instead
290    #[test]
291    fn test_encode_query_string() {
292        let sql = "SELECT 1";
293        let bytes = PgEncoder::encode_query_string(sql);
294
295        // Message type
296        assert_eq!(bytes[0], b'Q');
297
298        // Length: 4 (length field) + 8 (query) + 1 (null) = 13
299        let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
300        assert_eq!(len, 13);
301
302        // Query content
303        assert_eq!(&bytes[5..13], b"SELECT 1");
304
305        // Null terminator
306        assert_eq!(bytes[13], 0);
307    }
308
309    #[test]
310    fn test_encode_terminate() {
311        let bytes = PgEncoder::encode_terminate();
312        assert_eq!(bytes.as_ref(), &[b'X', 0, 0, 0, 4]);
313    }
314
315    #[test]
316    fn test_encode_sync() {
317        let bytes = PgEncoder::encode_sync();
318        assert_eq!(bytes.as_ref(), &[b'S', 0, 0, 0, 4]);
319    }
320
321    #[test]
322    fn test_encode_parse() {
323        let bytes = PgEncoder::encode_parse("", "SELECT $1", &[]);
324
325        // Message type 'P'
326        assert_eq!(bytes[0], b'P');
327
328        // Content should include query
329        let content = String::from_utf8_lossy(&bytes[5..]);
330        assert!(content.contains("SELECT $1"));
331    }
332
333    #[test]
334    fn test_encode_bind() {
335        let params = vec![
336            Some(b"42".to_vec()),
337            None, // NULL
338        ];
339        let bytes = PgEncoder::encode_bind("", "", &params).unwrap();
340
341        // Message type 'B'
342        assert_eq!(bytes[0], b'B');
343
344        // Should have proper length
345        let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
346        assert!(len > 4); // At least header
347    }
348
349    #[test]
350    fn test_encode_execute() {
351        let bytes = PgEncoder::encode_execute("", 0);
352
353        // Message type 'E'
354        assert_eq!(bytes[0], b'E');
355
356        // Length: 4 + 1 (null) + 4 (max_rows) = 9
357        let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
358        assert_eq!(len, 9);
359    }
360
361    #[test]
362    fn test_encode_extended_query() {
363        let params = vec![Some(b"hello".to_vec())];
364        let bytes = PgEncoder::encode_extended_query("SELECT $1", &params).unwrap();
365
366        // Should contain all 4 message types: P, B, E, S
367        assert!(bytes.windows(1).any(|w| w == [b'P']));
368        assert!(bytes.windows(1).any(|w| w == [b'B']));
369        assert!(bytes.windows(1).any(|w| w == [b'E']));
370        assert!(bytes.windows(1).any(|w| w == [b'S']));
371    }
372}
373
374// ==================== ULTRA-OPTIMIZED Hot Path Encoders ====================
375//
376// These encoders are designed to beat C:
377// - Direct integer writes (no temp arrays, no bounds checks)
378// - Borrowed slice params (zero-copy)
379// - Single store instructions via BufMut
380//
381
382use bytes::BufMut;
383
384/// Zero-copy parameter for ultra-fast encoding.
385/// Uses borrowed slices to avoid any allocation or copy.
386pub enum Param<'a> {
387    Null,
388    Bytes(&'a [u8]),
389}
390
391impl PgEncoder {
392    /// Direct i32 write - no temp array, no bounds check.
393    /// LLVM emits a single store instruction.
394    #[inline(always)]
395    fn put_i32_be(buf: &mut BytesMut, v: i32) {
396        buf.put_i32(v);
397    }
398
399    #[inline(always)]
400    fn put_i16_be(buf: &mut BytesMut, v: i16) {
401        buf.put_i16(v);
402    }
403
404    /// Encode Bind message - ULTRA OPTIMIZED.
405    /// - Direct integer writes (no temp arrays)
406    /// - Borrowed params (zero-copy)
407    /// - Single allocation check
408    #[inline]
409    pub fn encode_bind_ultra<'a>(buf: &mut BytesMut, statement: &str, params: &[Param<'a>]) -> Result<(), EncodeError> {
410        if params.len() > i16::MAX as usize {
411            return Err(EncodeError::TooManyParameters(params.len()));
412        }
413
414        // Calculate content length upfront
415        let params_size: usize = params
416            .iter()
417            .map(|p| match p {
418                Param::Null => 4,
419                Param::Bytes(b) => 4 + b.len(),
420            })
421            .sum();
422        let content_len = 1 + statement.len() + 1 + 2 + 2 + params_size + 2;
423
424        // Single reserve - no more allocations
425        buf.reserve(1 + 4 + content_len);
426
427        // Message type 'B'
428        buf.put_u8(b'B');
429
430        // Length (includes itself) - DIRECT WRITE
431        Self::put_i32_be(buf, (content_len + 4) as i32);
432
433        // Portal name (empty, null-terminated)
434        buf.put_u8(0);
435
436        // Statement name (null-terminated)
437        buf.extend_from_slice(statement.as_bytes());
438        buf.put_u8(0);
439
440        // Format codes count (0 = default text)
441        Self::put_i16_be(buf, 0);
442
443        // Parameter count
444        Self::put_i16_be(buf, params.len() as i16);
445
446        // Parameters - ZERO COPY from borrowed slices
447        for param in params {
448            match param {
449                Param::Null => Self::put_i32_be(buf, -1),
450                Param::Bytes(data) => {
451                    Self::put_i32_be(buf, data.len() as i32);
452                    buf.extend_from_slice(data);
453                }
454            }
455        }
456
457        // Result format codes count (0 = default text)
458        Self::put_i16_be(buf, 0);
459        Ok(())
460    }
461
462    /// Encode Execute message - ULTRA OPTIMIZED.
463    #[inline(always)]
464    pub fn encode_execute_ultra(buf: &mut BytesMut) {
465        // Execute: 'E' + len(9) + portal("") + max_rows(0)
466        // = 'E' 00 00 00 09 00 00 00 00 00
467        buf.extend_from_slice(&[b'E', 0, 0, 0, 9, 0, 0, 0, 0, 0]);
468    }
469
470    /// Encode Sync message - ULTRA OPTIMIZED.
471    #[inline(always)]
472    pub fn encode_sync_ultra(buf: &mut BytesMut) {
473        buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
474    }
475
476    // Keep the original methods for compatibility
477
478    /// Encode Bind message directly into existing buffer (ZERO ALLOCATION).
479    /// This is the hot path optimization - no intermediate Vec allocation.
480    #[inline]
481    pub fn encode_bind_to(buf: &mut BytesMut, statement: &str, params: &[Option<Vec<u8>>]) -> Result<(), EncodeError> {
482        if params.len() > i16::MAX as usize {
483            return Err(EncodeError::TooManyParameters(params.len()));
484        }
485
486        // Calculate content length upfront
487        // portal(1) + statement(len+1) + format_codes(2) + param_count(2) + params_data + result_format(2)
488        let params_size: usize = params
489            .iter()
490            .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
491            .sum();
492        let content_len = 1 + statement.len() + 1 + 2 + 2 + params_size + 2;
493
494        buf.reserve(1 + 4 + content_len);
495
496        // Message type 'B'
497        buf.put_u8(b'B');
498
499        // Length (includes itself) - DIRECT WRITE
500        Self::put_i32_be(buf, (content_len + 4) as i32);
501
502        // Portal name (empty, null-terminated)
503        buf.put_u8(0);
504
505        // Statement name (null-terminated)
506        buf.extend_from_slice(statement.as_bytes());
507        buf.put_u8(0);
508
509        // Format codes count (0 = default text)
510        Self::put_i16_be(buf, 0);
511
512        // Parameter count
513        Self::put_i16_be(buf, params.len() as i16);
514
515        // Parameters
516        for param in params {
517            match param {
518                None => Self::put_i32_be(buf, -1),
519                Some(data) => {
520                    Self::put_i32_be(buf, data.len() as i32);
521                    buf.extend_from_slice(data);
522                }
523            }
524        }
525
526        // Result format codes count (0 = default text)
527        Self::put_i16_be(buf, 0);
528        Ok(())
529    }
530
531    /// Encode Execute message directly into existing buffer (ZERO ALLOCATION).
532    #[inline]
533    pub fn encode_execute_to(buf: &mut BytesMut) {
534        // Content: portal(1) + max_rows(4) = 5 bytes
535        buf.extend_from_slice(&[b'E', 0, 0, 0, 9, 0, 0, 0, 0, 0]);
536    }
537
538    /// Encode Sync message directly into existing buffer (ZERO ALLOCATION).
539    #[inline]
540    pub fn encode_sync_to(buf: &mut BytesMut) {
541        buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
542    }
543}