1use std::io::{stdout, BufWriter, Read, Write};
78use std::net::{SocketAddr, TcpStream, UdpSocket};
79use std::thread::{spawn, Builder, JoinHandle};
80
81use mproxy_client::target_socket_interface;
82use mproxy_server::upstream_socket_interface;
83
84const BUFSIZE: usize = 8096;
85
86pub fn forward_udp(listen_addr: String, downstream_addrs: &[String], tee: bool) -> JoinHandle<()> {
89 let (_addr, listen_socket) =
90 upstream_socket_interface(listen_addr).expect("binding server socket listener");
91 let mut output_buffer = BufWriter::new(stdout());
92 let targets: Vec<(SocketAddr, UdpSocket)> = downstream_addrs
93 .iter()
94 .map(|t| target_socket_interface(t).expect("binding client socket sender"))
95 .collect();
96 let mut buf = [0u8; BUFSIZE]; Builder::new()
98 .name(format!("{:#?}", listen_socket))
99 .spawn(move || {
100 listen_socket.set_broadcast(true).unwrap();
102 loop {
103 match listen_socket.recv_from(&mut buf[0..]) {
104 Ok((c, _remote_addr)) => {
105 for (target_addr, target_socket) in &targets {
106 if !(target_addr.is_ipv6() && target_addr.ip().is_multicast()) {
107 target_socket
108 .send_to(&buf[0..c], target_addr)
109 .unwrap_or_else(|e| panic!("sending to server socket: {}", e));
110 } else {
111 target_socket
112 .send(&buf[0..c])
113 .unwrap_or_else(|e| panic!("sending to server socket: {}", e));
114 }
115 }
116 if tee {
117 let _o = output_buffer
118 .write(&buf[0..c])
119 .expect("writing to output buffer");
120 #[cfg(debug_assertions)]
121 assert!(c == _o);
122 }
123 }
124 Err(err) => {
125 eprintln!("forward_udp: got an error: {}", err);
127 #[cfg(debug_assertions)]
128 panic!("forward_udp: got an error: {}", err);
129 }
130 }
131 output_buffer.flush().unwrap();
132 }
133 })
134 .unwrap()
135}
136
137pub fn proxy_gateway(
139 downstream_addrs: &[String],
140 listen_addrs: &[String],
141 tee: bool,
142) -> Vec<JoinHandle<()>> {
143 let mut threads: Vec<JoinHandle<()>> = vec![];
144 for listen_addr in listen_addrs {
145 #[cfg(debug_assertions)]
146 println!(
147 "proxy: forwarding {:?} -> {:?}",
148 listen_addr, downstream_addrs
149 );
150 threads.push(forward_udp(listen_addr.to_string(), downstream_addrs, tee));
151 }
152 threads
153}
154
155pub fn proxy_tcp_udp(upstream_tcp: String, downstream_udp: String) -> JoinHandle<()> {
159 let mut buf = [0u8; BUFSIZE];
160
161 #[cfg(debug_assertions)]
162 println!(
163 "proxy: forwarding TCP {:?} -> UDP {:?}",
164 upstream_tcp, downstream_udp
165 );
166
167 spawn(move || loop {
168 let target = target_socket_interface(&downstream_udp);
169
170 let (target_addr, target_socket) = if let Ok((target_addr, target_socket)) = target {
171 (target_addr, target_socket)
172 } else {
173 println!("Retrying...");
174 std::thread::sleep(std::time::Duration::from_secs(5));
175 continue;
176 };
177
178 #[cfg(feature = "tls")]
179 let (mut conn, mut stream) =
180 if let Ok((conn, stream)) = tls_connection(upstream_tcp.clone()) {
181 (conn, stream)
182 } else {
183 println!("Retrying...");
184 std::thread::sleep(std::time::Duration::from_secs(5));
185 continue;
186 };
187 #[cfg(feature = "tls")]
188 let mut stream = TlsStream::new(&mut conn, &mut stream);
189 #[cfg(not(feature = "tls"))]
190 let stream = TcpStream::connect(upstream_tcp.clone());
191 #[cfg(not(feature = "tls"))]
192 let mut stream = if let Ok(s) = stream {
193 s
194 } else {
195 println!("Retrying...");
196 std::thread::sleep(std::time::Duration::from_secs(5));
197 continue;
198 };
199
200 loop {
201 match stream.read(&mut buf[0..]) {
202 Ok(c) => {
203 if c == 0 {
204 eprintln!("encountered EOF, disconnecting TCP proxy thread...");
205 break;
206 }
207 if !(target_addr.is_ipv6() && target_addr.ip().is_multicast()) {
208 target_socket
209 .send_to(&buf[0..c], target_addr)
210 .expect("sending to UDP socket");
211 } else {
212 target_socket
213 .send(&buf[0..c])
214 .expect("sending to UDP socket");
215 }
216 }
217 Err(e) => {
218 eprintln!("err: {}", e);
219 break;
220 }
221 }
222 }
223 println!("Retrying...");
224 std::thread::sleep(std::time::Duration::from_secs(5))
225 })
226}
227
228#[cfg(feature = "tls")]
229use rustls::client::{ClientConfig, ClientConnection, ServerName};
230#[cfg(feature = "tls")]
231use rustls::Stream as TlsStream;
232#[cfg(feature = "tls")]
233use std::sync::Arc;
234#[cfg(feature = "tls")]
235use webpki_roots::TLS_SERVER_ROOTS;
236
237#[cfg(feature = "tls")]
238pub fn tls_connection(
239 tls_connect_addr: String,
240) -> Result<(ClientConnection, TcpStream), Box<dyn std::error::Error>> {
241 let mut root_store = rustls::RootCertStore::empty();
242 root_store.add_server_trust_anchors(TLS_SERVER_ROOTS.0.iter().map(|ta| {
243 rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
244 ta.subject,
245 ta.spki,
246 ta.name_constraints,
247 )
248 }));
249 let config = rustls::ClientConfig::builder()
250 .with_safe_defaults()
251 .with_root_certificates(root_store)
252 .with_no_client_auth();
253 let rc_config: Arc<ClientConfig> = Arc::new(config);
254 let dns_name: String = tls_connect_addr.split(':').next().unwrap().to_string();
255 let server_name = ServerName::try_from(dns_name.as_str());
256 let server_name = if let Ok(name) = server_name {
257 name
258 } else {
259 return Err(format!("Resolving DNS for {}", dns_name).into());
260 };
261 let conn = rustls::ClientConnection::new(rc_config, server_name);
262 let mut conn = if let Ok(c) = conn {
263 c
264 } else {
265 return Err("Performing handshake".into());
266 };
267 let sock = TcpStream::connect(tls_connect_addr.clone());
268 let sock = if let Ok(s) = sock {
269 s
270 } else {
271 return Err(format!("Connecting to {}", tls_connect_addr).into());
272 };
273 sock.set_nodelay(true).unwrap();
274
275 let request = format!(
277 "GET / HTTP/1.1\r\n\
278 Host: {}\r\n\
279 Connection: close\r\n\
280 Accept-Encoding: identity\r\n\
281 \r\n",
282 tls_connect_addr
283 );
284 if let Some(mut early_data) = conn.early_data() {
285 early_data.write_all(request.as_bytes()).unwrap();
286 }
287 Ok((conn, sock))
288}