Skip to main content

ocular_proxy/
lib.rs

1use anyhow::Result;
2use ocular_protocol::{Protocol, mysql::mysql_response_complete, postgres::postgres_response_complete, parse_request, parse_response, extract_full_command, format_response_detail, parse_amqp_frame, parse_amqp_request_full, is_async_method, amqp_frame_len};
3use std::sync::{Arc, Mutex};
4use std::time::{Duration, Instant, SystemTime};
5use std::sync::atomic::{AtomicUsize, Ordering};
6use tokio::io::{AsyncReadExt, AsyncWriteExt, AsyncRead, AsyncWrite};
7use tokio::net::{TcpListener, TcpStream};
8use tokio::sync::broadcast;
9use tracing::{info, warn, error, debug};
10
11pub use ocular_protocol::ProxyEvent;
12
13/// Connection state for a proxy component, shared between proxy and TUI
14#[derive(Clone, Default)]
15pub struct ConnectionState {
16    pub active_connections: usize,
17    pub has_connector: bool,
18    pub last_error: Option<String>,
19    pub last_active_at: Option<SystemTime>,
20}
21
22/// Shared map from component name to connection state
23pub type StatusMap = Arc<Mutex<std::collections::HashMap<String, ConnectionState>>>;
24
25/// Pending request info
26struct PendingRequest {
27    timestamp: SystemTime,
28    instant: Instant,
29    command: String,
30    full_command: String,
31}
32
33pub async fn run_proxy(
34    listen_addr: String,
35    remote_addr: String,
36    name: String,
37    protocol: Protocol,
38    tx: broadcast::Sender<ProxyEvent>,
39    mut shutdown: tokio::sync::watch::Receiver<bool>,
40    status: StatusMap,
41) -> Result<()> {
42    let listener = match TcpListener::bind(&listen_addr).await {
43        Ok(l) => l,
44        Err(e) => {
45            let msg = format!("bind failed on {}: {}", listen_addr, e);
46            let _ = tx.send(ProxyEvent::system_event(&name, msg));
47            status.lock().unwrap().entry(name.clone()).or_default().last_error = Some(format!("bind failed: {}", e));
48            return Err(e.into());
49        }
50    };
51    let conn_count = Arc::new(AtomicUsize::new(0));
52    {
53        let mut map = status.lock().unwrap();
54        map.entry(name.clone()).or_default().has_connector = true;
55    }
56    info!(component = %name, listen = %listen_addr, remote = %remote_addr, ?protocol, "proxy listening");
57
58    loop {
59        tokio::select! {
60            result = listener.accept() => {
61                let (client, peer) = result?;
62                debug!(component = %name, peer = %peer, "new client connection");
63                let remote = remote_addr.clone();
64                let name = name.clone();
65                let tx = tx.clone();
66                let process = resolve_peer_process(peer.port());
67                let peer_addr = peer.to_string();
68                let remote_for_conn = remote.clone();
69                let conn_count = conn_count.clone();
70                let status = status.clone();
71                let protocol_for_conn = protocol;
72                tokio::spawn(async move {
73                    conn_count.fetch_add(1, Ordering::Relaxed);
74                    {
75                        let mut map = status.lock().unwrap();
76                        let s = map.entry(name.clone()).or_default();
77                        s.active_connections = conn_count.load(Ordering::Relaxed);
78                        s.last_active_at = Some(SystemTime::now());
79                    }
80                    if let Err(e) = handle_conn(client, &remote, &name, protocol_for_conn, &tx, process, peer_addr, remote_for_conn).await {
81                        warn!(component = %name, remote = %remote, error = %e, "connection ended with error");
82                        let _ = tx.send(ProxyEvent::system_event(&name, format!("connection error: {}", e)));
83                        status.lock().unwrap().entry(name.clone()).or_default().last_error = Some(e.to_string());
84                    }
85                    let remaining = conn_count.fetch_sub(1, Ordering::Relaxed).saturating_sub(1);
86                    status.lock().unwrap().entry(name.clone()).or_default().active_connections = remaining;
87                });
88            }
89            _ = shutdown.changed() => {
90                info!(component = %name, "proxy shutting down");
91                break;
92            }
93        }
94    }
95    Ok(())
96}
97
98#[allow(clippy::too_many_arguments)]
99async fn handle_conn(
100    mut client: TcpStream,
101    remote_addr: &str,
102    name: &str,
103    protocol: Protocol,
104    tx: &broadcast::Sender<ProxyEvent>,
105    process: Option<String>,
106    src: String,
107    dest: String,
108) -> Result<()> {
109    // Parse remote address: detect https:// for TLS outbound
110    let (actual_addr, use_tls, tls_host) = if remote_addr.starts_with("https://") {
111        let stripped = remote_addr.strip_prefix("https://").unwrap();
112        let host = stripped.split(':').next().unwrap_or(stripped).to_string();
113        (stripped.to_string(), true, host)
114    } else {
115        let stripped = remote_addr.strip_prefix("http://").unwrap_or(remote_addr);
116        (stripped.to_string(), false, String::new())
117    };
118
119    let tcp_stream = match TcpStream::connect(&actual_addr).await {
120        Ok(s) => {
121            debug!(component = %name, remote = %actual_addr, "connected to remote");
122            s
123        }
124        Err(e) => {
125            error!(component = %name, remote = %actual_addr, error = %e,
126                "failed to connect to remote — is the service running?");
127            let _ = tx.send(ProxyEvent::system_event(name, format!("cannot reach {} ({})", actual_addr, e)));
128            if protocol == Protocol::Redis {
129                let err_msg = format!("-ERR ocular proxy: cannot reach {} ({})\r\n", actual_addr, e);
130                let _ = client.write_all(err_msg.as_bytes()).await;
131            }
132            return Err(e.into());
133        }
134    };
135
136    let (sr, sw): (Box<dyn AsyncRead + Unpin + Send>, Box<dyn AsyncWrite + Unpin + Send>) = if use_tls {
137        let config = rustls::ClientConfig::builder()
138            .dangerous()
139            .with_custom_certificate_verifier(Arc::new(NoVerify))
140            .with_no_client_auth();
141        let connector = tokio_rustls::TlsConnector::from(Arc::new(config));
142        let domain = rustls::pki_types::ServerName::try_from(tls_host)
143            .map_err(|e| anyhow::anyhow!("invalid TLS hostname: {}", e))?;
144        let tls_stream = connector.connect(domain, tcp_stream).await?;
145        let (r, w) = tokio::io::split(tls_stream);
146        (Box::new(r) as Box<dyn AsyncRead + Unpin + Send>, Box::new(w) as Box<dyn AsyncWrite + Unpin + Send>)
147    } else {
148        let (r, w) = tokio::io::split(tcp_stream);
149        (Box::new(r) as Box<dyn AsyncRead + Unpin + Send>, Box::new(w) as Box<dyn AsyncWrite + Unpin + Send>)
150    };
151
152    let mut sr = sr;
153    let mut sw = sw;
154
155    // For MySQL: strip SSL from greeting
156    if protocol == Protocol::Mysql {
157        let mut greeting_buf = [0u8; 65536];
158        let n = sr.read(&mut greeting_buf).await?;
159        if n == 0 { return Ok(()); }
160        let mut greeting = greeting_buf[..n].to_vec();
161        strip_mysql_ssl_flag(&mut greeting);
162        client.write_all(&greeting).await?;
163        debug!(component = %name, "forwarded MySQL greeting with SSL stripped");
164    }
165
166    // For PostgreSQL: strip SSL by forwarding negotiation to server but replying N to client.
167    // This lets the server know the connection won't be encrypted (may affect auth requirements).
168    if protocol == Protocol::Postgres {
169        let mut buf = [0u8; 256];
170        let n = client.read(&mut buf).await?;
171        if n == 0 { return Ok(()); }
172        let data = &buf[..n];
173        let neg_code = if n >= 8 {
174            u32::from_be_bytes([data[4], data[5], data[6], data[7]])
175        } else { 0 };
176        if neg_code == 80877103 || neg_code == 80877104 {
177            // Forward negotiation to server so it knows the connection state
178            sw.write_all(data).await?;
179            // Read server's response (single byte: N or S)
180            let mut resp = [0u8; 1];
181            let rn = sr.read(&mut resp).await?;
182            if rn == 0 { return Ok(()); }
183            // Always tell client: no SSL/GSS (force plaintext for proxy to parse)
184            client.write_all(b"N").await?;
185        } else {
186            // Not a negotiation request, forward as Startup
187            sw.write_all(data).await?;
188        }
189    }
190
191    let (mut cr, mut cw) = client.split();
192
193    let pending: Arc<Mutex<Option<PendingRequest>>> = Arc::new(Mutex::new(None));
194
195    let name_req = name.to_string();
196    let name_resp = name.to_string();
197    let tx_req = tx.clone();
198    let tx_resp = tx.clone();
199    let pending_w = pending.clone();
200    let pending_final = pending.clone();
201    let pending_r = pending;
202    let process_info = process;
203
204    let process_req = process_info.clone();
205    let src_req = src.clone();
206    let dest_req = dest.clone();
207    let src_resp = src.clone();
208    let dest_resp = dest;
209    let client_to_server = async move {
210        let mut buf = [0u8; 65536];
211        let mut http_req_buf: Vec<u8> = Vec::with_capacity(4096);
212        let mut memcached_req_buf: Vec<u8> = Vec::with_capacity(4096);
213        let mut kafka_req_buf: Vec<u8> = Vec::with_capacity(4096);
214        loop {
215            let n = cr.read(&mut buf).await?;
216            if n == 0 { break; }
217            let data = &buf[..n];
218
219            if protocol == Protocol::Amqp {
220                // AMQP: loop through all frames in this read
221                let mut pos = 0;
222                while pos < data.len() {
223                    let frame_data = &data[pos..];
224                    let Some(flen) = amqp_frame_len(frame_data) else { break };
225                    if let Some(frame) = parse_amqp_frame(frame_data) {
226                        // Skip heartbeat — not a real request
227                        if frame.frame_type == 8 {
228                            pos += flen;
229                            continue;
230                        }
231                        if let Some(ref method) = frame.method {
232                            if is_async_method(method.class_id, method.method_id) {
233                                let (summary, detail) = parse_amqp_request_full(frame_data)
234                                    .unwrap_or_else(|| (method.summary.clone(), method.detail.clone()));
235                                let _ = tx_req.send(ProxyEvent {
236                                    timestamp: SystemTime::now(),
237                                    component: name_req.clone(),
238                                    protocol,
239                                    command: summary,
240                                    full_command: detail.clone(),
241                                    response: String::new(),
242                                    response_detail: detail,
243                                    latency: std::time::Duration::ZERO,
244                                    process: process_req.clone(),
245                                    src: Some(src_req.clone()),
246                                    dest: Some(dest_req.clone()),
247                    system: false,
248                                });
249                            } else {
250                                debug!(component = %name_req, command = %method.summary);
251                                *pending_w.lock().unwrap() = Some(PendingRequest {
252                                    timestamp: SystemTime::now(),
253                                    instant: Instant::now(),
254                                    command: method.summary.clone(),
255                                    full_command: method.detail.clone(),
256                                });
257                            }
258                        }
259                    }
260                    pos += flen;
261                }
262            } else if protocol == Protocol::Http {
263                http_req_buf.extend_from_slice(data);
264                if ocular_protocol::http::http_request_complete(&http_req_buf) {
265                    if let Some(command) = parse_request(protocol, &http_req_buf) {
266                        let full_command = extract_full_command(protocol, &http_req_buf).unwrap_or_else(|| command.clone());
267                        *pending_w.lock().unwrap() = Some(PendingRequest {
268                            timestamp: SystemTime::now(),
269                            instant: Instant::now(),
270                            command,
271                            full_command,
272                        });
273                    }
274                    http_req_buf.clear();
275                }
276            } else if protocol == Protocol::Memcached {
277                memcached_req_buf.extend_from_slice(data);
278                while ocular_protocol::memcached::memcached_request_complete(&memcached_req_buf) {
279                    // If there's already a pending request that won't get a response pairing,
280                    // emit it as a standalone event
281                    if let Some(prev) = pending_w.lock().unwrap().take() {
282                        let _ = tx_req.send(ProxyEvent {
283                            timestamp: prev.timestamp,
284                            component: name_req.clone(),
285                            protocol,
286                            command: prev.command,
287                            full_command: prev.full_command,
288                            response: String::new(),
289                            response_detail: String::new(),
290                            latency: Duration::ZERO,
291                            process: process_req.clone(),
292                            src: Some(src_req.clone()),
293                            dest: Some(dest_req.clone()),
294                            system: false,
295                        });
296                    }
297                    if let Some(command) = parse_request(protocol, &memcached_req_buf) {
298                        let full_command = extract_full_command(protocol, &memcached_req_buf).unwrap_or_else(|| command.clone());
299                        *pending_w.lock().unwrap() = Some(PendingRequest {
300                            timestamp: SystemTime::now(),
301                            instant: Instant::now(),
302                            command,
303                            full_command,
304                        });
305                    }
306                    // Advance past this request
307                    let s = std::str::from_utf8(&memcached_req_buf).unwrap_or("");
308                    let first_crlf = s.find("\r\n").unwrap_or(0);
309                    let line = &s[..first_crlf];
310                    let parts: Vec<&str> = line.split_whitespace().collect();
311                    let cmd = parts.first().map(|c| c.to_uppercase()).unwrap_or_default();
312                    let consumed = match cmd.as_str() {
313                        "SET" | "ADD" | "REPLACE" | "APPEND" | "PREPEND" | "CAS" => {
314                            let bytes: usize = parts.get(4).and_then(|b| b.parse().ok()).unwrap_or(0);
315                            first_crlf + 2 + bytes + 2
316                        }
317                        _ => first_crlf + 2,
318                    };
319                    memcached_req_buf = memcached_req_buf[consumed..].to_vec();
320                }
321            } else if protocol == Protocol::Kafka {
322                kafka_req_buf.extend_from_slice(data);
323                while ocular_protocol::kafka::kafka_frame_complete(&kafka_req_buf) {
324                    let frame_len = i32::from_be_bytes([kafka_req_buf[0], kafka_req_buf[1], kafka_req_buf[2], kafka_req_buf[3]]) as usize + 4;
325                    let frame = &kafka_req_buf[..frame_len];
326                    if let Some(command) = parse_request(protocol, frame) {
327                        let full_command = extract_full_command(protocol, frame).unwrap_or_else(|| command.clone());
328                        *pending_w.lock().unwrap() = Some(PendingRequest {
329                            timestamp: SystemTime::now(),
330                            instant: Instant::now(),
331                            command,
332                            full_command,
333                        });
334                    }
335                    kafka_req_buf = kafka_req_buf[frame_len..].to_vec();
336                }
337            } else if protocol == Protocol::Postgres {
338                // Postgres: scan all messages in this read, keep SQL from Q/P only
339                let mut pos = 0;
340                while pos < data.len() {
341                    let first = data[pos];
342                    let is_typed = matches!(first, b'Q' | b'P' | b'B' | b'E' | b'D' | b'S' | b'X' | b'C' | b'p' | b'H' | b'F' | b'd' | b'c' | b'f');
343                    if !is_typed { break; }
344                    if pos + 5 > data.len() { break; }
345                    let len = u32::from_be_bytes([data[pos+1], data[pos+2], data[pos+3], data[pos+4]]) as usize;
346                    let end = pos + 1 + len;
347                    if end > data.len() { break; }
348                    // Only set pending for Q (simple query) or P (Parse with SQL)
349                    if first == b'Q' || first == b'P' {
350                        let msg = &data[pos..end];
351                        if let Some(command) = parse_request(protocol, msg) {
352                            let full_command = extract_full_command(protocol, msg).unwrap_or_else(|| command.clone());
353                            *pending_w.lock().unwrap() = Some(PendingRequest {
354                                timestamp: SystemTime::now(),
355                                instant: Instant::now(),
356                                command,
357                                full_command,
358                            });
359                        }
360                    }
361                    pos = end;
362                }
363            } else if let Some(command) = parse_request(protocol, data) {
364                let full_command = extract_full_command(protocol, data).unwrap_or_else(|| command.clone());
365                debug!(component = %name_req, %command);
366                *pending_w.lock().unwrap() = Some(PendingRequest {
367                    timestamp: SystemTime::now(),
368                    instant: Instant::now(),
369                    command,
370                    full_command,
371                });
372            }
373
374            sw.write_all(data).await?;
375        }
376        Ok::<_, anyhow::Error>(())
377    };
378
379    let process_mysql = process_info.clone();
380    let server_to_client = async move {
381        let mut buf = [0u8; 65536];
382        let mut mysql_buf: Vec<u8> = Vec::with_capacity(4096);
383        let mut http_resp_buf: Vec<u8> = Vec::with_capacity(4096);
384        let mut memcached_resp_buf: Vec<u8> = Vec::with_capacity(4096);
385        let mut kafka_resp_buf: Vec<u8> = Vec::with_capacity(4096);
386        let mut pg_resp_buf: Vec<u8> = Vec::with_capacity(4096);
387        let mut awaiting_response = false;
388        let mut memcached_awaiting = false;
389        loop {
390            let n = sr.read(&mut buf).await?;
391            if n == 0 { break; }
392            let data = &buf[..n];
393            cw.write_all(data).await?;
394
395            if protocol == Protocol::Mysql {
396                let has_pending = pending_r.lock().unwrap().is_some();
397                if has_pending || awaiting_response {
398                    awaiting_response = true;
399                    mysql_buf.extend_from_slice(data);
400                    if mysql_response_complete(&mysql_buf) {
401                        if let Some(req) = pending_r.lock().unwrap().take() {
402                            let latency = req.instant.elapsed();
403                            let response = parse_response(protocol, &mysql_buf).unwrap_or_default();
404                            let response_detail = format_response_detail(protocol, &mysql_buf).unwrap_or_default();
405                            let _ = tx_resp.send(ProxyEvent {
406                                timestamp: req.timestamp,
407                                component: name_resp.clone(),
408                                protocol,
409                                command: req.command,
410                                full_command: req.full_command,
411                                response,
412                                response_detail,
413                                latency,
414                                process: process_mysql.clone(),
415                                src: Some(src_resp.clone()),
416                                dest: Some(dest_resp.clone()),
417                    system: false,
418                            });
419                        }
420                        mysql_buf.clear();
421                        awaiting_response = false;
422                    }
423                }
424            } else if protocol == Protocol::Http {
425                http_resp_buf.extend_from_slice(data);
426                if ocular_protocol::http::http_response_complete(&http_resp_buf) {
427                    if let Some(req) = pending_r.lock().unwrap().take() {
428                        let latency = req.instant.elapsed();
429                        let response = parse_response(protocol, &http_resp_buf).unwrap_or_default();
430                        let response_detail = format_response_detail(protocol, &http_resp_buf).unwrap_or_else(|| response.clone());
431                        let _ = tx_resp.send(ProxyEvent {
432                            timestamp: req.timestamp,
433                            component: name_resp.clone(),
434                            protocol,
435                            command: req.command,
436                            full_command: req.full_command,
437                            response,
438                            response_detail,
439                            latency,
440                            process: process_info.clone(),
441                                    src: Some(src_resp.clone()),
442                                    dest: Some(dest_resp.clone()),
443                    system: false,
444                        });
445                    }
446                    http_resp_buf.clear();
447                }
448            } else if protocol == Protocol::Amqp {
449                // AMQP: loop through all server frames
450                let mut pos = 0;
451                while pos < data.len() {
452                    let frame_data = &data[pos..];
453                    let Some(flen) = amqp_frame_len(frame_data) else { break };
454                    if let Some(frame) = parse_amqp_frame(frame_data) {
455                        // Skip content header and body frames — handled below with method
456                        if frame.frame_type == 2 || frame.frame_type == 3 {
457                            pos += flen;
458                            continue;
459                        }
460                        // Heartbeat: skip
461                        if frame.frame_type == 8 {
462                            pos += flen;
463                            continue;
464                        }
465                    }
466
467                    // Extract body from subsequent Header+Body frames
468                    let mut body_text = String::new();
469                    let mut peek = pos + flen;
470                    while peek < data.len() {
471                        let peek_data = &data[peek..];
472                        let Some(plen) = amqp_frame_len(peek_data) else { break };
473                        if let Some(pf) = parse_amqp_frame(peek_data) {
474                            if pf.frame_type == 2 {
475                                // Header frame, skip
476                            } else if pf.frame_type == 3 {
477                                // Body frame
478                                if let Some(body) = &pf.body {
479                                    body_text = String::from_utf8_lossy(body).to_string();
480                                }
481                            } else {
482                                break;
483                            }
484                        } else {
485                            break;
486                        }
487                        peek += plen;
488                    }
489
490                    if let Some(req) = pending_r.lock().unwrap().take() {
491                        let latency = req.instant.elapsed();
492                        let mut response = parse_response(protocol, frame_data).unwrap_or_default();
493                        let mut response_detail = format_response_detail(protocol, frame_data).unwrap_or_else(|| response.clone());
494                        if !body_text.is_empty() {
495                            response = format!("{} | {}", response, body_text);
496                            response_detail = format!("{}\nBody: {}", response_detail, body_text);
497                        }
498                        let _ = tx_resp.send(ProxyEvent {
499                            timestamp: req.timestamp,
500                            component: name_resp.clone(),
501                            protocol,
502                            command: req.command,
503                            full_command: req.full_command,
504                            response,
505                            response_detail,
506                            latency,
507                            process: process_info.clone(),
508                                    src: Some(src_resp.clone()),
509                                    dest: Some(dest_resp.clone()),
510                    system: false,
511                        });
512                    } else if let Some(frame) = parse_amqp_frame(frame_data) {
513                        // Server-initiated method (e.g. Basic.Deliver) — emit as standalone
514                        if let Some(ref method) = frame.method {
515                            let response = if body_text.is_empty() { String::new() } else { body_text.clone() };
516                            let response_detail = if body_text.is_empty() { String::new() } else { body_text.clone() };
517                            let command = method.summary.clone();
518                            let _ = tx_resp.send(ProxyEvent {
519                                timestamp: SystemTime::now(),
520                                component: name_resp.clone(),
521                                protocol,
522                                command,
523                                full_command: method.detail.clone(),
524                                response,
525                                response_detail,
526                                latency: std::time::Duration::ZERO,
527                                process: process_info.clone(),
528                                    src: Some(dest_resp.clone()),
529                                    dest: Some(src_resp.clone()),
530                    system: false,
531                            });
532                        }
533                    }
534                    // Advance past the method frame + any header/body frames we consumed
535                    pos = peek;
536                }
537            } else if protocol == Protocol::Postgres {
538                // Buffer until ReadyForQuery, then emit single event
539                pg_resp_buf.extend_from_slice(data);
540                if postgres_response_complete(&pg_resp_buf) {
541                    if let Some(req) = pending_r.lock().unwrap().take() {
542                        let latency = req.instant.elapsed();
543                        let response = parse_response(protocol, &pg_resp_buf).unwrap_or_default();
544                        let response_detail = format_response_detail(protocol, &pg_resp_buf).unwrap_or_else(|| response.clone());
545                        let _ = tx_resp.send(ProxyEvent {
546                            timestamp: req.timestamp,
547                            component: name_resp.clone(),
548                            protocol,
549                            command: req.command,
550                            full_command: req.full_command,
551                            response,
552                            response_detail,
553                            latency,
554                            process: process_info.clone(),
555                            src: Some(src_resp.clone()),
556                            dest: Some(dest_resp.clone()),
557                            system: false,
558                        });
559                    }
560                    pg_resp_buf.clear();
561                }
562            } else if protocol == Protocol::Memcached {
563                let has_pending = pending_r.lock().unwrap().is_some();
564                if has_pending || memcached_awaiting {
565                    memcached_awaiting = true;
566                    memcached_resp_buf.extend_from_slice(data);
567                    if ocular_protocol::memcached::memcached_response_complete(&memcached_resp_buf) {
568                        if let Some(req) = pending_r.lock().unwrap().take() {
569                            let latency = req.instant.elapsed();
570                            let response = parse_response(protocol, &memcached_resp_buf).unwrap_or_default();
571                            let response_detail = format_response_detail(protocol, &memcached_resp_buf).unwrap_or_else(|| response.clone());
572                            let _ = tx_resp.send(ProxyEvent {
573                                timestamp: req.timestamp,
574                                component: name_resp.clone(),
575                                protocol,
576                                command: req.command,
577                                full_command: req.full_command,
578                                response,
579                                response_detail,
580                                latency,
581                                process: process_info.clone(),
582                                src: Some(src_resp.clone()),
583                                dest: Some(dest_resp.clone()),
584                    system: false,
585                            });
586                        }
587                        memcached_resp_buf.clear();
588                        memcached_awaiting = false;
589                    }
590                }
591            } else if protocol == Protocol::Kafka {
592                kafka_resp_buf.extend_from_slice(data);
593                while ocular_protocol::kafka::kafka_frame_complete(&kafka_resp_buf) {
594                    let frame_len = i32::from_be_bytes([kafka_resp_buf[0], kafka_resp_buf[1], kafka_resp_buf[2], kafka_resp_buf[3]]) as usize + 4;
595                    if let Some(req) = pending_r.lock().unwrap().take() {
596                        let latency = req.instant.elapsed();
597                        let response = parse_response(protocol, &kafka_resp_buf[..frame_len]).unwrap_or_default();
598                        let response_detail = format_response_detail(protocol, &kafka_resp_buf[..frame_len]).unwrap_or_else(|| response.clone());
599                        let _ = tx_resp.send(ProxyEvent {
600                            timestamp: req.timestamp,
601                            component: name_resp.clone(),
602                            protocol,
603                            command: req.command,
604                            full_command: req.full_command,
605                            response,
606                            response_detail,
607                            latency,
608                            process: process_info.clone(),
609                            src: Some(src_resp.clone()),
610                            dest: Some(dest_resp.clone()),
611                    system: false,
612                        });
613                    }
614                    kafka_resp_buf = kafka_resp_buf[frame_len..].to_vec();
615                }
616            } else {
617                // Redis/MongoDB: single request/response per read
618                if let Some(req) = pending_r.lock().unwrap().take() {
619                    let latency = req.instant.elapsed();
620                    let response = parse_response(protocol, data).unwrap_or_default();
621                    let response_detail = format_response_detail(protocol, data).unwrap_or_else(|| response.clone());
622                    let _ = tx_resp.send(ProxyEvent {
623                        timestamp: req.timestamp,
624                        component: name_resp.clone(),
625                        protocol,
626                        command: req.command,
627                        full_command: req.full_command,
628                        response,
629                        response_detail,
630                        latency,
631                        process: process_info.clone(),
632                        src: Some(src_resp.clone()),
633                        dest: Some(dest_resp.clone()),
634                    system: false,
635                    });
636                }
637            }
638        }
639        Ok::<_, anyhow::Error>(())
640    };
641
642    tokio::pin!(client_to_server);
643    tokio::pin!(server_to_client);
644
645    tokio::select! {
646        r = &mut client_to_server => {
647            // Client closed write end; give server time to send final response
648            if r.is_ok() && pending_final.lock().unwrap().is_some() {
649                let _ = tokio::time::timeout(
650                    Duration::from_millis(500),
651                    &mut server_to_client,
652                ).await;
653            }
654        },
655        r = &mut server_to_client => r?,
656    }
657    Ok(())
658}
659
660fn strip_mysql_ssl_flag(packet: &mut [u8]) {
661    if packet.len() < 5 { return; }
662    let payload = &mut packet[4..];
663    if payload.is_empty() || payload[0] != 10 { return; }
664    let mut pos = 1;
665    while pos < payload.len() && payload[pos] != 0 { pos += 1; }
666    pos += 1;
667    pos += 4;
668    pos += 8;
669    pos += 1;
670    if pos + 2 > payload.len() { return; }
671    let cap_lower = u16::from_le_bytes([payload[pos], payload[pos + 1]]);
672    let cap_lower_new = cap_lower & !0x0800;
673    payload[pos] = (cap_lower_new & 0xff) as u8;
674    payload[pos + 1] = ((cap_lower_new >> 8) & 0xff) as u8;
675}
676
677/// Resolve which process owns a local TCP port (the client's ephemeral port).
678fn resolve_peer_process(port: u16) -> Option<String> {
679    use std::process::Command;
680    let my_pid = std::process::id().to_string();
681
682    if cfg!(target_os = "macos") {
683        // lsof -i tcp:PORT -sTCP:ESTABLISHED -Fp -Fc
684        // Returns multiple process entries; skip our own PID
685        let output = Command::new("lsof")
686            .args(["-i", &format!("tcp:{}", port), "-sTCP:ESTABLISHED", "-Fp", "-Fc"])
687            .output()
688            .ok()?;
689        let text = String::from_utf8_lossy(&output.stdout);
690        let mut current_pid = String::new();
691        let mut current_cmd = String::new();
692        for line in text.lines() {
693            if let Some(p) = line.strip_prefix('p') {
694                // Save previous entry if it wasn't us
695                if !current_pid.is_empty() && current_pid != my_pid {
696                    return Some(format!("[{}] {}", current_pid, current_cmd));
697                }
698                current_pid = p.to_string();
699                current_cmd.clear();
700            }
701            if let Some(c) = line.strip_prefix('c') {
702                current_cmd = c.to_string();
703            }
704        }
705        // Check last entry
706        if !current_pid.is_empty() && current_pid != my_pid {
707            return Some(format!("[{}] {}", current_pid, current_cmd));
708        }
709        None
710    } else {
711        // Linux: ss -tnp sport = :PORT
712        let output = Command::new("ss")
713            .args(["-tnp", &format!("sport = :{}", port)])
714            .output()
715            .ok()?;
716        let text = String::from_utf8_lossy(&output.stdout);
717        // Parse: users:(("process_name",pid=1234,fd=5))
718        for line in text.lines() {
719            if let Some(start) = line.find("users:((\"") {
720                let rest = &line[start + 9..];
721                if let Some(end) = rest.find('"') {
722                    let proc_name = &rest[..end];
723                    let pid = rest.find("pid=")
724                        .and_then(|i| rest[i+4..].split(|c: char| !c.is_ascii_digit()).next())
725                        .unwrap_or("?");
726                    return Some(format!("[{}] {}", pid, proc_name));
727                }
728            }
729        }
730        None
731    }
732}
733
734/// TLS certificate verifier that accepts any certificate (for proxying to known backends).
735#[derive(Debug)]
736struct NoVerify;
737
738impl rustls::client::danger::ServerCertVerifier for NoVerify {
739    fn verify_server_cert(
740        &self, _: &rustls::pki_types::CertificateDer<'_>, _: &[rustls::pki_types::CertificateDer<'_>],
741        _: &rustls::pki_types::ServerName<'_>, _: &[u8], _: rustls::pki_types::UnixTime,
742    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
743        Ok(rustls::client::danger::ServerCertVerified::assertion())
744    }
745    fn verify_tls12_signature(
746        &self, _: &[u8], _: &rustls::pki_types::CertificateDer<'_>, _: &rustls::DigitallySignedStruct,
747    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
748        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
749    }
750    fn verify_tls13_signature(
751        &self, _: &[u8], _: &rustls::pki_types::CertificateDer<'_>, _: &rustls::DigitallySignedStruct,
752    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
753        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
754    }
755    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
756        vec![
757            rustls::SignatureScheme::RSA_PKCS1_SHA256,
758            rustls::SignatureScheme::RSA_PKCS1_SHA384,
759            rustls::SignatureScheme::RSA_PKCS1_SHA512,
760            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
761            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
762            rustls::SignatureScheme::RSA_PSS_SHA256,
763            rustls::SignatureScheme::RSA_PSS_SHA384,
764            rustls::SignatureScheme::RSA_PSS_SHA512,
765            rustls::SignatureScheme::ED25519,
766        ]
767    }
768}
769
770#[cfg(test)]
771mod tests {
772    use super::*;
773
774    #[test]
775    fn test_strip_mysql_ssl_flag_short_packet() {
776        let mut buf = vec![0u8; 3];
777        strip_mysql_ssl_flag(&mut buf);
778        assert_eq!(buf, vec![0u8; 3]);
779    }
780
781    #[test]
782    fn test_strip_mysql_ssl_flag_not_greeting() {
783        let mut buf = vec![0u8; 10];
784        buf[4] = 9;
785        strip_mysql_ssl_flag(&mut buf);
786        assert_eq!(buf[4], 9);
787    }
788
789    /// Find the capability flags offset in a MySQL greeting packet
790    fn caps_offset(pkt: &[u8]) -> Option<usize> {
791        if pkt.len() < 5 { return None; }
792        let mut pos = 5;
793        // Skip null-terminated server version
794        while pos < pkt.len() && pkt[pos] != 0 { pos += 1; }
795        pos += 1; // null
796        if pos + 13 > pkt.len() { return None; }
797        pos += 4; // thread id
798        pos += 8; // salt part 1
799        pos += 1; // filler
800        Some(pos)
801    }
802
803    #[test]
804    fn test_strip_mysql_ssl_flag_clears_ssl_bit() {
805        let version = b"5.7.0\0";
806        let mut payload = vec![10]; // protocol version
807        payload.extend_from_slice(version);
808        payload.extend_from_slice(&[0u8; 4]); // thread id
809        payload.extend_from_slice(&[0u8; 8]); // salt part 1
810        payload.push(0); // filler
811        let caps: u16 = 0x0800; // SSL flag set
812        payload.extend_from_slice(&caps.to_le_bytes());
813        payload.extend_from_slice(&[0u8; 13]);
814
815        let pkt_len = payload.len();
816        let mut pkt = vec![
817            (pkt_len & 0xff) as u8,
818            ((pkt_len >> 8) & 0xff) as u8,
819            ((pkt_len >> 16) & 0xff) as u8,
820            0,
821        ];
822        pkt.extend_from_slice(&payload);
823
824        let off = caps_offset(&pkt).unwrap();
825        assert!(u16::from_le_bytes([pkt[off], pkt[off + 1]]) & 0x0800 != 0);
826
827        strip_mysql_ssl_flag(&mut pkt);
828
829        assert_eq!(u16::from_le_bytes([pkt[off], pkt[off + 1]]) & 0x0800, 0);
830    }
831
832    #[test]
833    fn test_resolve_peer_process_does_not_panic() {
834        let result = std::panic::catch_unwind(|| resolve_peer_process(0));
835        assert!(result.is_ok());
836    }
837}