fluvio_socket/
sink.rs

1use std::fmt;
2use std::fmt::Debug;
3use std::sync::Arc;
4
5use tracing::{trace, instrument};
6use futures_util::SinkExt;
7use async_lock::Mutex;
8use async_lock::MutexGuard;
9use tokio_util::compat::{Compat, FuturesAsyncWriteCompatExt};
10use tokio_util::codec::FramedWrite;
11
12use fluvio_protocol::api::{RequestMessage, ResponseMessage};
13use fluvio_protocol::codec::FluvioCodec;
14use fluvio_protocol::Encoder as FlvEncoder;
15use fluvio_protocol::Version;
16use fluvio_future::net::{BoxWriteConnection, ConnectionFd};
17
18use crate::SocketError;
19
20type SinkFrame = FramedWrite<Compat<BoxWriteConnection>, FluvioCodec>;
21
22pub struct FluvioSink {
23    inner: SinkFrame,
24    fd: ConnectionFd,
25    enable_zero_copy: bool,
26}
27
28impl fmt::Debug for FluvioSink {
29    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
30        write!(f, "fd({})", self.id())
31    }
32}
33
34impl FluvioSink {
35    pub fn get_mut_tcp_sink(&mut self) -> &mut SinkFrame {
36        &mut self.inner
37    }
38
39    pub fn id(&self) -> ConnectionFd {
40        #[allow(clippy::clone_on_copy)]
41        self.fd.clone()
42    }
43
44    /// convert to shared sink
45    #[allow(clippy::wrong_self_convention)]
46    pub fn as_shared(self) -> ExclusiveFlvSink {
47        ExclusiveFlvSink::new(self)
48    }
49
50    pub fn new(sink: BoxWriteConnection, fd: ConnectionFd) -> Self {
51        Self {
52            fd,
53            enable_zero_copy: true,
54            inner: SinkFrame::new(sink.compat_write(), FluvioCodec::new()),
55        }
56    }
57
58    /// don't use zero copy
59    pub fn disable_zerocopy(&mut self) {
60        self.enable_zero_copy = false;
61    }
62
63    /// as client, send request to server
64    #[instrument(level = "trace",skip(req_msg),fields(req=?req_msg))]
65    pub async fn send_request<R>(&mut self, req_msg: &RequestMessage<R>) -> Result<(), SocketError>
66    where
67        RequestMessage<R>: FlvEncoder + Debug,
68    {
69        self.inner.send((req_msg, 0)).await?;
70        Ok(())
71    }
72
73    #[instrument(level = "trace", skip(resp_msg))]
74    /// as server, send back response
75    pub async fn send_response<P>(
76        &mut self,
77        resp_msg: &ResponseMessage<P>,
78        version: Version,
79    ) -> Result<(), SocketError>
80    where
81        ResponseMessage<P>: FlvEncoder + Debug,
82    {
83        trace!("sending response {:#?}", &resp_msg);
84        self.inner.send((resp_msg, version)).await?;
85        Ok(())
86    }
87}
88
89#[cfg(unix)]
90mod fd {
91
92    use std::os::unix::io::AsRawFd;
93    use std::os::unix::io::RawFd;
94
95    use super::FluvioSink;
96
97    impl AsRawFd for FluvioSink {
98        fn as_raw_fd(&self) -> RawFd {
99            self.fd
100        }
101    }
102}
103
104#[cfg(feature = "file")]
105mod file {
106    use std::os::fd::BorrowedFd;
107
108    use bytes::BytesMut;
109    use fluvio_future::task::spawn_blocking;
110    use futures_util::AsyncWriteExt;
111    use nix::sys::uio::pread;
112
113    use fluvio_protocol::store::{FileWrite, StoreValue};
114    use fluvio_future::zero_copy::ZeroCopy;
115
116    use super::*;
117
118    impl FluvioSink {
119        /// write
120        pub async fn encode_file_slices<T>(
121            &mut self,
122            msg: &T,
123            version: Version,
124        ) -> Result<usize, SocketError>
125        where
126            T: FileWrite,
127        {
128            trace!("encoding file slices version: {}", version);
129            let mut buf = BytesMut::with_capacity(1000);
130            let mut data: Vec<StoreValue> = vec![];
131            msg.file_encode(&mut buf, &mut data, version)?;
132            trace!("encoded buffer len: {}", buf.len());
133            // add remainder
134            data.push(StoreValue::Bytes(buf.freeze()));
135            self.write_store_values(data).await
136        }
137
138        /// write store values to socket
139        async fn write_store_values(
140            &mut self,
141            values: Vec<StoreValue>,
142        ) -> Result<usize, SocketError> {
143            trace!("writing store values to socket values: {}", values.len());
144
145            let mut total_bytes_written = 0usize;
146
147            for value in values {
148                match value {
149                    StoreValue::Bytes(bytes) => {
150                        trace!("writing store bytes to socket len: {}", bytes.len());
151                        // These bytes should be already encoded so don't need to pass
152                        // through the FluvioCodec
153                        self.get_mut_tcp_sink()
154                            .get_mut()
155                            .get_mut()
156                            .write_all(&bytes)
157                            .await?;
158                        total_bytes_written += bytes.len();
159                    }
160                    StoreValue::FileSlice(f_slice) => {
161                        if f_slice.is_empty() {
162                            trace!("empty slice, skipping");
163                        } else {
164                            trace!(
165                                "writing file slice pos: {} len: {} to socket",
166                                f_slice.position(),
167                                f_slice.len()
168                            );
169                            if self.enable_zero_copy {
170                                let writer = ZeroCopy::raw(self.fd);
171                                let bytes_written =
172                                    writer.copy_slice(&f_slice).await.map_err(|err| {
173                                        std::io::Error::other(format!("zero copy failed: {err}"))
174                                    })?;
175                                trace!("finish writing file slice with {bytes_written} bytes");
176                                total_bytes_written += bytes_written;
177                            } else {
178                                let offset = f_slice.position() as i64;
179
180                                #[cfg(all(target_pointer_width = "32", target_env = "gnu"))]
181                                let offset: i32 = offset.try_into().unwrap();
182
183                                let in_fd = f_slice.fd();
184                                trace!(
185                                    in_fd,
186                                    offset,
187                                    len = f_slice.len(),
188                                    "reading from file slice"
189                                );
190                                let (read_result, mut buf) = spawn_blocking(move || {
191                                    let mut buf = BytesMut::with_capacity(f_slice.len() as usize);
192                                    buf.resize(f_slice.len() as usize, 0);
193                                    let fd = unsafe { BorrowedFd::borrow_raw(in_fd) };
194                                    let read_size = pread(fd, &mut buf, offset).map_err(|err| {
195                                        std::io::Error::other(format!("pread failed: {err}"))
196                                    });
197                                    (read_size, buf)
198                                })
199                                .await;
200
201                                let read = read_result?;
202                                buf.resize(read, 0);
203
204                                trace!(read, in_fd, buf_len = buf.len(), "status from file slice");
205
206                                // write to socket
207                                self.get_mut_tcp_sink()
208                                    .get_mut()
209                                    .get_mut()
210                                    .write_all(&buf)
211                                    .await?;
212
213                                total_bytes_written += read;
214                            }
215                        }
216                    }
217                }
218            }
219
220            trace!(total_bytes_written, "finish writing store values");
221            Ok(total_bytes_written)
222        }
223    }
224
225    #[cfg(test)]
226    mod tests {
227
228        use std::io::Cursor;
229        use std::io::ErrorKind;
230        use std::sync::Arc;
231        use std::time::Duration;
232        use std::io::Error as IoError;
233
234        use bytes::Buf;
235        use bytes::BufMut;
236        use bytes::BytesMut;
237        use futures_util::AsyncWriteExt;
238        use futures_util::future::join;
239        use futures_util::StreamExt;
240        use tracing::debug;
241
242        use fluvio_future::file_slice::AsyncFileSlice;
243        use fluvio_future::net::TcpListener;
244        use fluvio_protocol::Version;
245        use fluvio_protocol::store::FileWrite;
246        use fluvio_protocol::store::StoreValue;
247        use fluvio_future::fs::util;
248        use fluvio_future::fs::AsyncFileExtension;
249        use fluvio_future::timer::sleep;
250        use fluvio_protocol::{Decoder, Encoder};
251        use fluvio_types::event::StickyEvent;
252
253        use crate::FluvioSocket;
254        use crate::SocketError;
255
256        // slice that outputs to socket with len and slice
257        #[derive(Debug, Default)]
258        struct SliceWrapper(AsyncFileSlice);
259
260        impl SliceWrapper {
261            pub fn len(&self) -> usize {
262                self.0.len() as usize
263            }
264
265            pub fn raw_slice(&self) -> AsyncFileSlice {
266                self.0.clone()
267            }
268        }
269
270        impl Encoder for SliceWrapper {
271            fn write_size(&self, _version: Version) -> usize {
272                self.len() + 4 // include header
273            }
274
275            fn encode<T>(&self, src: &mut T, version: Version) -> Result<(), IoError>
276            where
277                T: BufMut,
278            {
279                // can only encode zero length
280                if self.len() == 0 {
281                    let len: u32 = 0;
282                    len.encode(src, version)
283                } else {
284                    Err(IoError::new(
285                        ErrorKind::InvalidInput,
286                        format!("len {} is not zeo", self.len()),
287                    ))
288                }
289            }
290        }
291
292        impl Decoder for SliceWrapper {
293            fn decode<T>(&mut self, _src: &mut T, _version: Version) -> Result<(), IoError>
294            where
295                T: Buf,
296            {
297                unimplemented!("file slice cannot be decoded in the ButMut")
298            }
299        }
300
301        impl FileWrite for SliceWrapper {
302            fn file_encode(
303                &self,
304                _dest: &mut BytesMut,
305                data: &mut Vec<StoreValue>,
306                _version: Version,
307            ) -> Result<(), IoError> {
308                // just push slice
309                data.push(StoreValue::FileSlice(self.raw_slice()));
310                Ok(())
311            }
312        }
313
314        async fn test_server(
315            addr: &str,
316            end: Arc<StickyEvent>,
317            disable_zc: bool,
318        ) -> Result<(), SocketError> {
319            let listener = TcpListener::bind(&addr).await.expect("bind");
320            debug!("server is running");
321            let mut incoming = listener.incoming();
322
323            end.notify();
324            let incoming_stream = incoming.next().await;
325            debug!("server: got connection");
326            let incoming_stream = incoming_stream.expect("next").expect("unwrap again");
327            let mut socket: FluvioSocket = incoming_stream.into();
328
329            let raw_tcp_sink = socket.get_mut_sink().get_mut_tcp_sink();
330
331            const TEXT_LEN: u16 = 5;
332
333            // directly encode total buffer with is 4 + 2 + string
334            let mut out = vec![];
335            let len: i32 = TEXT_LEN as i32 + 2; // msg plus file
336            len.encode(&mut out, 0).expect("encode"); // codec len
337            out.put_u16(TEXT_LEN); // string message len
338
339            raw_tcp_sink.get_mut().get_mut().write_all(&out).await?;
340
341            // send out file
342            debug!("server: sending out file contents");
343            let data_file = util::open("tests/test.txt").await.expect("open file");
344            let fslice = data_file.as_slice(0, None).await.expect("slice");
345            assert_eq!(fslice.len(), 5);
346            let wrapper = SliceWrapper(fslice);
347
348            let (mut sink, _stream) = socket.split();
349            // output file slice
350            if disable_zc {
351                sink.disable_zerocopy();
352            }
353            sink.encode_file_slices(&wrapper, 0).await.expect("encode");
354
355            debug!("server: hanging on client to test");
356            // just in case if we need to keep it on
357            sleep(Duration::from_millis(500)).await;
358            debug!("server: finish");
359            Ok(())
360        }
361
362        async fn setup_client(addr: &str, end: Arc<StickyEvent>) -> Result<(), SocketError> {
363            debug!("waiting for server to start");
364            while !end.is_set() {
365                end.listen().await;
366            }
367            debug!("client: trying to connect");
368            let mut socket = FluvioSocket::connect(addr).await.expect("connect");
369            debug!("client: connect to test server and waiting for server to send out");
370            let stream = socket.get_mut_stream();
371            debug!("client: waiting for bytes");
372            let next_value = stream.get_mut_tcp_stream().next().await;
373            debug!("client: got bytes");
374            let bytes = next_value.expect("next").expect("bytes");
375            assert_eq!(bytes.len(), 7);
376            debug!("decoding values");
377            let mut src = Cursor::new(&bytes);
378            let mut msg1 = String::new();
379            msg1.decode(&mut src, 0).expect("decode should work");
380            assert_eq!(msg1, "hello");
381
382            Ok(())
383        }
384
385        #[fluvio_future::test]
386        async fn test_sink_zero_copy() {
387            let port = portpicker::pick_unused_port().expect("No free ports left");
388            let addr = format!("127.0.0.1:{port}");
389
390            let send_event = StickyEvent::shared();
391            let _r = join(
392                setup_client(&addr, send_event.clone()),
393                test_server(&addr, send_event, false),
394            )
395            .await;
396        }
397
398        #[fluvio_future::test]
399        async fn test_sink_buffer_copy() {
400            let port = portpicker::pick_unused_port().expect("No free ports left");
401            let addr = format!("127.0.0.1:{port}");
402
403            let send_event = StickyEvent::shared();
404            let _r = join(
405                setup_client(&addr, send_event.clone()),
406                test_server(&addr, send_event, true),
407            )
408            .await;
409        }
410    }
411}
412
413/// Multi-thread aware Sink.  Only allow sending request one a time.
414pub struct ExclusiveFlvSink {
415    inner: Arc<Mutex<FluvioSink>>,
416    fd: ConnectionFd,
417}
418
419impl ExclusiveFlvSink {
420    pub fn new(sink: FluvioSink) -> Self {
421        let fd = sink.id();
422        ExclusiveFlvSink {
423            inner: Arc::new(Mutex::new(sink)),
424            fd,
425        }
426    }
427}
428
429impl ExclusiveFlvSink {
430    pub async fn lock(&self) -> MutexGuard<'_, FluvioSink> {
431        self.inner.lock().await
432    }
433
434    pub async fn send_request<R>(&self, req_msg: &RequestMessage<R>) -> Result<(), SocketError>
435    where
436        RequestMessage<R>: FlvEncoder + Debug,
437    {
438        let mut inner_sink = self.inner.lock().await;
439        inner_sink.send_request(req_msg).await
440    }
441
442    /// helper method to send back response
443    pub async fn send_response<P>(
444        &mut self,
445        resp_msg: &ResponseMessage<P>,
446        version: Version,
447    ) -> Result<(), SocketError>
448    where
449        ResponseMessage<P>: FlvEncoder + Debug,
450    {
451        let mut inner_sink = self.inner.lock().await;
452        inner_sink.send_response(resp_msg, version).await
453    }
454
455    pub fn id(&self) -> ConnectionFd {
456        #[allow(clippy::clone_on_copy)]
457        self.fd.clone()
458    }
459}
460
461impl Clone for ExclusiveFlvSink {
462    fn clone(&self) -> Self {
463        #[allow(clippy::clone_on_copy)]
464        Self {
465            inner: self.inner.clone(),
466            fd: self.fd.clone(),
467        }
468    }
469}