1use anyhow::Result;
2use ocular_protocol::{Protocol, mysql::mysql_response_complete, postgres::postgres_response_complete, parse_request, parse_response, extract_full_command, format_response_detail, parse_amqp_frame, parse_amqp_request_full, is_async_method, amqp_frame_len};
3use std::sync::{Arc, Mutex};
4use std::time::{Duration, Instant, SystemTime};
5use std::sync::atomic::{AtomicUsize, Ordering};
6use tokio::io::{AsyncReadExt, AsyncWriteExt, AsyncRead, AsyncWrite};
7use tokio::net::{TcpListener, TcpStream};
8use tokio::sync::broadcast;
9use tracing::{info, warn, error, debug};
10
11pub use ocular_protocol::ProxyEvent;
12pub use ocular_protocol::{ConnectionState, StatusMap};
13
14struct PendingRequest {
16 timestamp: SystemTime,
17 instant: Instant,
18 command: String,
19 full_command: String,
20}
21
22pub async fn run_proxy(
23 listen_addr: String,
24 remote_addr: String,
25 name: String,
26 protocol: Protocol,
27 tx: broadcast::Sender<ProxyEvent>,
28 mut shutdown: tokio::sync::watch::Receiver<bool>,
29 status: StatusMap,
30) -> Result<()> {
31 let listener = match TcpListener::bind(&listen_addr).await {
32 Ok(l) => l,
33 Err(e) => {
34 let msg = format!("bind failed on {}: {}", listen_addr, e);
35 let _ = tx.send(ProxyEvent::system_event(&name, msg));
36 status.lock().unwrap().entry(name.clone()).or_default().last_error = Some(format!("bind failed: {}", e));
37 return Err(e.into());
38 }
39 };
40 let conn_count = Arc::new(AtomicUsize::new(0));
41 {
42 let mut map = status.lock().unwrap();
43 map.entry(name.clone()).or_default().has_connector = true;
44 }
45 info!(component = %name, listen = %listen_addr, remote = %remote_addr, ?protocol, "proxy listening");
46
47 loop {
48 tokio::select! {
49 result = listener.accept() => {
50 let (client, peer) = result?;
51 debug!(component = %name, peer = %peer, "new client connection");
52 let remote = remote_addr.clone();
53 let name = name.clone();
54 let tx = tx.clone();
55 let process = resolve_peer_process(peer.port());
56 let peer_addr = peer.to_string();
57 let remote_for_conn = remote.clone();
58 let conn_count = conn_count.clone();
59 let status = status.clone();
60 let protocol_for_conn = protocol;
61 tokio::spawn(async move {
62 conn_count.fetch_add(1, Ordering::Relaxed);
63 {
64 let mut map = status.lock().unwrap();
65 let s = map.entry(name.clone()).or_default();
66 s.active_connections = conn_count.load(Ordering::Relaxed);
67 s.last_active_at = Some(SystemTime::now());
68 }
69 if let Err(e) = handle_conn(client, &remote, &name, protocol_for_conn, &tx, process, peer_addr, remote_for_conn).await {
70 warn!(component = %name, remote = %remote, error = %e, "connection ended with error");
71 let _ = tx.send(ProxyEvent::system_event(&name, format!("connection error: {}", e)));
72 status.lock().unwrap().entry(name.clone()).or_default().last_error = Some(e.to_string());
73 }
74 let remaining = conn_count.fetch_sub(1, Ordering::Relaxed).saturating_sub(1);
75 status.lock().unwrap().entry(name.clone()).or_default().active_connections = remaining;
76 });
77 }
78 _ = shutdown.changed() => {
79 info!(component = %name, "proxy shutting down");
80 break;
81 }
82 }
83 }
84 Ok(())
85}
86
87#[allow(clippy::too_many_arguments)]
88async fn handle_conn(
89 mut client: TcpStream,
90 remote_addr: &str,
91 name: &str,
92 protocol: Protocol,
93 tx: &broadcast::Sender<ProxyEvent>,
94 process: Option<String>,
95 src: String,
96 dest: String,
97) -> Result<()> {
98 let (actual_addr, use_tls, tls_host) = if remote_addr.starts_with("https://") {
100 let stripped = remote_addr.strip_prefix("https://").unwrap();
101 let host = stripped.split(':').next().unwrap_or(stripped).to_string();
102 (stripped.to_string(), true, host)
103 } else {
104 let stripped = remote_addr.strip_prefix("http://").unwrap_or(remote_addr);
105 (stripped.to_string(), false, String::new())
106 };
107
108 let tcp_stream = match TcpStream::connect(&actual_addr).await {
109 Ok(s) => {
110 debug!(component = %name, remote = %actual_addr, "connected to remote");
111 s
112 }
113 Err(e) => {
114 error!(component = %name, remote = %actual_addr, error = %e,
115 "failed to connect to remote — is the service running?");
116 let _ = tx.send(ProxyEvent::system_event(name, format!("cannot reach {} ({})", actual_addr, e)));
117 if protocol == Protocol::Redis {
118 let err_msg = format!("-ERR ocular proxy: cannot reach {} ({})\r\n", actual_addr, e);
119 let _ = client.write_all(err_msg.as_bytes()).await;
120 }
121 return Err(e.into());
122 }
123 };
124
125 let (sr, sw): (Box<dyn AsyncRead + Unpin + Send>, Box<dyn AsyncWrite + Unpin + Send>) = if use_tls {
126 let config = rustls::ClientConfig::builder()
127 .dangerous()
128 .with_custom_certificate_verifier(Arc::new(NoVerify))
129 .with_no_client_auth();
130 let connector = tokio_rustls::TlsConnector::from(Arc::new(config));
131 let domain = rustls::pki_types::ServerName::try_from(tls_host)
132 .map_err(|e| anyhow::anyhow!("invalid TLS hostname: {}", e))?;
133 let tls_stream = connector.connect(domain, tcp_stream).await?;
134 let (r, w) = tokio::io::split(tls_stream);
135 (Box::new(r) as Box<dyn AsyncRead + Unpin + Send>, Box::new(w) as Box<dyn AsyncWrite + Unpin + Send>)
136 } else {
137 let (r, w) = tokio::io::split(tcp_stream);
138 (Box::new(r) as Box<dyn AsyncRead + Unpin + Send>, Box::new(w) as Box<dyn AsyncWrite + Unpin + Send>)
139 };
140
141 let mut sr = sr;
142 let mut sw = sw;
143
144 if protocol == Protocol::Mysql {
146 let mut greeting_buf = [0u8; 65536];
147 let n = sr.read(&mut greeting_buf).await?;
148 if n == 0 { return Ok(()); }
149 let mut greeting = greeting_buf[..n].to_vec();
150 strip_mysql_ssl_flag(&mut greeting);
151 client.write_all(&greeting).await?;
152 debug!(component = %name, "forwarded MySQL greeting with SSL stripped");
153 }
154
155 if protocol == Protocol::Postgres {
158 let mut buf = [0u8; 256];
159 let n = client.read(&mut buf).await?;
160 if n == 0 { return Ok(()); }
161 let data = &buf[..n];
162 let neg_code = if n >= 8 {
163 u32::from_be_bytes([data[4], data[5], data[6], data[7]])
164 } else { 0 };
165 if neg_code == 80877103 || neg_code == 80877104 {
166 sw.write_all(data).await?;
168 let mut resp = [0u8; 1];
170 let rn = sr.read(&mut resp).await?;
171 if rn == 0 { return Ok(()); }
172 client.write_all(b"N").await?;
174 } else {
175 sw.write_all(data).await?;
177 }
178 }
179
180 let (mut cr, mut cw) = client.split();
181
182 let pending: Arc<Mutex<Option<PendingRequest>>> = Arc::new(Mutex::new(None));
183
184 let name_req = name.to_string();
185 let name_resp = name.to_string();
186 let tx_req = tx.clone();
187 let tx_resp = tx.clone();
188 let pending_w = pending.clone();
189 let pending_final = pending.clone();
190 let pending_r = pending;
191 let process_info = process;
192
193 let process_req = process_info.clone();
194 let src_req = src.clone();
195 let dest_req = dest.clone();
196 let src_resp = src.clone();
197 let dest_resp = dest;
198 let client_to_server = async move {
199 let mut buf = [0u8; 65536];
200 let mut http_req_buf: Vec<u8> = Vec::with_capacity(4096);
201 let mut memcached_req_buf: Vec<u8> = Vec::with_capacity(4096);
202 let mut kafka_req_buf: Vec<u8> = Vec::with_capacity(4096);
203 loop {
204 let n = cr.read(&mut buf).await?;
205 if n == 0 { break; }
206 let data = &buf[..n];
207
208 if protocol == Protocol::Amqp {
209 let mut pos = 0;
211 while pos < data.len() {
212 let frame_data = &data[pos..];
213 let Some(flen) = amqp_frame_len(frame_data) else { break };
214 if let Some(frame) = parse_amqp_frame(frame_data) {
215 if frame.frame_type == 8 {
217 pos += flen;
218 continue;
219 }
220 if let Some(ref method) = frame.method {
221 if is_async_method(method.class_id, method.method_id) {
222 let (summary, detail) = parse_amqp_request_full(frame_data)
223 .unwrap_or_else(|| (method.summary.clone(), method.detail.clone()));
224 let _ = tx_req.send(ProxyEvent {
225 timestamp: SystemTime::now(),
226 component: name_req.clone(),
227 protocol,
228 command: summary,
229 full_command: detail.clone(),
230 response: String::new(),
231 response_detail: detail,
232 latency: std::time::Duration::ZERO,
233 process: process_req.clone(),
234 src: Some(src_req.clone()),
235 dest: Some(dest_req.clone()),
236 system: false,
237 });
238 } else {
239 debug!(component = %name_req, command = %method.summary);
240 *pending_w.lock().unwrap() = Some(PendingRequest {
241 timestamp: SystemTime::now(),
242 instant: Instant::now(),
243 command: method.summary.clone(),
244 full_command: method.detail.clone(),
245 });
246 }
247 }
248 }
249 pos += flen;
250 }
251 } else if protocol == Protocol::Http {
252 http_req_buf.extend_from_slice(data);
253 if ocular_protocol::http::http_request_complete(&http_req_buf) {
254 if let Some(command) = parse_request(protocol, &http_req_buf) {
255 let full_command = extract_full_command(protocol, &http_req_buf).unwrap_or_else(|| command.clone());
256 *pending_w.lock().unwrap() = Some(PendingRequest {
257 timestamp: SystemTime::now(),
258 instant: Instant::now(),
259 command,
260 full_command,
261 });
262 }
263 http_req_buf.clear();
264 }
265 } else if protocol == Protocol::Memcached {
266 memcached_req_buf.extend_from_slice(data);
267 while ocular_protocol::memcached::memcached_request_complete(&memcached_req_buf) {
268 if let Some(prev) = pending_w.lock().unwrap().take() {
271 let _ = tx_req.send(ProxyEvent {
272 timestamp: prev.timestamp,
273 component: name_req.clone(),
274 protocol,
275 command: prev.command,
276 full_command: prev.full_command,
277 response: String::new(),
278 response_detail: String::new(),
279 latency: Duration::ZERO,
280 process: process_req.clone(),
281 src: Some(src_req.clone()),
282 dest: Some(dest_req.clone()),
283 system: false,
284 });
285 }
286 if let Some(command) = parse_request(protocol, &memcached_req_buf) {
287 let full_command = extract_full_command(protocol, &memcached_req_buf).unwrap_or_else(|| command.clone());
288 *pending_w.lock().unwrap() = Some(PendingRequest {
289 timestamp: SystemTime::now(),
290 instant: Instant::now(),
291 command,
292 full_command,
293 });
294 }
295 let s = std::str::from_utf8(&memcached_req_buf).unwrap_or("");
297 let first_crlf = s.find("\r\n").unwrap_or(0);
298 let line = &s[..first_crlf];
299 let parts: Vec<&str> = line.split_whitespace().collect();
300 let cmd = parts.first().map(|c| c.to_uppercase()).unwrap_or_default();
301 let consumed = match cmd.as_str() {
302 "SET" | "ADD" | "REPLACE" | "APPEND" | "PREPEND" | "CAS" => {
303 let bytes: usize = parts.get(4).and_then(|b| b.parse().ok()).unwrap_or(0);
304 first_crlf + 2 + bytes + 2
305 }
306 _ => first_crlf + 2,
307 };
308 memcached_req_buf = memcached_req_buf[consumed..].to_vec();
309 }
310 } else if protocol == Protocol::Kafka {
311 kafka_req_buf.extend_from_slice(data);
312 while ocular_protocol::kafka::kafka_frame_complete(&kafka_req_buf) {
313 let frame_len = i32::from_be_bytes([kafka_req_buf[0], kafka_req_buf[1], kafka_req_buf[2], kafka_req_buf[3]]) as usize + 4;
314 let frame = &kafka_req_buf[..frame_len];
315 if let Some(command) = parse_request(protocol, frame) {
316 let full_command = extract_full_command(protocol, frame).unwrap_or_else(|| command.clone());
317 *pending_w.lock().unwrap() = Some(PendingRequest {
318 timestamp: SystemTime::now(),
319 instant: Instant::now(),
320 command,
321 full_command,
322 });
323 }
324 kafka_req_buf = kafka_req_buf[frame_len..].to_vec();
325 }
326 } else if protocol == Protocol::Postgres {
327 let mut pos = 0;
329 while pos < data.len() {
330 let first = data[pos];
331 let is_typed = matches!(first, b'Q' | b'P' | b'B' | b'E' | b'D' | b'S' | b'X' | b'C' | b'p' | b'H' | b'F' | b'd' | b'c' | b'f');
332 if !is_typed { break; }
333 if pos + 5 > data.len() { break; }
334 let len = u32::from_be_bytes([data[pos+1], data[pos+2], data[pos+3], data[pos+4]]) as usize;
335 let end = pos + 1 + len;
336 if end > data.len() { break; }
337 if first == b'Q' || first == b'P' {
339 let msg = &data[pos..end];
340 if let Some(command) = parse_request(protocol, msg) {
341 let full_command = extract_full_command(protocol, msg).unwrap_or_else(|| command.clone());
342 *pending_w.lock().unwrap() = Some(PendingRequest {
343 timestamp: SystemTime::now(),
344 instant: Instant::now(),
345 command,
346 full_command,
347 });
348 }
349 }
350 pos = end;
351 }
352 } else if let Some(command) = parse_request(protocol, data) {
353 let full_command = extract_full_command(protocol, data).unwrap_or_else(|| command.clone());
354 debug!(component = %name_req, %command);
355 *pending_w.lock().unwrap() = Some(PendingRequest {
356 timestamp: SystemTime::now(),
357 instant: Instant::now(),
358 command,
359 full_command,
360 });
361 }
362
363 sw.write_all(data).await?;
364 }
365 Ok::<_, anyhow::Error>(())
366 };
367
368 let process_mysql = process_info.clone();
369 let server_to_client = async move {
370 let mut buf = [0u8; 65536];
371 let mut mysql_buf: Vec<u8> = Vec::with_capacity(4096);
372 let mut http_resp_buf: Vec<u8> = Vec::with_capacity(4096);
373 let mut memcached_resp_buf: Vec<u8> = Vec::with_capacity(4096);
374 let mut kafka_resp_buf: Vec<u8> = Vec::with_capacity(4096);
375 let mut pg_resp_buf: Vec<u8> = Vec::with_capacity(4096);
376 let mut awaiting_response = false;
377 let mut memcached_awaiting = false;
378 loop {
379 let n = sr.read(&mut buf).await?;
380 if n == 0 { break; }
381 let data = &buf[..n];
382 cw.write_all(data).await?;
383
384 if protocol == Protocol::Mysql {
385 let has_pending = pending_r.lock().unwrap().is_some();
386 if has_pending || awaiting_response {
387 awaiting_response = true;
388 mysql_buf.extend_from_slice(data);
389 if mysql_response_complete(&mysql_buf) {
390 if let Some(req) = pending_r.lock().unwrap().take() {
391 let latency = req.instant.elapsed();
392 let response = parse_response(protocol, &mysql_buf).unwrap_or_default();
393 let response_detail = format_response_detail(protocol, &mysql_buf).unwrap_or_default();
394 let _ = tx_resp.send(ProxyEvent {
395 timestamp: req.timestamp,
396 component: name_resp.clone(),
397 protocol,
398 command: req.command,
399 full_command: req.full_command,
400 response,
401 response_detail,
402 latency,
403 process: process_mysql.clone(),
404 src: Some(src_resp.clone()),
405 dest: Some(dest_resp.clone()),
406 system: false,
407 });
408 }
409 mysql_buf.clear();
410 awaiting_response = false;
411 }
412 }
413 } else if protocol == Protocol::Http {
414 http_resp_buf.extend_from_slice(data);
415 if ocular_protocol::http::http_response_complete(&http_resp_buf) {
416 if let Some(req) = pending_r.lock().unwrap().take() {
417 let latency = req.instant.elapsed();
418 let response = parse_response(protocol, &http_resp_buf).unwrap_or_default();
419 let response_detail = format_response_detail(protocol, &http_resp_buf).unwrap_or_else(|| response.clone());
420 let _ = tx_resp.send(ProxyEvent {
421 timestamp: req.timestamp,
422 component: name_resp.clone(),
423 protocol,
424 command: req.command,
425 full_command: req.full_command,
426 response,
427 response_detail,
428 latency,
429 process: process_info.clone(),
430 src: Some(src_resp.clone()),
431 dest: Some(dest_resp.clone()),
432 system: false,
433 });
434 }
435 http_resp_buf.clear();
436 }
437 } else if protocol == Protocol::Amqp {
438 let mut pos = 0;
440 while pos < data.len() {
441 let frame_data = &data[pos..];
442 let Some(flen) = amqp_frame_len(frame_data) else { break };
443 if let Some(frame) = parse_amqp_frame(frame_data) {
444 if frame.frame_type == 2 || frame.frame_type == 3 {
446 pos += flen;
447 continue;
448 }
449 if frame.frame_type == 8 {
451 pos += flen;
452 continue;
453 }
454 }
455
456 let mut body_text = String::new();
458 let mut peek = pos + flen;
459 while peek < data.len() {
460 let peek_data = &data[peek..];
461 let Some(plen) = amqp_frame_len(peek_data) else { break };
462 if let Some(pf) = parse_amqp_frame(peek_data) {
463 if pf.frame_type == 2 {
464 } else if pf.frame_type == 3 {
466 if let Some(body) = &pf.body {
468 body_text = String::from_utf8_lossy(body).to_string();
469 }
470 } else {
471 break;
472 }
473 } else {
474 break;
475 }
476 peek += plen;
477 }
478
479 if let Some(req) = pending_r.lock().unwrap().take() {
480 let latency = req.instant.elapsed();
481 let mut response = parse_response(protocol, frame_data).unwrap_or_default();
482 let mut response_detail = format_response_detail(protocol, frame_data).unwrap_or_else(|| response.clone());
483 if !body_text.is_empty() {
484 response = format!("{} | {}", response, body_text);
485 response_detail = format!("{}\nBody: {}", response_detail, body_text);
486 }
487 let _ = tx_resp.send(ProxyEvent {
488 timestamp: req.timestamp,
489 component: name_resp.clone(),
490 protocol,
491 command: req.command,
492 full_command: req.full_command,
493 response,
494 response_detail,
495 latency,
496 process: process_info.clone(),
497 src: Some(src_resp.clone()),
498 dest: Some(dest_resp.clone()),
499 system: false,
500 });
501 } else if let Some(frame) = parse_amqp_frame(frame_data) {
502 if let Some(ref method) = frame.method {
504 let response = if body_text.is_empty() { String::new() } else { body_text.clone() };
505 let response_detail = if body_text.is_empty() { String::new() } else { body_text.clone() };
506 let command = method.summary.clone();
507 let _ = tx_resp.send(ProxyEvent {
508 timestamp: SystemTime::now(),
509 component: name_resp.clone(),
510 protocol,
511 command,
512 full_command: method.detail.clone(),
513 response,
514 response_detail,
515 latency: std::time::Duration::ZERO,
516 process: process_info.clone(),
517 src: Some(dest_resp.clone()),
518 dest: Some(src_resp.clone()),
519 system: false,
520 });
521 }
522 }
523 pos = peek;
525 }
526 } else if protocol == Protocol::Postgres {
527 pg_resp_buf.extend_from_slice(data);
529 if postgres_response_complete(&pg_resp_buf) {
530 if let Some(req) = pending_r.lock().unwrap().take() {
531 let latency = req.instant.elapsed();
532 let response = parse_response(protocol, &pg_resp_buf).unwrap_or_default();
533 let response_detail = format_response_detail(protocol, &pg_resp_buf).unwrap_or_else(|| response.clone());
534 let _ = tx_resp.send(ProxyEvent {
535 timestamp: req.timestamp,
536 component: name_resp.clone(),
537 protocol,
538 command: req.command,
539 full_command: req.full_command,
540 response,
541 response_detail,
542 latency,
543 process: process_info.clone(),
544 src: Some(src_resp.clone()),
545 dest: Some(dest_resp.clone()),
546 system: false,
547 });
548 }
549 pg_resp_buf.clear();
550 }
551 } else if protocol == Protocol::Memcached {
552 let has_pending = pending_r.lock().unwrap().is_some();
553 if has_pending || memcached_awaiting {
554 memcached_awaiting = true;
555 memcached_resp_buf.extend_from_slice(data);
556 if ocular_protocol::memcached::memcached_response_complete(&memcached_resp_buf) {
557 if let Some(req) = pending_r.lock().unwrap().take() {
558 let latency = req.instant.elapsed();
559 let response = parse_response(protocol, &memcached_resp_buf).unwrap_or_default();
560 let response_detail = format_response_detail(protocol, &memcached_resp_buf).unwrap_or_else(|| response.clone());
561 let _ = tx_resp.send(ProxyEvent {
562 timestamp: req.timestamp,
563 component: name_resp.clone(),
564 protocol,
565 command: req.command,
566 full_command: req.full_command,
567 response,
568 response_detail,
569 latency,
570 process: process_info.clone(),
571 src: Some(src_resp.clone()),
572 dest: Some(dest_resp.clone()),
573 system: false,
574 });
575 }
576 memcached_resp_buf.clear();
577 memcached_awaiting = false;
578 }
579 }
580 } else if protocol == Protocol::Kafka {
581 kafka_resp_buf.extend_from_slice(data);
582 while ocular_protocol::kafka::kafka_frame_complete(&kafka_resp_buf) {
583 let frame_len = i32::from_be_bytes([kafka_resp_buf[0], kafka_resp_buf[1], kafka_resp_buf[2], kafka_resp_buf[3]]) as usize + 4;
584 if let Some(req) = pending_r.lock().unwrap().take() {
585 let latency = req.instant.elapsed();
586 let response = parse_response(protocol, &kafka_resp_buf[..frame_len]).unwrap_or_default();
587 let response_detail = format_response_detail(protocol, &kafka_resp_buf[..frame_len]).unwrap_or_else(|| response.clone());
588 let _ = tx_resp.send(ProxyEvent {
589 timestamp: req.timestamp,
590 component: name_resp.clone(),
591 protocol,
592 command: req.command,
593 full_command: req.full_command,
594 response,
595 response_detail,
596 latency,
597 process: process_info.clone(),
598 src: Some(src_resp.clone()),
599 dest: Some(dest_resp.clone()),
600 system: false,
601 });
602 }
603 kafka_resp_buf = kafka_resp_buf[frame_len..].to_vec();
604 }
605 } else {
606 if let Some(req) = pending_r.lock().unwrap().take() {
608 let latency = req.instant.elapsed();
609 let response = parse_response(protocol, data).unwrap_or_default();
610 let response_detail = format_response_detail(protocol, data).unwrap_or_else(|| response.clone());
611 let _ = tx_resp.send(ProxyEvent {
612 timestamp: req.timestamp,
613 component: name_resp.clone(),
614 protocol,
615 command: req.command,
616 full_command: req.full_command,
617 response,
618 response_detail,
619 latency,
620 process: process_info.clone(),
621 src: Some(src_resp.clone()),
622 dest: Some(dest_resp.clone()),
623 system: false,
624 });
625 }
626 }
627 }
628 Ok::<_, anyhow::Error>(())
629 };
630
631 tokio::pin!(client_to_server);
632 tokio::pin!(server_to_client);
633
634 tokio::select! {
635 r = &mut client_to_server => {
636 if r.is_ok() && pending_final.lock().unwrap().is_some() {
638 let _ = tokio::time::timeout(
639 Duration::from_millis(500),
640 &mut server_to_client,
641 ).await;
642 }
643 },
644 r = &mut server_to_client => r?,
645 }
646 Ok(())
647}
648
649fn strip_mysql_ssl_flag(packet: &mut [u8]) {
650 if packet.len() < 5 { return; }
651 let payload = &mut packet[4..];
652 if payload.is_empty() || payload[0] != 10 { return; }
653 let mut pos = 1;
654 while pos < payload.len() && payload[pos] != 0 { pos += 1; }
655 pos += 1;
656 pos += 4;
657 pos += 8;
658 pos += 1;
659 if pos + 2 > payload.len() { return; }
660 let cap_lower = u16::from_le_bytes([payload[pos], payload[pos + 1]]);
661 let cap_lower_new = cap_lower & !0x0800;
662 payload[pos] = (cap_lower_new & 0xff) as u8;
663 payload[pos + 1] = ((cap_lower_new >> 8) & 0xff) as u8;
664}
665
666fn resolve_peer_process(port: u16) -> Option<String> {
668 use std::process::Command;
669 let my_pid = std::process::id().to_string();
670
671 if cfg!(target_os = "macos") {
672 let output = Command::new("lsof")
675 .args(["-i", &format!("tcp:{}", port), "-sTCP:ESTABLISHED", "-Fp", "-Fc"])
676 .output()
677 .ok()?;
678 let text = String::from_utf8_lossy(&output.stdout);
679 let mut current_pid = String::new();
680 let mut current_cmd = String::new();
681 for line in text.lines() {
682 if let Some(p) = line.strip_prefix('p') {
683 if !current_pid.is_empty() && current_pid != my_pid {
685 return Some(format!("[{}] {}", current_pid, current_cmd));
686 }
687 current_pid = p.to_string();
688 current_cmd.clear();
689 }
690 if let Some(c) = line.strip_prefix('c') {
691 current_cmd = c.to_string();
692 }
693 }
694 if !current_pid.is_empty() && current_pid != my_pid {
696 return Some(format!("[{}] {}", current_pid, current_cmd));
697 }
698 None
699 } else {
700 let output = Command::new("ss")
702 .args(["-tnp", &format!("sport = :{}", port)])
703 .output()
704 .ok()?;
705 let text = String::from_utf8_lossy(&output.stdout);
706 for line in text.lines() {
708 if let Some(start) = line.find("users:((\"") {
709 let rest = &line[start + 9..];
710 if let Some(end) = rest.find('"') {
711 let proc_name = &rest[..end];
712 let pid = rest.find("pid=")
713 .and_then(|i| rest[i+4..].split(|c: char| !c.is_ascii_digit()).next())
714 .unwrap_or("?");
715 return Some(format!("[{}] {}", pid, proc_name));
716 }
717 }
718 }
719 None
720 }
721}
722
723#[derive(Debug)]
725struct NoVerify;
726
727impl rustls::client::danger::ServerCertVerifier for NoVerify {
728 fn verify_server_cert(
729 &self, _: &rustls::pki_types::CertificateDer<'_>, _: &[rustls::pki_types::CertificateDer<'_>],
730 _: &rustls::pki_types::ServerName<'_>, _: &[u8], _: rustls::pki_types::UnixTime,
731 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
732 Ok(rustls::client::danger::ServerCertVerified::assertion())
733 }
734 fn verify_tls12_signature(
735 &self, _: &[u8], _: &rustls::pki_types::CertificateDer<'_>, _: &rustls::DigitallySignedStruct,
736 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
737 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
738 }
739 fn verify_tls13_signature(
740 &self, _: &[u8], _: &rustls::pki_types::CertificateDer<'_>, _: &rustls::DigitallySignedStruct,
741 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
742 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
743 }
744 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
745 vec![
746 rustls::SignatureScheme::RSA_PKCS1_SHA256,
747 rustls::SignatureScheme::RSA_PKCS1_SHA384,
748 rustls::SignatureScheme::RSA_PKCS1_SHA512,
749 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
750 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
751 rustls::SignatureScheme::RSA_PSS_SHA256,
752 rustls::SignatureScheme::RSA_PSS_SHA384,
753 rustls::SignatureScheme::RSA_PSS_SHA512,
754 rustls::SignatureScheme::ED25519,
755 ]
756 }
757}
758
759#[cfg(test)]
760mod tests {
761 use super::*;
762
763 #[test]
764 fn test_strip_mysql_ssl_flag_short_packet() {
765 let mut buf = vec![0u8; 3];
766 strip_mysql_ssl_flag(&mut buf);
767 assert_eq!(buf, vec![0u8; 3]);
768 }
769
770 #[test]
771 fn test_strip_mysql_ssl_flag_not_greeting() {
772 let mut buf = vec![0u8; 10];
773 buf[4] = 9;
774 strip_mysql_ssl_flag(&mut buf);
775 assert_eq!(buf[4], 9);
776 }
777
778 fn caps_offset(pkt: &[u8]) -> Option<usize> {
780 if pkt.len() < 5 { return None; }
781 let mut pos = 5;
782 while pos < pkt.len() && pkt[pos] != 0 { pos += 1; }
784 pos += 1; if pos + 13 > pkt.len() { return None; }
786 pos += 4; pos += 8; pos += 1; Some(pos)
790 }
791
792 #[test]
793 fn test_strip_mysql_ssl_flag_clears_ssl_bit() {
794 let version = b"5.7.0\0";
795 let mut payload = vec![10]; payload.extend_from_slice(version);
797 payload.extend_from_slice(&[0u8; 4]); payload.extend_from_slice(&[0u8; 8]); payload.push(0); let caps: u16 = 0x0800; payload.extend_from_slice(&caps.to_le_bytes());
802 payload.extend_from_slice(&[0u8; 13]);
803
804 let pkt_len = payload.len();
805 let mut pkt = vec![
806 (pkt_len & 0xff) as u8,
807 ((pkt_len >> 8) & 0xff) as u8,
808 ((pkt_len >> 16) & 0xff) as u8,
809 0,
810 ];
811 pkt.extend_from_slice(&payload);
812
813 let off = caps_offset(&pkt).unwrap();
814 assert!(u16::from_le_bytes([pkt[off], pkt[off + 1]]) & 0x0800 != 0);
815
816 strip_mysql_ssl_flag(&mut pkt);
817
818 assert_eq!(u16::from_le_bytes([pkt[off], pkt[off + 1]]) & 0x0800, 0);
819 }
820
821 #[test]
822 fn test_resolve_peer_process_does_not_panic() {
823 let result = std::panic::catch_unwind(|| resolve_peer_process(0));
824 assert!(result.is_ok());
825 }
826}