cloudpub_common/
proxy_protocol.rs

1//! PROXY protocol v2 implementation
2//!
3//! Implements HAProxy's PROXY protocol version 2 for passing client connection
4//! information to backend servers in a binary header format.
5//!
6//! Reference: https://www.haproxy.org/download/2.9/doc/proxy-protocol.txt
7
8use std::net::{IpAddr, SocketAddr};
9
10/// PROXY protocol v2 signature (12 bytes)
11const PROXY_V2_SIGNATURE: [u8; 12] = [
12    0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
13];
14
15/// Version and command byte values
16const PROXY_V2_VERSION: u8 = 0x20; // Version 2
17const PROXY_CMD_PROXY: u8 = 0x01; // PROXY command
18
19/// Address family values
20const AF_UNSPEC: u8 = 0x00;
21const AF_INET: u8 = 0x10; // IPv4
22const AF_INET6: u8 = 0x20; // IPv6
23
24/// Transport protocol values
25const STREAM: u8 = 0x01; // TCP
26
27/// Builds a PROXY protocol v2 header for TCP connections
28///
29/// # Arguments
30/// * `client_addr` - The original client's socket address
31/// * `server_addr` - The server's socket address (destination)
32///
33/// # Returns
34/// A byte vector containing the complete PROXY v2 header
35pub fn build_proxy_v2_header(client_addr: &SocketAddr, server_addr: &SocketAddr) -> Vec<u8> {
36    let mut header = Vec::with_capacity(52); // Max size for IPv6
37
38    // 1. Signature (12 bytes)
39    header.extend_from_slice(&PROXY_V2_SIGNATURE);
40
41    // Determine address family and protocol
42    let (af_proto, addr_len, addr_data) = match (client_addr.ip(), server_addr.ip()) {
43        (IpAddr::V4(src_ip), IpAddr::V4(dst_ip)) => {
44            // IPv4 + TCP
45            let af_proto = AF_INET | STREAM;
46            let mut data = Vec::with_capacity(12);
47            data.extend_from_slice(&src_ip.octets()); // 4 bytes src addr
48            data.extend_from_slice(&dst_ip.octets()); // 4 bytes dst addr
49            data.extend_from_slice(&client_addr.port().to_be_bytes()); // 2 bytes src port
50            data.extend_from_slice(&server_addr.port().to_be_bytes()); // 2 bytes dst port
51            (af_proto, 12u16, data)
52        }
53        (IpAddr::V6(src_ip), IpAddr::V6(dst_ip)) => {
54            // IPv6 + TCP
55            let af_proto = AF_INET6 | STREAM;
56            let mut data = Vec::with_capacity(36);
57            data.extend_from_slice(&src_ip.octets()); // 16 bytes src addr
58            data.extend_from_slice(&dst_ip.octets()); // 16 bytes dst addr
59            data.extend_from_slice(&client_addr.port().to_be_bytes()); // 2 bytes src port
60            data.extend_from_slice(&server_addr.port().to_be_bytes()); // 2 bytes dst port
61            (af_proto, 36u16, data)
62        }
63        (IpAddr::V4(src_ip), IpAddr::V6(_)) => {
64            // Mixed: convert IPv4 to IPv6-mapped address
65            let src_v6 = src_ip.to_ipv6_mapped();
66            let dst_v6 = match server_addr.ip() {
67                IpAddr::V6(ip) => ip,
68                _ => unreachable!(),
69            };
70            let af_proto = AF_INET6 | STREAM;
71            let mut data = Vec::with_capacity(36);
72            data.extend_from_slice(&src_v6.octets());
73            data.extend_from_slice(&dst_v6.octets());
74            data.extend_from_slice(&client_addr.port().to_be_bytes());
75            data.extend_from_slice(&server_addr.port().to_be_bytes());
76            (af_proto, 36u16, data)
77        }
78        (IpAddr::V6(_), IpAddr::V4(dst_ip)) => {
79            // Mixed: convert IPv4 to IPv6-mapped address
80            let src_v6 = match client_addr.ip() {
81                IpAddr::V6(ip) => ip,
82                _ => unreachable!(),
83            };
84            let dst_v6 = dst_ip.to_ipv6_mapped();
85            let af_proto = AF_INET6 | STREAM;
86            let mut data = Vec::with_capacity(36);
87            data.extend_from_slice(&src_v6.octets());
88            data.extend_from_slice(&dst_v6.octets());
89            data.extend_from_slice(&client_addr.port().to_be_bytes());
90            data.extend_from_slice(&server_addr.port().to_be_bytes());
91            (af_proto, 36u16, data)
92        }
93    };
94
95    // 2. Version and command (1 byte)
96    header.push(PROXY_V2_VERSION | PROXY_CMD_PROXY);
97
98    // 3. Address family and protocol (1 byte)
99    header.push(af_proto);
100
101    // 4. Address length (2 bytes, big-endian)
102    header.extend_from_slice(&addr_len.to_be_bytes());
103
104    // 5. Address data (variable length)
105    header.extend_from_slice(&addr_data);
106
107    header
108}
109
110/// Builds a minimal PROXY protocol v2 header for LOCAL command (no address info)
111///
112/// This is used when the connection is health-check or internal and doesn't
113/// need client address information passed to the backend.
114pub fn build_proxy_v2_local_header() -> Vec<u8> {
115    let mut header = Vec::with_capacity(16);
116
117    // Signature
118    header.extend_from_slice(&PROXY_V2_SIGNATURE);
119
120    // Version (2) + Command (LOCAL = 0)
121    header.push(PROXY_V2_VERSION); // 0x20 = version 2, command LOCAL
122
123    // Address family UNSPEC + protocol UNSPEC
124    header.push(AF_UNSPEC);
125
126    // Address length = 0
127    header.extend_from_slice(&0u16.to_be_bytes());
128
129    header
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
136
137    #[test]
138    fn test_proxy_v2_header_ipv4() {
139        let client = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 100), 54321));
140        let server = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 25565));
141
142        let header = build_proxy_v2_header(&client, &server);
143
144        // Check signature
145        assert_eq!(&header[0..12], &PROXY_V2_SIGNATURE);
146
147        // Check version + command
148        assert_eq!(header[12], 0x21); // v2 + PROXY
149
150        // Check address family + protocol
151        assert_eq!(header[13], 0x11); // AF_INET + STREAM
152
153        // Check address length
154        assert_eq!(&header[14..16], &12u16.to_be_bytes());
155
156        // Total header size for IPv4: 16 + 12 = 28 bytes
157        assert_eq!(header.len(), 28);
158
159        // Verify source IP
160        assert_eq!(&header[16..20], &[192, 168, 1, 100]);
161
162        // Verify destination IP
163        assert_eq!(&header[20..24], &[10, 0, 0, 1]);
164
165        // Verify source port (54321 = 0xD431)
166        assert_eq!(&header[24..26], &54321u16.to_be_bytes());
167
168        // Verify destination port (25565 = 0x63DD)
169        assert_eq!(&header[26..28], &25565u16.to_be_bytes());
170    }
171
172    #[test]
173    fn test_proxy_v2_header_ipv6() {
174        let client = SocketAddr::V6(SocketAddrV6::new(
175            Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1),
176            54321,
177            0,
178            0,
179        ));
180        let server = SocketAddr::V6(SocketAddrV6::new(
181            Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 2),
182            25565,
183            0,
184            0,
185        ));
186
187        let header = build_proxy_v2_header(&client, &server);
188
189        // Check signature
190        assert_eq!(&header[0..12], &PROXY_V2_SIGNATURE);
191
192        // Check version + command
193        assert_eq!(header[12], 0x21); // v2 + PROXY
194
195        // Check address family + protocol
196        assert_eq!(header[13], 0x21); // AF_INET6 + STREAM
197
198        // Check address length
199        assert_eq!(&header[14..16], &36u16.to_be_bytes());
200
201        // Total header size for IPv6: 16 + 36 = 52 bytes
202        assert_eq!(header.len(), 52);
203    }
204
205    #[test]
206    fn test_proxy_v2_local_header() {
207        let header = build_proxy_v2_local_header();
208
209        // Check signature
210        assert_eq!(&header[0..12], &PROXY_V2_SIGNATURE);
211
212        // Check version + command (LOCAL = 0)
213        assert_eq!(header[12], 0x20); // v2 + LOCAL
214
215        // Check address family + protocol
216        assert_eq!(header[13], 0x00); // UNSPEC
217
218        // Check address length = 0
219        assert_eq!(&header[14..16], &0u16.to_be_bytes());
220
221        // Total size: 16 bytes
222        assert_eq!(header.len(), 16);
223    }
224}