1use bytes::{BufMut, BytesMut};
7
8use crate::constants::length;
9use crate::error::{Error, Result};
10
11#[derive(Debug)]
13pub struct WriteBuffer {
14 data: BytesMut,
16 max_capacity: Option<usize>,
18}
19
20impl WriteBuffer {
21 pub fn new() -> Self {
23 Self {
24 data: BytesMut::with_capacity(8192),
25 max_capacity: None,
26 }
27 }
28
29 pub fn with_capacity(capacity: usize) -> Self {
31 Self {
32 data: BytesMut::with_capacity(capacity),
33 max_capacity: None,
34 }
35 }
36
37 pub fn with_max_capacity(capacity: usize, max_capacity: usize) -> Self {
39 Self {
40 data: BytesMut::with_capacity(capacity),
41 max_capacity: Some(max_capacity),
42 }
43 }
44
45 #[inline]
47 pub fn len(&self) -> usize {
48 self.data.len()
49 }
50
51 #[inline]
53 pub fn is_empty(&self) -> bool {
54 self.data.is_empty()
55 }
56
57 #[inline]
59 pub fn capacity(&self) -> usize {
60 self.data.capacity()
61 }
62
63 #[inline]
65 pub fn remaining_capacity(&self) -> usize {
66 match self.max_capacity {
67 Some(max) => max.saturating_sub(self.data.len()),
68 None => usize::MAX - self.data.len(),
69 }
70 }
71
72 pub fn clear(&mut self) {
74 self.data.clear();
75 }
76
77 pub fn reserve(&mut self, additional: usize) {
79 self.data.reserve(additional);
80 }
81
82 pub fn as_slice(&self) -> &[u8] {
84 &self.data
85 }
86
87 pub fn into_inner(self) -> BytesMut {
89 self.data
90 }
91
92 pub fn freeze(self) -> bytes::Bytes {
94 self.data.freeze()
95 }
96
97 pub fn inner_mut(&mut self) -> &mut BytesMut {
99 &mut self.data
100 }
101
102 #[inline]
107 fn ensure_capacity(&self, n: usize) -> Result<()> {
108 if let Some(max) = self.max_capacity {
109 if self.data.len() + n > max {
110 return Err(Error::BufferOverflow {
111 needed: n,
112 available: max.saturating_sub(self.data.len()),
113 });
114 }
115 }
116 Ok(())
117 }
118
119 pub fn write_u8(&mut self, value: u8) -> Result<()> {
125 self.ensure_capacity(1)?;
126 self.data.put_u8(value);
127 Ok(())
128 }
129
130 pub fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
132 self.ensure_capacity(bytes.len())?;
133 self.data.put_slice(bytes);
134 Ok(())
135 }
136
137 pub fn write_zeros(&mut self, n: usize) -> Result<()> {
139 self.ensure_capacity(n)?;
140 for _ in 0..n {
141 self.data.put_u8(0);
142 }
143 Ok(())
144 }
145
146 pub fn write_u16_be(&mut self, value: u16) -> Result<()> {
152 self.ensure_capacity(2)?;
153 self.data.put_u16(value);
154 Ok(())
155 }
156
157 pub fn write_u16_le(&mut self, value: u16) -> Result<()> {
159 self.ensure_capacity(2)?;
160 self.data.put_u16_le(value);
161 Ok(())
162 }
163
164 pub fn write_u32_be(&mut self, value: u32) -> Result<()> {
166 self.ensure_capacity(4)?;
167 self.data.put_u32(value);
168 Ok(())
169 }
170
171 pub fn write_u64_be(&mut self, value: u64) -> Result<()> {
173 self.ensure_capacity(8)?;
174 self.data.put_u64(value);
175 Ok(())
176 }
177
178 #[inline]
184 pub fn write_ub1(&mut self, value: u8) -> Result<()> {
185 self.write_u8(value)
186 }
187
188 pub fn write_ub2(&mut self, value: u16) -> Result<()> {
195 match value {
196 0 => self.write_u8(0),
197 1..=255 => {
198 self.write_u8(1)?;
199 self.write_u8(value as u8)
200 }
201 _ => {
202 self.write_u8(2)?;
203 self.write_u16_be(value)
204 }
205 }
206 }
207
208 pub fn write_ub4(&mut self, value: u32) -> Result<()> {
216 match value {
217 0 => self.write_u8(0),
218 1..=255 => {
219 self.write_u8(1)?;
220 self.write_u8(value as u8)
221 }
222 256..=65535 => {
223 self.write_u8(2)?;
224 self.write_u16_be(value as u16)
225 }
226 _ => {
227 self.write_u8(4)?;
228 self.write_u32_be(value)
229 }
230 }
231 }
232
233 pub fn write_ub8(&mut self, value: u64) -> Result<()> {
242 match value {
243 0 => self.write_u8(0),
244 1..=255 => {
245 self.write_u8(1)?;
246 self.write_u8(value as u8)
247 }
248 256..=65535 => {
249 self.write_u8(2)?;
250 self.write_u16_be(value as u16)
251 }
252 65536..=4294967295 => {
253 self.write_u8(4)?;
254 self.write_u32_be(value as u32)
255 }
256 _ => {
257 self.write_u8(8)?;
258 self.write_u64_be(value)
259 }
260 }
261 }
262
263 pub fn write_bytes_with_length(&mut self, bytes: Option<&[u8]>) -> Result<()> {
271 const CHUNK_SIZE: usize = 32767;
273
274 match bytes {
275 None => self.write_u8(length::NULL_INDICATOR),
276 Some(data) => {
277 let len = data.len();
278 if len == 0 {
279 self.write_u8(0)
280 } else if len <= length::MAX_SHORT as usize {
281 self.write_u8(len as u8)?;
282 self.write_bytes(data)
283 } else {
284 self.write_u8(length::LONG_INDICATOR)?;
286 let mut offset = 0;
287 while offset < len {
288 let chunk_len = std::cmp::min(len - offset, CHUNK_SIZE);
289 self.write_ub4(chunk_len as u32)?;
290 self.write_bytes(&data[offset..offset + chunk_len])?;
291 offset += chunk_len;
292 }
293 self.write_ub4(0)
295 }
296 }
297 }
298 }
299
300 pub fn write_string_with_length(&mut self, s: Option<&str>) -> Result<()> {
302 self.write_bytes_with_length(s.map(|s| s.as_bytes()))
303 }
304
305 pub fn write_oracle_int(&mut self, value: i64) -> Result<()> {
310 if value == 0 {
311 return self.write_u8(0);
312 }
313
314 let (is_negative, abs_value) = if value < 0 {
315 (true, (-value) as u64)
316 } else {
317 (false, value as u64)
318 };
319
320 let len = ((64 - abs_value.leading_zeros() + 7) / 8) as u8;
322
323 let len_byte = if is_negative { len | 0x80 } else { len };
325 self.write_u8(len_byte)?;
326
327 for i in (0..len).rev() {
329 self.write_u8((abs_value >> (i * 8)) as u8)?;
330 }
331
332 Ok(())
333 }
334
335 pub fn write_oracle_uint(&mut self, value: u64) -> Result<()> {
337 if value == 0 {
338 return self.write_u8(0);
339 }
340
341 let len = ((64 - value.leading_zeros() + 7) / 8) as u8;
343
344 self.write_u8(len)?;
345
346 for i in (0..len).rev() {
348 self.write_u8((value >> (i * 8)) as u8)?;
349 }
350
351 Ok(())
352 }
353
354 pub fn truncate(&mut self, len: usize) {
358 self.data.truncate(len);
359 }
360
361 pub fn patch_u16_be(&mut self, pos: usize, value: u16) -> Result<()> {
366 if pos + 2 > self.data.len() {
367 return Err(Error::BufferOverflow {
368 needed: 2,
369 available: self.data.len().saturating_sub(pos),
370 });
371 }
372 let bytes = value.to_be_bytes();
373 self.data[pos] = bytes[0];
374 self.data[pos + 1] = bytes[1];
375 Ok(())
376 }
377
378 pub fn patch_u32_be(&mut self, pos: usize, value: u32) -> Result<()> {
380 if pos + 4 > self.data.len() {
381 return Err(Error::BufferOverflow {
382 needed: 4,
383 available: self.data.len().saturating_sub(pos),
384 });
385 }
386 let bytes = value.to_be_bytes();
387 self.data[pos] = bytes[0];
388 self.data[pos + 1] = bytes[1];
389 self.data[pos + 2] = bytes[2];
390 self.data[pos + 3] = bytes[3];
391 Ok(())
392 }
393}
394
395impl Default for WriteBuffer {
396 fn default() -> Self {
397 Self::new()
398 }
399}
400
401impl AsRef<[u8]> for WriteBuffer {
402 fn as_ref(&self) -> &[u8] {
403 &self.data
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410
411 #[test]
412 fn test_write_u8() {
413 let mut buf = WriteBuffer::new();
414 buf.write_u8(0x42).unwrap();
415 assert_eq!(buf.as_slice(), &[0x42]);
416 }
417
418 #[test]
419 fn test_write_bytes() {
420 let mut buf = WriteBuffer::new();
421 buf.write_bytes(&[0x01, 0x02, 0x03]).unwrap();
422 assert_eq!(buf.as_slice(), &[0x01, 0x02, 0x03]);
423 }
424
425 #[test]
426 fn test_write_u16_be() {
427 let mut buf = WriteBuffer::new();
428 buf.write_u16_be(0x0102).unwrap();
429 assert_eq!(buf.as_slice(), &[0x01, 0x02]);
430 }
431
432 #[test]
433 fn test_write_u32_be() {
434 let mut buf = WriteBuffer::new();
435 buf.write_u32_be(0x01020304).unwrap();
436 assert_eq!(buf.as_slice(), &[0x01, 0x02, 0x03, 0x04]);
437 }
438
439 #[test]
440 fn test_write_u64_be() {
441 let mut buf = WriteBuffer::new();
442 buf.write_u64_be(0x0102030405060708).unwrap();
443 assert_eq!(buf.as_slice(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]);
444 }
445
446 #[test]
447 fn test_write_ub2_zero() {
448 let mut buf = WriteBuffer::new();
449 buf.write_ub2(0).unwrap();
450 assert_eq!(buf.as_slice(), &[0x00]);
451 }
452
453 #[test]
454 fn test_write_ub2_short() {
455 let mut buf = WriteBuffer::new();
456 buf.write_ub2(0x42).unwrap();
457 assert_eq!(buf.as_slice(), &[0x01, 0x42]); }
459
460 #[test]
461 fn test_write_ub2_long() {
462 let mut buf = WriteBuffer::new();
463 buf.write_ub2(0x0102).unwrap();
464 assert_eq!(buf.as_slice(), &[0x02, 0x01, 0x02]); }
466
467 #[test]
468 fn test_write_ub4_zero() {
469 let mut buf = WriteBuffer::new();
470 buf.write_ub4(0).unwrap();
471 assert_eq!(buf.as_slice(), &[0x00]);
472 }
473
474 #[test]
475 fn test_write_ub4_short() {
476 let mut buf = WriteBuffer::new();
477 buf.write_ub4(0x42).unwrap();
478 assert_eq!(buf.as_slice(), &[0x01, 0x42]); }
480
481 #[test]
482 fn test_write_ub4_medium() {
483 let mut buf = WriteBuffer::new();
484 buf.write_ub4(0x0102).unwrap();
485 assert_eq!(buf.as_slice(), &[0x02, 0x01, 0x02]); }
487
488 #[test]
489 fn test_write_ub4_long() {
490 let mut buf = WriteBuffer::new();
491 buf.write_ub4(0x01020304).unwrap();
492 assert_eq!(buf.as_slice(), &[0x04, 0x01, 0x02, 0x03, 0x04]); }
494
495 #[test]
496 fn test_write_bytes_with_length_null() {
497 let mut buf = WriteBuffer::new();
498 buf.write_bytes_with_length(None).unwrap();
499 assert_eq!(buf.as_slice(), &[0xff]);
500 }
501
502 #[test]
503 fn test_write_bytes_with_length_empty() {
504 let mut buf = WriteBuffer::new();
505 buf.write_bytes_with_length(Some(&[])).unwrap();
506 assert_eq!(buf.as_slice(), &[0x00]);
507 }
508
509 #[test]
510 fn test_write_bytes_with_length_short() {
511 let mut buf = WriteBuffer::new();
512 buf.write_bytes_with_length(Some(&[0x41, 0x42, 0x43])).unwrap();
513 assert_eq!(buf.as_slice(), &[0x03, 0x41, 0x42, 0x43]);
514 }
515
516 #[test]
517 fn test_write_oracle_int_zero() {
518 let mut buf = WriteBuffer::new();
519 buf.write_oracle_int(0).unwrap();
520 assert_eq!(buf.as_slice(), &[0x00]);
521 }
522
523 #[test]
524 fn test_write_oracle_int_positive() {
525 let mut buf = WriteBuffer::new();
526 buf.write_oracle_int(258).unwrap();
527 assert_eq!(buf.as_slice(), &[0x02, 0x01, 0x02]);
529 }
530
531 #[test]
532 fn test_write_oracle_int_negative() {
533 let mut buf = WriteBuffer::new();
534 buf.write_oracle_int(-258).unwrap();
535 assert_eq!(buf.as_slice(), &[0x82, 0x01, 0x02]);
537 }
538
539 #[test]
540 fn test_patch_u16_be() {
541 let mut buf = WriteBuffer::new();
542 buf.write_u16_be(0x0000).unwrap(); buf.write_u8(0x42).unwrap();
544 buf.patch_u16_be(0, 0x1234).unwrap();
545 assert_eq!(buf.as_slice(), &[0x12, 0x34, 0x42]);
546 }
547
548 #[test]
549 fn test_patch_u32_be() {
550 let mut buf = WriteBuffer::new();
551 buf.write_u32_be(0x00000000).unwrap(); buf.write_u8(0x42).unwrap();
553 buf.patch_u32_be(0, 0x12345678).unwrap();
554 assert_eq!(buf.as_slice(), &[0x12, 0x34, 0x56, 0x78, 0x42]);
555 }
556
557 #[test]
558 fn test_max_capacity() {
559 let mut buf = WriteBuffer::with_max_capacity(10, 5);
560 buf.write_bytes(&[0x01, 0x02, 0x03, 0x04, 0x05]).unwrap();
561 assert!(buf.write_u8(0x06).is_err());
562 }
563
564 #[test]
565 fn test_roundtrip_ub2() {
566 use crate::buffer::ReadBuffer;
567
568 for value in [0u16, 1, 100, 253, 254, 255, 1000, 10000, 65535] {
569 let mut write_buf = WriteBuffer::new();
570 write_buf.write_ub2(value).unwrap();
571
572 let mut read_buf = ReadBuffer::from_slice(write_buf.as_slice());
573 let read_value = read_buf.read_ub2().unwrap();
574
575 assert_eq!(value, read_value, "UB2 roundtrip failed for {}", value);
576 }
577 }
578
579 #[test]
580 fn test_roundtrip_ub4() {
581 use crate::buffer::ReadBuffer;
582
583 for value in [0u32, 1, 100, 253, 254, 255, 1000, 100000, 0xFFFFFFFF] {
584 let mut write_buf = WriteBuffer::new();
585 write_buf.write_ub4(value).unwrap();
586
587 let mut read_buf = ReadBuffer::from_slice(write_buf.as_slice());
588 let read_value = read_buf.read_ub4().unwrap();
589
590 assert_eq!(value, read_value, "UB4 roundtrip failed for {}", value);
591 }
592 }
593
594 #[test]
617 fn test_wire_long_data_chunked_format() {
618 let mut buf = WriteBuffer::new();
619
620 let long_data: Vec<u8> = (0..300u16).map(|i| (i % 256) as u8).collect();
622 buf.write_bytes_with_length(Some(&long_data)).unwrap();
623
624 let result = buf.as_slice();
625
626 assert_eq!(result[0], 0xFE,
628 "Long data must start with TNS_LONG_LENGTH_INDICATOR (0xFE)");
629
630 assert_eq!(result[1], 2, "ub4(300) prefix: 2 bytes follow");
634 let chunk_len = u16::from_be_bytes([result[2], result[3]]);
635 assert_eq!(chunk_len as usize, long_data.len(),
636 "Chunk length must match data length");
637
638 assert_eq!(&result[4..4 + long_data.len()], &long_data[..]);
640
641 let term_pos = 4 + long_data.len();
643 assert_eq!(result[term_pos], 0x00,
644 "Chunked data must end with ub4(0) terminator");
645 assert_eq!(result.len(), term_pos + 1,
646 "Total length: 1 (0xFE) + 3 (ub4(300)) + 300 (data) + 1 (ub4(0))");
647 }
648
649 #[test]
651 fn test_wire_short_data_single_byte_length() {
652 let mut buf = WriteBuffer::new();
653
654 let short_data: Vec<u8> = (0..252u16).map(|i| (i % 256) as u8).collect();
656 buf.write_bytes_with_length(Some(&short_data)).unwrap();
657
658 let result = buf.as_slice();
659
660 assert_eq!(result[0], 252,
662 "Short data length must be single byte");
663
664 assert_eq!(result.len(), 253);
666 }
667
668 #[test]
670 fn test_wire_max_short_length_is_252() {
671 let mut buf252 = WriteBuffer::new();
673 let data252: Vec<u8> = vec![0xAA; 252];
674 buf252.write_bytes_with_length(Some(&data252)).unwrap();
675 assert_eq!(buf252.as_slice()[0], 252, "252 bytes uses short format");
676
677 let mut buf253 = WriteBuffer::new();
679 let data253: Vec<u8> = vec![0xAA; 253];
680 buf253.write_bytes_with_length(Some(&data253)).unwrap();
681 assert_eq!(buf253.as_slice()[0], 0xFE, "253 bytes uses long format");
682
683 }
685}