Skip to main content

ferrule_sql/
backend.rs

1#![allow(unused_imports, unused_variables)]
2
3use crate::backends;
4use crate::connection::{AsyncConnection, ConnectOptions, Connection};
5use crate::error::SqlError;
6use crate::sync::SyncConnection;
7use crate::url::DatabaseUrl;
8
9/// Supported database backends.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum Backend {
12    #[cfg(feature = "postgres")]
13    Postgres,
14    #[cfg(feature = "mysql")]
15    MySql,
16    #[cfg(feature = "mssql")]
17    MsSql,
18    #[cfg(feature = "sqlite")]
19    Sqlite,
20    #[cfg(feature = "oracle")]
21    Oracle,
22}
23
24impl Backend {
25    /// Resolve a backend from a URL scheme.
26    pub fn from_scheme(scheme: &str) -> Option<Self> {
27        match scheme {
28            #[cfg(feature = "postgres")]
29            "postgres" | "postgresql" => Some(Self::Postgres),
30            #[cfg(feature = "mysql")]
31            "mysql" | "mariadb" => Some(Self::MySql),
32            #[cfg(feature = "mssql")]
33            "mssql" | "sqlserver" | "tds" => Some(Self::MsSql),
34            #[cfg(feature = "sqlite")]
35            "sqlite" => Some(Self::Sqlite),
36            #[cfg(feature = "oracle")]
37            "oracle" => Some(Self::Oracle),
38            _ => None,
39        }
40    }
41
42    /// Human-readable name.
43    pub fn name(&self) -> &'static str {
44        match *self {
45            #[cfg(feature = "postgres")]
46            Self::Postgres => "PostgreSQL",
47            #[cfg(feature = "mysql")]
48            Self::MySql => "MySQL",
49            #[cfg(feature = "mssql")]
50            Self::MsSql => "Microsoft SQL Server",
51            #[cfg(feature = "sqlite")]
52            Self::Sqlite => "SQLite",
53            #[cfg(feature = "oracle")]
54            Self::Oracle => "Oracle",
55        }
56    }
57}
58
59/// Establish a direct (non-proxied) connection to the given URL.
60///
61/// Internal async helper. Returns the inner [`AsyncConnection`]; the
62/// public [`connect`] wraps it in a [`SyncConnection`] that owns the
63/// driving runtime.
64async fn connect_direct(
65    url: &DatabaseUrl,
66    opts: &ConnectOptions,
67) -> Result<Box<dyn AsyncConnection>, SqlError> {
68    let backend = Backend::from_scheme(url.scheme())
69        .ok_or_else(|| SqlError::UnsupportedScheme(url.scheme().to_string()))?;
70
71    match backend {
72        #[cfg(feature = "postgres")]
73        Backend::Postgres => {
74            let conn = backends::postgres::connect(url, opts).await?;
75            Ok(Box::new(conn))
76        }
77        #[cfg(feature = "mysql")]
78        Backend::MySql => {
79            let conn = backends::mysql::connect(url, opts).await?;
80            Ok(Box::new(conn))
81        }
82        #[cfg(feature = "mssql")]
83        Backend::MsSql => {
84            let conn = backends::mssql::connect(url, opts).await?;
85            Ok(Box::new(conn))
86        }
87        #[cfg(feature = "sqlite")]
88        Backend::Sqlite => {
89            let conn = backends::sqlite::connect(url, opts).await?;
90            Ok(Box::new(conn))
91        }
92        #[cfg(feature = "oracle")]
93        Backend::Oracle => {
94            let conn = backends::oracle::connect(url, opts).await?;
95            Ok(Box::new(conn))
96        }
97    }
98}
99
100/// Build the private current-thread runtime that drives one
101/// connection handle's driver futures for its whole lifetime.
102///
103/// Current-thread (not multi-thread) by design: a single embedded
104/// connection needs exactly one driver task plus the in-flight
105/// statement future, both polled cooperatively on the calling thread
106/// during `block_on`. `enable_all` turns on the I/O and time drivers
107/// the network backends need.
108fn build_runtime() -> Result<tokio::runtime::Runtime, SqlError> {
109    tokio::runtime::Builder::new_current_thread()
110        .enable_all()
111        .build()
112        .map_err(|e| SqlError::ConnectionFailed(format!("failed to build connection runtime: {e}")))
113}
114
115/// Establish a blocking connection to the given URL.
116///
117/// When `proxy` is `Some`, the connection is routed through the proxy
118/// via HTTP CONNECT. For Postgres this means `connect_raw` with a
119/// pre-built stream; for MySQL/MSSQL/Oracle a local TCP listener is
120/// bound and bytes are pumped through the proxy; for SQLite the proxy is
121/// ignored (local file, no network).
122///
123/// **Blocking:** this call blocks until the connection is established.
124/// The returned [`Connection`] owns a private current-thread runtime
125/// that drives every later blocking call; it never exposes a `Future`.
126/// Do not call from inside another `block_on` on the same thread.
127#[must_use = "a connection handle must be used or the connection is dropped immediately"]
128pub fn connect(
129    url: &DatabaseUrl,
130    opts: &ConnectOptions,
131    proxy: Option<&crate::proxy::ProxyConfig>,
132) -> Result<Box<dyn Connection>, SqlError> {
133    let rt = build_runtime()?;
134    let inner = rt.block_on(connect_inner(url, opts, proxy))?;
135    Ok(Box::new(SyncConnection::new(rt, inner)))
136}
137
138/// Async core of [`connect`]: resolve the backend and establish the
139/// inner [`AsyncConnection`], honoring an optional HTTP CONNECT proxy.
140/// Must be driven on the same runtime that [`connect`] later moves into
141/// the returned [`SyncConnection`].
142async fn connect_inner(
143    url: &DatabaseUrl,
144    opts: &ConnectOptions,
145    proxy: Option<&crate::proxy::ProxyConfig>,
146) -> Result<Box<dyn AsyncConnection>, SqlError> {
147    let backend = Backend::from_scheme(url.scheme())
148        .ok_or_else(|| SqlError::UnsupportedScheme(url.scheme().to_string()))?;
149
150    match backend {
151        #[cfg(feature = "postgres")]
152        Backend::Postgres => {
153            if let Some(proxy) = proxy {
154                let target_host = url.host().ok_or_else(|| {
155                    SqlError::ConnectionFailed(
156                        "URL has no host — proxy requires a network target".to_string(),
157                    )
158                })?;
159                let target_port = url.port().unwrap_or(5432);
160                let stream = crate::proxy::http_connect(proxy, target_host, target_port).await?;
161                let conn = backends::postgres::connect_with_stream(url, opts, stream).await?;
162                Ok(Box::new(conn))
163            } else {
164                connect_direct(url, opts).await
165            }
166        }
167        #[cfg(feature = "mysql")]
168        Backend::MySql => {
169            if let Some(proxy) = proxy {
170                connect_via_proxy_listener(url, opts, proxy, backend).await
171            } else {
172                connect_direct(url, opts).await
173            }
174        }
175        #[cfg(feature = "mssql")]
176        Backend::MsSql => {
177            if let Some(proxy) = proxy {
178                connect_via_proxy_listener(url, opts, proxy, backend).await
179            } else {
180                connect_direct(url, opts).await
181            }
182        }
183        #[cfg(feature = "sqlite")]
184        Backend::Sqlite => {
185            // SQLite is a local-file backend; proxy is irrelevant.
186            connect_direct(url, opts).await
187        }
188        #[cfg(feature = "oracle")]
189        Backend::Oracle => {
190            if let Some(proxy) = proxy {
191                connect_via_proxy_listener(url, opts, proxy, backend).await
192            } else {
193                connect_direct(url, opts).await
194            }
195        }
196    }
197}
198
199/// Bind a local TCP listener, pump each accepted connection through
200/// a fresh HTTP CONNECT tunnel to the original host:port.
201#[cfg(any(feature = "mysql", feature = "mssql", feature = "oracle"))]
202async fn connect_via_proxy_listener(
203    url: &DatabaseUrl,
204    opts: &ConnectOptions,
205    proxy: &crate::proxy::ProxyConfig,
206    backend: Backend,
207) -> Result<Box<dyn AsyncConnection>, SqlError> {
208    let target_host = url
209        .host()
210        .ok_or_else(|| {
211            SqlError::ConnectionFailed(
212                "URL has no host — proxy requires a network target".to_string(),
213            )
214        })?
215        .to_string();
216    let target_port = url.port().unwrap_or_else(|| default_port_for(backend));
217
218    let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
219        .await
220        .map_err(|e| SqlError::ConnectionFailed(format!("proxy listener bind: {e}")))?;
221    let port = listener.local_addr()?.port();
222
223    let proxy = proxy.clone();
224    let forwarder = tokio::spawn(async move {
225        loop {
226            let (mut tcp, _addr) = match listener.accept().await {
227                Ok(pair) => pair,
228                Err(e) => {
229                    eprintln!("[ferrule] proxy listener accept failed: {e}");
230                    return;
231                }
232            };
233            let target_host = target_host.clone();
234            let proxy = proxy.clone();
235            tokio::spawn(async move {
236                let mut proxy_stream =
237                    match crate::proxy::http_connect(&proxy, &target_host, target_port).await {
238                        Ok(s) => s,
239                        Err(e) => {
240                            eprintln!("[ferrule] proxy connect failed: {e}");
241                            return;
242                        }
243                    };
244                if let Err(e) = tokio::io::copy_bidirectional(&mut tcp, &mut proxy_stream).await {
245                    // Normal close is expected; don't spam stderr.
246                    let _ = e;
247                }
248            });
249        }
250    });
251
252    let local_url = rewrite_url_to_local(url, port)?;
253    let inner = connect_direct(&local_url, opts).await?;
254    Ok(Box::new(crate::proxy::ProxiedConnection {
255        inner,
256        forwarder: Some(forwarder),
257    }))
258}
259
260/// Establish a connection to `url` through an SSH tunnel.
261///
262/// An optional `proxy` routes the SSH session itself through an
263/// HTTP CONNECT proxy (corporate firewall scenario: proxy → bastion
264/// → SSH → direct-tcpip → DB).
265///
266/// Picks the transport based on the backend:
267///
268/// - **Postgres** — `Stream` transport. The russh channel is fed
269///   directly into `tokio_postgres::Config::connect_raw`, so there
270///   is no extra TCP hop. TLS (if requested via `?sslmode=...`) is
271///   negotiated end-to-end inside the SSH stream.
272/// - **MySQL, MSSQL, Oracle** — `LocalListener` transport. A
273///   `127.0.0.1:<random>` listener is bound; bytes are pumped through
274///   the SSH channel by a forwarder task. The driver opens a regular
275///   TCP connection to that local port.
276/// - **SQLite** — rejected. SQLite is a local-file backend, so SSH
277///   tunneling is not applicable.
278///
279/// The returned `Box<dyn Connection>` wraps the inner backend
280/// connection in a [`crate::tunnel::TunneledConnection`] that owns
281/// the SSH session (and, for the LocalListener path, the forwarder
282/// task) for the connection's lifetime.
283#[cfg(feature = "ssh")]
284async fn connect_with_tunnel_inner(
285    url: &DatabaseUrl,
286    opts: &ConnectOptions,
287    ssh_config: &crate::tunnel::SshConfig,
288    key_source: &crate::tunnel::KeySource,
289    proxy: Option<&crate::proxy::ProxyConfig>,
290) -> Result<Box<dyn AsyncConnection>, SqlError> {
291    use crate::tunnel::{
292        TunnelError, TunnelTransport, TunnelTransportResult, TunneledConnection, setup_tunnel,
293    };
294
295    let backend = Backend::from_scheme(url.scheme())
296        .ok_or_else(|| SqlError::UnsupportedScheme(url.scheme().to_string()))?;
297
298    #[cfg(feature = "sqlite")]
299    if matches!(backend, Backend::Sqlite) {
300        return Err(SqlError::ConnectionFailed(
301            "SSH tunneling is not applicable to SQLite (local-file backend)".to_string(),
302        ));
303    }
304
305    let target_host = url
306        .host()
307        .ok_or_else(|| {
308            SqlError::ConnectionFailed(
309                "URL has no host — SSH tunneling requires a network-based backend".to_string(),
310            )
311        })?
312        .to_string();
313    let target_port = url.port().unwrap_or_else(|| default_port_for(backend));
314
315    let transport = match backend {
316        #[cfg(feature = "postgres")]
317        Backend::Postgres => TunnelTransport::Stream,
318        _ => TunnelTransport::LocalListener,
319    };
320
321    let tunnel = setup_tunnel(
322        ssh_config,
323        key_source,
324        &target_host,
325        target_port,
326        transport,
327        proxy,
328    )
329    .await
330    .map_err(|e| match e {
331        TunnelError::HostKeyMismatch { host, port, .. } => {
332            SqlError::SshHostKeyMismatch { host, port }
333        }
334        TunnelError::UnknownHost {
335            host,
336            port,
337            algorithm,
338            fingerprint,
339            key,
340            ..
341        } => SqlError::SshUnknownHost {
342            host,
343            port,
344            algorithm,
345            fingerprint,
346            key,
347        },
348        other => SqlError::ConnectionFailed(format!("SSH tunnel setup: {other}")),
349    })?;
350
351    let session = tunnel.session;
352    match tunnel.transport {
353        TunnelTransportResult::Stream { stream } => {
354            #[cfg(feature = "postgres")]
355            if matches!(backend, Backend::Postgres) {
356                let pg = backends::postgres::connect_with_stream(url, opts, *stream).await?;
357                return Ok(Box::new(TunneledConnection {
358                    inner: Box::new(pg),
359                    session,
360                    forwarder: None,
361                }));
362            }
363            Err(SqlError::ConnectionFailed(
364                "Stream transport selected but no backend handler is registered \
365                 (this is a ferrule bug — please report)"
366                    .to_string(),
367            ))
368        }
369        TunnelTransportResult::LocalPort { port, forwarder } => {
370            let local_url = rewrite_url_to_local(url, port)?;
371            let inner = connect_direct(&local_url, opts).await?;
372            Ok(Box::new(TunneledConnection {
373                inner,
374                session,
375                forwarder: Some(forwarder),
376            }))
377        }
378    }
379}
380
381#[cfg(any(
382    feature = "ssh",
383    feature = "mysql",
384    feature = "mssql",
385    feature = "oracle"
386))]
387fn default_port_for(backend: Backend) -> u16 {
388    match backend {
389        #[cfg(feature = "postgres")]
390        Backend::Postgres => 5432,
391        #[cfg(feature = "mysql")]
392        Backend::MySql => 3306,
393        #[cfg(feature = "mssql")]
394        Backend::MsSql => 1433,
395        #[cfg(feature = "sqlite")]
396        Backend::Sqlite => 0, // unreachable in contexts that call this
397        #[cfg(feature = "oracle")]
398        Backend::Oracle => 1521,
399    }
400}
401
402#[cfg(any(
403    feature = "ssh",
404    feature = "mysql",
405    feature = "mssql",
406    feature = "oracle"
407))]
408fn rewrite_url_to_local(url: &DatabaseUrl, port: u16) -> Result<DatabaseUrl, SqlError> {
409    let mut parsed = ::url::Url::parse(url.raw())
410        .map_err(|e| SqlError::InvalidUrl(format!("re-parse for tunnel rewrite: {e}")))?;
411    parsed
412        .set_host(Some("127.0.0.1"))
413        .map_err(|e| SqlError::InvalidUrl(format!("set_host(127.0.0.1): {e}")))?;
414    parsed
415        .set_port(Some(port))
416        .map_err(|()| SqlError::InvalidUrl("set_port failed".to_string()))?;
417    DatabaseUrl::parse(parsed.as_str())
418}
419
420/// Establish a blocking connection to `url` through an SSH tunnel.
421///
422/// An optional `proxy` routes the SSH session itself through an HTTP
423/// CONNECT proxy. See [`connect`] for the runtime / blocking contract:
424/// the returned handle owns the private runtime that also hosts the SSH
425/// session and (for the LocalListener transport) the byte-forwarder
426/// task, so all of them are driven on every later blocking call and torn
427/// down together when the handle drops.
428#[cfg(feature = "ssh")]
429#[must_use = "a connection handle must be used or the connection is dropped immediately"]
430pub fn connect_with_tunnel(
431    url: &DatabaseUrl,
432    opts: &ConnectOptions,
433    ssh_config: &crate::tunnel::SshConfig,
434    key_source: &crate::tunnel::KeySource,
435    proxy: Option<&crate::proxy::ProxyConfig>,
436) -> Result<Box<dyn Connection>, SqlError> {
437    let rt = build_runtime()?;
438    let inner = rt.block_on(connect_with_tunnel_inner(
439        url, opts, ssh_config, key_source, proxy,
440    ))?;
441    Ok(Box::new(SyncConnection::new(rt, inner)))
442}