1use anyhow::Result;
2use ocular_protocol::{Protocol, parse_request, parse_response, extract_full_command, format_response_detail, parse_amqp_frame, parse_amqp_request_full, is_async_method, amqp_frame_len};
3use std::sync::Arc;
4use std::time::{Instant, SystemTime};
5use tokio::io::{AsyncReadExt, AsyncWriteExt, AsyncRead, AsyncWrite};
6use tokio::net::{TcpListener, TcpStream};
7use tokio::sync::{broadcast, Mutex};
8use tracing::{info, warn, error, debug};
9
10pub use ocular_protocol::ProxyEvent;
11
12struct PendingRequest {
14 timestamp: SystemTime,
15 instant: Instant,
16 command: String,
17 full_command: String,
18}
19
20pub async fn run_proxy(
21 listen_addr: String,
22 remote_addr: String,
23 name: String,
24 protocol: Protocol,
25 tx: broadcast::Sender<ProxyEvent>,
26) -> Result<()> {
27 let listener = TcpListener::bind(&listen_addr).await?;
28 info!(component = %name, listen = %listen_addr, remote = %remote_addr, ?protocol, "proxy listening");
29
30 loop {
31 let (client, peer) = listener.accept().await?;
32 debug!(component = %name, peer = %peer, "new client connection");
33 let remote = remote_addr.clone();
34 let name = name.clone();
35 let tx = tx.clone();
36 let process = resolve_peer_process(peer.port());
37 let peer_addr = peer.to_string();
38 let remote_for_conn = remote.clone();
39 tokio::spawn(async move {
40 if let Err(e) = handle_conn(client, &remote, &name, protocol, &tx, process, peer_addr, remote_for_conn).await {
41 warn!(component = %name, remote = %remote, error = %e, "connection ended with error");
42 }
43 });
44 }
45}
46
47#[allow(clippy::too_many_arguments)]
48async fn handle_conn(
49 mut client: TcpStream,
50 remote_addr: &str,
51 name: &str,
52 protocol: Protocol,
53 tx: &broadcast::Sender<ProxyEvent>,
54 process: Option<String>,
55 src: String,
56 dest: String,
57) -> Result<()> {
58 let (actual_addr, use_tls, tls_host) = if remote_addr.starts_with("https://") {
60 let stripped = remote_addr.strip_prefix("https://").unwrap();
61 let host = stripped.split(':').next().unwrap_or(stripped).to_string();
62 (stripped.to_string(), true, host)
63 } else {
64 let stripped = remote_addr.strip_prefix("http://").unwrap_or(remote_addr);
65 (stripped.to_string(), false, String::new())
66 };
67
68 let tcp_stream = match TcpStream::connect(&actual_addr).await {
69 Ok(s) => {
70 debug!(component = %name, remote = %actual_addr, "connected to remote");
71 s
72 }
73 Err(e) => {
74 error!(component = %name, remote = %actual_addr, error = %e,
75 "failed to connect to remote — is the service running?");
76 if protocol == Protocol::Redis {
77 let err_msg = format!("-ERR ocular proxy: cannot reach {} ({})\r\n", actual_addr, e);
78 let _ = client.write_all(err_msg.as_bytes()).await;
79 }
80 return Err(e.into());
81 }
82 };
83
84 let (sr, sw): (Box<dyn AsyncRead + Unpin + Send>, Box<dyn AsyncWrite + Unpin + Send>) = if use_tls {
85 let config = rustls::ClientConfig::builder()
86 .dangerous()
87 .with_custom_certificate_verifier(Arc::new(NoVerify))
88 .with_no_client_auth();
89 let connector = tokio_rustls::TlsConnector::from(Arc::new(config));
90 let domain = rustls::pki_types::ServerName::try_from(tls_host)
91 .map_err(|e| anyhow::anyhow!("invalid TLS hostname: {}", e))?;
92 let tls_stream = connector.connect(domain, tcp_stream).await?;
93 let (r, w) = tokio::io::split(tls_stream);
94 (Box::new(r) as Box<dyn AsyncRead + Unpin + Send>, Box::new(w) as Box<dyn AsyncWrite + Unpin + Send>)
95 } else {
96 let (r, w) = tokio::io::split(tcp_stream);
97 (Box::new(r) as Box<dyn AsyncRead + Unpin + Send>, Box::new(w) as Box<dyn AsyncWrite + Unpin + Send>)
98 };
99
100 let mut sr = sr;
101 let mut sw = sw;
102
103 if protocol == Protocol::Mysql {
105 let mut greeting_buf = [0u8; 65536];
106 let n = sr.read(&mut greeting_buf).await?;
107 if n == 0 { return Ok(()); }
108 let mut greeting = greeting_buf[..n].to_vec();
109 strip_mysql_ssl_flag(&mut greeting);
110 client.write_all(&greeting).await?;
111 debug!(component = %name, "forwarded MySQL greeting with SSL stripped");
112 }
113
114 if protocol == Protocol::Postgres {
116 let mut buf = [0u8; 256];
117 let n = client.read(&mut buf).await?;
118 if n == 0 { return Ok(()); }
119 let data = &buf[..n];
120 sw.write_all(data).await?;
122 let mut resp = [0u8; 1];
124 let rn = sr.read(&mut resp).await?;
125 if rn == 0 { return Ok(()); }
126 client.write_all(&resp[..rn]).await?;
128 let command = parse_request(protocol, data).unwrap_or_else(|| "SSLRequest".into());
130 let response = if resp[0] == b'N' { "SSLResponse: No" } else { "SSLResponse: Yes" };
131 let _ = tx.send(ProxyEvent {
132 timestamp: SystemTime::now(),
133 component: name.to_string(),
134 protocol,
135 command: command.clone(),
136 full_command: command,
137 response: response.into(),
138 response_detail: response.into(),
139 latency: std::time::Duration::ZERO,
140 process: process.clone(),
141 src: Some(src.clone()),
142 dest: Some(dest.clone()),
143 });
144 }
147
148 let (mut cr, mut cw) = client.split();
149
150 let pending: Arc<Mutex<Option<PendingRequest>>> = Arc::new(Mutex::new(None));
151
152 let name_req = name.to_string();
153 let name_resp = name.to_string();
154 let tx_req = tx.clone();
155 let tx_resp = tx.clone();
156 let pending_w = pending.clone();
157 let pending_r = pending;
158 let process_info = process;
159
160 let process_req = process_info.clone();
161 let src_req = src.clone();
162 let dest_req = dest.clone();
163 let src_resp = src.clone();
164 let dest_resp = dest;
165 let client_to_server = async move {
166 let mut buf = [0u8; 65536];
167 let mut http_req_buf: Vec<u8> = Vec::new();
168 loop {
169 let n = cr.read(&mut buf).await?;
170 if n == 0 { break; }
171 let data = &buf[..n];
172
173 if protocol == Protocol::Amqp {
174 let mut pos = 0;
176 while pos < data.len() {
177 let frame_data = &data[pos..];
178 let Some(flen) = amqp_frame_len(frame_data) else { break };
179 if let Some(frame) = parse_amqp_frame(frame_data) {
180 if frame.frame_type == 8 {
182 pos += flen;
183 continue;
184 }
185 if let Some(ref method) = frame.method {
186 if is_async_method(method.class_id, method.method_id) {
187 let (summary, detail) = parse_amqp_request_full(frame_data)
188 .unwrap_or_else(|| (method.summary.clone(), method.detail.clone()));
189 let _ = tx_req.send(ProxyEvent {
190 timestamp: SystemTime::now(),
191 component: name_req.clone(),
192 protocol,
193 command: summary,
194 full_command: detail.clone(),
195 response: String::new(),
196 response_detail: detail,
197 latency: std::time::Duration::ZERO,
198 process: process_req.clone(),
199 src: Some(src_req.clone()),
200 dest: Some(dest_req.clone()),
201 });
202 } else {
203 debug!(component = %name_req, command = %method.summary);
204 *pending_w.lock().await = Some(PendingRequest {
205 timestamp: SystemTime::now(),
206 instant: Instant::now(),
207 command: method.summary.clone(),
208 full_command: method.detail.clone(),
209 });
210 }
211 }
212 }
213 pos += flen;
214 }
215 } else if protocol == Protocol::Http {
216 http_req_buf.extend_from_slice(data);
217 if ocular_protocol::http::http_request_complete(&http_req_buf) {
218 if let Some(command) = parse_request(protocol, &http_req_buf) {
219 let full_command = extract_full_command(protocol, &http_req_buf).unwrap_or_else(|| command.clone());
220 *pending_w.lock().await = Some(PendingRequest {
221 timestamp: SystemTime::now(),
222 instant: Instant::now(),
223 command,
224 full_command,
225 });
226 }
227 http_req_buf.clear();
228 }
229 } else if let Some(command) = parse_request(protocol, data) {
230 let full_command = extract_full_command(protocol, data).unwrap_or_else(|| command.clone());
231 debug!(component = %name_req, %command);
232 *pending_w.lock().await = Some(PendingRequest {
233 timestamp: SystemTime::now(),
234 instant: Instant::now(),
235 command,
236 full_command,
237 });
238 } else if protocol == Protocol::Postgres && n > 0 {
239 info!(component = %name_req, bytes = n, first_byte = format!("0x{:02x}", data[0]), "pg client→server UNPARSED");
240 }
241
242 sw.write_all(data).await?;
243 }
244 Ok::<_, anyhow::Error>(())
245 };
246
247 let process_mysql = process_info.clone();
248 let server_to_client = async move {
249 let mut buf = [0u8; 65536];
250 let mut mysql_buf: Vec<u8> = Vec::new();
251 let mut http_resp_buf: Vec<u8> = Vec::new();
252 let mut awaiting_response = false;
253 loop {
254 let n = sr.read(&mut buf).await?;
255 if n == 0 { break; }
256 let data = &buf[..n];
257 cw.write_all(data).await?;
258
259 if protocol == Protocol::Mysql {
260 let has_pending = pending_r.lock().await.is_some();
261 if has_pending || awaiting_response {
262 awaiting_response = true;
263 mysql_buf.extend_from_slice(data);
264 if mysql_response_complete(&mysql_buf) {
265 if let Some(req) = pending_r.lock().await.take() {
266 let latency = req.instant.elapsed();
267 let response = parse_response(protocol, &mysql_buf).unwrap_or_default();
268 let response_detail = format_response_detail(protocol, &mysql_buf).unwrap_or_default();
269 let _ = tx_resp.send(ProxyEvent {
270 timestamp: req.timestamp,
271 component: name_resp.clone(),
272 protocol,
273 command: req.command,
274 full_command: req.full_command,
275 response,
276 response_detail,
277 latency,
278 process: process_mysql.clone(),
279 src: Some(src_resp.clone()),
280 dest: Some(dest_resp.clone()),
281 });
282 }
283 mysql_buf.clear();
284 awaiting_response = false;
285 }
286 }
287 } else if protocol == Protocol::Http {
288 http_resp_buf.extend_from_slice(data);
289 if ocular_protocol::http::http_response_complete(&http_resp_buf) {
290 if let Some(req) = pending_r.lock().await.take() {
291 let latency = req.instant.elapsed();
292 let response = parse_response(protocol, &http_resp_buf).unwrap_or_default();
293 let response_detail = format_response_detail(protocol, &http_resp_buf).unwrap_or_else(|| response.clone());
294 let _ = tx_resp.send(ProxyEvent {
295 timestamp: req.timestamp,
296 component: name_resp.clone(),
297 protocol,
298 command: req.command,
299 full_command: req.full_command,
300 response,
301 response_detail,
302 latency,
303 process: process_info.clone(),
304 src: Some(src_resp.clone()),
305 dest: Some(dest_resp.clone()),
306 });
307 }
308 http_resp_buf.clear();
309 }
310 } else if protocol == Protocol::Amqp {
311 let mut pos = 0;
313 while pos < data.len() {
314 let frame_data = &data[pos..];
315 let Some(flen) = amqp_frame_len(frame_data) else { break };
316 if let Some(frame) = parse_amqp_frame(frame_data) {
317 if frame.frame_type == 2 || frame.frame_type == 3 {
319 pos += flen;
320 continue;
321 }
322 if frame.frame_type == 8 {
324 pos += flen;
325 continue;
326 }
327 }
328
329 let mut body_text = String::new();
331 let mut peek = pos + flen;
332 while peek < data.len() {
333 let peek_data = &data[peek..];
334 let Some(plen) = amqp_frame_len(peek_data) else { break };
335 if let Some(pf) = parse_amqp_frame(peek_data) {
336 if pf.frame_type == 2 {
337 } else if pf.frame_type == 3 {
339 if let Some(body) = &pf.body {
341 body_text = String::from_utf8_lossy(body).to_string();
342 }
343 } else {
344 break;
345 }
346 } else {
347 break;
348 }
349 peek += plen;
350 }
351
352 if let Some(req) = pending_r.lock().await.take() {
353 let latency = req.instant.elapsed();
354 let mut response = parse_response(protocol, frame_data).unwrap_or_default();
355 let mut response_detail = format_response_detail(protocol, frame_data).unwrap_or_else(|| response.clone());
356 if !body_text.is_empty() {
357 response = format!("{} | {}", response, body_text);
358 response_detail = format!("{}\nBody: {}", response_detail, body_text);
359 }
360 let _ = tx_resp.send(ProxyEvent {
361 timestamp: req.timestamp,
362 component: name_resp.clone(),
363 protocol,
364 command: req.command,
365 full_command: req.full_command,
366 response,
367 response_detail,
368 latency,
369 process: process_info.clone(),
370 src: Some(src_resp.clone()),
371 dest: Some(dest_resp.clone()),
372 });
373 } else if let Some(frame) = parse_amqp_frame(frame_data) {
374 if let Some(ref method) = frame.method {
376 let response = if body_text.is_empty() { String::new() } else { body_text.clone() };
377 let response_detail = if body_text.is_empty() { String::new() } else { body_text.clone() };
378 let command = method.summary.clone();
379 let _ = tx_resp.send(ProxyEvent {
380 timestamp: SystemTime::now(),
381 component: name_resp.clone(),
382 protocol,
383 command,
384 full_command: method.detail.clone(),
385 response,
386 response_detail,
387 latency: std::time::Duration::ZERO,
388 process: process_info.clone(),
389 src: Some(dest_resp.clone()),
390 dest: Some(src_resp.clone()),
391 });
392 }
393 }
394 pos = peek;
396 }
397 } else if protocol == Protocol::Postgres {
398 let first = data[0];
400 info!(component = %name_resp, bytes = n, first_byte = format!("0x{:02x}", first),
401 hex_head = format!("{:02x?}", &data[..n.min(20)]), "pg server→client");
402 let is_meaningful = matches!(first, b'C' | b'E' | b'T' | b'Z' | b'I' | b'D' | b'R');
404 if is_meaningful {
405 if let Some(req) = pending_r.lock().await.take() {
406 let latency = req.instant.elapsed();
407 let response = parse_response(protocol, data).unwrap_or_default();
408 let response_detail = format_response_detail(protocol, data).unwrap_or_else(|| response.clone());
409 let _ = tx_resp.send(ProxyEvent {
410 timestamp: req.timestamp,
411 component: name_resp.clone(),
412 protocol,
413 command: req.command,
414 full_command: req.full_command,
415 response,
416 response_detail,
417 latency,
418 process: process_info.clone(),
419 src: Some(src_resp.clone()),
420 dest: Some(dest_resp.clone()),
421 });
422 }
423 }
424 } else {
426 if let Some(req) = pending_r.lock().await.take() {
428 let latency = req.instant.elapsed();
429 let response = parse_response(protocol, data).unwrap_or_default();
430 let response_detail = format_response_detail(protocol, data).unwrap_or_else(|| response.clone());
431 let _ = tx_resp.send(ProxyEvent {
432 timestamp: req.timestamp,
433 component: name_resp.clone(),
434 protocol,
435 command: req.command,
436 full_command: req.full_command,
437 response,
438 response_detail,
439 latency,
440 process: process_info.clone(),
441 src: Some(src_resp.clone()),
442 dest: Some(dest_resp.clone()),
443 });
444 }
445 }
446 }
447 Ok::<_, anyhow::Error>(())
448 };
449
450 tokio::select! {
451 r = client_to_server => r?,
452 r = server_to_client => r?,
453 }
454 Ok(())
455}
456
457fn mysql_response_complete(buf: &[u8]) -> bool {
458 if buf.len() < 5 { return false; }
459 let first_marker = buf[4];
460 match first_marker {
461 0x00 | 0xff => return true,
462 _ => {}
463 }
464 let mut pos = 0;
465 let mut last_marker = 0u8;
466 let mut last_pkt_len = 0usize;
467 while pos + 4 <= buf.len() {
468 let pkt_len = (buf[pos] as usize) | (buf[pos+1] as usize) << 8 | (buf[pos+2] as usize) << 16;
469 let end = pos + 4 + pkt_len;
470 if end > buf.len() { break; }
471 if pkt_len > 0 {
472 last_marker = buf[pos + 4];
473 last_pkt_len = pkt_len;
474 }
475 pos = end;
476 }
477 (last_marker == 0xfe && last_pkt_len < 9) || (last_marker == 0x00 && last_pkt_len < 16 && pos == buf.len())
478}
479
480fn strip_mysql_ssl_flag(packet: &mut [u8]) {
481 if packet.len() < 5 { return; }
482 let payload = &mut packet[4..];
483 if payload.is_empty() || payload[0] != 10 { return; }
484 let mut pos = 1;
485 while pos < payload.len() && payload[pos] != 0 { pos += 1; }
486 pos += 1;
487 pos += 4;
488 pos += 8;
489 pos += 1;
490 if pos + 2 > payload.len() { return; }
491 let cap_lower = u16::from_le_bytes([payload[pos], payload[pos + 1]]);
492 let cap_lower_new = cap_lower & !0x0800;
493 payload[pos] = (cap_lower_new & 0xff) as u8;
494 payload[pos + 1] = ((cap_lower_new >> 8) & 0xff) as u8;
495}
496
497fn resolve_peer_process(port: u16) -> Option<String> {
499 use std::process::Command;
500 let my_pid = std::process::id().to_string();
501
502 if cfg!(target_os = "macos") {
503 let output = Command::new("lsof")
506 .args(["-i", &format!("tcp:{}", port), "-sTCP:ESTABLISHED", "-Fp", "-Fc"])
507 .output()
508 .ok()?;
509 let text = String::from_utf8_lossy(&output.stdout);
510 let mut current_pid = String::new();
511 let mut current_cmd = String::new();
512 for line in text.lines() {
513 if let Some(p) = line.strip_prefix('p') {
514 if !current_pid.is_empty() && current_pid != my_pid {
516 return Some(format!("[{}] {}", current_pid, current_cmd));
517 }
518 current_pid = p.to_string();
519 current_cmd.clear();
520 }
521 if let Some(c) = line.strip_prefix('c') {
522 current_cmd = c.to_string();
523 }
524 }
525 if !current_pid.is_empty() && current_pid != my_pid {
527 return Some(format!("[{}] {}", current_pid, current_cmd));
528 }
529 None
530 } else {
531 let output = Command::new("ss")
533 .args(["-tnp", &format!("sport = :{}", port)])
534 .output()
535 .ok()?;
536 let text = String::from_utf8_lossy(&output.stdout);
537 for line in text.lines() {
539 if let Some(start) = line.find("users:((\"") {
540 let rest = &line[start + 9..];
541 if let Some(end) = rest.find('"') {
542 let proc_name = &rest[..end];
543 let pid = rest.find("pid=")
544 .and_then(|i| rest[i+4..].split(|c: char| !c.is_ascii_digit()).next())
545 .unwrap_or("?");
546 return Some(format!("[{}] {}", pid, proc_name));
547 }
548 }
549 }
550 None
551 }
552}
553
554#[derive(Debug)]
556struct NoVerify;
557
558impl rustls::client::danger::ServerCertVerifier for NoVerify {
559 fn verify_server_cert(
560 &self, _: &rustls::pki_types::CertificateDer<'_>, _: &[rustls::pki_types::CertificateDer<'_>],
561 _: &rustls::pki_types::ServerName<'_>, _: &[u8], _: rustls::pki_types::UnixTime,
562 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
563 Ok(rustls::client::danger::ServerCertVerified::assertion())
564 }
565 fn verify_tls12_signature(
566 &self, _: &[u8], _: &rustls::pki_types::CertificateDer<'_>, _: &rustls::DigitallySignedStruct,
567 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
568 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
569 }
570 fn verify_tls13_signature(
571 &self, _: &[u8], _: &rustls::pki_types::CertificateDer<'_>, _: &rustls::DigitallySignedStruct,
572 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
573 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
574 }
575 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
576 vec![
577 rustls::SignatureScheme::RSA_PKCS1_SHA256,
578 rustls::SignatureScheme::RSA_PKCS1_SHA384,
579 rustls::SignatureScheme::RSA_PKCS1_SHA512,
580 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
581 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
582 rustls::SignatureScheme::RSA_PSS_SHA256,
583 rustls::SignatureScheme::RSA_PSS_SHA384,
584 rustls::SignatureScheme::RSA_PSS_SHA512,
585 rustls::SignatureScheme::ED25519,
586 ]
587 }
588}