1#![doc = include_str!("../README.md")]
2
3mod peer_filter;
4mod tls;
5
6pub use peer_filter::PeerFilter;
7
8use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
9use std::path::Path;
10use std::sync::Arc;
11
12use google_cloud_sql_v1::client::SqlConnectService;
13use google_cloud_sql_v1::model::{ConnectSettings, IpMapping, SqlIpAddressType, SslCert};
14use rsa::RsaPrivateKey;
15use rsa::pkcs8::EncodePrivateKey as _;
16use rsa::pkcs8::EncodePublicKey as _;
17use rustls::pki_types::{PrivateKeyDer, ServerName};
18use tokio::io::copy_bidirectional;
19use tokio::net::{TcpListener, TcpStream, UnixListener};
20use tokio_rustls::TlsConnector;
21use tokio_rustls::client::TlsStream;
22
23const CLOUD_SQL_PORT: u16 = 3307;
25
26const RSA_KEY_BITS: usize = 2048;
28
29#[derive(Debug, thiserror::Error)]
31pub enum Error {
32 #[error(transparent)]
34 ClientBuilder(#[from] google_cloud_gax::client_builder::Error),
35 #[error(transparent)]
37 CloudSqlApi(#[from] google_cloud_sql_v1::Error),
38 #[error("ephemeral certificate PEM is empty")]
40 EphemeralCertEmpty,
41 #[error("ephemeral certificate missing from generateEphemeralCert response")]
43 EphemeralCertMissing,
44 #[error("invalid IP address from Cloud SQL API: {address}")]
46 InvalidIpAddress {
47 address: String,
49 #[source]
51 source: core::net::AddrParseError,
52 },
53 #[error(transparent)]
55 Io(#[from] std::io::Error),
56 #[error("no certificates found in PEM data")]
58 NoCertificatesInPem,
59 #[error("no primary IP address found for Cloud SQL instance")]
61 NoPrimaryIp,
62 #[error("failed to encode RSA key: {0}")]
64 Pkcs8(#[from] rsa::pkcs8::Error),
65 #[error("failed to generate RSA key: {0}")]
67 RsaKeyGeneration(#[from] rsa::Error),
68 #[error("server CA certificate PEM is empty")]
70 ServerCaCertEmpty,
71 #[error("server CA certificate missing from connectSettings response")]
73 ServerCaCertMissing,
74 #[error("failed to encode RSA public key: {0}")]
76 Spki(#[from] rsa::pkcs8::spki::Error),
77 #[error("TLS configuration error: {0}")]
79 TlsConfig(#[from] rustls::Error),
80}
81
82#[derive(Debug)]
87pub struct Dialer {
88 client: SqlConnectService,
90 instance: String,
92 project: String,
94 rsa_private_key: RsaPrivateKey,
96}
97
98impl Dialer {
99 async fn connect_settings(&self) -> Result<ConnectSettings, Error> {
101 Ok(self
102 .client
103 .get_connect_settings()
104 .set_project(&self.project)
105 .set_instance(&self.instance)
106 .send()
107 .await?)
108 }
109
110 pub async fn dial(&self) -> Result<TlsStream<TcpStream>, Error> {
116 let (settings, cert) = tokio::try_join!(self.connect_settings(), self.ephemeral_cert(),)?;
117
118 let primary_ip = extract_primary_ip(&settings.ip_addresses)?;
119 let server_ca = tls::extract_server_ca_cert(&settings)?;
120 if cert.cert.is_empty() {
121 return Err(Error::EphemeralCertEmpty);
122 }
123
124 let client_cert = tls::parse_pem_cert(&cert.cert)?;
125 let private_key_der = self.private_key_der()?;
126
127 let tls_config = tls::build_config(server_ca, client_cert, private_key_der)?;
128 let connector = TlsConnector::from(Arc::new(tls_config));
129
130 let tcp_stream = TcpStream::connect((primary_ip, CLOUD_SQL_PORT)).await?;
131
132 let server_name = ServerName::IpAddress(primary_ip.into());
136
137 Ok(connector.connect(server_name, tcp_stream).await?)
138 }
139
140 async fn ephemeral_cert(&self) -> Result<SslCert, Error> {
142 let response = self
143 .client
144 .generate_ephemeral_cert()
145 .set_project(&self.project)
146 .set_instance(&self.instance)
147 .set_public_key(&self.public_key_pem()?)
148 .send()
149 .await?;
150
151 response.ephemeral_cert.ok_or(Error::EphemeralCertMissing)
152 }
153
154 pub async fn new(
158 project: impl Into<String>,
159 instance: impl Into<String>,
160 ) -> Result<Self, Error> {
161 let client = SqlConnectService::builder().build().await?;
162
163 let rsa_private_key = RsaPrivateKey::new(&mut rsa::rand_core::OsRng, RSA_KEY_BITS)?;
164
165 Ok(Self {
166 client,
167 instance: instance.into(),
168 project: project.into(),
169 rsa_private_key,
170 })
171 }
172
173 fn private_key_der(&self) -> Result<PrivateKeyDer<'static>, Error> {
175 let der = self.rsa_private_key.to_pkcs8_der()?;
176 Ok(PrivateKeyDer::Pkcs8(
177 rustls::pki_types::PrivatePkcs8KeyDer::from(der.as_bytes().to_vec()),
178 ))
179 }
180
181 fn public_key_pem(&self) -> Result<String, Error> {
183 Ok(self
184 .rsa_private_key
185 .to_public_key()
186 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)?)
187 }
188}
189
190#[derive(Debug)]
195pub struct UnixSocketServer {
196 dialer: Arc<Dialer>,
197 listener: UnixListener,
198}
199
200impl UnixSocketServer {
201 pub fn new(dialer: Arc<Dialer>, socket_path: &Path) -> Result<Self, Error> {
206 let listener = UnixListener::bind(socket_path)?;
207
208 log::info!("Cloud SQL proxy listening on {}", socket_path.display());
209
210 Ok(Self { dialer, listener })
211 }
212
213 pub async fn serve(&self) -> Result<(), Error> {
217 loop {
218 let (mut local_stream, _addr) = self.listener.accept().await?;
219
220 let dialer = Arc::clone(&self.dialer);
221
222 tokio::spawn(async move {
223 match dialer.dial().await {
224 Ok(mut tls_stream) => {
225 if let Err(error) =
226 copy_bidirectional(&mut local_stream, &mut tls_stream).await
227 {
228 log::warn!("Cloud SQL proxy connection ended: {error}");
229 }
230 }
231 Err(error) => {
232 log::warn!("Cloud SQL proxy dial failed: {error}");
233 }
234 }
235 });
236 }
237 }
238}
239
240#[derive(Debug)]
245pub struct TcpServer {
246 dialer: Arc<Dialer>,
247 listener: TcpListener,
248 local_addr: SocketAddr,
249 peer_filter: PeerFilter,
250}
251
252impl TcpServer {
253 pub async fn new(
258 dialer: Arc<Dialer>,
259 address: SocketAddr,
260 peer_filter: PeerFilter,
261 ) -> Result<Self, Error> {
262 let listener = TcpListener::bind(address).await?;
263 let local_addr = listener.local_addr()?;
264
265 log::info!("Cloud SQL proxy listening on {local_addr}");
266
267 Ok(Self {
268 dialer,
269 listener,
270 local_addr,
271 peer_filter,
272 })
273 }
274
275 pub async fn new_localhost_v4(dialer: Arc<Dialer>) -> Result<Self, Error> {
283 Self::new(
284 dialer,
285 SocketAddr::from((Ipv4Addr::LOCALHOST, 0)),
286 PeerFilter::All,
287 )
288 .await
289 }
290
291 pub async fn new_localhost_v6(dialer: Arc<Dialer>) -> Result<Self, Error> {
299 Self::new(
300 dialer,
301 SocketAddr::from((Ipv6Addr::LOCALHOST, 0)),
302 PeerFilter::All,
303 )
304 .await
305 }
306
307 #[must_use]
309 pub fn local_addr(&self) -> SocketAddr {
310 self.local_addr
311 }
312
313 pub async fn serve(&self) -> Result<(), Error> {
317 loop {
318 let (mut local_stream, peer_addr) = self.listener.accept().await?;
319
320 if !self.peer_filter.is_allowed(peer_addr) {
321 log::warn!("Cloud SQL proxy rejected connection from {peer_addr}");
322 continue;
323 }
324
325 let dialer = Arc::clone(&self.dialer);
326
327 tokio::spawn(async move {
328 match dialer.dial().await {
329 Ok(mut tls_stream) => {
330 if let Err(error) =
331 copy_bidirectional(&mut local_stream, &mut tls_stream).await
332 {
333 log::warn!("Cloud SQL proxy connection ended: {error}");
334 }
335 }
336 Err(error) => {
337 log::warn!("Cloud SQL proxy dial failed: {error}");
338 }
339 }
340 });
341 }
342 }
343}
344
345fn extract_primary_ip(ip_addresses: &[IpMapping]) -> Result<IpAddr, Error> {
347 for mapping in ip_addresses {
348 if mapping.r#type == SqlIpAddressType::Primary {
349 return mapping.ip_address.parse::<IpAddr>().map_err(|source| {
350 Error::InvalidIpAddress {
351 address: mapping.ip_address.clone(),
352 source,
353 }
354 });
355 }
356 }
357
358 Err(Error::NoPrimaryIp)
359}