Skip to main content

cloud_sql_connector/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod tls;
4
5use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
6use std::path::Path;
7use std::sync::Arc;
8
9use google_cloud_sql_v1::client::SqlConnectService;
10use google_cloud_sql_v1::model::{ConnectSettings, IpMapping, SqlIpAddressType, SslCert};
11use rsa::RsaPrivateKey;
12use rsa::pkcs8::EncodePrivateKey as _;
13use rsa::pkcs8::EncodePublicKey as _;
14use rustls::pki_types::{PrivateKeyDer, ServerName};
15use tokio::io::copy_bidirectional;
16use tokio::net::{TcpListener, TcpStream, UnixListener};
17use tokio_rustls::TlsConnector;
18use tokio_rustls::client::TlsStream;
19
20/// Cloud SQL proxy port (used by the Cloud SQL server-side proxy).
21const CLOUD_SQL_PORT: u16 = 3307;
22
23/// RSA key size in bits for ephemeral certificate requests.
24const RSA_KEY_BITS: usize = 2048;
25
26/// Cloud SQL connector errors.
27#[derive(Debug, thiserror::Error)]
28pub enum Error {
29    /// Cloud SQL API client builder error.
30    #[error(transparent)]
31    ClientBuilder(#[from] google_cloud_gax::client_builder::Error),
32    /// Cloud SQL API request error.
33    #[error(transparent)]
34    CloudSqlApi(#[from] google_cloud_sql_v1::Error),
35    /// Ephemeral certificate PEM is empty.
36    #[error("ephemeral certificate PEM is empty")]
37    EphemeralCertEmpty,
38    /// Ephemeral certificate missing from API response.
39    #[error("ephemeral certificate missing from generateEphemeralCert response")]
40    EphemeralCertMissing,
41    /// Failed to parse IP address from API response.
42    #[error("invalid IP address from Cloud SQL API: {address}")]
43    InvalidIpAddress {
44        /// The invalid address string.
45        address: String,
46        /// The underlying parse error.
47        #[source]
48        source: core::net::AddrParseError,
49    },
50    /// IO error.
51    #[error(transparent)]
52    Io(#[from] std::io::Error),
53    /// No certificates found in PEM data.
54    #[error("no certificates found in PEM data")]
55    NoCertificatesInPem,
56    /// No primary IP address found for the instance.
57    #[error("no primary IP address found for Cloud SQL instance")]
58    NoPrimaryIp,
59    /// PKCS#8 encoding error.
60    #[error("failed to encode RSA key: {0}")]
61    Pkcs8(#[from] rsa::pkcs8::Error),
62    /// RSA key generation error.
63    #[error("failed to generate RSA key: {0}")]
64    RsaKeyGeneration(#[from] rsa::Error),
65    /// Server CA certificate PEM is empty.
66    #[error("server CA certificate PEM is empty")]
67    ServerCaCertEmpty,
68    /// Server CA certificate missing from API response.
69    #[error("server CA certificate missing from connectSettings response")]
70    ServerCaCertMissing,
71    /// SPKI encoding error.
72    #[error("failed to encode RSA public key: {0}")]
73    Spki(#[from] rsa::pkcs8::spki::Error),
74    /// TLS configuration error.
75    #[error("TLS configuration error: {0}")]
76    TlsConfig(#[from] rustls::Error),
77}
78
79/// Cloud SQL Auth Proxy dialer.
80///
81/// Manages ephemeral certificates and establishes TLS connections
82/// to a Cloud SQL instance through the Cloud SQL Admin API.
83#[derive(Debug)]
84pub struct Dialer {
85    /// Cloud SQL Admin API client.
86    client: SqlConnectService,
87    /// Cloud SQL instance name.
88    instance: String,
89    /// GCP project ID.
90    project: String,
91    /// RSA private key for ephemeral certificate requests.
92    rsa_private_key: RsaPrivateKey,
93}
94
95impl Dialer {
96    /// Retrieve connect settings for the Cloud SQL instance.
97    async fn connect_settings(&self) -> Result<ConnectSettings, Error> {
98        Ok(self
99            .client
100            .get_connect_settings()
101            .set_project(&self.project)
102            .set_instance(&self.instance)
103            .send()
104            .await?)
105    }
106
107    /// Dial the Cloud SQL instance, returning an authenticated TLS stream.
108    ///
109    /// Fetches connect settings and an ephemeral certificate from the Cloud SQL
110    /// Admin API, then establishes a TLS 1.3 connection with mutual
111    /// authentication to the instance.
112    pub async fn dial(&self) -> Result<TlsStream<TcpStream>, Error> {
113        let (settings, cert) = tokio::try_join!(self.connect_settings(), self.ephemeral_cert(),)?;
114
115        let primary_ip = extract_primary_ip(&settings.ip_addresses)?;
116        let server_ca = tls::extract_server_ca_cert(&settings)?;
117        if cert.cert.is_empty() {
118            return Err(Error::EphemeralCertEmpty);
119        }
120
121        let client_cert = tls::parse_pem_cert(&cert.cert)?;
122        let private_key_der = self.private_key_der()?;
123
124        let tls_config = tls::build_config(server_ca, client_cert, private_key_der)?;
125        let connector = TlsConnector::from(Arc::new(tls_config));
126
127        let tcp_stream = TcpStream::connect((primary_ip, CLOUD_SQL_PORT)).await?;
128
129        // The server name is not used for hostname verification — the custom
130        // CloudSqlCertVerifier validates the CA chain only. A value is required
131        // by rustls for the TLS handshake SNI extension.
132        let server_name = ServerName::IpAddress(primary_ip.into());
133
134        Ok(connector.connect(server_name, tcp_stream).await?)
135    }
136
137    /// Request an ephemeral client certificate for the Cloud SQL instance.
138    async fn ephemeral_cert(&self) -> Result<SslCert, Error> {
139        let response = self
140            .client
141            .generate_ephemeral_cert()
142            .set_project(&self.project)
143            .set_instance(&self.instance)
144            .set_public_key(&self.public_key_pem()?)
145            .send()
146            .await?;
147
148        response.ephemeral_cert.ok_or(Error::EphemeralCertMissing)
149    }
150
151    /// Create a new Cloud SQL Auth Proxy dialer for a specific instance.
152    ///
153    /// Builds the API client and generates an RSA 2048-bit keypair.
154    pub async fn new(
155        project: impl Into<String>,
156        instance: impl Into<String>,
157    ) -> Result<Self, Error> {
158        let client = SqlConnectService::builder().build().await?;
159
160        let rsa_private_key = RsaPrivateKey::new(&mut rsa::rand_core::OsRng, RSA_KEY_BITS)?;
161
162        Ok(Self {
163            client,
164            instance: instance.into(),
165            project: project.into(),
166            rsa_private_key,
167        })
168    }
169
170    /// Encode the RSA private key as PKCS#8 DER for TLS.
171    fn private_key_der(&self) -> Result<PrivateKeyDer<'static>, Error> {
172        let der = self.rsa_private_key.to_pkcs8_der()?;
173        Ok(PrivateKeyDer::Pkcs8(
174            rustls::pki_types::PrivatePkcs8KeyDer::from(der.as_bytes().to_vec()),
175        ))
176    }
177
178    /// Encode the RSA public key as PEM for the Cloud SQL Admin API.
179    fn public_key_pem(&self) -> Result<String, Error> {
180        Ok(self
181            .rsa_private_key
182            .to_public_key()
183            .to_public_key_pem(rsa::pkcs8::LineEnding::LF)?)
184    }
185}
186
187/// Unix socket proxy server for a Cloud SQL instance.
188///
189/// Binds a Unix socket on construction, guaranteeing the socket is ready
190/// to accept connections once the struct is obtained.
191#[derive(Debug)]
192pub struct UnixSocketServer {
193    dialer: Arc<Dialer>,
194    listener: UnixListener,
195}
196
197impl UnixSocketServer {
198    /// Bind a Unix socket proxy for a Cloud SQL instance.
199    ///
200    /// The socket is bound immediately — if this returns `Ok`, the server
201    /// is ready to accept connections.
202    pub fn new(dialer: Arc<Dialer>, socket_path: &Path) -> Result<Self, Error> {
203        let listener = UnixListener::bind(socket_path)?;
204
205        log::info!("Cloud SQL proxy listening on {}", socket_path.display());
206
207        Ok(Self { dialer, listener })
208    }
209
210    /// Accept connections and proxy traffic to the Cloud SQL instance.
211    ///
212    /// Runs until the listener encounters an accept error.
213    pub async fn serve(&self) -> Result<(), Error> {
214        loop {
215            let (mut local_stream, _addr) = self.listener.accept().await?;
216
217            let dialer = Arc::clone(&self.dialer);
218
219            tokio::spawn(async move {
220                match dialer.dial().await {
221                    Ok(mut tls_stream) => {
222                        if let Err(error) =
223                            copy_bidirectional(&mut local_stream, &mut tls_stream).await
224                        {
225                            log::warn!("Cloud SQL proxy connection ended: {error}");
226                        }
227                    }
228                    Err(error) => {
229                        log::warn!("Cloud SQL proxy dial failed: {error}");
230                    }
231                }
232            });
233        }
234    }
235}
236
237/// TCP proxy server for a Cloud SQL instance.
238///
239/// Binds a TCP listener on construction, guaranteeing the socket is ready
240/// to accept connections once the struct is obtained.
241#[derive(Debug)]
242pub struct TcpServer {
243    dialer: Arc<Dialer>,
244    listener: TcpListener,
245}
246
247impl TcpServer {
248    /// Bind a TCP proxy for a Cloud SQL instance.
249    ///
250    /// The socket is bound immediately — if this returns `Ok`, the server
251    /// is ready to accept connections.
252    pub async fn new(dialer: Arc<Dialer>, address: SocketAddr) -> Result<Self, Error> {
253        let listener = TcpListener::bind(address).await?;
254
255        log::info!("Cloud SQL proxy listening on {address}");
256
257        Ok(Self { dialer, listener })
258    }
259
260    /// Bind a TCP proxy on `localhost` with an OS-assigned port.
261    ///
262    /// Use [`TcpServer::local_addr`] to discover the assigned port.
263    pub async fn new_localhost_v4(dialer: Arc<Dialer>) -> Result<Self, Error> {
264        Self::new(dialer, SocketAddr::from((Ipv4Addr::LOCALHOST, 0))).await
265    }
266
267    /// Bind a TCP proxy on IPv6 `localhost` with an OS-assigned port.
268    ///
269    /// Use [`TcpServer::local_addr`] to discover the assigned port.
270    pub async fn new_localhost_v6(dialer: Arc<Dialer>) -> Result<Self, Error> {
271        Self::new(dialer, SocketAddr::from((Ipv6Addr::LOCALHOST, 0))).await
272    }
273
274    /// Return the local address this server is bound to.
275    pub fn local_addr(&self) -> Result<SocketAddr, Error> {
276        Ok(self.listener.local_addr()?)
277    }
278
279    /// Accept connections and proxy traffic to the Cloud SQL instance.
280    ///
281    /// Runs until the listener encounters an accept error.
282    pub async fn serve(&self) -> Result<(), Error> {
283        loop {
284            let (mut local_stream, _addr) = self.listener.accept().await?;
285
286            let dialer = Arc::clone(&self.dialer);
287
288            tokio::spawn(async move {
289                match dialer.dial().await {
290                    Ok(mut tls_stream) => {
291                        if let Err(error) =
292                            copy_bidirectional(&mut local_stream, &mut tls_stream).await
293                        {
294                            log::warn!("Cloud SQL proxy connection ended: {error}");
295                        }
296                    }
297                    Err(error) => {
298                        log::warn!("Cloud SQL proxy dial failed: {error}");
299                    }
300                }
301            });
302        }
303    }
304}
305
306/// Extract the primary IP address from instance IP mappings.
307fn extract_primary_ip(ip_addresses: &[IpMapping]) -> Result<IpAddr, Error> {
308    for mapping in ip_addresses {
309        if mapping.r#type == SqlIpAddressType::Primary {
310            return mapping.ip_address.parse::<IpAddr>().map_err(|source| {
311                Error::InvalidIpAddress {
312                    address: mapping.ip_address.clone(),
313                    source,
314                }
315            });
316        }
317    }
318
319    Err(Error::NoPrimaryIp)
320}