Skip to main content

sentinel_driver/copy/
binary.rs

1use bytes::{BufMut, BytesMut};
2
3use crate::error::{Error, Result};
4
5/// Binary COPY header: `PGCOPY\n\377\r\n\0` + flags(4) + extension_len(4)
6const BINARY_HEADER: &[u8] = b"PGCOPY\n\xff\r\n\0";
7const HEADER_FLAGS: i32 = 0;
8const HEADER_EXTENSION_LEN: i32 = 0;
9
10/// Binary COPY trailer: field count = -1
11const BINARY_TRAILER_FIELD_COUNT: i16 = -1;
12
13/// Encoder for binary COPY IN format.
14///
15/// Builds a buffer containing the binary header, tuple data, and trailer.
16///
17/// # Example
18///
19/// ```rust
20/// use sentinel_driver::copy::binary::BinaryCopyEncoder;
21///
22/// let mut encoder = BinaryCopyEncoder::new();
23///
24/// // Write a row with two columns: int4(42) and text("hello")
25/// encoder.begin_row(2);
26/// encoder.write_field(&42i32.to_be_bytes());
27/// encoder.write_field(b"hello");
28///
29/// // Write another row with a NULL second column
30/// encoder.begin_row(2);
31/// encoder.write_field(&7i32.to_be_bytes());
32/// encoder.write_null();
33///
34/// let data = encoder.finish();
35/// // data can be sent via CopyIn::write_raw()
36/// ```
37pub struct BinaryCopyEncoder {
38    buf: BytesMut,
39    header_written: bool,
40}
41
42impl BinaryCopyEncoder {
43    pub fn new() -> Self {
44        Self {
45            buf: BytesMut::with_capacity(8192),
46            header_written: false,
47        }
48    }
49
50    fn ensure_header(&mut self) {
51        if !self.header_written {
52            self.buf.put_slice(BINARY_HEADER);
53            self.buf.put_i32(HEADER_FLAGS);
54            self.buf.put_i32(HEADER_EXTENSION_LEN);
55            self.header_written = true;
56        }
57    }
58
59    /// Begin a new row with the given number of fields.
60    pub fn begin_row(&mut self, field_count: i16) {
61        self.ensure_header();
62        self.buf.put_i16(field_count);
63    }
64
65    /// Write a non-NULL field value (already in binary PG format).
66    pub fn write_field(&mut self, data: &[u8]) {
67        self.buf.put_i32(data.len() as i32);
68        self.buf.put_slice(data);
69    }
70
71    /// Write a NULL field.
72    pub fn write_null(&mut self) {
73        self.buf.put_i32(-1);
74    }
75
76    /// Finish encoding and return the complete binary COPY data.
77    ///
78    /// Appends the trailer (field_count = -1).
79    pub fn finish(mut self) -> Vec<u8> {
80        self.ensure_header();
81        self.buf.put_i16(BINARY_TRAILER_FIELD_COUNT);
82        self.buf.to_vec()
83    }
84
85    /// Get the current buffer size.
86    pub fn len(&self) -> usize {
87        self.buf.len()
88    }
89
90    pub fn is_empty(&self) -> bool {
91        self.buf.is_empty()
92    }
93}
94
95impl Default for BinaryCopyEncoder {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101/// Decoder for binary COPY OUT format.
102///
103/// Parses binary COPY data received from the server.
104pub struct BinaryCopyDecoder<'a> {
105    data: &'a [u8],
106    pos: usize,
107    header_parsed: bool,
108}
109
110impl<'a> BinaryCopyDecoder<'a> {
111    pub fn new(data: &'a [u8]) -> Self {
112        Self {
113            data,
114            pos: 0,
115            header_parsed: false,
116        }
117    }
118
119    /// Parse the binary COPY header. Must be called before reading rows.
120    pub fn parse_header(&mut self) -> Result<()> {
121        if self.data.len() < BINARY_HEADER.len() + 8 {
122            return Err(Error::Copy("binary COPY data too short for header".into()));
123        }
124
125        if &self.data[..BINARY_HEADER.len()] != BINARY_HEADER {
126            return Err(Error::Copy("invalid binary COPY header".into()));
127        }
128
129        self.pos = BINARY_HEADER.len();
130
131        // flags
132        let _flags = read_i32(self.data, self.pos);
133        self.pos += 4;
134
135        // extension area length
136        let ext_len = read_i32(self.data, self.pos) as usize;
137        self.pos += 4;
138
139        // skip extension area
140        self.pos += ext_len;
141
142        self.header_parsed = true;
143        Ok(())
144    }
145
146    /// Read the next row. Returns `None` at the trailer.
147    ///
148    /// Each field is returned as `Option<&[u8]>` (None = NULL).
149    pub fn next_row(&mut self) -> Result<Option<Vec<Option<&'a [u8]>>>> {
150        if !self.header_parsed {
151            self.parse_header()?;
152        }
153
154        if self.pos + 2 > self.data.len() {
155            return Ok(None);
156        }
157
158        let field_count = read_i16(self.data, self.pos);
159        self.pos += 2;
160
161        // Trailer: field_count == -1
162        if field_count == BINARY_TRAILER_FIELD_COUNT {
163            return Ok(None);
164        }
165
166        if field_count < 0 {
167            return Err(Error::Copy(format!("invalid field count: {field_count}")));
168        }
169
170        let mut fields = Vec::with_capacity(field_count as usize);
171
172        for _ in 0..field_count {
173            if self.pos + 4 > self.data.len() {
174                return Err(Error::Copy("truncated binary COPY row".into()));
175            }
176
177            let len = read_i32(self.data, self.pos);
178            self.pos += 4;
179
180            if len == -1 {
181                fields.push(None); // NULL
182            } else if len < 0 {
183                return Err(Error::Copy(format!("invalid field length: {len}")));
184            } else {
185                let len = len as usize;
186                if self.pos + len > self.data.len() {
187                    return Err(Error::Copy("truncated binary COPY field".into()));
188                }
189                fields.push(Some(&self.data[self.pos..self.pos + len]));
190                self.pos += len;
191            }
192        }
193
194        Ok(Some(fields))
195    }
196}
197
198fn read_i32(buf: &[u8], offset: usize) -> i32 {
199    i32::from_be_bytes([
200        buf[offset],
201        buf[offset + 1],
202        buf[offset + 2],
203        buf[offset + 3],
204    ])
205}
206
207fn read_i16(buf: &[u8], offset: usize) -> i16 {
208    i16::from_be_bytes([buf[offset], buf[offset + 1]])
209}