1use std::fmt;
10#[cfg(unix)]
11use std::path::PathBuf;
12
13use super::error::{Error, Result};
14
15#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum ConnectionEndpoint {
23 Tcp {
25 host: String,
27 port: u16,
29 },
30
31 #[cfg(unix)]
33 DomainSocket {
34 directory: PathBuf,
36 name: String,
38 },
39 #[cfg(windows)]
46 NamedPipe {
47 host: String,
49 name: String,
51 },
52}
53
54impl ConnectionEndpoint {
55 pub fn tcp(host: impl Into<String>, port: u16) -> Self {
57 ConnectionEndpoint::Tcp {
58 host: host.into(),
59 port,
60 }
61 }
62
63 #[cfg(unix)]
65 pub fn domain_socket(directory: impl Into<PathBuf>, name: impl Into<String>) -> Self {
66 ConnectionEndpoint::DomainSocket {
67 directory: directory.into(),
68 name: name.into(),
69 }
70 }
71
72 #[cfg(windows)]
78 pub fn named_pipe(host: impl Into<String>, name: impl Into<String>) -> Self {
79 ConnectionEndpoint::NamedPipe {
80 host: host.into(),
81 name: name.into(),
82 }
83 }
84
85 pub fn parse(descriptor: &str) -> Result<Self> {
104 #[allow(unused_variables, reason = "`rest` is unused on non-unix platforms")]
106 if let Some(rest) = descriptor.strip_prefix("tab.domain://") {
107 #[cfg(unix)]
108 {
109 let idx = rest.find("/domain/").ok_or_else(|| {
110 Error::connection(format!(
111 "Invalid domain socket format: '{descriptor}'. Expected 'tab.domain://<dir>/domain/<name>'"
112 ))
113 })?;
114 let directory = &rest[..idx];
115 let name = &rest[idx + 8..]; if name.is_empty() {
118 return Err(Error::connection("Domain socket name cannot be empty"));
119 }
120
121 return Ok(ConnectionEndpoint::DomainSocket {
122 directory: PathBuf::from(directory),
123 name: name.to_string(),
124 });
125 }
126 #[cfg(not(unix))]
127 {
128 return Err(Error::connection(
129 "Unix domain sockets are not supported on this platform",
130 ));
131 }
132 }
133
134 #[allow(unused_variables, reason = "`rest` is unused on non-windows platforms")]
136 if let Some(rest) = descriptor.strip_prefix("tab.pipe://") {
137 #[cfg(windows)]
138 {
139 let idx = rest.find("/pipe/").ok_or_else(|| {
142 Error::connection(format!(
143 "Invalid named pipe format: '{descriptor}'. Expected 'tab.pipe://<host>/pipe/<name>'"
144 ))
145 })?;
146 let host = &rest[..idx];
147 let name = &rest[idx + 6..]; if name.is_empty() {
150 return Err(Error::connection("Named pipe name cannot be empty"));
151 }
152
153 return Ok(ConnectionEndpoint::NamedPipe {
154 host: host.to_string(),
155 name: name.to_string(),
156 });
157 }
158 #[cfg(not(windows))]
159 {
160 return Err(Error::connection(
161 "Named pipes are not supported on this platform",
162 ));
163 }
164 }
165
166 let tcp_part = descriptor
168 .strip_prefix("tab.tcp://")
169 .or_else(|| descriptor.strip_prefix("tcp.libpq://"))
170 .unwrap_or(descriptor);
171
172 Self::parse_tcp(tcp_part)
173 }
174
175 fn parse_tcp(s: &str) -> Result<Self> {
177 if s.starts_with('[') {
179 let end_bracket = s
180 .find(']')
181 .ok_or_else(|| Error::connection(format!("Invalid IPv6 address format: '{s}'")))?;
182 let host = &s[1..end_bracket];
183 let port_str = s[end_bracket + 1..]
184 .strip_prefix(':')
185 .ok_or_else(|| Error::connection(format!("Missing port in: '{s}'")))?;
186
187 let port = Self::parse_port(port_str)?;
188 return Ok(ConnectionEndpoint::Tcp {
189 host: host.to_string(),
190 port,
191 });
192 }
193
194 let colon_idx = s.rfind(':').ok_or_else(|| {
196 Error::connection(format!(
197 "Invalid endpoint format: '{s}'. Expected 'host:port'"
198 ))
199 })?;
200
201 let host = &s[..colon_idx];
202 let port_str = &s[colon_idx + 1..];
203
204 if host.is_empty() {
205 return Err(Error::connection("Host cannot be empty"));
206 }
207
208 let port = Self::parse_port(port_str)?;
209
210 Ok(ConnectionEndpoint::Tcp {
211 host: host.to_string(),
212 port,
213 })
214 }
215
216 fn parse_port(s: &str) -> Result<u16> {
218 if s == "auto" {
219 return Ok(0);
220 }
221 s.parse::<u16>()
222 .map_err(|_| Error::connection(format!("Invalid port number: '{s}'")))
223 }
224
225 #[must_use]
229 pub fn to_descriptor(&self) -> String {
230 match self {
231 ConnectionEndpoint::Tcp { host, port } => {
232 let port_str = if *port == 0 {
233 "auto".to_string()
234 } else {
235 port.to_string()
236 };
237 if host.contains(':') {
239 format!("tab.tcp://[{host}]:{port_str}")
240 } else {
241 format!("tab.tcp://{host}:{port_str}")
242 }
243 }
244 #[cfg(unix)]
245 ConnectionEndpoint::DomainSocket { directory, name } => {
246 format!("tab.domain://{}/domain/{}", directory.display(), name)
247 }
248 #[cfg(windows)]
249 ConnectionEndpoint::NamedPipe { host, name } => {
250 format!("tab.pipe://{host}/pipe/{name}")
251 }
252 }
253 }
254
255 #[cfg(unix)]
257 #[must_use]
258 pub fn socket_path(&self) -> Option<PathBuf> {
259 match self {
260 ConnectionEndpoint::DomainSocket { directory, name } => Some(directory.join(name)),
261 ConnectionEndpoint::Tcp { .. } => None,
262 }
263 }
264
265 #[cfg(windows)]
267 pub fn pipe_path(&self) -> Option<String> {
268 match self {
269 ConnectionEndpoint::NamedPipe { host, name } => {
270 Some(format!("\\\\{host}\\pipe\\{name}"))
271 }
272 ConnectionEndpoint::Tcp { .. } => None,
273 }
274 }
275
276 #[must_use]
278 pub fn is_tcp(&self) -> bool {
279 matches!(self, ConnectionEndpoint::Tcp { .. })
280 }
281
282 #[cfg(unix)]
284 #[must_use]
285 pub fn is_domain_socket(&self) -> bool {
286 matches!(self, ConnectionEndpoint::DomainSocket { .. })
287 }
288
289 #[cfg(windows)]
291 pub fn is_named_pipe(&self) -> bool {
292 matches!(self, ConnectionEndpoint::NamedPipe { .. })
293 }
294
295 #[must_use]
297 pub fn tcp_addr(&self) -> Option<(&str, u16)> {
298 match self {
299 ConnectionEndpoint::Tcp { host, port } => Some((host, *port)),
300 #[cfg(unix)]
301 ConnectionEndpoint::DomainSocket { .. } => None,
302 #[cfg(windows)]
303 ConnectionEndpoint::NamedPipe { .. } => None,
304 }
305 }
306}
307
308impl fmt::Display for ConnectionEndpoint {
309 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
310 match self {
311 ConnectionEndpoint::Tcp { host, port } => {
312 if host.contains(':') {
313 write!(f, "[{host}]:{port}")
314 } else {
315 write!(f, "{host}:{port}")
316 }
317 }
318 #[cfg(unix)]
319 ConnectionEndpoint::DomainSocket { directory, name } => {
320 write!(f, "{}/{}", directory.display(), name)
321 }
322 #[cfg(windows)]
323 ConnectionEndpoint::NamedPipe { host, name } => {
324 write!(f, "\\\\{host}\\pipe\\{name}")
325 }
326 }
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_parse_tcp_simple() {
336 let ep = ConnectionEndpoint::parse("localhost:7483").unwrap();
337 assert_eq!(
338 ep,
339 ConnectionEndpoint::Tcp {
340 host: "localhost".to_string(),
341 port: 7483
342 }
343 );
344 }
345
346 #[test]
347 fn test_parse_tcp_with_scheme() {
348 let ep = ConnectionEndpoint::parse("tab.tcp://127.0.0.1:7483").unwrap();
349 assert_eq!(
350 ep,
351 ConnectionEndpoint::Tcp {
352 host: "127.0.0.1".to_string(),
353 port: 7483
354 }
355 );
356 }
357
358 #[test]
359 fn test_parse_tcp_auto_port() {
360 let ep = ConnectionEndpoint::parse("tab.tcp://localhost:auto").unwrap();
361 assert_eq!(
362 ep,
363 ConnectionEndpoint::Tcp {
364 host: "localhost".to_string(),
365 port: 0
366 }
367 );
368 }
369
370 #[test]
371 fn test_parse_tcp_ipv6() {
372 let ep = ConnectionEndpoint::parse("tab.tcp://[::1]:7483").unwrap();
373 assert_eq!(
374 ep,
375 ConnectionEndpoint::Tcp {
376 host: "::1".to_string(),
377 port: 7483
378 }
379 );
380 }
381
382 #[cfg(unix)]
383 #[test]
384 fn test_parse_domain_socket() {
385 let ep =
386 ConnectionEndpoint::parse("tab.domain:///tmp/hyper/domain/.s.PGSQL.12345").unwrap();
387 assert_eq!(
388 ep,
389 ConnectionEndpoint::DomainSocket {
390 directory: PathBuf::from("/tmp/hyper"),
391 name: ".s.PGSQL.12345".to_string()
392 }
393 );
394 }
395
396 #[test]
397 fn test_to_descriptor_tcp() {
398 let ep = ConnectionEndpoint::tcp("localhost", 7483);
399 assert_eq!(ep.to_descriptor(), "tab.tcp://localhost:7483");
400 }
401
402 #[test]
403 fn test_to_descriptor_tcp_auto() {
404 let ep = ConnectionEndpoint::tcp("localhost", 0);
405 assert_eq!(ep.to_descriptor(), "tab.tcp://localhost:auto");
406 }
407
408 #[cfg(unix)]
409 #[test]
410 fn test_to_descriptor_domain_socket() {
411 let ep = ConnectionEndpoint::domain_socket("/tmp/hyper", ".s.PGSQL.12345");
412 assert_eq!(
413 ep.to_descriptor(),
414 "tab.domain:///tmp/hyper/domain/.s.PGSQL.12345"
415 );
416 }
417
418 #[cfg(unix)]
419 #[test]
420 fn test_socket_path() {
421 let ep = ConnectionEndpoint::domain_socket("/tmp/hyper", ".s.PGSQL.12345");
422 assert_eq!(
423 ep.socket_path(),
424 Some(PathBuf::from("/tmp/hyper/.s.PGSQL.12345"))
425 );
426 }
427
428 #[test]
429 fn test_display_tcp() {
430 let ep = ConnectionEndpoint::tcp("localhost", 7483);
431 assert_eq!(format!("{ep}"), "localhost:7483");
432 }
433
434 #[cfg(unix)]
435 #[test]
436 fn test_display_domain_socket() {
437 let ep = ConnectionEndpoint::domain_socket("/tmp/hyper", ".s.PGSQL.12345");
438 assert_eq!(format!("{ep}"), "/tmp/hyper/.s.PGSQL.12345");
439 }
440
441 #[cfg(windows)]
442 #[test]
443 fn test_parse_named_pipe() {
444 let ep = ConnectionEndpoint::parse("tab.pipe://./pipe/hyper-12345").unwrap();
445 assert_eq!(
446 ep,
447 ConnectionEndpoint::NamedPipe {
448 host: ".".to_string(),
449 name: "hyper-12345".to_string()
450 }
451 );
452 }
453
454 #[cfg(windows)]
455 #[test]
456 fn test_parse_named_pipe_remote() {
457 let ep = ConnectionEndpoint::parse("tab.pipe://server1/pipe/hyper-db").unwrap();
458 assert_eq!(
459 ep,
460 ConnectionEndpoint::NamedPipe {
461 host: "server1".to_string(),
462 name: "hyper-db".to_string()
463 }
464 );
465 }
466
467 #[cfg(windows)]
468 #[test]
469 fn test_to_descriptor_named_pipe() {
470 let ep = ConnectionEndpoint::named_pipe(".", "hyper-12345");
471 assert_eq!(ep.to_descriptor(), "tab.pipe://./pipe/hyper-12345");
472 }
473
474 #[cfg(windows)]
475 #[test]
476 fn test_pipe_path() {
477 let ep = ConnectionEndpoint::named_pipe(".", "hyper-12345");
478 assert_eq!(ep.pipe_path(), Some(r"\\.\pipe\hyper-12345".to_string()));
479 }
480
481 #[cfg(windows)]
482 #[test]
483 fn test_display_named_pipe() {
484 let ep = ConnectionEndpoint::named_pipe(".", "hyper-12345");
485 assert_eq!(format!("{ep}"), r"\\.\pipe\hyper-12345");
486 }
487
488 #[cfg(windows)]
489 #[test]
490 fn test_named_pipe_is_methods() {
491 let ep = ConnectionEndpoint::named_pipe(".", "test");
492 assert!(!ep.is_tcp());
493 assert!(ep.is_named_pipe());
494 }
495}