qail_pg/protocol/
encoder.rs

1//! PostgreSQL Encoder (Visitor Pattern)
2//!
3//! Compiles QailCmd AST into PostgreSQL wire protocol bytes.
4//! This is pure, synchronous computation - no I/O, no async.
5//!
6//! # Architecture
7//!
8//! Layer 2 of the QAIL architecture:
9//! - Input: QailCmd (AST)
10//! - Output: BytesMut (ready to send over the wire)
11//!
12//! The async I/O layer (Layer 3) consumes these bytes.
13
14use bytes::BytesMut;
15
16/// PostgreSQL protocol encoder.
17///
18/// Takes a QailCmd and produces wire protocol bytes.
19/// This is the "Visitor" in the visitor pattern.
20pub struct PgEncoder;
21
22impl PgEncoder {
23    /// Encode a raw SQL string as a Simple Query message.
24    ///
25    /// Wire format:
26    /// - 'Q' (1 byte) - message type
27    /// - length (4 bytes, big-endian, includes self)
28    /// - query string (null-terminated)
29    pub fn encode_query_string(sql: &str) -> BytesMut {
30        let mut buf = BytesMut::new();
31
32        // Message type 'Q' for Query
33        buf.extend_from_slice(b"Q");
34
35        // Content: query string + null terminator
36        let content_len = sql.len() + 1; // +1 for null terminator
37        let total_len = (content_len + 4) as i32; // +4 for length field itself
38
39        // Length (4 bytes, big-endian)
40        buf.extend_from_slice(&total_len.to_be_bytes());
41
42        // Query string
43        buf.extend_from_slice(sql.as_bytes());
44
45        // Null terminator
46        buf.extend_from_slice(&[0]);
47
48        buf
49    }
50
51    /// Encode a Terminate message to close the connection.
52    pub fn encode_terminate() -> BytesMut {
53        let mut buf = BytesMut::new();
54        buf.extend_from_slice(&[b'X', 0, 0, 0, 4]);
55        buf
56    }
57
58    /// Encode a Sync message (end of pipeline in extended query protocol).
59    pub fn encode_sync() -> BytesMut {
60        let mut buf = BytesMut::new();
61        buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
62        buf
63    }
64
65    // ==================== Extended Query Protocol ====================
66
67    /// Encode a Parse message (prepare a statement).
68    ///
69    /// Wire format:
70    /// - 'P' (1 byte) - message type
71    /// - length (4 bytes)
72    /// - statement name (null-terminated, "" for unnamed)
73    /// - query string (null-terminated)
74    /// - parameter count (2 bytes)
75    /// - parameter OIDs (4 bytes each, 0 = infer type)
76    pub fn encode_parse(name: &str, sql: &str, param_types: &[u32]) -> BytesMut {
77        let mut buf = BytesMut::new();
78
79        // Message type 'P'
80        buf.extend_from_slice(b"P");
81
82        // Build content first to calculate length
83        let mut content = Vec::new();
84
85        // Statement name (null-terminated)
86        content.extend_from_slice(name.as_bytes());
87        content.push(0);
88
89        // Query string (null-terminated)
90        content.extend_from_slice(sql.as_bytes());
91        content.push(0);
92
93        // Parameter count
94        content.extend_from_slice(&(param_types.len() as i16).to_be_bytes());
95
96        // Parameter OIDs
97        for &oid in param_types {
98            content.extend_from_slice(&oid.to_be_bytes());
99        }
100
101        // Length (includes length field itself)
102        let len = (content.len() + 4) as i32;
103        buf.extend_from_slice(&len.to_be_bytes());
104        buf.extend_from_slice(&content);
105
106        buf
107    }
108
109    /// Encode a Bind message (bind parameters to a prepared statement).
110    ///
111    /// Wire format:
112    /// - 'B' (1 byte) - message type
113    /// - length (4 bytes)
114    /// - portal name (null-terminated)
115    /// - statement name (null-terminated)
116    /// - format code count (2 bytes) - we use 0 (all text)
117    /// - parameter count (2 bytes)
118    /// - for each parameter: length (4 bytes, -1 for NULL), data
119    /// - result format count (2 bytes) - we use 0 (all text)
120    pub fn encode_bind(portal: &str, statement: &str, params: &[Option<Vec<u8>>]) -> BytesMut {
121        let mut buf = BytesMut::new();
122
123        // Message type 'B'
124        buf.extend_from_slice(b"B");
125
126        // Build content
127        let mut content = Vec::new();
128
129        // Portal name (null-terminated)
130        content.extend_from_slice(portal.as_bytes());
131        content.push(0);
132
133        // Statement name (null-terminated)
134        content.extend_from_slice(statement.as_bytes());
135        content.push(0);
136
137        // Format codes count (0 = use default text format)
138        content.extend_from_slice(&0i16.to_be_bytes());
139
140        // Parameter count
141        content.extend_from_slice(&(params.len() as i16).to_be_bytes());
142
143        // Parameters
144        for param in params {
145            match param {
146                None => {
147                    // NULL: length = -1
148                    content.extend_from_slice(&(-1i32).to_be_bytes());
149                }
150                Some(data) => {
151                    content.extend_from_slice(&(data.len() as i32).to_be_bytes());
152                    content.extend_from_slice(data);
153                }
154            }
155        }
156
157        // Result format codes count (0 = use default text format)
158        content.extend_from_slice(&0i16.to_be_bytes());
159
160        // Length
161        let len = (content.len() + 4) as i32;
162        buf.extend_from_slice(&len.to_be_bytes());
163        buf.extend_from_slice(&content);
164
165        buf
166    }
167
168    /// Encode an Execute message (execute a bound portal).
169    ///
170    /// Wire format:
171    /// - 'E' (1 byte) - message type
172    /// - length (4 bytes)
173    /// - portal name (null-terminated)
174    /// - max rows (4 bytes, 0 = unlimited)
175    pub fn encode_execute(portal: &str, max_rows: i32) -> BytesMut {
176        let mut buf = BytesMut::new();
177
178        // Message type 'E'
179        buf.extend_from_slice(b"E");
180
181        // Build content
182        let mut content = Vec::new();
183
184        // Portal name (null-terminated)
185        content.extend_from_slice(portal.as_bytes());
186        content.push(0);
187
188        // Max rows
189        content.extend_from_slice(&max_rows.to_be_bytes());
190
191        // Length
192        let len = (content.len() + 4) as i32;
193        buf.extend_from_slice(&len.to_be_bytes());
194        buf.extend_from_slice(&content);
195
196        buf
197    }
198
199    /// Encode a Describe message (get statement/portal metadata).
200    ///
201    /// Wire format:
202    /// - 'D' (1 byte) - message type
203    /// - length (4 bytes)
204    /// - 'S' for statement or 'P' for portal
205    /// - name (null-terminated)
206    pub fn encode_describe(is_portal: bool, name: &str) -> BytesMut {
207        let mut buf = BytesMut::new();
208
209        // Message type 'D'
210        buf.extend_from_slice(b"D");
211
212        // Build content
213        let mut content = Vec::new();
214
215        // Type: 'S' for statement, 'P' for portal
216        content.push(if is_portal { b'P' } else { b'S' });
217
218        // Name (null-terminated)
219        content.extend_from_slice(name.as_bytes());
220        content.push(0);
221
222        // Length
223        let len = (content.len() + 4) as i32;
224        buf.extend_from_slice(&len.to_be_bytes());
225        buf.extend_from_slice(&content);
226
227        buf
228    }
229
230    /// Encode a complete extended query pipeline (OPTIMIZED).
231    ///
232    /// This combines Parse + Bind + Execute + Sync in a single buffer.
233    /// Zero intermediate allocations - writes directly to pre-sized BytesMut.
234    pub fn encode_extended_query(sql: &str, params: &[Option<Vec<u8>>]) -> BytesMut {
235        // Calculate total size upfront to avoid reallocations
236        // Parse: 1 + 4 + 1 + sql.len() + 1 + 2 = 9 + sql.len()
237        // Bind: 1 + 4 + 1 + 1 + 2 + 2 + params_data + 2 = 13 + params_data
238        // Execute: 1 + 4 + 1 + 4 = 10
239        // Sync: 5
240        let params_size: usize = params
241            .iter()
242            .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
243            .sum();
244        let total_size = 9 + sql.len() + 13 + params_size + 10 + 5;
245
246        let mut buf = BytesMut::with_capacity(total_size);
247
248        // ===== PARSE =====
249        buf.extend_from_slice(b"P");
250        let parse_len = (1 + sql.len() + 1 + 2 + 4) as i32; // name + sql + param_count
251        buf.extend_from_slice(&parse_len.to_be_bytes());
252        buf.extend_from_slice(&[0]); // Unnamed statement
253        buf.extend_from_slice(sql.as_bytes());
254        buf.extend_from_slice(&[0]); // Null terminator
255        buf.extend_from_slice(&0i16.to_be_bytes()); // No param types (infer)
256
257        // ===== BIND =====
258        buf.extend_from_slice(b"B");
259        let bind_len = (1 + 1 + 2 + 2 + params_size + 2 + 4) as i32;
260        buf.extend_from_slice(&bind_len.to_be_bytes());
261        buf.extend_from_slice(&[0]); // Unnamed portal
262        buf.extend_from_slice(&[0]); // Unnamed statement
263        buf.extend_from_slice(&0i16.to_be_bytes()); // Format codes (default text)
264        buf.extend_from_slice(&(params.len() as i16).to_be_bytes());
265        for param in params {
266            match param {
267                None => buf.extend_from_slice(&(-1i32).to_be_bytes()),
268                Some(data) => {
269                    buf.extend_from_slice(&(data.len() as i32).to_be_bytes());
270                    buf.extend_from_slice(data);
271                }
272            }
273        }
274        buf.extend_from_slice(&0i16.to_be_bytes()); // Result format (default text)
275
276        // ===== EXECUTE =====
277        buf.extend_from_slice(b"E");
278        buf.extend_from_slice(&9i32.to_be_bytes()); // len = 4 + 1 + 4
279        buf.extend_from_slice(&[0]); // Unnamed portal
280        buf.extend_from_slice(&0i32.to_be_bytes()); // Unlimited rows
281
282        // ===== SYNC =====
283        buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
284
285        buf
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    // NOTE: test_encode_simple_query removed - use AstEncoder instead
294    #[test]
295    fn test_encode_query_string() {
296        let sql = "SELECT 1";
297        let bytes = PgEncoder::encode_query_string(sql);
298
299        // Message type
300        assert_eq!(bytes[0], b'Q');
301
302        // Length: 4 (length field) + 8 (query) + 1 (null) = 13
303        let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
304        assert_eq!(len, 13);
305
306        // Query content
307        assert_eq!(&bytes[5..13], b"SELECT 1");
308
309        // Null terminator
310        assert_eq!(bytes[13], 0);
311    }
312
313    #[test]
314    fn test_encode_terminate() {
315        let bytes = PgEncoder::encode_terminate();
316        assert_eq!(bytes.as_ref(), &[b'X', 0, 0, 0, 4]);
317    }
318
319    #[test]
320    fn test_encode_sync() {
321        let bytes = PgEncoder::encode_sync();
322        assert_eq!(bytes.as_ref(), &[b'S', 0, 0, 0, 4]);
323    }
324
325    #[test]
326    fn test_encode_parse() {
327        let bytes = PgEncoder::encode_parse("", "SELECT $1", &[]);
328
329        // Message type 'P'
330        assert_eq!(bytes[0], b'P');
331
332        // Content should include query
333        let content = String::from_utf8_lossy(&bytes[5..]);
334        assert!(content.contains("SELECT $1"));
335    }
336
337    #[test]
338    fn test_encode_bind() {
339        let params = vec![
340            Some(b"42".to_vec()),
341            None, // NULL
342        ];
343        let bytes = PgEncoder::encode_bind("", "", &params);
344
345        // Message type 'B'
346        assert_eq!(bytes[0], b'B');
347
348        // Should have proper length
349        let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
350        assert!(len > 4); // At least header
351    }
352
353    #[test]
354    fn test_encode_execute() {
355        let bytes = PgEncoder::encode_execute("", 0);
356
357        // Message type 'E'
358        assert_eq!(bytes[0], b'E');
359
360        // Length: 4 + 1 (null) + 4 (max_rows) = 9
361        let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
362        assert_eq!(len, 9);
363    }
364
365    #[test]
366    fn test_encode_extended_query() {
367        let params = vec![Some(b"hello".to_vec())];
368        let bytes = PgEncoder::encode_extended_query("SELECT $1", &params);
369
370        // Should contain all 4 message types: P, B, E, S
371        assert!(bytes.windows(1).any(|w| w == [b'P']));
372        assert!(bytes.windows(1).any(|w| w == [b'B']));
373        assert!(bytes.windows(1).any(|w| w == [b'E']));
374        assert!(bytes.windows(1).any(|w| w == [b'S']));
375    }
376}
377
378// ==================== ULTRA-OPTIMIZED Hot Path Encoders ====================
379//
380// These encoders are designed to beat C:
381// - Direct integer writes (no temp arrays, no bounds checks)
382// - Borrowed slice params (zero-copy)
383// - Single store instructions via BufMut
384//
385
386use bytes::BufMut;
387
388/// Zero-copy parameter for ultra-fast encoding.
389/// Uses borrowed slices to avoid any allocation or copy.
390pub enum Param<'a> {
391    Null,
392    Bytes(&'a [u8]),
393}
394
395impl PgEncoder {
396    /// Direct i32 write - no temp array, no bounds check.
397    /// LLVM emits a single store instruction.
398    #[inline(always)]
399    fn put_i32_be(buf: &mut BytesMut, v: i32) {
400        buf.put_i32(v);
401    }
402
403    /// Direct i16 write.
404    #[inline(always)]
405    fn put_i16_be(buf: &mut BytesMut, v: i16) {
406        buf.put_i16(v);
407    }
408
409    /// Encode Bind message - ULTRA OPTIMIZED.
410    ///
411    /// - Direct integer writes (no temp arrays)
412    /// - Borrowed params (zero-copy)
413    /// - Single allocation check
414    #[inline]
415    pub fn encode_bind_ultra<'a>(buf: &mut BytesMut, statement: &str, params: &[Param<'a>]) {
416        // Calculate content length upfront
417        let params_size: usize = params
418            .iter()
419            .map(|p| match p {
420                Param::Null => 4,
421                Param::Bytes(b) => 4 + b.len(),
422            })
423            .sum();
424        let content_len = 1 + statement.len() + 1 + 2 + 2 + params_size + 2;
425
426        // Single reserve - no more allocations
427        buf.reserve(1 + 4 + content_len);
428
429        // Message type 'B'
430        buf.put_u8(b'B');
431
432        // Length (includes itself) - DIRECT WRITE
433        Self::put_i32_be(buf, (content_len + 4) as i32);
434
435        // Portal name (empty, null-terminated)
436        buf.put_u8(0);
437
438        // Statement name (null-terminated)
439        buf.extend_from_slice(statement.as_bytes());
440        buf.put_u8(0);
441
442        // Format codes count (0 = default text)
443        Self::put_i16_be(buf, 0);
444
445        // Parameter count
446        Self::put_i16_be(buf, params.len() as i16);
447
448        // Parameters - ZERO COPY from borrowed slices
449        for param in params {
450            match param {
451                Param::Null => Self::put_i32_be(buf, -1),
452                Param::Bytes(data) => {
453                    Self::put_i32_be(buf, data.len() as i32);
454                    buf.extend_from_slice(data);
455                }
456            }
457        }
458
459        // Result format codes count (0 = default text)
460        Self::put_i16_be(buf, 0);
461    }
462
463    /// Encode Execute message - ULTRA OPTIMIZED.
464    /// Pre-computed constant bytes.
465    #[inline(always)]
466    pub fn encode_execute_ultra(buf: &mut BytesMut) {
467        // Execute: 'E' + len(9) + portal("") + max_rows(0)
468        // = 'E' 00 00 00 09 00 00 00 00 00
469        buf.extend_from_slice(&[b'E', 0, 0, 0, 9, 0, 0, 0, 0, 0]);
470    }
471
472    /// Encode Sync message - ULTRA OPTIMIZED.
473    /// Pre-computed constant bytes.
474    #[inline(always)]
475    pub fn encode_sync_ultra(buf: &mut BytesMut) {
476        buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
477    }
478
479    // Keep the original methods for compatibility
480
481    /// Encode Bind message directly into existing buffer (ZERO ALLOCATION).
482    ///
483    /// This is the hot path optimization - no intermediate Vec allocation.
484    #[inline]
485    pub fn encode_bind_to(buf: &mut BytesMut, statement: &str, params: &[Option<Vec<u8>>]) {
486        // Calculate content length upfront
487        // portal(1) + statement(len+1) + format_codes(2) + param_count(2) + params_data + result_format(2)
488        let params_size: usize = params
489            .iter()
490            .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
491            .sum();
492        let content_len = 1 + statement.len() + 1 + 2 + 2 + params_size + 2;
493
494        // Reserve space upfront
495        buf.reserve(1 + 4 + content_len);
496
497        // Message type 'B'
498        buf.put_u8(b'B');
499
500        // Length (includes itself) - DIRECT WRITE
501        Self::put_i32_be(buf, (content_len + 4) as i32);
502
503        // Portal name (empty, null-terminated)
504        buf.put_u8(0);
505
506        // Statement name (null-terminated)
507        buf.extend_from_slice(statement.as_bytes());
508        buf.put_u8(0);
509
510        // Format codes count (0 = default text)
511        Self::put_i16_be(buf, 0);
512
513        // Parameter count
514        Self::put_i16_be(buf, params.len() as i16);
515
516        // Parameters
517        for param in params {
518            match param {
519                None => Self::put_i32_be(buf, -1),
520                Some(data) => {
521                    Self::put_i32_be(buf, data.len() as i32);
522                    buf.extend_from_slice(data);
523                }
524            }
525        }
526
527        // Result format codes count (0 = default text)
528        Self::put_i16_be(buf, 0);
529    }
530
531    /// Encode Execute message directly into existing buffer (ZERO ALLOCATION).
532    #[inline]
533    pub fn encode_execute_to(buf: &mut BytesMut) {
534        // Execute with empty portal, max_rows=0 (unlimited)
535        // Content: portal(1) + max_rows(4) = 5 bytes
536        buf.extend_from_slice(&[b'E', 0, 0, 0, 9, 0, 0, 0, 0, 0]);
537    }
538
539    /// Encode Sync message directly into existing buffer (ZERO ALLOCATION).
540    #[inline]
541    pub fn encode_sync_to(buf: &mut BytesMut) {
542        buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
543    }
544}