Skip to main content

hyperdb_api_core/client/
async_stream.rs

1// Copyright (c) 2026, Salesforce, Inc. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! Async stream abstraction for multiple transport types.
5//!
6//! This module provides [`AsyncStream`], an enum that can hold different
7//! async stream types (TCP, Unix Domain Socket) while implementing the
8//! necessary async I/O traits.
9
10use std::io;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15use tokio::net::TcpStream;
16
17#[cfg(unix)]
18use tokio::net::UnixStream;
19
20#[cfg(windows)]
21use tokio::net::windows::named_pipe::NamedPipeClient;
22
23/// An async stream that can be either TCP or Unix Domain Socket.
24///
25/// This enum provides a unified interface for different transport mechanisms,
26/// allowing [`AsyncClient`](crate::client::AsyncClient) to work with both TCP and
27/// Unix Domain Sockets transparently.
28#[derive(Debug)]
29pub enum AsyncStream {
30    /// TCP stream for network connections.
31    Tcp(TcpStream),
32
33    /// Unix Domain Socket stream for local IPC (Unix only).
34    #[cfg(unix)]
35    Unix(UnixStream),
36
37    /// Windows Named Pipe stream for local IPC (Windows only).
38    #[cfg(windows)]
39    NamedPipe(NamedPipeClient),
40}
41
42impl AsyncStream {
43    /// Creates a new TCP stream wrapper.
44    pub fn tcp(stream: TcpStream) -> Self {
45        AsyncStream::Tcp(stream)
46    }
47
48    /// Creates a new Unix Domain Socket stream wrapper.
49    #[cfg(unix)]
50    pub fn unix(stream: UnixStream) -> Self {
51        AsyncStream::Unix(stream)
52    }
53
54    /// Returns true if this is a TCP stream.
55    pub fn is_tcp(&self) -> bool {
56        matches!(self, AsyncStream::Tcp(_))
57    }
58
59    /// Returns true if this is a Unix Domain Socket stream.
60    #[cfg(unix)]
61    pub fn is_unix(&self) -> bool {
62        matches!(self, AsyncStream::Unix(_))
63    }
64
65    /// Creates a new Windows Named Pipe stream wrapper.
66    #[cfg(windows)]
67    pub fn named_pipe(client: NamedPipeClient) -> Self {
68        AsyncStream::NamedPipe(client)
69    }
70
71    /// Returns true if this is a Windows Named Pipe stream.
72    #[cfg(windows)]
73    pub fn is_named_pipe(&self) -> bool {
74        matches!(self, AsyncStream::NamedPipe(_))
75    }
76
77    /// Sets `TCP_NODELAY` option (only applicable for TCP streams).
78    ///
79    /// # Errors
80    ///
81    /// Returns an [`io::Error`] from the underlying
82    /// [`tokio::net::TcpStream::set_nodelay`] when the socket option
83    /// cannot be applied. Unix-domain and named-pipe variants are
84    /// no-ops that always return `Ok(())`.
85    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
86        match self {
87            AsyncStream::Tcp(stream) => stream.set_nodelay(nodelay),
88            #[cfg(unix)]
89            AsyncStream::Unix(_) => Ok(()), // No-op for Unix sockets
90            #[cfg(windows)]
91            AsyncStream::NamedPipe(_) => Ok(()), // No-op for Named Pipes
92        }
93    }
94
95    /// Returns the local address for TCP streams, or a placeholder for Unix sockets.
96    pub fn local_addr_string(&self) -> String {
97        match self {
98            AsyncStream::Tcp(stream) => stream
99                .local_addr()
100                .map_or_else(|_| "unknown".to_string(), |a| a.to_string()),
101            #[cfg(unix)]
102            AsyncStream::Unix(stream) => stream
103                .local_addr()
104                .ok()
105                .and_then(|a| a.as_pathname().map(|p| p.display().to_string()))
106                .unwrap_or_else(|| "unix-socket".to_string()),
107            #[cfg(windows)]
108            AsyncStream::NamedPipe(_) => "named-pipe".to_string(),
109        }
110    }
111
112    /// Returns the peer address for TCP streams, or a placeholder for Unix sockets.
113    pub fn peer_addr_string(&self) -> String {
114        match self {
115            AsyncStream::Tcp(stream) => stream
116                .peer_addr()
117                .map_or_else(|_| "unknown".to_string(), |a| a.to_string()),
118            #[cfg(unix)]
119            AsyncStream::Unix(stream) => stream
120                .peer_addr()
121                .ok()
122                .and_then(|a| a.as_pathname().map(|p| p.display().to_string()))
123                .unwrap_or_else(|| "unix-socket".to_string()),
124            #[cfg(windows)]
125            AsyncStream::NamedPipe(_) => "named-pipe".to_string(),
126        }
127    }
128}
129
130impl AsyncRead for AsyncStream {
131    fn poll_read(
132        self: Pin<&mut Self>,
133        cx: &mut Context<'_>,
134        buf: &mut ReadBuf<'_>,
135    ) -> Poll<io::Result<()>> {
136        match self.get_mut() {
137            AsyncStream::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
138            #[cfg(unix)]
139            AsyncStream::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
140            #[cfg(windows)]
141            AsyncStream::NamedPipe(pipe) => Pin::new(pipe).poll_read(cx, buf),
142        }
143    }
144}
145
146impl AsyncWrite for AsyncStream {
147    fn poll_write(
148        self: Pin<&mut Self>,
149        cx: &mut Context<'_>,
150        buf: &[u8],
151    ) -> Poll<io::Result<usize>> {
152        match self.get_mut() {
153            AsyncStream::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
154            #[cfg(unix)]
155            AsyncStream::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
156            #[cfg(windows)]
157            AsyncStream::NamedPipe(pipe) => Pin::new(pipe).poll_write(cx, buf),
158        }
159    }
160
161    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
162        match self.get_mut() {
163            AsyncStream::Tcp(stream) => Pin::new(stream).poll_flush(cx),
164            #[cfg(unix)]
165            AsyncStream::Unix(stream) => Pin::new(stream).poll_flush(cx),
166            #[cfg(windows)]
167            AsyncStream::NamedPipe(pipe) => Pin::new(pipe).poll_flush(cx),
168        }
169    }
170
171    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
172        match self.get_mut() {
173            AsyncStream::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
174            #[cfg(unix)]
175            AsyncStream::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
176            #[cfg(windows)]
177            AsyncStream::NamedPipe(pipe) => Pin::new(pipe).poll_shutdown(cx),
178        }
179    }
180
181    fn poll_write_vectored(
182        self: Pin<&mut Self>,
183        cx: &mut Context<'_>,
184        bufs: &[io::IoSlice<'_>],
185    ) -> Poll<io::Result<usize>> {
186        match self.get_mut() {
187            AsyncStream::Tcp(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
188            #[cfg(unix)]
189            AsyncStream::Unix(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
190            #[cfg(windows)]
191            AsyncStream::NamedPipe(pipe) => Pin::new(pipe).poll_write_vectored(cx, bufs),
192        }
193    }
194
195    fn is_write_vectored(&self) -> bool {
196        match self {
197            AsyncStream::Tcp(stream) => stream.is_write_vectored(),
198            #[cfg(unix)]
199            AsyncStream::Unix(stream) => stream.is_write_vectored(),
200            #[cfg(windows)]
201            AsyncStream::NamedPipe(pipe) => pipe.is_write_vectored(),
202        }
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    #[expect(
209        clippy::assertions_on_constants,
210        reason = "compile-time invariant check kept as an assert for readability at the call site"
211    )]
212    #[test]
213    fn test_async_stream_variants_exist() {
214        // We can't easily create streams without connecting,
215        // so we just verify the module compiles correctly
216        assert!(true);
217    }
218}