oracle_rs/buffer/
write.rs

1//! Write buffer for encoding TNS protocol data
2//!
3//! Provides methods for writing various data types to a byte buffer,
4//! following the Oracle TNS wire format conventions.
5
6use bytes::{BufMut, BytesMut};
7
8use crate::constants::length;
9use crate::error::{Error, Result};
10
11/// A buffer for writing TNS protocol data
12#[derive(Debug)]
13pub struct WriteBuffer {
14    /// The underlying byte buffer
15    data: BytesMut,
16    /// Maximum capacity (for packet size limits)
17    max_capacity: Option<usize>,
18}
19
20impl WriteBuffer {
21    /// Create a new WriteBuffer with default capacity
22    pub fn new() -> Self {
23        Self {
24            data: BytesMut::with_capacity(8192),
25            max_capacity: None,
26        }
27    }
28
29    /// Create a new WriteBuffer with specified capacity
30    pub fn with_capacity(capacity: usize) -> Self {
31        Self {
32            data: BytesMut::with_capacity(capacity),
33            max_capacity: None,
34        }
35    }
36
37    /// Create a new WriteBuffer with a maximum capacity limit
38    pub fn with_max_capacity(capacity: usize, max_capacity: usize) -> Self {
39        Self {
40            data: BytesMut::with_capacity(capacity),
41            max_capacity: Some(max_capacity),
42        }
43    }
44
45    /// Get the current length of data in the buffer
46    #[inline]
47    pub fn len(&self) -> usize {
48        self.data.len()
49    }
50
51    /// Check if the buffer is empty
52    #[inline]
53    pub fn is_empty(&self) -> bool {
54        self.data.is_empty()
55    }
56
57    /// Get the capacity of the buffer
58    #[inline]
59    pub fn capacity(&self) -> usize {
60        self.data.capacity()
61    }
62
63    /// Get the remaining writable space
64    #[inline]
65    pub fn remaining_capacity(&self) -> usize {
66        match self.max_capacity {
67            Some(max) => max.saturating_sub(self.data.len()),
68            None => usize::MAX - self.data.len(),
69        }
70    }
71
72    /// Clear the buffer
73    pub fn clear(&mut self) {
74        self.data.clear();
75    }
76
77    /// Reserve additional capacity
78    pub fn reserve(&mut self, additional: usize) {
79        self.data.reserve(additional);
80    }
81
82    /// Get the buffer contents as a byte slice
83    pub fn as_slice(&self) -> &[u8] {
84        &self.data
85    }
86
87    /// Consume the buffer and return the underlying BytesMut
88    pub fn into_inner(self) -> BytesMut {
89        self.data
90    }
91
92    /// Freeze the buffer into immutable Bytes
93    pub fn freeze(self) -> bytes::Bytes {
94        self.data.freeze()
95    }
96
97    /// Get mutable access to the underlying BytesMut
98    pub fn inner_mut(&mut self) -> &mut BytesMut {
99        &mut self.data
100    }
101
102    // =========================================================================
103    // Internal helpers
104    // =========================================================================
105
106    #[inline]
107    fn ensure_capacity(&self, n: usize) -> Result<()> {
108        if let Some(max) = self.max_capacity {
109            if self.data.len() + n > max {
110                return Err(Error::BufferOverflow {
111                    needed: n,
112                    available: max.saturating_sub(self.data.len()),
113                });
114            }
115        }
116        Ok(())
117    }
118
119    // =========================================================================
120    // Raw byte writes
121    // =========================================================================
122
123    /// Write a single byte
124    pub fn write_u8(&mut self, value: u8) -> Result<()> {
125        self.ensure_capacity(1)?;
126        self.data.put_u8(value);
127        Ok(())
128    }
129
130    /// Write raw bytes
131    pub fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
132        self.ensure_capacity(bytes.len())?;
133        self.data.put_slice(bytes);
134        Ok(())
135    }
136
137    /// Write zeros
138    pub fn write_zeros(&mut self, n: usize) -> Result<()> {
139        self.ensure_capacity(n)?;
140        for _ in 0..n {
141            self.data.put_u8(0);
142        }
143        Ok(())
144    }
145
146    // =========================================================================
147    // Big-endian integer writes (network byte order)
148    // =========================================================================
149
150    /// Write a 16-bit unsigned integer in big-endian format
151    pub fn write_u16_be(&mut self, value: u16) -> Result<()> {
152        self.ensure_capacity(2)?;
153        self.data.put_u16(value);
154        Ok(())
155    }
156
157    /// Write a 16-bit unsigned integer in little-endian format
158    pub fn write_u16_le(&mut self, value: u16) -> Result<()> {
159        self.ensure_capacity(2)?;
160        self.data.put_u16_le(value);
161        Ok(())
162    }
163
164    /// Write a 32-bit unsigned integer in big-endian format
165    pub fn write_u32_be(&mut self, value: u32) -> Result<()> {
166        self.ensure_capacity(4)?;
167        self.data.put_u32(value);
168        Ok(())
169    }
170
171    /// Write a 64-bit unsigned integer in big-endian format
172    pub fn write_u64_be(&mut self, value: u64) -> Result<()> {
173        self.ensure_capacity(8)?;
174        self.data.put_u64(value);
175        Ok(())
176    }
177
178    // =========================================================================
179    // TNS-specific writes (Oracle's variable-length encoding)
180    // =========================================================================
181
182    /// Write a TNS UB1 (unsigned byte)
183    #[inline]
184    pub fn write_ub1(&mut self, value: u8) -> Result<()> {
185        self.write_u8(value)
186    }
187
188    /// Write a TNS UB2 (unsigned 2-byte, variable length encoded)
189    ///
190    /// TNS UB2 uses a length-prefixed encoding where:
191    /// - 0: write 0x00
192    /// - 1-255: write 0x01 + 1 byte
193    /// - 256-65535: write 0x02 + 2 bytes (big-endian)
194    pub fn write_ub2(&mut self, value: u16) -> Result<()> {
195        match value {
196            0 => self.write_u8(0),
197            1..=255 => {
198                self.write_u8(1)?;
199                self.write_u8(value as u8)
200            }
201            _ => {
202                self.write_u8(2)?;
203                self.write_u16_be(value)
204            }
205        }
206    }
207
208    /// Write a TNS UB4 (unsigned 4-byte, variable length encoded)
209    ///
210    /// TNS UB4 uses a variable-length encoding where:
211    /// - 0: write 0x00
212    /// - 1-255: write 0x01 + 1 byte
213    /// - 256-65535: write 0x02 + 2 bytes (big-endian)
214    /// - > 65535: write 0x04 + 4 bytes (big-endian)
215    pub fn write_ub4(&mut self, value: u32) -> Result<()> {
216        match value {
217            0 => self.write_u8(0),
218            1..=255 => {
219                self.write_u8(1)?;
220                self.write_u8(value as u8)
221            }
222            256..=65535 => {
223                self.write_u8(2)?;
224                self.write_u16_be(value as u16)
225            }
226            _ => {
227                self.write_u8(4)?;
228                self.write_u32_be(value)
229            }
230        }
231    }
232
233    /// Write a TNS UB8 (unsigned 8-byte, variable length encoded)
234    ///
235    /// TNS UB8 uses a variable-length encoding where:
236    /// - 0: write 0x00
237    /// - 1-255: write 0x01 + 1 byte
238    /// - 256-65535: write 0x02 + 2 bytes (big-endian)
239    /// - 65536-4294967295: write 0x04 + 4 bytes (big-endian)
240    /// - > 4294967295: write 0x08 + 8 bytes (big-endian)
241    pub fn write_ub8(&mut self, value: u64) -> Result<()> {
242        match value {
243            0 => self.write_u8(0),
244            1..=255 => {
245                self.write_u8(1)?;
246                self.write_u8(value as u8)
247            }
248            256..=65535 => {
249                self.write_u8(2)?;
250                self.write_u16_be(value as u16)
251            }
252            65536..=4294967295 => {
253                self.write_u8(4)?;
254                self.write_u32_be(value as u32)
255            }
256            _ => {
257                self.write_u8(8)?;
258                self.write_u64_be(value)
259            }
260        }
261    }
262
263    /// Write a TNS length-prefixed byte sequence
264    ///
265    /// If bytes is None, writes NULL_INDICATOR (255).
266    /// For data > 252 bytes, uses chunked encoding:
267    /// - Write LONG_INDICATOR (254)
268    /// - For each chunk up to 32767 bytes: write ub4(chunk_len) + raw bytes
269    /// - Write ub4(0) to terminate
270    pub fn write_bytes_with_length(&mut self, bytes: Option<&[u8]>) -> Result<()> {
271        /// Maximum chunk size for long data (TNS_CHUNK_SIZE from Python)
272        const CHUNK_SIZE: usize = 32767;
273
274        match bytes {
275            None => self.write_u8(length::NULL_INDICATOR),
276            Some(data) => {
277                let len = data.len();
278                if len == 0 {
279                    self.write_u8(0)
280                } else if len <= length::MAX_SHORT as usize {
281                    self.write_u8(len as u8)?;
282                    self.write_bytes(data)
283                } else {
284                    // Chunked encoding for long data
285                    self.write_u8(length::LONG_INDICATOR)?;
286                    let mut offset = 0;
287                    while offset < len {
288                        let chunk_len = std::cmp::min(len - offset, CHUNK_SIZE);
289                        self.write_ub4(chunk_len as u32)?;
290                        self.write_bytes(&data[offset..offset + chunk_len])?;
291                        offset += chunk_len;
292                    }
293                    // Terminating zero
294                    self.write_ub4(0)
295                }
296            }
297        }
298    }
299
300    /// Write a TNS length-prefixed string (UTF-8)
301    pub fn write_string_with_length(&mut self, s: Option<&str>) -> Result<()> {
302        self.write_bytes_with_length(s.map(|s| s.as_bytes()))
303    }
304
305    /// Write an Oracle-encoded integer
306    ///
307    /// Oracle encodes integers with a length prefix indicating the number of bytes.
308    /// Negative numbers have the high bit set in the length byte.
309    pub fn write_oracle_int(&mut self, value: i64) -> Result<()> {
310        if value == 0 {
311            return self.write_u8(0);
312        }
313
314        let (is_negative, abs_value) = if value < 0 {
315            (true, (-value) as u64)
316        } else {
317            (false, value as u64)
318        };
319
320        // Calculate the number of bytes needed
321        let len = ((64 - abs_value.leading_zeros() + 7) / 8) as u8;
322
323        // Write length byte (with sign bit if negative)
324        let len_byte = if is_negative { len | 0x80 } else { len };
325        self.write_u8(len_byte)?;
326
327        // Write bytes in big-endian order
328        for i in (0..len).rev() {
329            self.write_u8((abs_value >> (i * 8)) as u8)?;
330        }
331
332        Ok(())
333    }
334
335    /// Write an Oracle-encoded unsigned integer
336    pub fn write_oracle_uint(&mut self, value: u64) -> Result<()> {
337        if value == 0 {
338            return self.write_u8(0);
339        }
340
341        // Calculate the number of bytes needed
342        let len = ((64 - value.leading_zeros() + 7) / 8) as u8;
343
344        self.write_u8(len)?;
345
346        // Write bytes in big-endian order
347        for i in (0..len).rev() {
348            self.write_u8((value >> (i * 8)) as u8)?;
349        }
350
351        Ok(())
352    }
353
354    /// Set the current position (for patching previously written data)
355    ///
356    /// Note: This truncates the buffer to the given position
357    pub fn truncate(&mut self, len: usize) {
358        self.data.truncate(len);
359    }
360
361    /// Patch a u16 at a specific position (big-endian)
362    ///
363    /// This allows writing a placeholder and then patching it later
364    /// (e.g., for packet length fields)
365    pub fn patch_u16_be(&mut self, pos: usize, value: u16) -> Result<()> {
366        if pos + 2 > self.data.len() {
367            return Err(Error::BufferOverflow {
368                needed: 2,
369                available: self.data.len().saturating_sub(pos),
370            });
371        }
372        let bytes = value.to_be_bytes();
373        self.data[pos] = bytes[0];
374        self.data[pos + 1] = bytes[1];
375        Ok(())
376    }
377
378    /// Patch a u32 at a specific position (big-endian)
379    pub fn patch_u32_be(&mut self, pos: usize, value: u32) -> Result<()> {
380        if pos + 4 > self.data.len() {
381            return Err(Error::BufferOverflow {
382                needed: 4,
383                available: self.data.len().saturating_sub(pos),
384            });
385        }
386        let bytes = value.to_be_bytes();
387        self.data[pos] = bytes[0];
388        self.data[pos + 1] = bytes[1];
389        self.data[pos + 2] = bytes[2];
390        self.data[pos + 3] = bytes[3];
391        Ok(())
392    }
393}
394
395impl Default for WriteBuffer {
396    fn default() -> Self {
397        Self::new()
398    }
399}
400
401impl AsRef<[u8]> for WriteBuffer {
402    fn as_ref(&self) -> &[u8] {
403        &self.data
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410
411    #[test]
412    fn test_write_u8() {
413        let mut buf = WriteBuffer::new();
414        buf.write_u8(0x42).unwrap();
415        assert_eq!(buf.as_slice(), &[0x42]);
416    }
417
418    #[test]
419    fn test_write_bytes() {
420        let mut buf = WriteBuffer::new();
421        buf.write_bytes(&[0x01, 0x02, 0x03]).unwrap();
422        assert_eq!(buf.as_slice(), &[0x01, 0x02, 0x03]);
423    }
424
425    #[test]
426    fn test_write_u16_be() {
427        let mut buf = WriteBuffer::new();
428        buf.write_u16_be(0x0102).unwrap();
429        assert_eq!(buf.as_slice(), &[0x01, 0x02]);
430    }
431
432    #[test]
433    fn test_write_u32_be() {
434        let mut buf = WriteBuffer::new();
435        buf.write_u32_be(0x01020304).unwrap();
436        assert_eq!(buf.as_slice(), &[0x01, 0x02, 0x03, 0x04]);
437    }
438
439    #[test]
440    fn test_write_u64_be() {
441        let mut buf = WriteBuffer::new();
442        buf.write_u64_be(0x0102030405060708).unwrap();
443        assert_eq!(buf.as_slice(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]);
444    }
445
446    #[test]
447    fn test_write_ub2_zero() {
448        let mut buf = WriteBuffer::new();
449        buf.write_ub2(0).unwrap();
450        assert_eq!(buf.as_slice(), &[0x00]);
451    }
452
453    #[test]
454    fn test_write_ub2_short() {
455        let mut buf = WriteBuffer::new();
456        buf.write_ub2(0x42).unwrap();
457        assert_eq!(buf.as_slice(), &[0x01, 0x42]); // Length 1, then value
458    }
459
460    #[test]
461    fn test_write_ub2_long() {
462        let mut buf = WriteBuffer::new();
463        buf.write_ub2(0x0102).unwrap();
464        assert_eq!(buf.as_slice(), &[0x02, 0x01, 0x02]); // Length 2, then big-endian u16
465    }
466
467    #[test]
468    fn test_write_ub4_zero() {
469        let mut buf = WriteBuffer::new();
470        buf.write_ub4(0).unwrap();
471        assert_eq!(buf.as_slice(), &[0x00]);
472    }
473
474    #[test]
475    fn test_write_ub4_short() {
476        let mut buf = WriteBuffer::new();
477        buf.write_ub4(0x42).unwrap();
478        assert_eq!(buf.as_slice(), &[0x01, 0x42]); // Length 1, then value
479    }
480
481    #[test]
482    fn test_write_ub4_medium() {
483        let mut buf = WriteBuffer::new();
484        buf.write_ub4(0x0102).unwrap();
485        assert_eq!(buf.as_slice(), &[0x02, 0x01, 0x02]); // Length 2, then big-endian u16
486    }
487
488    #[test]
489    fn test_write_ub4_long() {
490        let mut buf = WriteBuffer::new();
491        buf.write_ub4(0x01020304).unwrap();
492        assert_eq!(buf.as_slice(), &[0x04, 0x01, 0x02, 0x03, 0x04]); // Length 4, then big-endian u32
493    }
494
495    #[test]
496    fn test_write_bytes_with_length_null() {
497        let mut buf = WriteBuffer::new();
498        buf.write_bytes_with_length(None).unwrap();
499        assert_eq!(buf.as_slice(), &[0xff]);
500    }
501
502    #[test]
503    fn test_write_bytes_with_length_empty() {
504        let mut buf = WriteBuffer::new();
505        buf.write_bytes_with_length(Some(&[])).unwrap();
506        assert_eq!(buf.as_slice(), &[0x00]);
507    }
508
509    #[test]
510    fn test_write_bytes_with_length_short() {
511        let mut buf = WriteBuffer::new();
512        buf.write_bytes_with_length(Some(&[0x41, 0x42, 0x43])).unwrap();
513        assert_eq!(buf.as_slice(), &[0x03, 0x41, 0x42, 0x43]);
514    }
515
516    #[test]
517    fn test_write_oracle_int_zero() {
518        let mut buf = WriteBuffer::new();
519        buf.write_oracle_int(0).unwrap();
520        assert_eq!(buf.as_slice(), &[0x00]);
521    }
522
523    #[test]
524    fn test_write_oracle_int_positive() {
525        let mut buf = WriteBuffer::new();
526        buf.write_oracle_int(258).unwrap();
527        // 258 = 0x0102, needs 2 bytes
528        assert_eq!(buf.as_slice(), &[0x02, 0x01, 0x02]);
529    }
530
531    #[test]
532    fn test_write_oracle_int_negative() {
533        let mut buf = WriteBuffer::new();
534        buf.write_oracle_int(-258).unwrap();
535        // -258, needs 2 bytes, sign bit set in length
536        assert_eq!(buf.as_slice(), &[0x82, 0x01, 0x02]);
537    }
538
539    #[test]
540    fn test_patch_u16_be() {
541        let mut buf = WriteBuffer::new();
542        buf.write_u16_be(0x0000).unwrap(); // Placeholder
543        buf.write_u8(0x42).unwrap();
544        buf.patch_u16_be(0, 0x1234).unwrap();
545        assert_eq!(buf.as_slice(), &[0x12, 0x34, 0x42]);
546    }
547
548    #[test]
549    fn test_patch_u32_be() {
550        let mut buf = WriteBuffer::new();
551        buf.write_u32_be(0x00000000).unwrap(); // Placeholder
552        buf.write_u8(0x42).unwrap();
553        buf.patch_u32_be(0, 0x12345678).unwrap();
554        assert_eq!(buf.as_slice(), &[0x12, 0x34, 0x56, 0x78, 0x42]);
555    }
556
557    #[test]
558    fn test_max_capacity() {
559        let mut buf = WriteBuffer::with_max_capacity(10, 5);
560        buf.write_bytes(&[0x01, 0x02, 0x03, 0x04, 0x05]).unwrap();
561        assert!(buf.write_u8(0x06).is_err());
562    }
563
564    #[test]
565    fn test_roundtrip_ub2() {
566        use crate::buffer::ReadBuffer;
567
568        for value in [0u16, 1, 100, 253, 254, 255, 1000, 10000, 65535] {
569            let mut write_buf = WriteBuffer::new();
570            write_buf.write_ub2(value).unwrap();
571
572            let mut read_buf = ReadBuffer::from_slice(write_buf.as_slice());
573            let read_value = read_buf.read_ub2().unwrap();
574
575            assert_eq!(value, read_value, "UB2 roundtrip failed for {}", value);
576        }
577    }
578
579    #[test]
580    fn test_roundtrip_ub4() {
581        use crate::buffer::ReadBuffer;
582
583        for value in [0u32, 1, 100, 253, 254, 255, 1000, 100000, 0xFFFFFFFF] {
584            let mut write_buf = WriteBuffer::new();
585            write_buf.write_ub4(value).unwrap();
586
587            let mut read_buf = ReadBuffer::from_slice(write_buf.as_slice());
588            let read_value = read_buf.read_ub4().unwrap();
589
590            assert_eq!(value, read_value, "UB4 roundtrip failed for {}", value);
591        }
592    }
593
594    // =========================================================================
595    // WIRE-LEVEL PROTOCOL TESTS
596    // These tests document specific protocol details learned during development.
597    // They serve as reference for anyone implementing Oracle/TNS protocols.
598    // =========================================================================
599
600    /// TNS length-prefixed data format:
601    ///
602    /// Short data (length <= 252 bytes):
603    ///   [length: u8] [data: bytes]
604    ///
605    /// Long data (length > 252 bytes):
606    ///   [0xFE] [chunk1_len: u32be] [chunk1_data] ... [0x00 0x00 0x00 0x00]
607    ///
608    /// The threshold is 252 (0xFC), NOT 254 or 255, because:
609    ///   - 253 (0xFD) is reserved
610    ///   - 254 (0xFE) is TNS_LONG_LENGTH_INDICATOR
611    ///   - 255 (0xFF) is TNS_NULL_LENGTH_INDICATOR
612    ///
613    /// For pickle format (collections), the threshold is 245 (TNS_OBJ_MAX_SHORT_LENGTH).
614    ///
615    /// Reference: Python python-oracledb buffer.pyx _write_raw_bytes_and_length
616    #[test]
617    fn test_wire_long_data_chunked_format() {
618        let mut buf = WriteBuffer::new();
619
620        // Create data that exceeds short format (>252 bytes)
621        let long_data: Vec<u8> = (0..300u16).map(|i| (i % 256) as u8).collect();
622        buf.write_bytes_with_length(Some(&long_data)).unwrap();
623
624        let result = buf.as_slice();
625
626        // First byte must be LONG_INDICATOR (0xFE = 254)
627        assert_eq!(result[0], 0xFE,
628            "Long data must start with TNS_LONG_LENGTH_INDICATOR (0xFE)");
629
630        // Chunk length uses write_ub4 (variable-length encoding):
631        // - 300 decimal = 0x012C
632        // - write_ub4(300) writes: [0x02, 0x01, 0x2C] (prefix 2 = "2 bytes follow", then BE value)
633        assert_eq!(result[1], 2, "ub4(300) prefix: 2 bytes follow");
634        let chunk_len = u16::from_be_bytes([result[2], result[3]]);
635        assert_eq!(chunk_len as usize, long_data.len(),
636            "Chunk length must match data length");
637
638        // Followed by data starting at byte 4
639        assert_eq!(&result[4..4 + long_data.len()], &long_data[..]);
640
641        // Ends with terminating zero using write_ub4(0) = single 0x00 byte
642        let term_pos = 4 + long_data.len();
643        assert_eq!(result[term_pos], 0x00,
644            "Chunked data must end with ub4(0) terminator");
645        assert_eq!(result.len(), term_pos + 1,
646            "Total length: 1 (0xFE) + 3 (ub4(300)) + 300 (data) + 1 (ub4(0))");
647    }
648
649    /// Short data format (<=252 bytes) uses single-byte length prefix
650    #[test]
651    fn test_wire_short_data_single_byte_length() {
652        let mut buf = WriteBuffer::new();
653
654        // Data exactly at threshold (252 bytes)
655        let short_data: Vec<u8> = (0..252u16).map(|i| (i % 256) as u8).collect();
656        buf.write_bytes_with_length(Some(&short_data)).unwrap();
657
658        let result = buf.as_slice();
659
660        // First byte is length (252 = 0xFC)
661        assert_eq!(result[0], 252,
662            "Short data length must be single byte");
663
664        // Total length: 1 (length byte) + 252 (data)
665        assert_eq!(result.len(), 253);
666    }
667
668    /// TNS_MAX_SHORT_LENGTH is 252, NOT 253/254/255
669    #[test]
670    fn test_wire_max_short_length_is_252() {
671        // 252 bytes - should use short format
672        let mut buf252 = WriteBuffer::new();
673        let data252: Vec<u8> = vec![0xAA; 252];
674        buf252.write_bytes_with_length(Some(&data252)).unwrap();
675        assert_eq!(buf252.as_slice()[0], 252, "252 bytes uses short format");
676
677        // 253 bytes - should use long format
678        let mut buf253 = WriteBuffer::new();
679        let data253: Vec<u8> = vec![0xAA; 253];
680        buf253.write_bytes_with_length(Some(&data253)).unwrap();
681        assert_eq!(buf253.as_slice()[0], 0xFE, "253 bytes uses long format");
682
683        // This catches the common mistake of using >= 254 as threshold
684    }
685}