1use anyhow::Result;
2use ocular_protocol::{Protocol, mysql::mysql_response_complete, postgres::postgres_response_complete, parse_request, parse_response, extract_full_command, format_response_detail, parse_amqp_frame, parse_amqp_request_full, is_async_method, amqp_frame_len};
3use std::sync::{Arc, Mutex};
4use std::time::{Duration, Instant, SystemTime};
5use std::sync::atomic::{AtomicUsize, Ordering};
6use tokio::io::{AsyncReadExt, AsyncWriteExt, AsyncRead, AsyncWrite};
7use tokio::net::{TcpListener, TcpStream};
8use tokio::sync::broadcast;
9use tracing::{info, warn, error, debug};
10
11pub use ocular_protocol::ProxyEvent;
12
13#[derive(Clone, Default)]
15pub struct ConnectionState {
16 pub active_connections: usize,
17 pub has_connector: bool,
18 pub last_error: Option<String>,
19 pub last_active_at: Option<SystemTime>,
20}
21
22pub type StatusMap = Arc<Mutex<std::collections::HashMap<String, ConnectionState>>>;
24
25struct PendingRequest {
27 timestamp: SystemTime,
28 instant: Instant,
29 command: String,
30 full_command: String,
31}
32
33pub async fn run_proxy(
34 listen_addr: String,
35 remote_addr: String,
36 name: String,
37 protocol: Protocol,
38 tx: broadcast::Sender<ProxyEvent>,
39 mut shutdown: tokio::sync::watch::Receiver<bool>,
40 status: StatusMap,
41) -> Result<()> {
42 let listener = match TcpListener::bind(&listen_addr).await {
43 Ok(l) => l,
44 Err(e) => {
45 let msg = format!("bind failed on {}: {}", listen_addr, e);
46 let _ = tx.send(ProxyEvent::system_event(&name, msg));
47 status.lock().unwrap().entry(name.clone()).or_default().last_error = Some(format!("bind failed: {}", e));
48 return Err(e.into());
49 }
50 };
51 let conn_count = Arc::new(AtomicUsize::new(0));
52 {
53 let mut map = status.lock().unwrap();
54 map.entry(name.clone()).or_default().has_connector = true;
55 }
56 info!(component = %name, listen = %listen_addr, remote = %remote_addr, ?protocol, "proxy listening");
57
58 loop {
59 tokio::select! {
60 result = listener.accept() => {
61 let (client, peer) = result?;
62 debug!(component = %name, peer = %peer, "new client connection");
63 let remote = remote_addr.clone();
64 let name = name.clone();
65 let tx = tx.clone();
66 let process = resolve_peer_process(peer.port());
67 let peer_addr = peer.to_string();
68 let remote_for_conn = remote.clone();
69 let conn_count = conn_count.clone();
70 let status = status.clone();
71 let protocol_for_conn = protocol;
72 tokio::spawn(async move {
73 conn_count.fetch_add(1, Ordering::Relaxed);
74 {
75 let mut map = status.lock().unwrap();
76 let s = map.entry(name.clone()).or_default();
77 s.active_connections = conn_count.load(Ordering::Relaxed);
78 s.last_active_at = Some(SystemTime::now());
79 }
80 if let Err(e) = handle_conn(client, &remote, &name, protocol_for_conn, &tx, process, peer_addr, remote_for_conn).await {
81 warn!(component = %name, remote = %remote, error = %e, "connection ended with error");
82 let _ = tx.send(ProxyEvent::system_event(&name, format!("connection error: {}", e)));
83 status.lock().unwrap().entry(name.clone()).or_default().last_error = Some(e.to_string());
84 }
85 let remaining = conn_count.fetch_sub(1, Ordering::Relaxed).saturating_sub(1);
86 status.lock().unwrap().entry(name.clone()).or_default().active_connections = remaining;
87 });
88 }
89 _ = shutdown.changed() => {
90 info!(component = %name, "proxy shutting down");
91 break;
92 }
93 }
94 }
95 Ok(())
96}
97
98#[allow(clippy::too_many_arguments)]
99async fn handle_conn(
100 mut client: TcpStream,
101 remote_addr: &str,
102 name: &str,
103 protocol: Protocol,
104 tx: &broadcast::Sender<ProxyEvent>,
105 process: Option<String>,
106 src: String,
107 dest: String,
108) -> Result<()> {
109 let (actual_addr, use_tls, tls_host) = if remote_addr.starts_with("https://") {
111 let stripped = remote_addr.strip_prefix("https://").unwrap();
112 let host = stripped.split(':').next().unwrap_or(stripped).to_string();
113 (stripped.to_string(), true, host)
114 } else {
115 let stripped = remote_addr.strip_prefix("http://").unwrap_or(remote_addr);
116 (stripped.to_string(), false, String::new())
117 };
118
119 let tcp_stream = match TcpStream::connect(&actual_addr).await {
120 Ok(s) => {
121 debug!(component = %name, remote = %actual_addr, "connected to remote");
122 s
123 }
124 Err(e) => {
125 error!(component = %name, remote = %actual_addr, error = %e,
126 "failed to connect to remote — is the service running?");
127 let _ = tx.send(ProxyEvent::system_event(name, format!("cannot reach {} ({})", actual_addr, e)));
128 if protocol == Protocol::Redis {
129 let err_msg = format!("-ERR ocular proxy: cannot reach {} ({})\r\n", actual_addr, e);
130 let _ = client.write_all(err_msg.as_bytes()).await;
131 }
132 return Err(e.into());
133 }
134 };
135
136 let (sr, sw): (Box<dyn AsyncRead + Unpin + Send>, Box<dyn AsyncWrite + Unpin + Send>) = if use_tls {
137 let config = rustls::ClientConfig::builder()
138 .dangerous()
139 .with_custom_certificate_verifier(Arc::new(NoVerify))
140 .with_no_client_auth();
141 let connector = tokio_rustls::TlsConnector::from(Arc::new(config));
142 let domain = rustls::pki_types::ServerName::try_from(tls_host)
143 .map_err(|e| anyhow::anyhow!("invalid TLS hostname: {}", e))?;
144 let tls_stream = connector.connect(domain, tcp_stream).await?;
145 let (r, w) = tokio::io::split(tls_stream);
146 (Box::new(r) as Box<dyn AsyncRead + Unpin + Send>, Box::new(w) as Box<dyn AsyncWrite + Unpin + Send>)
147 } else {
148 let (r, w) = tokio::io::split(tcp_stream);
149 (Box::new(r) as Box<dyn AsyncRead + Unpin + Send>, Box::new(w) as Box<dyn AsyncWrite + Unpin + Send>)
150 };
151
152 let mut sr = sr;
153 let mut sw = sw;
154
155 if protocol == Protocol::Mysql {
157 let mut greeting_buf = [0u8; 65536];
158 let n = sr.read(&mut greeting_buf).await?;
159 if n == 0 { return Ok(()); }
160 let mut greeting = greeting_buf[..n].to_vec();
161 strip_mysql_ssl_flag(&mut greeting);
162 client.write_all(&greeting).await?;
163 debug!(component = %name, "forwarded MySQL greeting with SSL stripped");
164 }
165
166 if protocol == Protocol::Postgres {
169 let mut buf = [0u8; 256];
170 let n = client.read(&mut buf).await?;
171 if n == 0 { return Ok(()); }
172 let data = &buf[..n];
173 let neg_code = if n >= 8 {
174 u32::from_be_bytes([data[4], data[5], data[6], data[7]])
175 } else { 0 };
176 if neg_code == 80877103 || neg_code == 80877104 {
177 sw.write_all(data).await?;
179 let mut resp = [0u8; 1];
181 let rn = sr.read(&mut resp).await?;
182 if rn == 0 { return Ok(()); }
183 client.write_all(b"N").await?;
185 } else {
186 sw.write_all(data).await?;
188 }
189 }
190
191 let (mut cr, mut cw) = client.split();
192
193 let pending: Arc<Mutex<Option<PendingRequest>>> = Arc::new(Mutex::new(None));
194
195 let name_req = name.to_string();
196 let name_resp = name.to_string();
197 let tx_req = tx.clone();
198 let tx_resp = tx.clone();
199 let pending_w = pending.clone();
200 let pending_final = pending.clone();
201 let pending_r = pending;
202 let process_info = process;
203
204 let process_req = process_info.clone();
205 let src_req = src.clone();
206 let dest_req = dest.clone();
207 let src_resp = src.clone();
208 let dest_resp = dest;
209 let client_to_server = async move {
210 let mut buf = [0u8; 65536];
211 let mut http_req_buf: Vec<u8> = Vec::with_capacity(4096);
212 let mut memcached_req_buf: Vec<u8> = Vec::with_capacity(4096);
213 let mut kafka_req_buf: Vec<u8> = Vec::with_capacity(4096);
214 loop {
215 let n = cr.read(&mut buf).await?;
216 if n == 0 { break; }
217 let data = &buf[..n];
218
219 if protocol == Protocol::Amqp {
220 let mut pos = 0;
222 while pos < data.len() {
223 let frame_data = &data[pos..];
224 let Some(flen) = amqp_frame_len(frame_data) else { break };
225 if let Some(frame) = parse_amqp_frame(frame_data) {
226 if frame.frame_type == 8 {
228 pos += flen;
229 continue;
230 }
231 if let Some(ref method) = frame.method {
232 if is_async_method(method.class_id, method.method_id) {
233 let (summary, detail) = parse_amqp_request_full(frame_data)
234 .unwrap_or_else(|| (method.summary.clone(), method.detail.clone()));
235 let _ = tx_req.send(ProxyEvent {
236 timestamp: SystemTime::now(),
237 component: name_req.clone(),
238 protocol,
239 command: summary,
240 full_command: detail.clone(),
241 response: String::new(),
242 response_detail: detail,
243 latency: std::time::Duration::ZERO,
244 process: process_req.clone(),
245 src: Some(src_req.clone()),
246 dest: Some(dest_req.clone()),
247 system: false,
248 });
249 } else {
250 debug!(component = %name_req, command = %method.summary);
251 *pending_w.lock().unwrap() = Some(PendingRequest {
252 timestamp: SystemTime::now(),
253 instant: Instant::now(),
254 command: method.summary.clone(),
255 full_command: method.detail.clone(),
256 });
257 }
258 }
259 }
260 pos += flen;
261 }
262 } else if protocol == Protocol::Http {
263 http_req_buf.extend_from_slice(data);
264 if ocular_protocol::http::http_request_complete(&http_req_buf) {
265 if let Some(command) = parse_request(protocol, &http_req_buf) {
266 let full_command = extract_full_command(protocol, &http_req_buf).unwrap_or_else(|| command.clone());
267 *pending_w.lock().unwrap() = Some(PendingRequest {
268 timestamp: SystemTime::now(),
269 instant: Instant::now(),
270 command,
271 full_command,
272 });
273 }
274 http_req_buf.clear();
275 }
276 } else if protocol == Protocol::Memcached {
277 memcached_req_buf.extend_from_slice(data);
278 while ocular_protocol::memcached::memcached_request_complete(&memcached_req_buf) {
279 if let Some(prev) = pending_w.lock().unwrap().take() {
282 let _ = tx_req.send(ProxyEvent {
283 timestamp: prev.timestamp,
284 component: name_req.clone(),
285 protocol,
286 command: prev.command,
287 full_command: prev.full_command,
288 response: String::new(),
289 response_detail: String::new(),
290 latency: Duration::ZERO,
291 process: process_req.clone(),
292 src: Some(src_req.clone()),
293 dest: Some(dest_req.clone()),
294 system: false,
295 });
296 }
297 if let Some(command) = parse_request(protocol, &memcached_req_buf) {
298 let full_command = extract_full_command(protocol, &memcached_req_buf).unwrap_or_else(|| command.clone());
299 *pending_w.lock().unwrap() = Some(PendingRequest {
300 timestamp: SystemTime::now(),
301 instant: Instant::now(),
302 command,
303 full_command,
304 });
305 }
306 let s = std::str::from_utf8(&memcached_req_buf).unwrap_or("");
308 let first_crlf = s.find("\r\n").unwrap_or(0);
309 let line = &s[..first_crlf];
310 let parts: Vec<&str> = line.split_whitespace().collect();
311 let cmd = parts.first().map(|c| c.to_uppercase()).unwrap_or_default();
312 let consumed = match cmd.as_str() {
313 "SET" | "ADD" | "REPLACE" | "APPEND" | "PREPEND" | "CAS" => {
314 let bytes: usize = parts.get(4).and_then(|b| b.parse().ok()).unwrap_or(0);
315 first_crlf + 2 + bytes + 2
316 }
317 _ => first_crlf + 2,
318 };
319 memcached_req_buf = memcached_req_buf[consumed..].to_vec();
320 }
321 } else if protocol == Protocol::Kafka {
322 kafka_req_buf.extend_from_slice(data);
323 while ocular_protocol::kafka::kafka_frame_complete(&kafka_req_buf) {
324 let frame_len = i32::from_be_bytes([kafka_req_buf[0], kafka_req_buf[1], kafka_req_buf[2], kafka_req_buf[3]]) as usize + 4;
325 let frame = &kafka_req_buf[..frame_len];
326 if let Some(command) = parse_request(protocol, frame) {
327 let full_command = extract_full_command(protocol, frame).unwrap_or_else(|| command.clone());
328 *pending_w.lock().unwrap() = Some(PendingRequest {
329 timestamp: SystemTime::now(),
330 instant: Instant::now(),
331 command,
332 full_command,
333 });
334 }
335 kafka_req_buf = kafka_req_buf[frame_len..].to_vec();
336 }
337 } else if protocol == Protocol::Postgres {
338 let mut pos = 0;
340 while pos < data.len() {
341 let first = data[pos];
342 let is_typed = matches!(first, b'Q' | b'P' | b'B' | b'E' | b'D' | b'S' | b'X' | b'C' | b'p' | b'H' | b'F' | b'd' | b'c' | b'f');
343 if !is_typed { break; }
344 if pos + 5 > data.len() { break; }
345 let len = u32::from_be_bytes([data[pos+1], data[pos+2], data[pos+3], data[pos+4]]) as usize;
346 let end = pos + 1 + len;
347 if end > data.len() { break; }
348 if first == b'Q' || first == b'P' {
350 let msg = &data[pos..end];
351 if let Some(command) = parse_request(protocol, msg) {
352 let full_command = extract_full_command(protocol, msg).unwrap_or_else(|| command.clone());
353 *pending_w.lock().unwrap() = Some(PendingRequest {
354 timestamp: SystemTime::now(),
355 instant: Instant::now(),
356 command,
357 full_command,
358 });
359 }
360 }
361 pos = end;
362 }
363 } else if let Some(command) = parse_request(protocol, data) {
364 let full_command = extract_full_command(protocol, data).unwrap_or_else(|| command.clone());
365 debug!(component = %name_req, %command);
366 *pending_w.lock().unwrap() = Some(PendingRequest {
367 timestamp: SystemTime::now(),
368 instant: Instant::now(),
369 command,
370 full_command,
371 });
372 }
373
374 sw.write_all(data).await?;
375 }
376 Ok::<_, anyhow::Error>(())
377 };
378
379 let process_mysql = process_info.clone();
380 let server_to_client = async move {
381 let mut buf = [0u8; 65536];
382 let mut mysql_buf: Vec<u8> = Vec::with_capacity(4096);
383 let mut http_resp_buf: Vec<u8> = Vec::with_capacity(4096);
384 let mut memcached_resp_buf: Vec<u8> = Vec::with_capacity(4096);
385 let mut kafka_resp_buf: Vec<u8> = Vec::with_capacity(4096);
386 let mut pg_resp_buf: Vec<u8> = Vec::with_capacity(4096);
387 let mut awaiting_response = false;
388 let mut memcached_awaiting = false;
389 loop {
390 let n = sr.read(&mut buf).await?;
391 if n == 0 { break; }
392 let data = &buf[..n];
393 cw.write_all(data).await?;
394
395 if protocol == Protocol::Mysql {
396 let has_pending = pending_r.lock().unwrap().is_some();
397 if has_pending || awaiting_response {
398 awaiting_response = true;
399 mysql_buf.extend_from_slice(data);
400 if mysql_response_complete(&mysql_buf) {
401 if let Some(req) = pending_r.lock().unwrap().take() {
402 let latency = req.instant.elapsed();
403 let response = parse_response(protocol, &mysql_buf).unwrap_or_default();
404 let response_detail = format_response_detail(protocol, &mysql_buf).unwrap_or_default();
405 let _ = tx_resp.send(ProxyEvent {
406 timestamp: req.timestamp,
407 component: name_resp.clone(),
408 protocol,
409 command: req.command,
410 full_command: req.full_command,
411 response,
412 response_detail,
413 latency,
414 process: process_mysql.clone(),
415 src: Some(src_resp.clone()),
416 dest: Some(dest_resp.clone()),
417 system: false,
418 });
419 }
420 mysql_buf.clear();
421 awaiting_response = false;
422 }
423 }
424 } else if protocol == Protocol::Http {
425 http_resp_buf.extend_from_slice(data);
426 if ocular_protocol::http::http_response_complete(&http_resp_buf) {
427 if let Some(req) = pending_r.lock().unwrap().take() {
428 let latency = req.instant.elapsed();
429 let response = parse_response(protocol, &http_resp_buf).unwrap_or_default();
430 let response_detail = format_response_detail(protocol, &http_resp_buf).unwrap_or_else(|| response.clone());
431 let _ = tx_resp.send(ProxyEvent {
432 timestamp: req.timestamp,
433 component: name_resp.clone(),
434 protocol,
435 command: req.command,
436 full_command: req.full_command,
437 response,
438 response_detail,
439 latency,
440 process: process_info.clone(),
441 src: Some(src_resp.clone()),
442 dest: Some(dest_resp.clone()),
443 system: false,
444 });
445 }
446 http_resp_buf.clear();
447 }
448 } else if protocol == Protocol::Amqp {
449 let mut pos = 0;
451 while pos < data.len() {
452 let frame_data = &data[pos..];
453 let Some(flen) = amqp_frame_len(frame_data) else { break };
454 if let Some(frame) = parse_amqp_frame(frame_data) {
455 if frame.frame_type == 2 || frame.frame_type == 3 {
457 pos += flen;
458 continue;
459 }
460 if frame.frame_type == 8 {
462 pos += flen;
463 continue;
464 }
465 }
466
467 let mut body_text = String::new();
469 let mut peek = pos + flen;
470 while peek < data.len() {
471 let peek_data = &data[peek..];
472 let Some(plen) = amqp_frame_len(peek_data) else { break };
473 if let Some(pf) = parse_amqp_frame(peek_data) {
474 if pf.frame_type == 2 {
475 } else if pf.frame_type == 3 {
477 if let Some(body) = &pf.body {
479 body_text = String::from_utf8_lossy(body).to_string();
480 }
481 } else {
482 break;
483 }
484 } else {
485 break;
486 }
487 peek += plen;
488 }
489
490 if let Some(req) = pending_r.lock().unwrap().take() {
491 let latency = req.instant.elapsed();
492 let mut response = parse_response(protocol, frame_data).unwrap_or_default();
493 let mut response_detail = format_response_detail(protocol, frame_data).unwrap_or_else(|| response.clone());
494 if !body_text.is_empty() {
495 response = format!("{} | {}", response, body_text);
496 response_detail = format!("{}\nBody: {}", response_detail, body_text);
497 }
498 let _ = tx_resp.send(ProxyEvent {
499 timestamp: req.timestamp,
500 component: name_resp.clone(),
501 protocol,
502 command: req.command,
503 full_command: req.full_command,
504 response,
505 response_detail,
506 latency,
507 process: process_info.clone(),
508 src: Some(src_resp.clone()),
509 dest: Some(dest_resp.clone()),
510 system: false,
511 });
512 } else if let Some(frame) = parse_amqp_frame(frame_data) {
513 if let Some(ref method) = frame.method {
515 let response = if body_text.is_empty() { String::new() } else { body_text.clone() };
516 let response_detail = if body_text.is_empty() { String::new() } else { body_text.clone() };
517 let command = method.summary.clone();
518 let _ = tx_resp.send(ProxyEvent {
519 timestamp: SystemTime::now(),
520 component: name_resp.clone(),
521 protocol,
522 command,
523 full_command: method.detail.clone(),
524 response,
525 response_detail,
526 latency: std::time::Duration::ZERO,
527 process: process_info.clone(),
528 src: Some(dest_resp.clone()),
529 dest: Some(src_resp.clone()),
530 system: false,
531 });
532 }
533 }
534 pos = peek;
536 }
537 } else if protocol == Protocol::Postgres {
538 pg_resp_buf.extend_from_slice(data);
540 if postgres_response_complete(&pg_resp_buf) {
541 if let Some(req) = pending_r.lock().unwrap().take() {
542 let latency = req.instant.elapsed();
543 let response = parse_response(protocol, &pg_resp_buf).unwrap_or_default();
544 let response_detail = format_response_detail(protocol, &pg_resp_buf).unwrap_or_else(|| response.clone());
545 let _ = tx_resp.send(ProxyEvent {
546 timestamp: req.timestamp,
547 component: name_resp.clone(),
548 protocol,
549 command: req.command,
550 full_command: req.full_command,
551 response,
552 response_detail,
553 latency,
554 process: process_info.clone(),
555 src: Some(src_resp.clone()),
556 dest: Some(dest_resp.clone()),
557 system: false,
558 });
559 }
560 pg_resp_buf.clear();
561 }
562 } else if protocol == Protocol::Memcached {
563 let has_pending = pending_r.lock().unwrap().is_some();
564 if has_pending || memcached_awaiting {
565 memcached_awaiting = true;
566 memcached_resp_buf.extend_from_slice(data);
567 if ocular_protocol::memcached::memcached_response_complete(&memcached_resp_buf) {
568 if let Some(req) = pending_r.lock().unwrap().take() {
569 let latency = req.instant.elapsed();
570 let response = parse_response(protocol, &memcached_resp_buf).unwrap_or_default();
571 let response_detail = format_response_detail(protocol, &memcached_resp_buf).unwrap_or_else(|| response.clone());
572 let _ = tx_resp.send(ProxyEvent {
573 timestamp: req.timestamp,
574 component: name_resp.clone(),
575 protocol,
576 command: req.command,
577 full_command: req.full_command,
578 response,
579 response_detail,
580 latency,
581 process: process_info.clone(),
582 src: Some(src_resp.clone()),
583 dest: Some(dest_resp.clone()),
584 system: false,
585 });
586 }
587 memcached_resp_buf.clear();
588 memcached_awaiting = false;
589 }
590 }
591 } else if protocol == Protocol::Kafka {
592 kafka_resp_buf.extend_from_slice(data);
593 while ocular_protocol::kafka::kafka_frame_complete(&kafka_resp_buf) {
594 let frame_len = i32::from_be_bytes([kafka_resp_buf[0], kafka_resp_buf[1], kafka_resp_buf[2], kafka_resp_buf[3]]) as usize + 4;
595 if let Some(req) = pending_r.lock().unwrap().take() {
596 let latency = req.instant.elapsed();
597 let response = parse_response(protocol, &kafka_resp_buf[..frame_len]).unwrap_or_default();
598 let response_detail = format_response_detail(protocol, &kafka_resp_buf[..frame_len]).unwrap_or_else(|| response.clone());
599 let _ = tx_resp.send(ProxyEvent {
600 timestamp: req.timestamp,
601 component: name_resp.clone(),
602 protocol,
603 command: req.command,
604 full_command: req.full_command,
605 response,
606 response_detail,
607 latency,
608 process: process_info.clone(),
609 src: Some(src_resp.clone()),
610 dest: Some(dest_resp.clone()),
611 system: false,
612 });
613 }
614 kafka_resp_buf = kafka_resp_buf[frame_len..].to_vec();
615 }
616 } else {
617 if let Some(req) = pending_r.lock().unwrap().take() {
619 let latency = req.instant.elapsed();
620 let response = parse_response(protocol, data).unwrap_or_default();
621 let response_detail = format_response_detail(protocol, data).unwrap_or_else(|| response.clone());
622 let _ = tx_resp.send(ProxyEvent {
623 timestamp: req.timestamp,
624 component: name_resp.clone(),
625 protocol,
626 command: req.command,
627 full_command: req.full_command,
628 response,
629 response_detail,
630 latency,
631 process: process_info.clone(),
632 src: Some(src_resp.clone()),
633 dest: Some(dest_resp.clone()),
634 system: false,
635 });
636 }
637 }
638 }
639 Ok::<_, anyhow::Error>(())
640 };
641
642 tokio::pin!(client_to_server);
643 tokio::pin!(server_to_client);
644
645 tokio::select! {
646 r = &mut client_to_server => {
647 if r.is_ok() && pending_final.lock().unwrap().is_some() {
649 let _ = tokio::time::timeout(
650 Duration::from_millis(500),
651 &mut server_to_client,
652 ).await;
653 }
654 },
655 r = &mut server_to_client => r?,
656 }
657 Ok(())
658}
659
660fn strip_mysql_ssl_flag(packet: &mut [u8]) {
661 if packet.len() < 5 { return; }
662 let payload = &mut packet[4..];
663 if payload.is_empty() || payload[0] != 10 { return; }
664 let mut pos = 1;
665 while pos < payload.len() && payload[pos] != 0 { pos += 1; }
666 pos += 1;
667 pos += 4;
668 pos += 8;
669 pos += 1;
670 if pos + 2 > payload.len() { return; }
671 let cap_lower = u16::from_le_bytes([payload[pos], payload[pos + 1]]);
672 let cap_lower_new = cap_lower & !0x0800;
673 payload[pos] = (cap_lower_new & 0xff) as u8;
674 payload[pos + 1] = ((cap_lower_new >> 8) & 0xff) as u8;
675}
676
677fn resolve_peer_process(port: u16) -> Option<String> {
679 use std::process::Command;
680 let my_pid = std::process::id().to_string();
681
682 if cfg!(target_os = "macos") {
683 let output = Command::new("lsof")
686 .args(["-i", &format!("tcp:{}", port), "-sTCP:ESTABLISHED", "-Fp", "-Fc"])
687 .output()
688 .ok()?;
689 let text = String::from_utf8_lossy(&output.stdout);
690 let mut current_pid = String::new();
691 let mut current_cmd = String::new();
692 for line in text.lines() {
693 if let Some(p) = line.strip_prefix('p') {
694 if !current_pid.is_empty() && current_pid != my_pid {
696 return Some(format!("[{}] {}", current_pid, current_cmd));
697 }
698 current_pid = p.to_string();
699 current_cmd.clear();
700 }
701 if let Some(c) = line.strip_prefix('c') {
702 current_cmd = c.to_string();
703 }
704 }
705 if !current_pid.is_empty() && current_pid != my_pid {
707 return Some(format!("[{}] {}", current_pid, current_cmd));
708 }
709 None
710 } else {
711 let output = Command::new("ss")
713 .args(["-tnp", &format!("sport = :{}", port)])
714 .output()
715 .ok()?;
716 let text = String::from_utf8_lossy(&output.stdout);
717 for line in text.lines() {
719 if let Some(start) = line.find("users:((\"") {
720 let rest = &line[start + 9..];
721 if let Some(end) = rest.find('"') {
722 let proc_name = &rest[..end];
723 let pid = rest.find("pid=")
724 .and_then(|i| rest[i+4..].split(|c: char| !c.is_ascii_digit()).next())
725 .unwrap_or("?");
726 return Some(format!("[{}] {}", pid, proc_name));
727 }
728 }
729 }
730 None
731 }
732}
733
734#[derive(Debug)]
736struct NoVerify;
737
738impl rustls::client::danger::ServerCertVerifier for NoVerify {
739 fn verify_server_cert(
740 &self, _: &rustls::pki_types::CertificateDer<'_>, _: &[rustls::pki_types::CertificateDer<'_>],
741 _: &rustls::pki_types::ServerName<'_>, _: &[u8], _: rustls::pki_types::UnixTime,
742 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
743 Ok(rustls::client::danger::ServerCertVerified::assertion())
744 }
745 fn verify_tls12_signature(
746 &self, _: &[u8], _: &rustls::pki_types::CertificateDer<'_>, _: &rustls::DigitallySignedStruct,
747 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
748 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
749 }
750 fn verify_tls13_signature(
751 &self, _: &[u8], _: &rustls::pki_types::CertificateDer<'_>, _: &rustls::DigitallySignedStruct,
752 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
753 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
754 }
755 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
756 vec![
757 rustls::SignatureScheme::RSA_PKCS1_SHA256,
758 rustls::SignatureScheme::RSA_PKCS1_SHA384,
759 rustls::SignatureScheme::RSA_PKCS1_SHA512,
760 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
761 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
762 rustls::SignatureScheme::RSA_PSS_SHA256,
763 rustls::SignatureScheme::RSA_PSS_SHA384,
764 rustls::SignatureScheme::RSA_PSS_SHA512,
765 rustls::SignatureScheme::ED25519,
766 ]
767 }
768}
769
770#[cfg(test)]
771mod tests {
772 use super::*;
773
774 #[test]
775 fn test_strip_mysql_ssl_flag_short_packet() {
776 let mut buf = vec![0u8; 3];
777 strip_mysql_ssl_flag(&mut buf);
778 assert_eq!(buf, vec![0u8; 3]);
779 }
780
781 #[test]
782 fn test_strip_mysql_ssl_flag_not_greeting() {
783 let mut buf = vec![0u8; 10];
784 buf[4] = 9;
785 strip_mysql_ssl_flag(&mut buf);
786 assert_eq!(buf[4], 9);
787 }
788
789 fn caps_offset(pkt: &[u8]) -> Option<usize> {
791 if pkt.len() < 5 { return None; }
792 let mut pos = 5;
793 while pos < pkt.len() && pkt[pos] != 0 { pos += 1; }
795 pos += 1; if pos + 13 > pkt.len() { return None; }
797 pos += 4; pos += 8; pos += 1; Some(pos)
801 }
802
803 #[test]
804 fn test_strip_mysql_ssl_flag_clears_ssl_bit() {
805 let version = b"5.7.0\0";
806 let mut payload = vec![10]; payload.extend_from_slice(version);
808 payload.extend_from_slice(&[0u8; 4]); payload.extend_from_slice(&[0u8; 8]); payload.push(0); let caps: u16 = 0x0800; payload.extend_from_slice(&caps.to_le_bytes());
813 payload.extend_from_slice(&[0u8; 13]);
814
815 let pkt_len = payload.len();
816 let mut pkt = vec![
817 (pkt_len & 0xff) as u8,
818 ((pkt_len >> 8) & 0xff) as u8,
819 ((pkt_len >> 16) & 0xff) as u8,
820 0,
821 ];
822 pkt.extend_from_slice(&payload);
823
824 let off = caps_offset(&pkt).unwrap();
825 assert!(u16::from_le_bytes([pkt[off], pkt[off + 1]]) & 0x0800 != 0);
826
827 strip_mysql_ssl_flag(&mut pkt);
828
829 assert_eq!(u16::from_le_bytes([pkt[off], pkt[off + 1]]) & 0x0800, 0);
830 }
831
832 #[test]
833 fn test_resolve_peer_process_does_not_panic() {
834 let result = std::panic::catch_unwind(|| resolve_peer_process(0));
835 assert!(result.is_ok());
836 }
837}