1use std::{
2 io::{BufReader, BufWriter, Read},
3 net::{SocketAddr, TcpStream},
4 thread::spawn,
5};
6
7use crate::serializer::Serializer;
8use nom::{IResult, Needed};
9
10use crate::{
11 parser::{parse_command, parse_reply},
12 serializer::CommandSerializer,
13 serializer::ReplySerializer,
14};
15
16const BUFFER_SIZE: usize = 1024 * 16;
17
18pub fn proxy_connection(downstream: TcpStream, target: &SocketAddr) -> std::io::Result<()> {
19 let upstream = TcpStream::connect(target)?;
20
21 let (downstream_reader, downstream_writer) = tcp_stream_pair(downstream)?;
22 let (upstream_reader, upstream_writer) = tcp_stream_pair(upstream)?;
23
24 let forward = spawn(move || {
26 pipe(
27 downstream_reader,
28 parse_command,
29 |command| command,
30 CommandSerializer::new(upstream_writer),
31 )
32 });
33 let backward = spawn(move || {
34 pipe(
35 upstream_reader,
36 parse_reply,
37 |reply| reply,
38 ReplySerializer::new(downstream_writer),
39 )
40 });
41
42 forward.join().unwrap()?;
43 backward.join().unwrap()?;
44
45 Ok(())
46}
47
48fn tcp_stream_pair(
49 stream: TcpStream,
50) -> std::io::Result<(BufReader<TcpStream>, BufWriter<TcpStream>)> {
51 let cloned = stream.try_clone()?;
52 Ok((BufReader::new(stream), BufWriter::new(cloned)))
53}
54
55fn pipe<O, R, P, H, S>(mut reader: R, parser: P, hook: H, mut serializer: S) -> std::io::Result<()>
56where
57 R: Read,
58 P: Fn(&[u8]) -> IResult<&[u8], O>,
59 H: Fn(O) -> O,
60 S: Serializer<O>,
61{
62 let mut buffer = [0u8; BUFFER_SIZE];
63 let mut buffer_index = 0;
64 loop {
65 parse_stream(&mut buffer, &mut buffer_index, &mut reader, &parser)
66 .and_then(|parsed| Ok(hook(parsed)))
67 .and_then(|parsed| serializer.serialize(&parsed))?;
68 }
69}
70
71fn parse_stream<O, R, P>(
72 buffer: &mut [u8],
73 buffer_index: &mut usize,
74 reader: &mut R,
75 parser: &P,
76) -> std::io::Result<O>
77where
78 R: Read,
79 P: Fn(&[u8]) -> IResult<&[u8], O>,
80{
81 loop {
82 let read_buffer = &buffer[..*buffer_index];
83 let parse_result = parser(read_buffer);
84 match parse_result {
85 Ok((unparsed, parsed)) => {
86 let unparsed_amount = unparsed.len();
88 if unparsed_amount > 0 {
89 let unparsed_index = (unparsed.as_ptr() as usize) - (buffer.as_ptr() as usize);
91
92 let unparsed_range = unparsed_index..unparsed_index + unparsed_amount;
94
95 buffer.copy_within(unparsed_range, 0);
97 *buffer_index = unparsed_amount;
98 } else {
99 *buffer_index = 0;
101 }
102
103 return Ok(parsed);
104 }
105 Err(nom::Err::Incomplete(Needed::Size(amount))) => {
106 let amount = amount.get();
107 reader.read_exact(&mut buffer[*buffer_index..(*buffer_index + amount)])?;
108 *buffer_index += amount;
109 }
110 Err(nom::Err::Incomplete(Needed::Unknown)) => {
111 let read_amount = reader.read(&mut buffer[*buffer_index..])?;
112 if read_amount == 0 {
113 return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof));
114 }
115
116 *buffer_index += read_amount;
117 }
118 Err(error) => {
119 dbg!(error);
120 panic!();
121 }
122 };
123 }
124}