Skip to main content

ocular_proxy/
lib.rs

1use anyhow::Result;
2use ocular_protocol::{Protocol, parse_request, parse_response, extract_full_command, format_response_detail, parse_amqp_frame, parse_amqp_request_full, is_async_method, amqp_frame_len};
3use std::sync::Arc;
4use std::time::{Instant, SystemTime};
5use tokio::io::{AsyncReadExt, AsyncWriteExt, AsyncRead, AsyncWrite};
6use tokio::net::{TcpListener, TcpStream};
7use tokio::sync::{broadcast, Mutex};
8use tracing::{info, warn, error, debug};
9
10pub use ocular_protocol::ProxyEvent;
11
12/// Pending request info
13struct PendingRequest {
14    timestamp: SystemTime,
15    instant: Instant,
16    command: String,
17    full_command: String,
18}
19
20pub async fn run_proxy(
21    listen_addr: String,
22    remote_addr: String,
23    name: String,
24    protocol: Protocol,
25    tx: broadcast::Sender<ProxyEvent>,
26) -> Result<()> {
27    let listener = TcpListener::bind(&listen_addr).await?;
28    info!(component = %name, listen = %listen_addr, remote = %remote_addr, ?protocol, "proxy listening");
29
30    loop {
31        let (client, peer) = listener.accept().await?;
32        debug!(component = %name, peer = %peer, "new client connection");
33        let remote = remote_addr.clone();
34        let name = name.clone();
35        let tx = tx.clone();
36        let process = resolve_peer_process(peer.port());
37        let peer_addr = peer.to_string();
38        let remote_for_conn = remote.clone();
39        tokio::spawn(async move {
40            if let Err(e) = handle_conn(client, &remote, &name, protocol, &tx, process, peer_addr, remote_for_conn).await {
41                warn!(component = %name, remote = %remote, error = %e, "connection ended with error");
42            }
43        });
44    }
45}
46
47#[allow(clippy::too_many_arguments)]
48async fn handle_conn(
49    mut client: TcpStream,
50    remote_addr: &str,
51    name: &str,
52    protocol: Protocol,
53    tx: &broadcast::Sender<ProxyEvent>,
54    process: Option<String>,
55    src: String,
56    dest: String,
57) -> Result<()> {
58    // Parse remote address: detect https:// for TLS outbound
59    let (actual_addr, use_tls, tls_host) = if remote_addr.starts_with("https://") {
60        let stripped = remote_addr.strip_prefix("https://").unwrap();
61        let host = stripped.split(':').next().unwrap_or(stripped).to_string();
62        (stripped.to_string(), true, host)
63    } else {
64        let stripped = remote_addr.strip_prefix("http://").unwrap_or(remote_addr);
65        (stripped.to_string(), false, String::new())
66    };
67
68    let tcp_stream = match TcpStream::connect(&actual_addr).await {
69        Ok(s) => {
70            debug!(component = %name, remote = %actual_addr, "connected to remote");
71            s
72        }
73        Err(e) => {
74            error!(component = %name, remote = %actual_addr, error = %e,
75                "failed to connect to remote — is the service running?");
76            if protocol == Protocol::Redis {
77                let err_msg = format!("-ERR ocular proxy: cannot reach {} ({})\r\n", actual_addr, e);
78                let _ = client.write_all(err_msg.as_bytes()).await;
79            }
80            return Err(e.into());
81        }
82    };
83
84    let (sr, sw): (Box<dyn AsyncRead + Unpin + Send>, Box<dyn AsyncWrite + Unpin + Send>) = if use_tls {
85        let config = rustls::ClientConfig::builder()
86            .dangerous()
87            .with_custom_certificate_verifier(Arc::new(NoVerify))
88            .with_no_client_auth();
89        let connector = tokio_rustls::TlsConnector::from(Arc::new(config));
90        let domain = rustls::pki_types::ServerName::try_from(tls_host)
91            .map_err(|e| anyhow::anyhow!("invalid TLS hostname: {}", e))?;
92        let tls_stream = connector.connect(domain, tcp_stream).await?;
93        let (r, w) = tokio::io::split(tls_stream);
94        (Box::new(r) as Box<dyn AsyncRead + Unpin + Send>, Box::new(w) as Box<dyn AsyncWrite + Unpin + Send>)
95    } else {
96        let (r, w) = tokio::io::split(tcp_stream);
97        (Box::new(r) as Box<dyn AsyncRead + Unpin + Send>, Box::new(w) as Box<dyn AsyncWrite + Unpin + Send>)
98    };
99
100    let mut sr = sr;
101    let mut sw = sw;
102
103    // For MySQL: strip SSL from greeting
104    if protocol == Protocol::Mysql {
105        let mut greeting_buf = [0u8; 65536];
106        let n = sr.read(&mut greeting_buf).await?;
107        if n == 0 { return Ok(()); }
108        let mut greeting = greeting_buf[..n].to_vec();
109        strip_mysql_ssl_flag(&mut greeting);
110        client.write_all(&greeting).await?;
111        debug!(component = %name, "forwarded MySQL greeting with SSL stripped");
112    }
113
114    // For PostgreSQL: handle SSL negotiation before normal flow
115    if protocol == Protocol::Postgres {
116        let mut buf = [0u8; 256];
117        let n = client.read(&mut buf).await?;
118        if n == 0 { return Ok(()); }
119        let data = &buf[..n];
120        // Forward SSLRequest to server
121        sw.write_all(data).await?;
122        // Read server's SSL response (single byte N or S)
123        let mut resp = [0u8; 1];
124        let rn = sr.read(&mut resp).await?;
125        if rn == 0 { return Ok(()); }
126        // Forward to client
127        client.write_all(&resp[..rn]).await?;
128        // Emit SSLRequest event
129        let command = parse_request(protocol, data).unwrap_or_else(|| "SSLRequest".into());
130        let response = if resp[0] == b'N' { "SSLResponse: No" } else { "SSLResponse: Yes" };
131        let _ = tx.send(ProxyEvent {
132            timestamp: SystemTime::now(),
133            component: name.to_string(),
134            protocol,
135            command: command.clone(),
136            full_command: command,
137            response: response.into(),
138            response_detail: response.into(),
139            latency: std::time::Duration::ZERO,
140            process: process.clone(),
141            src: Some(src.clone()),
142            dest: Some(dest.clone()),
143        });
144        // If server said 'S' (SSL), we'd need to upgrade — but we don't support that
145        // Most local setups respond 'N'
146    }
147
148    let (mut cr, mut cw) = client.split();
149
150    let pending: Arc<Mutex<Option<PendingRequest>>> = Arc::new(Mutex::new(None));
151
152    let name_req = name.to_string();
153    let name_resp = name.to_string();
154    let tx_req = tx.clone();
155    let tx_resp = tx.clone();
156    let pending_w = pending.clone();
157    let pending_r = pending;
158    let process_info = process;
159
160    let process_req = process_info.clone();
161    let src_req = src.clone();
162    let dest_req = dest.clone();
163    let src_resp = src.clone();
164    let dest_resp = dest;
165    let client_to_server = async move {
166        let mut buf = [0u8; 65536];
167        let mut http_req_buf: Vec<u8> = Vec::new();
168        loop {
169            let n = cr.read(&mut buf).await?;
170            if n == 0 { break; }
171            let data = &buf[..n];
172
173            if protocol == Protocol::Amqp {
174                // AMQP: loop through all frames in this read
175                let mut pos = 0;
176                while pos < data.len() {
177                    let frame_data = &data[pos..];
178                    let Some(flen) = amqp_frame_len(frame_data) else { break };
179                    if let Some(frame) = parse_amqp_frame(frame_data) {
180                        // Skip heartbeat — not a real request
181                        if frame.frame_type == 8 {
182                            pos += flen;
183                            continue;
184                        }
185                        if let Some(ref method) = frame.method {
186                            if is_async_method(method.class_id, method.method_id) {
187                                let (summary, detail) = parse_amqp_request_full(frame_data)
188                                    .unwrap_or_else(|| (method.summary.clone(), method.detail.clone()));
189                                let _ = tx_req.send(ProxyEvent {
190                                    timestamp: SystemTime::now(),
191                                    component: name_req.clone(),
192                                    protocol,
193                                    command: summary,
194                                    full_command: detail.clone(),
195                                    response: String::new(),
196                                    response_detail: detail,
197                                    latency: std::time::Duration::ZERO,
198                                    process: process_req.clone(),
199                                    src: Some(src_req.clone()),
200                                    dest: Some(dest_req.clone()),
201                                });
202                            } else {
203                                debug!(component = %name_req, command = %method.summary);
204                                *pending_w.lock().await = Some(PendingRequest {
205                                    timestamp: SystemTime::now(),
206                                    instant: Instant::now(),
207                                    command: method.summary.clone(),
208                                    full_command: method.detail.clone(),
209                                });
210                            }
211                        }
212                    }
213                    pos += flen;
214                }
215            } else if protocol == Protocol::Http {
216                http_req_buf.extend_from_slice(data);
217                if ocular_protocol::http::http_request_complete(&http_req_buf) {
218                    if let Some(command) = parse_request(protocol, &http_req_buf) {
219                        let full_command = extract_full_command(protocol, &http_req_buf).unwrap_or_else(|| command.clone());
220                        *pending_w.lock().await = Some(PendingRequest {
221                            timestamp: SystemTime::now(),
222                            instant: Instant::now(),
223                            command,
224                            full_command,
225                        });
226                    }
227                    http_req_buf.clear();
228                }
229            } else if let Some(command) = parse_request(protocol, data) {
230                let full_command = extract_full_command(protocol, data).unwrap_or_else(|| command.clone());
231                debug!(component = %name_req, %command);
232                *pending_w.lock().await = Some(PendingRequest {
233                    timestamp: SystemTime::now(),
234                    instant: Instant::now(),
235                    command,
236                    full_command,
237                });
238            } else if protocol == Protocol::Postgres && n > 0 {
239                info!(component = %name_req, bytes = n, first_byte = format!("0x{:02x}", data[0]), "pg client→server UNPARSED");
240            }
241
242            sw.write_all(data).await?;
243        }
244        Ok::<_, anyhow::Error>(())
245    };
246
247    let process_mysql = process_info.clone();
248    let server_to_client = async move {
249        let mut buf = [0u8; 65536];
250        let mut mysql_buf: Vec<u8> = Vec::new();
251        let mut http_resp_buf: Vec<u8> = Vec::new();
252        let mut awaiting_response = false;
253        loop {
254            let n = sr.read(&mut buf).await?;
255            if n == 0 { break; }
256            let data = &buf[..n];
257            cw.write_all(data).await?;
258
259            if protocol == Protocol::Mysql {
260                let has_pending = pending_r.lock().await.is_some();
261                if has_pending || awaiting_response {
262                    awaiting_response = true;
263                    mysql_buf.extend_from_slice(data);
264                    if mysql_response_complete(&mysql_buf) {
265                        if let Some(req) = pending_r.lock().await.take() {
266                            let latency = req.instant.elapsed();
267                            let response = parse_response(protocol, &mysql_buf).unwrap_or_default();
268                            let response_detail = format_response_detail(protocol, &mysql_buf).unwrap_or_default();
269                            let _ = tx_resp.send(ProxyEvent {
270                                timestamp: req.timestamp,
271                                component: name_resp.clone(),
272                                protocol,
273                                command: req.command,
274                                full_command: req.full_command,
275                                response,
276                                response_detail,
277                                latency,
278                                process: process_mysql.clone(),
279                                src: Some(src_resp.clone()),
280                                dest: Some(dest_resp.clone()),
281                            });
282                        }
283                        mysql_buf.clear();
284                        awaiting_response = false;
285                    }
286                }
287            } else if protocol == Protocol::Http {
288                http_resp_buf.extend_from_slice(data);
289                if ocular_protocol::http::http_response_complete(&http_resp_buf) {
290                    if let Some(req) = pending_r.lock().await.take() {
291                        let latency = req.instant.elapsed();
292                        let response = parse_response(protocol, &http_resp_buf).unwrap_or_default();
293                        let response_detail = format_response_detail(protocol, &http_resp_buf).unwrap_or_else(|| response.clone());
294                        let _ = tx_resp.send(ProxyEvent {
295                            timestamp: req.timestamp,
296                            component: name_resp.clone(),
297                            protocol,
298                            command: req.command,
299                            full_command: req.full_command,
300                            response,
301                            response_detail,
302                            latency,
303                            process: process_info.clone(),
304                                    src: Some(src_resp.clone()),
305                                    dest: Some(dest_resp.clone()),
306                        });
307                    }
308                    http_resp_buf.clear();
309                }
310            } else if protocol == Protocol::Amqp {
311                // AMQP: loop through all server frames
312                let mut pos = 0;
313                while pos < data.len() {
314                    let frame_data = &data[pos..];
315                    let Some(flen) = amqp_frame_len(frame_data) else { break };
316                    if let Some(frame) = parse_amqp_frame(frame_data) {
317                        // Skip content header and body frames — handled below with method
318                        if frame.frame_type == 2 || frame.frame_type == 3 {
319                            pos += flen;
320                            continue;
321                        }
322                        // Heartbeat: skip
323                        if frame.frame_type == 8 {
324                            pos += flen;
325                            continue;
326                        }
327                    }
328
329                    // Extract body from subsequent Header+Body frames
330                    let mut body_text = String::new();
331                    let mut peek = pos + flen;
332                    while peek < data.len() {
333                        let peek_data = &data[peek..];
334                        let Some(plen) = amqp_frame_len(peek_data) else { break };
335                        if let Some(pf) = parse_amqp_frame(peek_data) {
336                            if pf.frame_type == 2 {
337                                // Header frame, skip
338                            } else if pf.frame_type == 3 {
339                                // Body frame
340                                if let Some(body) = &pf.body {
341                                    body_text = String::from_utf8_lossy(body).to_string();
342                                }
343                            } else {
344                                break;
345                            }
346                        } else {
347                            break;
348                        }
349                        peek += plen;
350                    }
351
352                    if let Some(req) = pending_r.lock().await.take() {
353                        let latency = req.instant.elapsed();
354                        let mut response = parse_response(protocol, frame_data).unwrap_or_default();
355                        let mut response_detail = format_response_detail(protocol, frame_data).unwrap_or_else(|| response.clone());
356                        if !body_text.is_empty() {
357                            response = format!("{} | {}", response, body_text);
358                            response_detail = format!("{}\nBody: {}", response_detail, body_text);
359                        }
360                        let _ = tx_resp.send(ProxyEvent {
361                            timestamp: req.timestamp,
362                            component: name_resp.clone(),
363                            protocol,
364                            command: req.command,
365                            full_command: req.full_command,
366                            response,
367                            response_detail,
368                            latency,
369                            process: process_info.clone(),
370                                    src: Some(src_resp.clone()),
371                                    dest: Some(dest_resp.clone()),
372                        });
373                    } else if let Some(frame) = parse_amqp_frame(frame_data) {
374                        // Server-initiated method (e.g. Basic.Deliver) — emit as standalone
375                        if let Some(ref method) = frame.method {
376                            let response = if body_text.is_empty() { String::new() } else { body_text.clone() };
377                            let response_detail = if body_text.is_empty() { String::new() } else { body_text.clone() };
378                            let command = method.summary.clone();
379                            let _ = tx_resp.send(ProxyEvent {
380                                timestamp: SystemTime::now(),
381                                component: name_resp.clone(),
382                                protocol,
383                                command,
384                                full_command: method.detail.clone(),
385                                response,
386                                response_detail,
387                                latency: std::time::Duration::ZERO,
388                                process: process_info.clone(),
389                                    src: Some(dest_resp.clone()),
390                                    dest: Some(src_resp.clone()),
391                            });
392                        }
393                    }
394                    // Advance past the method frame + any header/body frames we consumed
395                    pos = peek;
396                }
397            } else if protocol == Protocol::Postgres {
398                // Postgres: only pair with meaningful responses, skip setup noise
399                let first = data[0];
400                info!(component = %name_resp, bytes = n, first_byte = format!("0x{:02x}", first),
401                    hex_head = format!("{:02x?}", &data[..n.min(20)]), "pg server→client");
402                // Use parse_postgres_response which scans all messages and prioritizes errors
403                let is_meaningful = matches!(first, b'C' | b'E' | b'T' | b'Z' | b'I' | b'D' | b'R');
404                if is_meaningful {
405                    if let Some(req) = pending_r.lock().await.take() {
406                        let latency = req.instant.elapsed();
407                        let response = parse_response(protocol, data).unwrap_or_default();
408                        let response_detail = format_response_detail(protocol, data).unwrap_or_else(|| response.clone());
409                        let _ = tx_resp.send(ProxyEvent {
410                            timestamp: req.timestamp,
411                            component: name_resp.clone(),
412                            protocol,
413                            command: req.command,
414                            full_command: req.full_command,
415                            response,
416                            response_detail,
417                            latency,
418                            process: process_info.clone(),
419                                    src: Some(src_resp.clone()),
420                                    dest: Some(dest_resp.clone()),
421                        });
422                    }
423                }
424                // ParameterStatus (S), BackendKeyData (K), etc. are silently skipped
425            } else {
426                // Redis/MongoDB: single request/response per read
427                if let Some(req) = pending_r.lock().await.take() {
428                    let latency = req.instant.elapsed();
429                    let response = parse_response(protocol, data).unwrap_or_default();
430                    let response_detail = format_response_detail(protocol, data).unwrap_or_else(|| response.clone());
431                    let _ = tx_resp.send(ProxyEvent {
432                        timestamp: req.timestamp,
433                        component: name_resp.clone(),
434                        protocol,
435                        command: req.command,
436                        full_command: req.full_command,
437                        response,
438                        response_detail,
439                        latency,
440                        process: process_info.clone(),
441                        src: Some(src_resp.clone()),
442                        dest: Some(dest_resp.clone()),
443                    });
444                }
445            }
446        }
447        Ok::<_, anyhow::Error>(())
448    };
449
450    tokio::select! {
451        r = client_to_server => r?,
452        r = server_to_client => r?,
453    }
454    Ok(())
455}
456
457fn mysql_response_complete(buf: &[u8]) -> bool {
458    if buf.len() < 5 { return false; }
459    let first_marker = buf[4];
460    match first_marker {
461        0x00 | 0xff => return true,
462        _ => {}
463    }
464    let mut pos = 0;
465    let mut last_marker = 0u8;
466    let mut last_pkt_len = 0usize;
467    while pos + 4 <= buf.len() {
468        let pkt_len = (buf[pos] as usize) | (buf[pos+1] as usize) << 8 | (buf[pos+2] as usize) << 16;
469        let end = pos + 4 + pkt_len;
470        if end > buf.len() { break; }
471        if pkt_len > 0 {
472            last_marker = buf[pos + 4];
473            last_pkt_len = pkt_len;
474        }
475        pos = end;
476    }
477    (last_marker == 0xfe && last_pkt_len < 9) || (last_marker == 0x00 && last_pkt_len < 16 && pos == buf.len())
478}
479
480fn strip_mysql_ssl_flag(packet: &mut [u8]) {
481    if packet.len() < 5 { return; }
482    let payload = &mut packet[4..];
483    if payload.is_empty() || payload[0] != 10 { return; }
484    let mut pos = 1;
485    while pos < payload.len() && payload[pos] != 0 { pos += 1; }
486    pos += 1;
487    pos += 4;
488    pos += 8;
489    pos += 1;
490    if pos + 2 > payload.len() { return; }
491    let cap_lower = u16::from_le_bytes([payload[pos], payload[pos + 1]]);
492    let cap_lower_new = cap_lower & !0x0800;
493    payload[pos] = (cap_lower_new & 0xff) as u8;
494    payload[pos + 1] = ((cap_lower_new >> 8) & 0xff) as u8;
495}
496
497/// Resolve which process owns a local TCP port (the client's ephemeral port).
498fn resolve_peer_process(port: u16) -> Option<String> {
499    use std::process::Command;
500    let my_pid = std::process::id().to_string();
501
502    if cfg!(target_os = "macos") {
503        // lsof -i tcp:PORT -sTCP:ESTABLISHED -Fp -Fc
504        // Returns multiple process entries; skip our own PID
505        let output = Command::new("lsof")
506            .args(["-i", &format!("tcp:{}", port), "-sTCP:ESTABLISHED", "-Fp", "-Fc"])
507            .output()
508            .ok()?;
509        let text = String::from_utf8_lossy(&output.stdout);
510        let mut current_pid = String::new();
511        let mut current_cmd = String::new();
512        for line in text.lines() {
513            if let Some(p) = line.strip_prefix('p') {
514                // Save previous entry if it wasn't us
515                if !current_pid.is_empty() && current_pid != my_pid {
516                    return Some(format!("[{}] {}", current_pid, current_cmd));
517                }
518                current_pid = p.to_string();
519                current_cmd.clear();
520            }
521            if let Some(c) = line.strip_prefix('c') {
522                current_cmd = c.to_string();
523            }
524        }
525        // Check last entry
526        if !current_pid.is_empty() && current_pid != my_pid {
527            return Some(format!("[{}] {}", current_pid, current_cmd));
528        }
529        None
530    } else {
531        // Linux: ss -tnp sport = :PORT
532        let output = Command::new("ss")
533            .args(["-tnp", &format!("sport = :{}", port)])
534            .output()
535            .ok()?;
536        let text = String::from_utf8_lossy(&output.stdout);
537        // Parse: users:(("process_name",pid=1234,fd=5))
538        for line in text.lines() {
539            if let Some(start) = line.find("users:((\"") {
540                let rest = &line[start + 9..];
541                if let Some(end) = rest.find('"') {
542                    let proc_name = &rest[..end];
543                    let pid = rest.find("pid=")
544                        .and_then(|i| rest[i+4..].split(|c: char| !c.is_ascii_digit()).next())
545                        .unwrap_or("?");
546                    return Some(format!("[{}] {}", pid, proc_name));
547                }
548            }
549        }
550        None
551    }
552}
553
554/// TLS certificate verifier that accepts any certificate (for proxying to known backends).
555#[derive(Debug)]
556struct NoVerify;
557
558impl rustls::client::danger::ServerCertVerifier for NoVerify {
559    fn verify_server_cert(
560        &self, _: &rustls::pki_types::CertificateDer<'_>, _: &[rustls::pki_types::CertificateDer<'_>],
561        _: &rustls::pki_types::ServerName<'_>, _: &[u8], _: rustls::pki_types::UnixTime,
562    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
563        Ok(rustls::client::danger::ServerCertVerified::assertion())
564    }
565    fn verify_tls12_signature(
566        &self, _: &[u8], _: &rustls::pki_types::CertificateDer<'_>, _: &rustls::DigitallySignedStruct,
567    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
568        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
569    }
570    fn verify_tls13_signature(
571        &self, _: &[u8], _: &rustls::pki_types::CertificateDer<'_>, _: &rustls::DigitallySignedStruct,
572    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
573        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
574    }
575    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
576        vec![
577            rustls::SignatureScheme::RSA_PKCS1_SHA256,
578            rustls::SignatureScheme::RSA_PKCS1_SHA384,
579            rustls::SignatureScheme::RSA_PKCS1_SHA512,
580            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
581            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
582            rustls::SignatureScheme::RSA_PSS_SHA256,
583            rustls::SignatureScheme::RSA_PSS_SHA384,
584            rustls::SignatureScheme::RSA_PSS_SHA512,
585            rustls::SignatureScheme::ED25519,
586        ]
587    }
588}