cloud_sql_connector/
lib.rs1#![doc = include_str!("../README.md")]
2
3mod tls;
4
5use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
6use std::path::Path;
7use std::sync::Arc;
8
9use google_cloud_sql_v1::client::SqlConnectService;
10use google_cloud_sql_v1::model::{ConnectSettings, IpMapping, SqlIpAddressType, SslCert};
11use rsa::RsaPrivateKey;
12use rsa::pkcs8::EncodePrivateKey as _;
13use rsa::pkcs8::EncodePublicKey as _;
14use rustls::pki_types::{PrivateKeyDer, ServerName};
15use tokio::io::copy_bidirectional;
16use tokio::net::{TcpListener, TcpStream, UnixListener};
17use tokio_rustls::TlsConnector;
18use tokio_rustls::client::TlsStream;
19
20const CLOUD_SQL_PORT: u16 = 3307;
22
23const RSA_KEY_BITS: usize = 2048;
25
26#[derive(Debug, thiserror::Error)]
28pub enum Error {
29 #[error(transparent)]
31 ClientBuilder(#[from] google_cloud_gax::client_builder::Error),
32 #[error(transparent)]
34 CloudSqlApi(#[from] google_cloud_sql_v1::Error),
35 #[error("ephemeral certificate PEM is empty")]
37 EphemeralCertEmpty,
38 #[error("ephemeral certificate missing from generateEphemeralCert response")]
40 EphemeralCertMissing,
41 #[error("invalid IP address from Cloud SQL API: {address}")]
43 InvalidIpAddress {
44 address: String,
46 #[source]
48 source: core::net::AddrParseError,
49 },
50 #[error(transparent)]
52 Io(#[from] std::io::Error),
53 #[error("no certificates found in PEM data")]
55 NoCertificatesInPem,
56 #[error("no primary IP address found for Cloud SQL instance")]
58 NoPrimaryIp,
59 #[error("failed to encode RSA key: {0}")]
61 Pkcs8(#[from] rsa::pkcs8::Error),
62 #[error("failed to generate RSA key: {0}")]
64 RsaKeyGeneration(#[from] rsa::Error),
65 #[error("server CA certificate PEM is empty")]
67 ServerCaCertEmpty,
68 #[error("server CA certificate missing from connectSettings response")]
70 ServerCaCertMissing,
71 #[error("failed to encode RSA public key: {0}")]
73 Spki(#[from] rsa::pkcs8::spki::Error),
74 #[error("TLS configuration error: {0}")]
76 TlsConfig(#[from] rustls::Error),
77}
78
79#[derive(Debug)]
84pub struct Dialer {
85 client: SqlConnectService,
87 instance: String,
89 project: String,
91 rsa_private_key: RsaPrivateKey,
93}
94
95impl Dialer {
96 async fn connect_settings(&self) -> Result<ConnectSettings, Error> {
98 Ok(self
99 .client
100 .get_connect_settings()
101 .set_project(&self.project)
102 .set_instance(&self.instance)
103 .send()
104 .await?)
105 }
106
107 pub async fn dial(&self) -> Result<TlsStream<TcpStream>, Error> {
113 let (settings, cert) = tokio::try_join!(self.connect_settings(), self.ephemeral_cert(),)?;
114
115 let primary_ip = extract_primary_ip(&settings.ip_addresses)?;
116 let server_ca = tls::extract_server_ca_cert(&settings)?;
117 if cert.cert.is_empty() {
118 return Err(Error::EphemeralCertEmpty);
119 }
120
121 let client_cert = tls::parse_pem_cert(&cert.cert)?;
122 let private_key_der = self.private_key_der()?;
123
124 let tls_config = tls::build_config(server_ca, client_cert, private_key_der)?;
125 let connector = TlsConnector::from(Arc::new(tls_config));
126
127 let tcp_stream = TcpStream::connect((primary_ip, CLOUD_SQL_PORT)).await?;
128
129 let server_name = ServerName::IpAddress(primary_ip.into());
133
134 Ok(connector.connect(server_name, tcp_stream).await?)
135 }
136
137 async fn ephemeral_cert(&self) -> Result<SslCert, Error> {
139 let response = self
140 .client
141 .generate_ephemeral_cert()
142 .set_project(&self.project)
143 .set_instance(&self.instance)
144 .set_public_key(&self.public_key_pem()?)
145 .send()
146 .await?;
147
148 response.ephemeral_cert.ok_or(Error::EphemeralCertMissing)
149 }
150
151 pub async fn new(
155 project: impl Into<String>,
156 instance: impl Into<String>,
157 ) -> Result<Self, Error> {
158 let client = SqlConnectService::builder().build().await?;
159
160 let rsa_private_key = RsaPrivateKey::new(&mut rsa::rand_core::OsRng, RSA_KEY_BITS)?;
161
162 Ok(Self {
163 client,
164 instance: instance.into(),
165 project: project.into(),
166 rsa_private_key,
167 })
168 }
169
170 fn private_key_der(&self) -> Result<PrivateKeyDer<'static>, Error> {
172 let der = self.rsa_private_key.to_pkcs8_der()?;
173 Ok(PrivateKeyDer::Pkcs8(
174 rustls::pki_types::PrivatePkcs8KeyDer::from(der.as_bytes().to_vec()),
175 ))
176 }
177
178 fn public_key_pem(&self) -> Result<String, Error> {
180 Ok(self
181 .rsa_private_key
182 .to_public_key()
183 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)?)
184 }
185}
186
187#[derive(Debug)]
192pub struct UnixSocketServer {
193 dialer: Arc<Dialer>,
194 listener: UnixListener,
195}
196
197impl UnixSocketServer {
198 pub fn new(dialer: Arc<Dialer>, socket_path: &Path) -> Result<Self, Error> {
203 let listener = UnixListener::bind(socket_path)?;
204
205 log::info!("Cloud SQL proxy listening on {}", socket_path.display());
206
207 Ok(Self { dialer, listener })
208 }
209
210 pub async fn serve(&self) -> Result<(), Error> {
214 loop {
215 let (mut local_stream, _addr) = self.listener.accept().await?;
216
217 let dialer = Arc::clone(&self.dialer);
218
219 tokio::spawn(async move {
220 match dialer.dial().await {
221 Ok(mut tls_stream) => {
222 if let Err(error) =
223 copy_bidirectional(&mut local_stream, &mut tls_stream).await
224 {
225 log::warn!("Cloud SQL proxy connection ended: {error}");
226 }
227 }
228 Err(error) => {
229 log::warn!("Cloud SQL proxy dial failed: {error}");
230 }
231 }
232 });
233 }
234 }
235}
236
237#[derive(Debug)]
242pub struct TcpServer {
243 dialer: Arc<Dialer>,
244 listener: TcpListener,
245}
246
247impl TcpServer {
248 pub async fn new(dialer: Arc<Dialer>, address: SocketAddr) -> Result<Self, Error> {
253 let listener = TcpListener::bind(address).await?;
254
255 log::info!("Cloud SQL proxy listening on {address}");
256
257 Ok(Self { dialer, listener })
258 }
259
260 pub async fn new_localhost_v4(dialer: Arc<Dialer>) -> Result<Self, Error> {
264 Self::new(dialer, SocketAddr::from((Ipv4Addr::LOCALHOST, 0))).await
265 }
266
267 pub async fn new_localhost_v6(dialer: Arc<Dialer>) -> Result<Self, Error> {
271 Self::new(dialer, SocketAddr::from((Ipv6Addr::LOCALHOST, 0))).await
272 }
273
274 pub fn local_addr(&self) -> Result<SocketAddr, Error> {
276 Ok(self.listener.local_addr()?)
277 }
278
279 pub async fn serve(&self) -> Result<(), Error> {
283 loop {
284 let (mut local_stream, _addr) = self.listener.accept().await?;
285
286 let dialer = Arc::clone(&self.dialer);
287
288 tokio::spawn(async move {
289 match dialer.dial().await {
290 Ok(mut tls_stream) => {
291 if let Err(error) =
292 copy_bidirectional(&mut local_stream, &mut tls_stream).await
293 {
294 log::warn!("Cloud SQL proxy connection ended: {error}");
295 }
296 }
297 Err(error) => {
298 log::warn!("Cloud SQL proxy dial failed: {error}");
299 }
300 }
301 });
302 }
303 }
304}
305
306fn extract_primary_ip(ip_addresses: &[IpMapping]) -> Result<IpAddr, Error> {
308 for mapping in ip_addresses {
309 if mapping.r#type == SqlIpAddressType::Primary {
310 return mapping.ip_address.parse::<IpAddr>().map_err(|source| {
311 Error::InvalidIpAddress {
312 address: mapping.ip_address.clone(),
313 source,
314 }
315 });
316 }
317 }
318
319 Err(Error::NoPrimaryIp)
320}