Skip to main content

sqlmodel_postgres/protocol/
writer.rs

1//! PostgreSQL message encoder.
2//!
3//! This module handles encoding frontend messages into the wire protocol format.
4
5#![allow(clippy::cast_possible_truncation)]
6
7use super::messages::{
8    CANCEL_REQUEST_CODE, DescribeKind, FrontendMessage, SSL_REQUEST_CODE, frontend_type,
9};
10
11/// Buffer for writing PostgreSQL protocol messages.
12///
13/// All multi-byte integers are written in big-endian (network) byte order.
14#[derive(Debug, Clone)]
15pub struct MessageWriter {
16    /// Internal buffer for message data
17    buf: Vec<u8>,
18}
19
20impl Default for MessageWriter {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl MessageWriter {
27    /// Create a new message writer with default capacity.
28    pub fn new() -> Self {
29        Self::with_capacity(1024)
30    }
31
32    /// Create a new message writer with specified capacity.
33    pub fn with_capacity(capacity: usize) -> Self {
34        Self {
35            buf: Vec::with_capacity(capacity),
36        }
37    }
38
39    /// Clear the internal buffer.
40    pub fn clear(&mut self) {
41        self.buf.clear();
42    }
43
44    /// Get the current buffer contents.
45    pub fn as_bytes(&self) -> &[u8] {
46        &self.buf
47    }
48
49    /// Take ownership of the buffer, leaving an empty one in its place.
50    pub fn take(&mut self) -> Vec<u8> {
51        std::mem::take(&mut self.buf)
52    }
53
54    /// Encode a frontend message into the buffer.
55    ///
56    /// Returns a slice to the encoded message data.
57    pub fn write(&mut self, msg: &FrontendMessage) -> &[u8] {
58        self.buf.clear();
59
60        match msg {
61            FrontendMessage::Startup { version, params } => {
62                self.write_startup(*version, params);
63            }
64            FrontendMessage::PasswordMessage(password) => {
65                self.write_password(password);
66            }
67            FrontendMessage::SASLInitialResponse { mechanism, data } => {
68                self.write_sasl_initial(mechanism, data);
69            }
70            FrontendMessage::SASLResponse(data) => {
71                self.write_sasl_response(data);
72            }
73            FrontendMessage::Query(query) => {
74                self.write_query(query);
75            }
76            FrontendMessage::Parse {
77                name,
78                query,
79                param_types,
80            } => {
81                self.write_parse(name, query, param_types);
82            }
83            FrontendMessage::Bind {
84                portal,
85                statement,
86                param_formats,
87                params,
88                result_formats,
89            } => {
90                self.write_bind(portal, statement, param_formats, params, result_formats);
91            }
92            FrontendMessage::Describe { kind, name } => {
93                self.write_describe(*kind, name);
94            }
95            FrontendMessage::Execute { portal, max_rows } => {
96                self.write_execute(portal, *max_rows);
97            }
98            FrontendMessage::Close { kind, name } => {
99                self.write_close(*kind, name);
100            }
101            FrontendMessage::Sync => {
102                self.write_sync();
103            }
104            FrontendMessage::Flush => {
105                self.write_flush();
106            }
107            FrontendMessage::CopyData(data) => {
108                self.write_copy_data(data);
109            }
110            FrontendMessage::CopyDone => {
111                self.write_copy_done();
112            }
113            FrontendMessage::CopyFail(message) => {
114                self.write_copy_fail(message);
115            }
116            FrontendMessage::Terminate => {
117                self.write_terminate();
118            }
119            FrontendMessage::CancelRequest {
120                process_id,
121                secret_key,
122            } => {
123                self.write_cancel_request(*process_id, *secret_key);
124            }
125            FrontendMessage::SSLRequest => {
126                self.write_ssl_request();
127            }
128        }
129
130        &self.buf
131    }
132
133    // ==================== Message Encoders ====================
134
135    /// Write a startup message (no type byte).
136    fn write_startup(&mut self, version: i32, params: &[(String, String)]) {
137        // Calculate body length
138        let mut body_len = 4; // version
139        for (key, value) in params {
140            body_len += key.len() + 1 + value.len() + 1;
141        }
142        body_len += 1; // terminating null
143
144        // Write length (includes itself)
145        let total_len = (body_len + 4) as i32;
146        self.buf.extend_from_slice(&total_len.to_be_bytes());
147
148        // Write version
149        self.buf.extend_from_slice(&version.to_be_bytes());
150
151        // Write parameters
152        for (key, value) in params {
153            self.buf.extend_from_slice(key.as_bytes());
154            self.buf.push(0);
155            self.buf.extend_from_slice(value.as_bytes());
156            self.buf.push(0);
157        }
158
159        // Terminating null
160        self.buf.push(0);
161    }
162
163    /// Write a password message.
164    fn write_password(&mut self, password: &str) {
165        self.write_simple_string_message(frontend_type::PASSWORD, password);
166    }
167
168    /// Write SASL initial response.
169    fn write_sasl_initial(&mut self, mechanism: &str, data: &[u8]) {
170        // Type byte
171        self.buf.push(frontend_type::PASSWORD);
172
173        // Calculate length: 4 (length) + mechanism + null + 4 (data length) + data
174        let body_len = mechanism.len() + 1 + 4 + data.len();
175        let total_len = (body_len + 4) as i32;
176        self.buf.extend_from_slice(&total_len.to_be_bytes());
177
178        // Mechanism name
179        self.buf.extend_from_slice(mechanism.as_bytes());
180        self.buf.push(0);
181
182        // Data length (-1 if no data)
183        if data.is_empty() {
184            self.buf.extend_from_slice(&(-1_i32).to_be_bytes());
185        } else {
186            let data_len = data.len() as i32;
187            self.buf.extend_from_slice(&data_len.to_be_bytes());
188            self.buf.extend_from_slice(data);
189        }
190    }
191
192    /// Write SASL response.
193    fn write_sasl_response(&mut self, data: &[u8]) {
194        self.buf.push(frontend_type::PASSWORD);
195        let len = (data.len() + 4) as i32;
196        self.buf.extend_from_slice(&len.to_be_bytes());
197        self.buf.extend_from_slice(data);
198    }
199
200    /// Write a simple query message.
201    fn write_query(&mut self, query: &str) {
202        self.write_simple_string_message(frontend_type::QUERY, query);
203    }
204
205    /// Write a Parse message (prepare statement).
206    fn write_parse(&mut self, name: &str, query: &str, param_types: &[u32]) {
207        self.buf.push(frontend_type::PARSE);
208
209        // Calculate length
210        let body_len = name.len() + 1 + query.len() + 1 + 2 + (param_types.len() * 4);
211        let total_len = (body_len + 4) as i32;
212        self.buf.extend_from_slice(&total_len.to_be_bytes());
213
214        // Statement name
215        self.buf.extend_from_slice(name.as_bytes());
216        self.buf.push(0);
217
218        // Query string
219        self.buf.extend_from_slice(query.as_bytes());
220        self.buf.push(0);
221
222        // Parameter types
223        let num_params = param_types.len() as i16;
224        self.buf.extend_from_slice(&num_params.to_be_bytes());
225        for &oid in param_types {
226            self.buf.extend_from_slice(&oid.to_be_bytes());
227        }
228    }
229
230    /// Write a Bind message.
231    fn write_bind(
232        &mut self,
233        portal: &str,
234        statement: &str,
235        param_formats: &[i16],
236        params: &[Option<Vec<u8>>],
237        result_formats: &[i16],
238    ) {
239        self.buf.push(frontend_type::BIND);
240
241        // Calculate body length
242        let mut body_len = portal.len() + 1 + statement.len() + 1;
243        body_len += 2 + (param_formats.len() * 2); // format codes
244        body_len += 2; // num params
245
246        for param in params {
247            body_len += 4; // length
248            if let Some(data) = param {
249                body_len += data.len();
250            }
251        }
252
253        body_len += 2 + (result_formats.len() * 2); // result format codes
254
255        let total_len = (body_len + 4) as i32;
256        self.buf.extend_from_slice(&total_len.to_be_bytes());
257
258        // Portal name
259        self.buf.extend_from_slice(portal.as_bytes());
260        self.buf.push(0);
261
262        // Statement name
263        self.buf.extend_from_slice(statement.as_bytes());
264        self.buf.push(0);
265
266        // Parameter format codes
267        let num_formats = param_formats.len() as i16;
268        self.buf.extend_from_slice(&num_formats.to_be_bytes());
269        for &fmt in param_formats {
270            self.buf.extend_from_slice(&fmt.to_be_bytes());
271        }
272
273        // Parameter values
274        let num_params = params.len() as i16;
275        self.buf.extend_from_slice(&num_params.to_be_bytes());
276        for param in params {
277            match param {
278                Some(data) => {
279                    let len = data.len() as i32;
280                    self.buf.extend_from_slice(&len.to_be_bytes());
281                    self.buf.extend_from_slice(data);
282                }
283                None => {
284                    // NULL value
285                    self.buf.extend_from_slice(&(-1_i32).to_be_bytes());
286                }
287            }
288        }
289
290        // Result format codes
291        let num_result_formats = result_formats.len() as i16;
292        self.buf
293            .extend_from_slice(&num_result_formats.to_be_bytes());
294        for &fmt in result_formats {
295            self.buf.extend_from_slice(&fmt.to_be_bytes());
296        }
297    }
298
299    /// Write a Describe message.
300    fn write_describe(&mut self, kind: DescribeKind, name: &str) {
301        self.buf.push(frontend_type::DESCRIBE);
302        let body_len = 1 + name.len() + 1;
303        let total_len = (body_len + 4) as i32;
304        self.buf.extend_from_slice(&total_len.to_be_bytes());
305        self.buf.push(kind.as_byte());
306        self.buf.extend_from_slice(name.as_bytes());
307        self.buf.push(0);
308    }
309
310    /// Write an Execute message.
311    fn write_execute(&mut self, portal: &str, max_rows: i32) {
312        self.buf.push(frontend_type::EXECUTE);
313        let body_len = portal.len() + 1 + 4;
314        let total_len = (body_len + 4) as i32;
315        self.buf.extend_from_slice(&total_len.to_be_bytes());
316        self.buf.extend_from_slice(portal.as_bytes());
317        self.buf.push(0);
318        self.buf.extend_from_slice(&max_rows.to_be_bytes());
319    }
320
321    /// Write a Close message.
322    fn write_close(&mut self, kind: DescribeKind, name: &str) {
323        self.buf.push(frontend_type::CLOSE);
324        let body_len = 1 + name.len() + 1;
325        let total_len = (body_len + 4) as i32;
326        self.buf.extend_from_slice(&total_len.to_be_bytes());
327        self.buf.push(kind.as_byte());
328        self.buf.extend_from_slice(name.as_bytes());
329        self.buf.push(0);
330    }
331
332    /// Write a Sync message.
333    fn write_sync(&mut self) {
334        self.write_empty_message(frontend_type::SYNC);
335    }
336
337    /// Write a Flush message.
338    fn write_flush(&mut self) {
339        self.write_empty_message(frontend_type::FLUSH);
340    }
341
342    /// Write COPY data.
343    fn write_copy_data(&mut self, data: &[u8]) {
344        self.buf.push(frontend_type::COPY_DATA);
345        let len = (data.len() + 4) as i32;
346        self.buf.extend_from_slice(&len.to_be_bytes());
347        self.buf.extend_from_slice(data);
348    }
349
350    /// Write COPY done.
351    fn write_copy_done(&mut self) {
352        self.write_empty_message(frontend_type::COPY_DONE);
353    }
354
355    /// Write COPY fail.
356    fn write_copy_fail(&mut self, message: &str) {
357        self.write_simple_string_message(frontend_type::COPY_FAIL, message);
358    }
359
360    /// Write Terminate message.
361    fn write_terminate(&mut self) {
362        self.write_empty_message(frontend_type::TERMINATE);
363    }
364
365    /// Write cancel request (special format, no type byte).
366    fn write_cancel_request(&mut self, process_id: i32, secret_key: i32) {
367        // Length (16 bytes total)
368        self.buf.extend_from_slice(&16_i32.to_be_bytes());
369        // Cancel request code
370        self.buf
371            .extend_from_slice(&CANCEL_REQUEST_CODE.to_be_bytes());
372        // Process ID
373        self.buf.extend_from_slice(&process_id.to_be_bytes());
374        // Secret key
375        self.buf.extend_from_slice(&secret_key.to_be_bytes());
376    }
377
378    /// Write SSL request (special format, no type byte).
379    fn write_ssl_request(&mut self) {
380        // Length (8 bytes total)
381        self.buf.extend_from_slice(&8_i32.to_be_bytes());
382        // SSL request code
383        self.buf.extend_from_slice(&SSL_REQUEST_CODE.to_be_bytes());
384    }
385
386    // ==================== Helper Methods ====================
387
388    /// Write a message with just a type byte and length (no body).
389    fn write_empty_message(&mut self, type_byte: u8) {
390        self.buf.push(type_byte);
391        self.buf.extend_from_slice(&4_i32.to_be_bytes());
392    }
393
394    /// Write a message containing a single null-terminated string.
395    fn write_simple_string_message(&mut self, type_byte: u8, s: &str) {
396        self.buf.push(type_byte);
397        let len = (s.len() + 5) as i32; // 4 for length + string + null
398        self.buf.extend_from_slice(&len.to_be_bytes());
399        self.buf.extend_from_slice(s.as_bytes());
400        self.buf.push(0);
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407    use crate::protocol::PROTOCOL_VERSION;
408
409    #[test]
410    fn test_startup_message() {
411        let mut writer = MessageWriter::new();
412        let msg = FrontendMessage::Startup {
413            version: PROTOCOL_VERSION,
414            params: vec![
415                ("user".to_string(), "postgres".to_string()),
416                ("database".to_string(), "test".to_string()),
417            ],
418        };
419
420        let data = writer.write(&msg);
421
422        // Verify structure
423        let len = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
424        assert!(len > 0);
425
426        let version = i32::from_be_bytes([data[4], data[5], data[6], data[7]]);
427        assert_eq!(version, PROTOCOL_VERSION);
428
429        // Check parameters are null-terminated
430        assert!(data.ends_with(&[0]));
431    }
432
433    #[test]
434    fn test_query_message() {
435        let mut writer = MessageWriter::new();
436        let msg = FrontendMessage::Query("SELECT 1".to_string());
437
438        let data = writer.write(&msg);
439
440        assert_eq!(data[0], b'Q');
441        let len = i32::from_be_bytes([data[1], data[2], data[3], data[4]]) as usize;
442        assert_eq!(len, 4 + 8 + 1); // length field + "SELECT 1" + null
443
444        // Check null terminator
445        assert_eq!(data[len], 0);
446    }
447
448    #[test]
449    fn test_sync_message() {
450        let mut writer = MessageWriter::new();
451        let msg = FrontendMessage::Sync;
452
453        let data = writer.write(&msg);
454
455        assert_eq!(data, &[b'S', 0, 0, 0, 4]);
456    }
457
458    #[test]
459    fn test_flush_message() {
460        let mut writer = MessageWriter::new();
461        let msg = FrontendMessage::Flush;
462
463        let data = writer.write(&msg);
464
465        assert_eq!(data, &[b'H', 0, 0, 0, 4]);
466    }
467
468    #[test]
469    fn test_terminate_message() {
470        let mut writer = MessageWriter::new();
471        let msg = FrontendMessage::Terminate;
472
473        let data = writer.write(&msg);
474
475        assert_eq!(data, &[b'X', 0, 0, 0, 4]);
476    }
477
478    #[test]
479    fn test_parse_message() {
480        let mut writer = MessageWriter::new();
481        let msg = FrontendMessage::Parse {
482            name: "stmt1".to_string(),
483            query: "SELECT $1".to_string(),
484            param_types: vec![23], // int4
485        };
486
487        let data = writer.write(&msg);
488
489        assert_eq!(data[0], b'P');
490
491        // Find the statement name
492        let name_start = 5;
493        let name_end = data[name_start..].iter().position(|&b| b == 0).unwrap() + name_start;
494        assert_eq!(&data[name_start..name_end], b"stmt1");
495    }
496
497    #[test]
498    fn test_describe_statement() {
499        let mut writer = MessageWriter::new();
500        let msg = FrontendMessage::Describe {
501            kind: DescribeKind::Statement,
502            name: "stmt1".to_string(),
503        };
504
505        let data = writer.write(&msg);
506
507        assert_eq!(data[0], b'D');
508        assert_eq!(data[5], b'S'); // Statement kind
509    }
510
511    #[test]
512    fn test_describe_portal() {
513        let mut writer = MessageWriter::new();
514        let msg = FrontendMessage::Describe {
515            kind: DescribeKind::Portal,
516            name: "portal1".to_string(),
517        };
518
519        let data = writer.write(&msg);
520
521        assert_eq!(data[0], b'D');
522        assert_eq!(data[5], b'P'); // Portal kind
523    }
524
525    #[test]
526    fn test_execute_message() {
527        let mut writer = MessageWriter::new();
528        let msg = FrontendMessage::Execute {
529            portal: String::new(),
530            max_rows: 0,
531        };
532
533        let data = writer.write(&msg);
534
535        assert_eq!(data[0], b'E');
536
537        // Check max_rows (0 = no limit)
538        let max_rows_offset = 5 + 1; // type + length + empty string + null
539        let max_rows = i32::from_be_bytes([
540            data[max_rows_offset],
541            data[max_rows_offset + 1],
542            data[max_rows_offset + 2],
543            data[max_rows_offset + 3],
544        ]);
545        assert_eq!(max_rows, 0);
546    }
547
548    #[test]
549    fn test_cancel_request() {
550        let mut writer = MessageWriter::new();
551        let msg = FrontendMessage::CancelRequest {
552            process_id: 12345,
553            secret_key: 67890,
554        };
555
556        let data = writer.write(&msg);
557
558        // Length
559        let len = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
560        assert_eq!(len, 16);
561
562        // Cancel code
563        let code = i32::from_be_bytes([data[4], data[5], data[6], data[7]]);
564        assert_eq!(code, CANCEL_REQUEST_CODE);
565
566        // Process ID
567        let pid = i32::from_be_bytes([data[8], data[9], data[10], data[11]]);
568        assert_eq!(pid, 12345);
569
570        // Secret key
571        let key = i32::from_be_bytes([data[12], data[13], data[14], data[15]]);
572        assert_eq!(key, 67890);
573    }
574
575    #[test]
576    fn test_ssl_request() {
577        let mut writer = MessageWriter::new();
578        let msg = FrontendMessage::SSLRequest;
579
580        let data = writer.write(&msg);
581
582        let len = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
583        assert_eq!(len, 8);
584
585        let code = i32::from_be_bytes([data[4], data[5], data[6], data[7]]);
586        assert_eq!(code, SSL_REQUEST_CODE);
587    }
588
589    #[test]
590    fn test_bind_with_null_params() {
591        let mut writer = MessageWriter::new();
592        let msg = FrontendMessage::Bind {
593            portal: String::new(),
594            statement: "stmt1".to_string(),
595            param_formats: vec![0],
596            params: vec![None], // NULL parameter
597            result_formats: vec![],
598        };
599
600        let data = writer.write(&msg);
601        assert_eq!(data[0], b'B');
602
603        // Look for -1 (NULL indicator) in the parameter section
604        let null_indicator = (-1_i32).to_be_bytes();
605        assert!(data.windows(4).any(|w| w == null_indicator));
606    }
607
608    #[test]
609    fn test_copy_data() {
610        let mut writer = MessageWriter::new();
611        let payload = b"hello\nworld\n";
612        let msg = FrontendMessage::CopyData(payload.to_vec());
613
614        let data = writer.write(&msg);
615
616        assert_eq!(data[0], b'd');
617        let len = i32::from_be_bytes([data[1], data[2], data[3], data[4]]);
618        assert_eq!(len, (4 + payload.len()) as i32);
619        assert_eq!(&data[5..], payload);
620    }
621
622    #[test]
623    fn test_writer_reuse() {
624        let mut writer = MessageWriter::new();
625
626        // First message
627        writer.write(&FrontendMessage::Sync);
628        assert_eq!(writer.as_bytes(), &[b'S', 0, 0, 0, 4]);
629
630        // Second message - should replace first
631        writer.write(&FrontendMessage::Flush);
632        assert_eq!(writer.as_bytes(), &[b'H', 0, 0, 0, 4]);
633    }
634}