Skip to main content

ocular_proxy/
lib.rs

1use anyhow::Result;
2use ocular_protocol::{Protocol, mysql::mysql_response_complete, postgres::postgres_response_complete, parse_request, parse_response, extract_full_command, format_response_detail, parse_amqp_frame, parse_amqp_request_full, is_async_method, amqp_frame_len};
3use std::sync::{Arc, Mutex};
4use std::time::{Duration, Instant, SystemTime};
5use std::sync::atomic::{AtomicUsize, Ordering};
6use tokio::io::{AsyncReadExt, AsyncWriteExt, AsyncRead, AsyncWrite};
7use tokio::net::{TcpListener, TcpStream};
8use tokio::sync::broadcast;
9use tracing::{info, warn, error, debug};
10
11pub use ocular_protocol::ProxyEvent;
12pub use ocular_protocol::{ConnectionState, StatusMap};
13
14/// Pending request info
15struct PendingRequest {
16    timestamp: SystemTime,
17    instant: Instant,
18    command: String,
19    full_command: String,
20}
21
22pub async fn run_proxy(
23    listen_addr: String,
24    remote_addr: String,
25    name: String,
26    protocol: Protocol,
27    tx: broadcast::Sender<ProxyEvent>,
28    mut shutdown: tokio::sync::watch::Receiver<bool>,
29    status: StatusMap,
30) -> Result<()> {
31    let listener = match TcpListener::bind(&listen_addr).await {
32        Ok(l) => l,
33        Err(e) => {
34            let msg = format!("bind failed on {}: {}", listen_addr, e);
35            let _ = tx.send(ProxyEvent::system_event(&name, msg));
36            status.lock().unwrap().entry(name.clone()).or_default().last_error = Some(format!("bind failed: {}", e));
37            return Err(e.into());
38        }
39    };
40    let conn_count = Arc::new(AtomicUsize::new(0));
41    {
42        let mut map = status.lock().unwrap();
43        map.entry(name.clone()).or_default().has_connector = true;
44    }
45    info!(component = %name, listen = %listen_addr, remote = %remote_addr, ?protocol, "proxy listening");
46
47    loop {
48        tokio::select! {
49            result = listener.accept() => {
50                let (client, peer) = result?;
51                debug!(component = %name, peer = %peer, "new client connection");
52                let remote = remote_addr.clone();
53                let name = name.clone();
54                let tx = tx.clone();
55                let process = resolve_peer_process(peer.port());
56                let peer_addr = peer.to_string();
57                let remote_for_conn = remote.clone();
58                let conn_count = conn_count.clone();
59                let status = status.clone();
60                let protocol_for_conn = protocol;
61                tokio::spawn(async move {
62                    conn_count.fetch_add(1, Ordering::Relaxed);
63                    {
64                        let mut map = status.lock().unwrap();
65                        let s = map.entry(name.clone()).or_default();
66                        s.active_connections = conn_count.load(Ordering::Relaxed);
67                        s.last_active_at = Some(SystemTime::now());
68                    }
69                    if let Err(e) = handle_conn(client, &remote, &name, protocol_for_conn, &tx, process, peer_addr, remote_for_conn).await {
70                        warn!(component = %name, remote = %remote, error = %e, "connection ended with error");
71                        let _ = tx.send(ProxyEvent::system_event(&name, format!("connection error: {}", e)));
72                        status.lock().unwrap().entry(name.clone()).or_default().last_error = Some(e.to_string());
73                    }
74                    let remaining = conn_count.fetch_sub(1, Ordering::Relaxed).saturating_sub(1);
75                    status.lock().unwrap().entry(name.clone()).or_default().active_connections = remaining;
76                });
77            }
78            _ = shutdown.changed() => {
79                info!(component = %name, "proxy shutting down");
80                break;
81            }
82        }
83    }
84    Ok(())
85}
86
87#[allow(clippy::too_many_arguments)]
88async fn handle_conn(
89    mut client: TcpStream,
90    remote_addr: &str,
91    name: &str,
92    protocol: Protocol,
93    tx: &broadcast::Sender<ProxyEvent>,
94    process: Option<String>,
95    src: String,
96    dest: String,
97) -> Result<()> {
98    // Parse remote address: detect https:// for TLS outbound
99    let (actual_addr, use_tls, tls_host) = if remote_addr.starts_with("https://") {
100        let stripped = remote_addr.strip_prefix("https://").unwrap();
101        let host = stripped.split(':').next().unwrap_or(stripped).to_string();
102        (stripped.to_string(), true, host)
103    } else {
104        let stripped = remote_addr.strip_prefix("http://").unwrap_or(remote_addr);
105        (stripped.to_string(), false, String::new())
106    };
107
108    let tcp_stream = match TcpStream::connect(&actual_addr).await {
109        Ok(s) => {
110            debug!(component = %name, remote = %actual_addr, "connected to remote");
111            s
112        }
113        Err(e) => {
114            error!(component = %name, remote = %actual_addr, error = %e,
115                "failed to connect to remote — is the service running?");
116            let _ = tx.send(ProxyEvent::system_event(name, format!("cannot reach {} ({})", actual_addr, e)));
117            if protocol == Protocol::Redis {
118                let err_msg = format!("-ERR ocular proxy: cannot reach {} ({})\r\n", actual_addr, e);
119                let _ = client.write_all(err_msg.as_bytes()).await;
120            }
121            return Err(e.into());
122        }
123    };
124
125    let (sr, sw): (Box<dyn AsyncRead + Unpin + Send>, Box<dyn AsyncWrite + Unpin + Send>) = if use_tls {
126        let config = rustls::ClientConfig::builder()
127            .dangerous()
128            .with_custom_certificate_verifier(Arc::new(NoVerify))
129            .with_no_client_auth();
130        let connector = tokio_rustls::TlsConnector::from(Arc::new(config));
131        let domain = rustls::pki_types::ServerName::try_from(tls_host)
132            .map_err(|e| anyhow::anyhow!("invalid TLS hostname: {}", e))?;
133        let tls_stream = connector.connect(domain, tcp_stream).await?;
134        let (r, w) = tokio::io::split(tls_stream);
135        (Box::new(r) as Box<dyn AsyncRead + Unpin + Send>, Box::new(w) as Box<dyn AsyncWrite + Unpin + Send>)
136    } else {
137        let (r, w) = tokio::io::split(tcp_stream);
138        (Box::new(r) as Box<dyn AsyncRead + Unpin + Send>, Box::new(w) as Box<dyn AsyncWrite + Unpin + Send>)
139    };
140
141    let mut sr = sr;
142    let mut sw = sw;
143
144    // For MySQL: strip SSL from greeting
145    if protocol == Protocol::Mysql {
146        let mut greeting_buf = [0u8; 65536];
147        let n = sr.read(&mut greeting_buf).await?;
148        if n == 0 { return Ok(()); }
149        let mut greeting = greeting_buf[..n].to_vec();
150        strip_mysql_ssl_flag(&mut greeting);
151        client.write_all(&greeting).await?;
152        debug!(component = %name, "forwarded MySQL greeting with SSL stripped");
153    }
154
155    // For PostgreSQL: strip SSL by forwarding negotiation to server but replying N to client.
156    // This lets the server know the connection won't be encrypted (may affect auth requirements).
157    if protocol == Protocol::Postgres {
158        let mut buf = [0u8; 256];
159        let n = client.read(&mut buf).await?;
160        if n == 0 { return Ok(()); }
161        let data = &buf[..n];
162        let neg_code = if n >= 8 {
163            u32::from_be_bytes([data[4], data[5], data[6], data[7]])
164        } else { 0 };
165        if neg_code == 80877103 || neg_code == 80877104 {
166            // Forward negotiation to server so it knows the connection state
167            sw.write_all(data).await?;
168            // Read server's response (single byte: N or S)
169            let mut resp = [0u8; 1];
170            let rn = sr.read(&mut resp).await?;
171            if rn == 0 { return Ok(()); }
172            // Always tell client: no SSL/GSS (force plaintext for proxy to parse)
173            client.write_all(b"N").await?;
174        } else {
175            // Not a negotiation request, forward as Startup
176            sw.write_all(data).await?;
177        }
178    }
179
180    let (mut cr, mut cw) = client.split();
181
182    let pending: Arc<Mutex<Option<PendingRequest>>> = Arc::new(Mutex::new(None));
183
184    let name_req = name.to_string();
185    let name_resp = name.to_string();
186    let tx_req = tx.clone();
187    let tx_resp = tx.clone();
188    let pending_w = pending.clone();
189    let pending_final = pending.clone();
190    let pending_r = pending;
191    let process_info = process;
192
193    let process_req = process_info.clone();
194    let src_req = src.clone();
195    let dest_req = dest.clone();
196    let src_resp = src.clone();
197    let dest_resp = dest;
198    let client_to_server = async move {
199        let mut buf = [0u8; 65536];
200        let mut http_req_buf: Vec<u8> = Vec::with_capacity(4096);
201        let mut memcached_req_buf: Vec<u8> = Vec::with_capacity(4096);
202        let mut kafka_req_buf: Vec<u8> = Vec::with_capacity(4096);
203        loop {
204            let n = cr.read(&mut buf).await?;
205            if n == 0 { break; }
206            let data = &buf[..n];
207
208            if protocol == Protocol::Amqp {
209                // AMQP: loop through all frames in this read
210                let mut pos = 0;
211                while pos < data.len() {
212                    let frame_data = &data[pos..];
213                    let Some(flen) = amqp_frame_len(frame_data) else { break };
214                    if let Some(frame) = parse_amqp_frame(frame_data) {
215                        // Skip heartbeat — not a real request
216                        if frame.frame_type == 8 {
217                            pos += flen;
218                            continue;
219                        }
220                        if let Some(ref method) = frame.method {
221                            if is_async_method(method.class_id, method.method_id) {
222                                let (summary, detail) = parse_amqp_request_full(frame_data)
223                                    .unwrap_or_else(|| (method.summary.clone(), method.detail.clone()));
224                                let _ = tx_req.send(ProxyEvent {
225                                    timestamp: SystemTime::now(),
226                                    component: name_req.clone(),
227                                    protocol,
228                                    command: summary,
229                                    full_command: detail.clone(),
230                                    response: String::new(),
231                                    response_detail: detail,
232                                    latency: std::time::Duration::ZERO,
233                                    process: process_req.clone(),
234                                    src: Some(src_req.clone()),
235                                    dest: Some(dest_req.clone()),
236                    system: false,
237                                });
238                            } else {
239                                debug!(component = %name_req, command = %method.summary);
240                                *pending_w.lock().unwrap() = Some(PendingRequest {
241                                    timestamp: SystemTime::now(),
242                                    instant: Instant::now(),
243                                    command: method.summary.clone(),
244                                    full_command: method.detail.clone(),
245                                });
246                            }
247                        }
248                    }
249                    pos += flen;
250                }
251            } else if protocol == Protocol::Http {
252                http_req_buf.extend_from_slice(data);
253                if ocular_protocol::http::http_request_complete(&http_req_buf) {
254                    if let Some(command) = parse_request(protocol, &http_req_buf) {
255                        let full_command = extract_full_command(protocol, &http_req_buf).unwrap_or_else(|| command.clone());
256                        *pending_w.lock().unwrap() = Some(PendingRequest {
257                            timestamp: SystemTime::now(),
258                            instant: Instant::now(),
259                            command,
260                            full_command,
261                        });
262                    }
263                    http_req_buf.clear();
264                }
265            } else if protocol == Protocol::Memcached {
266                memcached_req_buf.extend_from_slice(data);
267                while ocular_protocol::memcached::memcached_request_complete(&memcached_req_buf) {
268                    // If there's already a pending request that won't get a response pairing,
269                    // emit it as a standalone event
270                    if let Some(prev) = pending_w.lock().unwrap().take() {
271                        let _ = tx_req.send(ProxyEvent {
272                            timestamp: prev.timestamp,
273                            component: name_req.clone(),
274                            protocol,
275                            command: prev.command,
276                            full_command: prev.full_command,
277                            response: String::new(),
278                            response_detail: String::new(),
279                            latency: Duration::ZERO,
280                            process: process_req.clone(),
281                            src: Some(src_req.clone()),
282                            dest: Some(dest_req.clone()),
283                            system: false,
284                        });
285                    }
286                    if let Some(command) = parse_request(protocol, &memcached_req_buf) {
287                        let full_command = extract_full_command(protocol, &memcached_req_buf).unwrap_or_else(|| command.clone());
288                        *pending_w.lock().unwrap() = Some(PendingRequest {
289                            timestamp: SystemTime::now(),
290                            instant: Instant::now(),
291                            command,
292                            full_command,
293                        });
294                    }
295                    // Advance past this request
296                    let s = std::str::from_utf8(&memcached_req_buf).unwrap_or("");
297                    let first_crlf = s.find("\r\n").unwrap_or(0);
298                    let line = &s[..first_crlf];
299                    let parts: Vec<&str> = line.split_whitespace().collect();
300                    let cmd = parts.first().map(|c| c.to_uppercase()).unwrap_or_default();
301                    let consumed = match cmd.as_str() {
302                        "SET" | "ADD" | "REPLACE" | "APPEND" | "PREPEND" | "CAS" => {
303                            let bytes: usize = parts.get(4).and_then(|b| b.parse().ok()).unwrap_or(0);
304                            first_crlf + 2 + bytes + 2
305                        }
306                        _ => first_crlf + 2,
307                    };
308                    memcached_req_buf = memcached_req_buf[consumed..].to_vec();
309                }
310            } else if protocol == Protocol::Kafka {
311                kafka_req_buf.extend_from_slice(data);
312                while ocular_protocol::kafka::kafka_frame_complete(&kafka_req_buf) {
313                    let frame_len = i32::from_be_bytes([kafka_req_buf[0], kafka_req_buf[1], kafka_req_buf[2], kafka_req_buf[3]]) as usize + 4;
314                    let frame = &kafka_req_buf[..frame_len];
315                    if let Some(command) = parse_request(protocol, frame) {
316                        let full_command = extract_full_command(protocol, frame).unwrap_or_else(|| command.clone());
317                        *pending_w.lock().unwrap() = Some(PendingRequest {
318                            timestamp: SystemTime::now(),
319                            instant: Instant::now(),
320                            command,
321                            full_command,
322                        });
323                    }
324                    kafka_req_buf = kafka_req_buf[frame_len..].to_vec();
325                }
326            } else if protocol == Protocol::Postgres {
327                // Postgres: scan all messages in this read, keep SQL from Q/P only
328                let mut pos = 0;
329                while pos < data.len() {
330                    let first = data[pos];
331                    let is_typed = matches!(first, b'Q' | b'P' | b'B' | b'E' | b'D' | b'S' | b'X' | b'C' | b'p' | b'H' | b'F' | b'd' | b'c' | b'f');
332                    if !is_typed { break; }
333                    if pos + 5 > data.len() { break; }
334                    let len = u32::from_be_bytes([data[pos+1], data[pos+2], data[pos+3], data[pos+4]]) as usize;
335                    let end = pos + 1 + len;
336                    if end > data.len() { break; }
337                    // Only set pending for Q (simple query) or P (Parse with SQL)
338                    if first == b'Q' || first == b'P' {
339                        let msg = &data[pos..end];
340                        if let Some(command) = parse_request(protocol, msg) {
341                            let full_command = extract_full_command(protocol, msg).unwrap_or_else(|| command.clone());
342                            *pending_w.lock().unwrap() = Some(PendingRequest {
343                                timestamp: SystemTime::now(),
344                                instant: Instant::now(),
345                                command,
346                                full_command,
347                            });
348                        }
349                    }
350                    pos = end;
351                }
352            } else if let Some(command) = parse_request(protocol, data) {
353                let full_command = extract_full_command(protocol, data).unwrap_or_else(|| command.clone());
354                debug!(component = %name_req, %command);
355                *pending_w.lock().unwrap() = Some(PendingRequest {
356                    timestamp: SystemTime::now(),
357                    instant: Instant::now(),
358                    command,
359                    full_command,
360                });
361            }
362
363            sw.write_all(data).await?;
364        }
365        Ok::<_, anyhow::Error>(())
366    };
367
368    let process_mysql = process_info.clone();
369    let server_to_client = async move {
370        let mut buf = [0u8; 65536];
371        let mut mysql_buf: Vec<u8> = Vec::with_capacity(4096);
372        let mut http_resp_buf: Vec<u8> = Vec::with_capacity(4096);
373        let mut memcached_resp_buf: Vec<u8> = Vec::with_capacity(4096);
374        let mut kafka_resp_buf: Vec<u8> = Vec::with_capacity(4096);
375        let mut pg_resp_buf: Vec<u8> = Vec::with_capacity(4096);
376        let mut awaiting_response = false;
377        let mut memcached_awaiting = false;
378        loop {
379            let n = sr.read(&mut buf).await?;
380            if n == 0 { break; }
381            let data = &buf[..n];
382            cw.write_all(data).await?;
383
384            if protocol == Protocol::Mysql {
385                let has_pending = pending_r.lock().unwrap().is_some();
386                if has_pending || awaiting_response {
387                    awaiting_response = true;
388                    mysql_buf.extend_from_slice(data);
389                    if mysql_response_complete(&mysql_buf) {
390                        if let Some(req) = pending_r.lock().unwrap().take() {
391                            let latency = req.instant.elapsed();
392                            let response = parse_response(protocol, &mysql_buf).unwrap_or_default();
393                            let response_detail = format_response_detail(protocol, &mysql_buf).unwrap_or_default();
394                            let _ = tx_resp.send(ProxyEvent {
395                                timestamp: req.timestamp,
396                                component: name_resp.clone(),
397                                protocol,
398                                command: req.command,
399                                full_command: req.full_command,
400                                response,
401                                response_detail,
402                                latency,
403                                process: process_mysql.clone(),
404                                src: Some(src_resp.clone()),
405                                dest: Some(dest_resp.clone()),
406                    system: false,
407                            });
408                        }
409                        mysql_buf.clear();
410                        awaiting_response = false;
411                    }
412                }
413            } else if protocol == Protocol::Http {
414                http_resp_buf.extend_from_slice(data);
415                if ocular_protocol::http::http_response_complete(&http_resp_buf) {
416                    if let Some(req) = pending_r.lock().unwrap().take() {
417                        let latency = req.instant.elapsed();
418                        let response = parse_response(protocol, &http_resp_buf).unwrap_or_default();
419                        let response_detail = format_response_detail(protocol, &http_resp_buf).unwrap_or_else(|| response.clone());
420                        let _ = tx_resp.send(ProxyEvent {
421                            timestamp: req.timestamp,
422                            component: name_resp.clone(),
423                            protocol,
424                            command: req.command,
425                            full_command: req.full_command,
426                            response,
427                            response_detail,
428                            latency,
429                            process: process_info.clone(),
430                                    src: Some(src_resp.clone()),
431                                    dest: Some(dest_resp.clone()),
432                    system: false,
433                        });
434                    }
435                    http_resp_buf.clear();
436                }
437            } else if protocol == Protocol::Amqp {
438                // AMQP: loop through all server frames
439                let mut pos = 0;
440                while pos < data.len() {
441                    let frame_data = &data[pos..];
442                    let Some(flen) = amqp_frame_len(frame_data) else { break };
443                    if let Some(frame) = parse_amqp_frame(frame_data) {
444                        // Skip content header and body frames — handled below with method
445                        if frame.frame_type == 2 || frame.frame_type == 3 {
446                            pos += flen;
447                            continue;
448                        }
449                        // Heartbeat: skip
450                        if frame.frame_type == 8 {
451                            pos += flen;
452                            continue;
453                        }
454                    }
455
456                    // Extract body from subsequent Header+Body frames
457                    let mut body_text = String::new();
458                    let mut peek = pos + flen;
459                    while peek < data.len() {
460                        let peek_data = &data[peek..];
461                        let Some(plen) = amqp_frame_len(peek_data) else { break };
462                        if let Some(pf) = parse_amqp_frame(peek_data) {
463                            if pf.frame_type == 2 {
464                                // Header frame, skip
465                            } else if pf.frame_type == 3 {
466                                // Body frame
467                                if let Some(body) = &pf.body {
468                                    body_text = String::from_utf8_lossy(body).to_string();
469                                }
470                            } else {
471                                break;
472                            }
473                        } else {
474                            break;
475                        }
476                        peek += plen;
477                    }
478
479                    if let Some(req) = pending_r.lock().unwrap().take() {
480                        let latency = req.instant.elapsed();
481                        let mut response = parse_response(protocol, frame_data).unwrap_or_default();
482                        let mut response_detail = format_response_detail(protocol, frame_data).unwrap_or_else(|| response.clone());
483                        if !body_text.is_empty() {
484                            response = format!("{} | {}", response, body_text);
485                            response_detail = format!("{}\nBody: {}", response_detail, body_text);
486                        }
487                        let _ = tx_resp.send(ProxyEvent {
488                            timestamp: req.timestamp,
489                            component: name_resp.clone(),
490                            protocol,
491                            command: req.command,
492                            full_command: req.full_command,
493                            response,
494                            response_detail,
495                            latency,
496                            process: process_info.clone(),
497                                    src: Some(src_resp.clone()),
498                                    dest: Some(dest_resp.clone()),
499                    system: false,
500                        });
501                    } else if let Some(frame) = parse_amqp_frame(frame_data) {
502                        // Server-initiated method (e.g. Basic.Deliver) — emit as standalone
503                        if let Some(ref method) = frame.method {
504                            let response = if body_text.is_empty() { String::new() } else { body_text.clone() };
505                            let response_detail = if body_text.is_empty() { String::new() } else { body_text.clone() };
506                            let command = method.summary.clone();
507                            let _ = tx_resp.send(ProxyEvent {
508                                timestamp: SystemTime::now(),
509                                component: name_resp.clone(),
510                                protocol,
511                                command,
512                                full_command: method.detail.clone(),
513                                response,
514                                response_detail,
515                                latency: std::time::Duration::ZERO,
516                                process: process_info.clone(),
517                                    src: Some(dest_resp.clone()),
518                                    dest: Some(src_resp.clone()),
519                    system: false,
520                            });
521                        }
522                    }
523                    // Advance past the method frame + any header/body frames we consumed
524                    pos = peek;
525                }
526            } else if protocol == Protocol::Postgres {
527                // Buffer until ReadyForQuery, then emit single event
528                pg_resp_buf.extend_from_slice(data);
529                if postgres_response_complete(&pg_resp_buf) {
530                    if let Some(req) = pending_r.lock().unwrap().take() {
531                        let latency = req.instant.elapsed();
532                        let response = parse_response(protocol, &pg_resp_buf).unwrap_or_default();
533                        let response_detail = format_response_detail(protocol, &pg_resp_buf).unwrap_or_else(|| response.clone());
534                        let _ = tx_resp.send(ProxyEvent {
535                            timestamp: req.timestamp,
536                            component: name_resp.clone(),
537                            protocol,
538                            command: req.command,
539                            full_command: req.full_command,
540                            response,
541                            response_detail,
542                            latency,
543                            process: process_info.clone(),
544                            src: Some(src_resp.clone()),
545                            dest: Some(dest_resp.clone()),
546                            system: false,
547                        });
548                    }
549                    pg_resp_buf.clear();
550                }
551            } else if protocol == Protocol::Memcached {
552                let has_pending = pending_r.lock().unwrap().is_some();
553                if has_pending || memcached_awaiting {
554                    memcached_awaiting = true;
555                    memcached_resp_buf.extend_from_slice(data);
556                    if ocular_protocol::memcached::memcached_response_complete(&memcached_resp_buf) {
557                        if let Some(req) = pending_r.lock().unwrap().take() {
558                            let latency = req.instant.elapsed();
559                            let response = parse_response(protocol, &memcached_resp_buf).unwrap_or_default();
560                            let response_detail = format_response_detail(protocol, &memcached_resp_buf).unwrap_or_else(|| response.clone());
561                            let _ = tx_resp.send(ProxyEvent {
562                                timestamp: req.timestamp,
563                                component: name_resp.clone(),
564                                protocol,
565                                command: req.command,
566                                full_command: req.full_command,
567                                response,
568                                response_detail,
569                                latency,
570                                process: process_info.clone(),
571                                src: Some(src_resp.clone()),
572                                dest: Some(dest_resp.clone()),
573                    system: false,
574                            });
575                        }
576                        memcached_resp_buf.clear();
577                        memcached_awaiting = false;
578                    }
579                }
580            } else if protocol == Protocol::Kafka {
581                kafka_resp_buf.extend_from_slice(data);
582                while ocular_protocol::kafka::kafka_frame_complete(&kafka_resp_buf) {
583                    let frame_len = i32::from_be_bytes([kafka_resp_buf[0], kafka_resp_buf[1], kafka_resp_buf[2], kafka_resp_buf[3]]) as usize + 4;
584                    if let Some(req) = pending_r.lock().unwrap().take() {
585                        let latency = req.instant.elapsed();
586                        let response = parse_response(protocol, &kafka_resp_buf[..frame_len]).unwrap_or_default();
587                        let response_detail = format_response_detail(protocol, &kafka_resp_buf[..frame_len]).unwrap_or_else(|| response.clone());
588                        let _ = tx_resp.send(ProxyEvent {
589                            timestamp: req.timestamp,
590                            component: name_resp.clone(),
591                            protocol,
592                            command: req.command,
593                            full_command: req.full_command,
594                            response,
595                            response_detail,
596                            latency,
597                            process: process_info.clone(),
598                            src: Some(src_resp.clone()),
599                            dest: Some(dest_resp.clone()),
600                    system: false,
601                        });
602                    }
603                    kafka_resp_buf = kafka_resp_buf[frame_len..].to_vec();
604                }
605            } else {
606                // Redis/MongoDB: single request/response per read
607                if let Some(req) = pending_r.lock().unwrap().take() {
608                    let latency = req.instant.elapsed();
609                    let response = parse_response(protocol, data).unwrap_or_default();
610                    let response_detail = format_response_detail(protocol, data).unwrap_or_else(|| response.clone());
611                    let _ = tx_resp.send(ProxyEvent {
612                        timestamp: req.timestamp,
613                        component: name_resp.clone(),
614                        protocol,
615                        command: req.command,
616                        full_command: req.full_command,
617                        response,
618                        response_detail,
619                        latency,
620                        process: process_info.clone(),
621                        src: Some(src_resp.clone()),
622                        dest: Some(dest_resp.clone()),
623                    system: false,
624                    });
625                }
626            }
627        }
628        Ok::<_, anyhow::Error>(())
629    };
630
631    tokio::pin!(client_to_server);
632    tokio::pin!(server_to_client);
633
634    tokio::select! {
635        r = &mut client_to_server => {
636            // Client closed write end; give server time to send final response
637            if r.is_ok() && pending_final.lock().unwrap().is_some() {
638                let _ = tokio::time::timeout(
639                    Duration::from_millis(500),
640                    &mut server_to_client,
641                ).await;
642            }
643        },
644        r = &mut server_to_client => r?,
645    }
646    Ok(())
647}
648
649fn strip_mysql_ssl_flag(packet: &mut [u8]) {
650    if packet.len() < 5 { return; }
651    let payload = &mut packet[4..];
652    if payload.is_empty() || payload[0] != 10 { return; }
653    let mut pos = 1;
654    while pos < payload.len() && payload[pos] != 0 { pos += 1; }
655    pos += 1;
656    pos += 4;
657    pos += 8;
658    pos += 1;
659    if pos + 2 > payload.len() { return; }
660    let cap_lower = u16::from_le_bytes([payload[pos], payload[pos + 1]]);
661    let cap_lower_new = cap_lower & !0x0800;
662    payload[pos] = (cap_lower_new & 0xff) as u8;
663    payload[pos + 1] = ((cap_lower_new >> 8) & 0xff) as u8;
664}
665
666/// Resolve which process owns a local TCP port (the client's ephemeral port).
667fn resolve_peer_process(port: u16) -> Option<String> {
668    use std::process::Command;
669    let my_pid = std::process::id().to_string();
670
671    if cfg!(target_os = "macos") {
672        // lsof -i tcp:PORT -sTCP:ESTABLISHED -Fp -Fc
673        // Returns multiple process entries; skip our own PID
674        let output = Command::new("lsof")
675            .args(["-i", &format!("tcp:{}", port), "-sTCP:ESTABLISHED", "-Fp", "-Fc"])
676            .output()
677            .ok()?;
678        let text = String::from_utf8_lossy(&output.stdout);
679        let mut current_pid = String::new();
680        let mut current_cmd = String::new();
681        for line in text.lines() {
682            if let Some(p) = line.strip_prefix('p') {
683                // Save previous entry if it wasn't us
684                if !current_pid.is_empty() && current_pid != my_pid {
685                    return Some(format!("[{}] {}", current_pid, current_cmd));
686                }
687                current_pid = p.to_string();
688                current_cmd.clear();
689            }
690            if let Some(c) = line.strip_prefix('c') {
691                current_cmd = c.to_string();
692            }
693        }
694        // Check last entry
695        if !current_pid.is_empty() && current_pid != my_pid {
696            return Some(format!("[{}] {}", current_pid, current_cmd));
697        }
698        None
699    } else {
700        // Linux: ss -tnp sport = :PORT
701        let output = Command::new("ss")
702            .args(["-tnp", &format!("sport = :{}", port)])
703            .output()
704            .ok()?;
705        let text = String::from_utf8_lossy(&output.stdout);
706        // Parse: users:(("process_name",pid=1234,fd=5))
707        for line in text.lines() {
708            if let Some(start) = line.find("users:((\"") {
709                let rest = &line[start + 9..];
710                if let Some(end) = rest.find('"') {
711                    let proc_name = &rest[..end];
712                    let pid = rest.find("pid=")
713                        .and_then(|i| rest[i+4..].split(|c: char| !c.is_ascii_digit()).next())
714                        .unwrap_or("?");
715                    return Some(format!("[{}] {}", pid, proc_name));
716                }
717            }
718        }
719        None
720    }
721}
722
723/// TLS certificate verifier that accepts any certificate (for proxying to known backends).
724#[derive(Debug)]
725struct NoVerify;
726
727impl rustls::client::danger::ServerCertVerifier for NoVerify {
728    fn verify_server_cert(
729        &self, _: &rustls::pki_types::CertificateDer<'_>, _: &[rustls::pki_types::CertificateDer<'_>],
730        _: &rustls::pki_types::ServerName<'_>, _: &[u8], _: rustls::pki_types::UnixTime,
731    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
732        Ok(rustls::client::danger::ServerCertVerified::assertion())
733    }
734    fn verify_tls12_signature(
735        &self, _: &[u8], _: &rustls::pki_types::CertificateDer<'_>, _: &rustls::DigitallySignedStruct,
736    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
737        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
738    }
739    fn verify_tls13_signature(
740        &self, _: &[u8], _: &rustls::pki_types::CertificateDer<'_>, _: &rustls::DigitallySignedStruct,
741    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
742        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
743    }
744    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
745        vec![
746            rustls::SignatureScheme::RSA_PKCS1_SHA256,
747            rustls::SignatureScheme::RSA_PKCS1_SHA384,
748            rustls::SignatureScheme::RSA_PKCS1_SHA512,
749            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
750            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
751            rustls::SignatureScheme::RSA_PSS_SHA256,
752            rustls::SignatureScheme::RSA_PSS_SHA384,
753            rustls::SignatureScheme::RSA_PSS_SHA512,
754            rustls::SignatureScheme::ED25519,
755        ]
756    }
757}
758
759#[cfg(test)]
760mod tests {
761    use super::*;
762
763    #[test]
764    fn test_strip_mysql_ssl_flag_short_packet() {
765        let mut buf = vec![0u8; 3];
766        strip_mysql_ssl_flag(&mut buf);
767        assert_eq!(buf, vec![0u8; 3]);
768    }
769
770    #[test]
771    fn test_strip_mysql_ssl_flag_not_greeting() {
772        let mut buf = vec![0u8; 10];
773        buf[4] = 9;
774        strip_mysql_ssl_flag(&mut buf);
775        assert_eq!(buf[4], 9);
776    }
777
778    /// Find the capability flags offset in a MySQL greeting packet
779    fn caps_offset(pkt: &[u8]) -> Option<usize> {
780        if pkt.len() < 5 { return None; }
781        let mut pos = 5;
782        // Skip null-terminated server version
783        while pos < pkt.len() && pkt[pos] != 0 { pos += 1; }
784        pos += 1; // null
785        if pos + 13 > pkt.len() { return None; }
786        pos += 4; // thread id
787        pos += 8; // salt part 1
788        pos += 1; // filler
789        Some(pos)
790    }
791
792    #[test]
793    fn test_strip_mysql_ssl_flag_clears_ssl_bit() {
794        let version = b"5.7.0\0";
795        let mut payload = vec![10]; // protocol version
796        payload.extend_from_slice(version);
797        payload.extend_from_slice(&[0u8; 4]); // thread id
798        payload.extend_from_slice(&[0u8; 8]); // salt part 1
799        payload.push(0); // filler
800        let caps: u16 = 0x0800; // SSL flag set
801        payload.extend_from_slice(&caps.to_le_bytes());
802        payload.extend_from_slice(&[0u8; 13]);
803
804        let pkt_len = payload.len();
805        let mut pkt = vec![
806            (pkt_len & 0xff) as u8,
807            ((pkt_len >> 8) & 0xff) as u8,
808            ((pkt_len >> 16) & 0xff) as u8,
809            0,
810        ];
811        pkt.extend_from_slice(&payload);
812
813        let off = caps_offset(&pkt).unwrap();
814        assert!(u16::from_le_bytes([pkt[off], pkt[off + 1]]) & 0x0800 != 0);
815
816        strip_mysql_ssl_flag(&mut pkt);
817
818        assert_eq!(u16::from_le_bytes([pkt[off], pkt[off + 1]]) & 0x0800, 0);
819    }
820
821    #[test]
822    fn test_resolve_peer_process_does_not_panic() {
823        let result = std::panic::catch_unwind(|| resolve_peer_process(0));
824        assert!(result.is_ok());
825    }
826}