Skip to main content

cloud_sql_connector/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod peer_filter;
4mod tls;
5
6pub use peer_filter::PeerFilter;
7
8use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
9use std::path::Path;
10use std::sync::Arc;
11
12use google_cloud_sql_v1::client::SqlConnectService;
13use google_cloud_sql_v1::model::{ConnectSettings, IpMapping, SqlIpAddressType, SslCert};
14use rsa::RsaPrivateKey;
15use rsa::pkcs8::EncodePrivateKey as _;
16use rsa::pkcs8::EncodePublicKey as _;
17use rustls::pki_types::{PrivateKeyDer, ServerName};
18use tokio::io::copy_bidirectional;
19use tokio::net::{TcpListener, TcpStream, UnixListener};
20use tokio_rustls::TlsConnector;
21use tokio_rustls::client::TlsStream;
22
23/// Cloud SQL proxy port (used by the Cloud SQL server-side proxy).
24const CLOUD_SQL_PORT: u16 = 3307;
25
26/// RSA key size in bits for ephemeral certificate requests.
27const RSA_KEY_BITS: usize = 2048;
28
29/// Cloud SQL connector errors.
30#[derive(Debug, thiserror::Error)]
31pub enum Error {
32    /// Cloud SQL API client builder error.
33    #[error(transparent)]
34    ClientBuilder(#[from] google_cloud_gax::client_builder::Error),
35    /// Cloud SQL API request error.
36    #[error(transparent)]
37    CloudSqlApi(#[from] google_cloud_sql_v1::Error),
38    /// Ephemeral certificate PEM is empty.
39    #[error("ephemeral certificate PEM is empty")]
40    EphemeralCertEmpty,
41    /// Ephemeral certificate missing from API response.
42    #[error("ephemeral certificate missing from generateEphemeralCert response")]
43    EphemeralCertMissing,
44    /// Failed to parse IP address from API response.
45    #[error("invalid IP address from Cloud SQL API: {address}")]
46    InvalidIpAddress {
47        /// The invalid address string.
48        address: String,
49        /// The underlying parse error.
50        #[source]
51        source: core::net::AddrParseError,
52    },
53    /// IO error.
54    #[error(transparent)]
55    Io(#[from] std::io::Error),
56    /// No certificates found in PEM data.
57    #[error("no certificates found in PEM data")]
58    NoCertificatesInPem,
59    /// No primary IP address found for the instance.
60    #[error("no primary IP address found for Cloud SQL instance")]
61    NoPrimaryIp,
62    /// PKCS#8 encoding error.
63    #[error("failed to encode RSA key: {0}")]
64    Pkcs8(#[from] rsa::pkcs8::Error),
65    /// RSA key generation error.
66    #[error("failed to generate RSA key: {0}")]
67    RsaKeyGeneration(#[from] rsa::Error),
68    /// Server CA certificate PEM is empty.
69    #[error("server CA certificate PEM is empty")]
70    ServerCaCertEmpty,
71    /// Server CA certificate missing from API response.
72    #[error("server CA certificate missing from connectSettings response")]
73    ServerCaCertMissing,
74    /// SPKI encoding error.
75    #[error("failed to encode RSA public key: {0}")]
76    Spki(#[from] rsa::pkcs8::spki::Error),
77    /// TLS configuration error.
78    #[error("TLS configuration error: {0}")]
79    TlsConfig(#[from] rustls::Error),
80}
81
82/// Cloud SQL Auth Proxy dialer.
83///
84/// Manages ephemeral certificates and establishes TLS connections
85/// to a Cloud SQL instance through the Cloud SQL Admin API.
86#[derive(Debug)]
87pub struct Dialer {
88    /// Cloud SQL Admin API client.
89    client: SqlConnectService,
90    /// Cloud SQL instance name.
91    instance: String,
92    /// GCP project ID.
93    project: String,
94    /// RSA private key for ephemeral certificate requests.
95    rsa_private_key: RsaPrivateKey,
96}
97
98impl Dialer {
99    /// Retrieve connect settings for the Cloud SQL instance.
100    async fn connect_settings(&self) -> Result<ConnectSettings, Error> {
101        Ok(self
102            .client
103            .get_connect_settings()
104            .set_project(&self.project)
105            .set_instance(&self.instance)
106            .send()
107            .await?)
108    }
109
110    /// Dial the Cloud SQL instance, returning an authenticated TLS stream.
111    ///
112    /// Fetches connect settings and an ephemeral certificate from the Cloud SQL
113    /// Admin API, then establishes a TLS 1.3 connection with mutual
114    /// authentication to the instance.
115    pub async fn dial(&self) -> Result<TlsStream<TcpStream>, Error> {
116        let (settings, cert) = tokio::try_join!(self.connect_settings(), self.ephemeral_cert(),)?;
117
118        let primary_ip = extract_primary_ip(&settings.ip_addresses)?;
119        let server_ca = tls::extract_server_ca_cert(&settings)?;
120        if cert.cert.is_empty() {
121            return Err(Error::EphemeralCertEmpty);
122        }
123
124        let client_cert = tls::parse_pem_cert(&cert.cert)?;
125        let private_key_der = self.private_key_der()?;
126
127        let tls_config = tls::build_config(server_ca, client_cert, private_key_der)?;
128        let connector = TlsConnector::from(Arc::new(tls_config));
129
130        let tcp_stream = TcpStream::connect((primary_ip, CLOUD_SQL_PORT)).await?;
131
132        // The server name is not used for hostname verification — the custom
133        // CloudSqlCertVerifier validates the CA chain only. A value is required
134        // by rustls for the TLS handshake SNI extension.
135        let server_name = ServerName::IpAddress(primary_ip.into());
136
137        Ok(connector.connect(server_name, tcp_stream).await?)
138    }
139
140    /// Request an ephemeral client certificate for the Cloud SQL instance.
141    async fn ephemeral_cert(&self) -> Result<SslCert, Error> {
142        let response = self
143            .client
144            .generate_ephemeral_cert()
145            .set_project(&self.project)
146            .set_instance(&self.instance)
147            .set_public_key(&self.public_key_pem()?)
148            .send()
149            .await?;
150
151        response.ephemeral_cert.ok_or(Error::EphemeralCertMissing)
152    }
153
154    /// Create a new Cloud SQL Auth Proxy dialer for a specific instance.
155    ///
156    /// Builds the API client and generates an RSA 2048-bit keypair.
157    pub async fn new(
158        project: impl Into<String>,
159        instance: impl Into<String>,
160    ) -> Result<Self, Error> {
161        let client = SqlConnectService::builder().build().await?;
162
163        let rsa_private_key = RsaPrivateKey::new(&mut rsa::rand_core::OsRng, RSA_KEY_BITS)?;
164
165        Ok(Self {
166            client,
167            instance: instance.into(),
168            project: project.into(),
169            rsa_private_key,
170        })
171    }
172
173    /// Encode the RSA private key as PKCS#8 DER for TLS.
174    fn private_key_der(&self) -> Result<PrivateKeyDer<'static>, Error> {
175        let der = self.rsa_private_key.to_pkcs8_der()?;
176        Ok(PrivateKeyDer::Pkcs8(
177            rustls::pki_types::PrivatePkcs8KeyDer::from(der.as_bytes().to_vec()),
178        ))
179    }
180
181    /// Encode the RSA public key as PEM for the Cloud SQL Admin API.
182    fn public_key_pem(&self) -> Result<String, Error> {
183        Ok(self
184            .rsa_private_key
185            .to_public_key()
186            .to_public_key_pem(rsa::pkcs8::LineEnding::LF)?)
187    }
188}
189
190/// Unix socket proxy server for a Cloud SQL instance.
191///
192/// Binds a Unix socket on construction, guaranteeing the socket is ready
193/// to accept connections once the struct is obtained.
194#[derive(Debug)]
195pub struct UnixSocketServer {
196    dialer: Arc<Dialer>,
197    listener: UnixListener,
198}
199
200impl UnixSocketServer {
201    /// Bind a Unix socket proxy for a Cloud SQL instance.
202    ///
203    /// The socket is bound immediately — if this returns `Ok`, the server
204    /// is ready to accept connections.
205    pub fn new(dialer: Arc<Dialer>, socket_path: &Path) -> Result<Self, Error> {
206        let listener = UnixListener::bind(socket_path)?;
207
208        log::info!("Cloud SQL proxy listening on {}", socket_path.display());
209
210        Ok(Self { dialer, listener })
211    }
212
213    /// Accept connections and proxy traffic to the Cloud SQL instance.
214    ///
215    /// Runs until the listener encounters an accept error.
216    pub async fn serve(&self) -> Result<(), Error> {
217        loop {
218            let (mut local_stream, _addr) = self.listener.accept().await?;
219
220            let dialer = Arc::clone(&self.dialer);
221
222            tokio::spawn(async move {
223                match dialer.dial().await {
224                    Ok(mut tls_stream) => {
225                        if let Err(error) =
226                            copy_bidirectional(&mut local_stream, &mut tls_stream).await
227                        {
228                            log::warn!("Cloud SQL proxy connection ended: {error}");
229                        }
230                    }
231                    Err(error) => {
232                        log::warn!("Cloud SQL proxy dial failed: {error}");
233                    }
234                }
235            });
236        }
237    }
238}
239
240/// TCP proxy server for a Cloud SQL instance.
241///
242/// Binds a TCP listener on construction, guaranteeing the socket is ready
243/// to accept connections once the struct is obtained.
244#[derive(Debug)]
245pub struct TcpServer {
246    dialer: Arc<Dialer>,
247    listener: TcpListener,
248    local_addr: SocketAddr,
249    peer_filter: PeerFilter,
250}
251
252impl TcpServer {
253    /// Bind a TCP proxy for a Cloud SQL instance.
254    ///
255    /// The socket is bound immediately — if this returns `Ok`, the server
256    /// is ready to accept connections.
257    pub async fn new(
258        dialer: Arc<Dialer>,
259        address: SocketAddr,
260        peer_filter: PeerFilter,
261    ) -> Result<Self, Error> {
262        let listener = TcpListener::bind(address).await?;
263        let local_addr = listener.local_addr()?;
264
265        log::info!("Cloud SQL proxy listening on {local_addr}");
266
267        Ok(Self {
268            dialer,
269            listener,
270            local_addr,
271            peer_filter,
272        })
273    }
274
275    /// Bind a TCP proxy on `localhost` with an OS-assigned port.
276    ///
277    /// Accepts all incoming connections since the listener is bound to
278    /// loopback and only reachable from the local machine. Use [`TcpServer::new`]
279    /// if you need to filter on loopback.
280    ///
281    /// Use [`TcpServer::local_addr`] to discover the assigned port.
282    pub async fn new_localhost_v4(dialer: Arc<Dialer>) -> Result<Self, Error> {
283        Self::new(
284            dialer,
285            SocketAddr::from((Ipv4Addr::LOCALHOST, 0)),
286            PeerFilter::All,
287        )
288        .await
289    }
290
291    /// Bind a TCP proxy on IPv6 `localhost` with an OS-assigned port.
292    ///
293    /// Accepts all incoming connections since the listener is bound to
294    /// loopback and only reachable from the local machine. Use [`TcpServer::new`]
295    /// if you need to filter on loopback.
296    ///
297    /// Use [`TcpServer::local_addr`] to discover the assigned port.
298    pub async fn new_localhost_v6(dialer: Arc<Dialer>) -> Result<Self, Error> {
299        Self::new(
300            dialer,
301            SocketAddr::from((Ipv6Addr::LOCALHOST, 0)),
302            PeerFilter::All,
303        )
304        .await
305    }
306
307    /// Return the local address this server is bound to.
308    #[must_use]
309    pub fn local_addr(&self) -> SocketAddr {
310        self.local_addr
311    }
312
313    /// Accept connections and proxy traffic to the Cloud SQL instance.
314    ///
315    /// Runs until the listener encounters an accept error.
316    pub async fn serve(&self) -> Result<(), Error> {
317        loop {
318            let (mut local_stream, peer_addr) = self.listener.accept().await?;
319
320            if !self.peer_filter.is_allowed(peer_addr) {
321                log::warn!("Cloud SQL proxy rejected connection from {peer_addr}");
322                continue;
323            }
324
325            let dialer = Arc::clone(&self.dialer);
326
327            tokio::spawn(async move {
328                match dialer.dial().await {
329                    Ok(mut tls_stream) => {
330                        if let Err(error) =
331                            copy_bidirectional(&mut local_stream, &mut tls_stream).await
332                        {
333                            log::warn!("Cloud SQL proxy connection ended: {error}");
334                        }
335                    }
336                    Err(error) => {
337                        log::warn!("Cloud SQL proxy dial failed: {error}");
338                    }
339                }
340            });
341        }
342    }
343}
344
345/// Extract the primary IP address from instance IP mappings.
346fn extract_primary_ip(ip_addresses: &[IpMapping]) -> Result<IpAddr, Error> {
347    for mapping in ip_addresses {
348        if mapping.r#type == SqlIpAddressType::Primary {
349            return mapping.ip_address.parse::<IpAddr>().map_err(|source| {
350                Error::InvalidIpAddress {
351                    address: mapping.ip_address.clone(),
352                    source,
353                }
354            });
355        }
356    }
357
358    Err(Error::NoPrimaryIp)
359}