1use std::fmt;
2use std::fmt::Debug;
3use std::sync::Arc;
4
5use tracing::{trace, instrument};
6use futures_util::SinkExt;
7use async_lock::Mutex;
8use async_lock::MutexGuard;
9use tokio_util::compat::{Compat, FuturesAsyncWriteCompatExt};
10use tokio_util::codec::FramedWrite;
11
12use fluvio_protocol::api::{RequestMessage, ResponseMessage};
13use fluvio_protocol::codec::FluvioCodec;
14use fluvio_protocol::Encoder as FlvEncoder;
15use fluvio_protocol::Version;
16use fluvio_future::net::{BoxWriteConnection, ConnectionFd};
17
18use crate::SocketError;
19
20type SinkFrame = FramedWrite<Compat<BoxWriteConnection>, FluvioCodec>;
21
22pub struct FluvioSink {
23 inner: SinkFrame,
24 fd: ConnectionFd,
25 enable_zero_copy: bool,
26}
27
28impl fmt::Debug for FluvioSink {
29 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
30 write!(f, "fd({})", self.id())
31 }
32}
33
34impl FluvioSink {
35 pub fn get_mut_tcp_sink(&mut self) -> &mut SinkFrame {
36 &mut self.inner
37 }
38
39 pub fn id(&self) -> ConnectionFd {
40 #[allow(clippy::clone_on_copy)]
41 self.fd.clone()
42 }
43
44 #[allow(clippy::wrong_self_convention)]
46 pub fn as_shared(self) -> ExclusiveFlvSink {
47 ExclusiveFlvSink::new(self)
48 }
49
50 pub fn new(sink: BoxWriteConnection, fd: ConnectionFd) -> Self {
51 Self {
52 fd,
53 enable_zero_copy: true,
54 inner: SinkFrame::new(sink.compat_write(), FluvioCodec::new()),
55 }
56 }
57
58 pub fn disable_zerocopy(&mut self) {
60 self.enable_zero_copy = false;
61 }
62
63 #[instrument(level = "trace",skip(req_msg),fields(req=?req_msg))]
65 pub async fn send_request<R>(&mut self, req_msg: &RequestMessage<R>) -> Result<(), SocketError>
66 where
67 RequestMessage<R>: FlvEncoder + Debug,
68 {
69 self.inner.send((req_msg, 0)).await?;
70 Ok(())
71 }
72
73 #[instrument(level = "trace", skip(resp_msg))]
74 pub async fn send_response<P>(
76 &mut self,
77 resp_msg: &ResponseMessage<P>,
78 version: Version,
79 ) -> Result<(), SocketError>
80 where
81 ResponseMessage<P>: FlvEncoder + Debug,
82 {
83 trace!("sending response {:#?}", &resp_msg);
84 self.inner.send((resp_msg, version)).await?;
85 Ok(())
86 }
87}
88
89#[cfg(unix)]
90mod fd {
91
92 use std::os::unix::io::AsRawFd;
93 use std::os::unix::io::RawFd;
94
95 use super::FluvioSink;
96
97 impl AsRawFd for FluvioSink {
98 fn as_raw_fd(&self) -> RawFd {
99 self.fd
100 }
101 }
102}
103
104#[cfg(feature = "file")]
105mod file {
106 use std::os::fd::BorrowedFd;
107
108 use bytes::BytesMut;
109 use fluvio_future::task::spawn_blocking;
110 use futures_util::AsyncWriteExt;
111 use nix::sys::uio::pread;
112
113 use fluvio_protocol::store::{FileWrite, StoreValue};
114 use fluvio_future::zero_copy::ZeroCopy;
115
116 use super::*;
117
118 impl FluvioSink {
119 pub async fn encode_file_slices<T>(
121 &mut self,
122 msg: &T,
123 version: Version,
124 ) -> Result<usize, SocketError>
125 where
126 T: FileWrite,
127 {
128 trace!("encoding file slices version: {}", version);
129 let mut buf = BytesMut::with_capacity(1000);
130 let mut data: Vec<StoreValue> = vec![];
131 msg.file_encode(&mut buf, &mut data, version)?;
132 trace!("encoded buffer len: {}", buf.len());
133 data.push(StoreValue::Bytes(buf.freeze()));
135 self.write_store_values(data).await
136 }
137
138 async fn write_store_values(
140 &mut self,
141 values: Vec<StoreValue>,
142 ) -> Result<usize, SocketError> {
143 trace!("writing store values to socket values: {}", values.len());
144
145 let mut total_bytes_written = 0usize;
146
147 for value in values {
148 match value {
149 StoreValue::Bytes(bytes) => {
150 trace!("writing store bytes to socket len: {}", bytes.len());
151 self.get_mut_tcp_sink()
154 .get_mut()
155 .get_mut()
156 .write_all(&bytes)
157 .await?;
158 total_bytes_written += bytes.len();
159 }
160 StoreValue::FileSlice(f_slice) => {
161 if f_slice.is_empty() {
162 trace!("empty slice, skipping");
163 } else {
164 trace!(
165 "writing file slice pos: {} len: {} to socket",
166 f_slice.position(),
167 f_slice.len()
168 );
169 if self.enable_zero_copy {
170 let writer = ZeroCopy::raw(self.fd);
171 let bytes_written =
172 writer.copy_slice(&f_slice).await.map_err(|err| {
173 std::io::Error::other(format!("zero copy failed: {err}"))
174 })?;
175 trace!("finish writing file slice with {bytes_written} bytes");
176 total_bytes_written += bytes_written;
177 } else {
178 let offset = f_slice.position() as i64;
179
180 #[cfg(all(target_pointer_width = "32", target_env = "gnu"))]
181 let offset: i32 = offset.try_into().unwrap();
182
183 let in_fd = f_slice.fd();
184 trace!(
185 in_fd,
186 offset,
187 len = f_slice.len(),
188 "reading from file slice"
189 );
190 let (read_result, mut buf) = spawn_blocking(move || {
191 let mut buf = BytesMut::with_capacity(f_slice.len() as usize);
192 buf.resize(f_slice.len() as usize, 0);
193 let fd = unsafe { BorrowedFd::borrow_raw(in_fd) };
194 let read_size = pread(fd, &mut buf, offset).map_err(|err| {
195 std::io::Error::other(format!("pread failed: {err}"))
196 });
197 (read_size, buf)
198 })
199 .await;
200
201 let read = read_result?;
202 buf.resize(read, 0);
203
204 trace!(read, in_fd, buf_len = buf.len(), "status from file slice");
205
206 self.get_mut_tcp_sink()
208 .get_mut()
209 .get_mut()
210 .write_all(&buf)
211 .await?;
212
213 total_bytes_written += read;
214 }
215 }
216 }
217 }
218 }
219
220 trace!(total_bytes_written, "finish writing store values");
221 Ok(total_bytes_written)
222 }
223 }
224
225 #[cfg(test)]
226 mod tests {
227
228 use std::io::Cursor;
229 use std::io::ErrorKind;
230 use std::sync::Arc;
231 use std::time::Duration;
232 use std::io::Error as IoError;
233
234 use bytes::Buf;
235 use bytes::BufMut;
236 use bytes::BytesMut;
237 use futures_util::AsyncWriteExt;
238 use futures_util::future::join;
239 use futures_util::StreamExt;
240 use tracing::debug;
241
242 use fluvio_future::file_slice::AsyncFileSlice;
243 use fluvio_future::net::TcpListener;
244 use fluvio_protocol::Version;
245 use fluvio_protocol::store::FileWrite;
246 use fluvio_protocol::store::StoreValue;
247 use fluvio_future::fs::util;
248 use fluvio_future::fs::AsyncFileExtension;
249 use fluvio_future::timer::sleep;
250 use fluvio_protocol::{Decoder, Encoder};
251 use fluvio_types::event::StickyEvent;
252
253 use crate::FluvioSocket;
254 use crate::SocketError;
255
256 #[derive(Debug, Default)]
258 struct SliceWrapper(AsyncFileSlice);
259
260 impl SliceWrapper {
261 pub fn len(&self) -> usize {
262 self.0.len() as usize
263 }
264
265 pub fn raw_slice(&self) -> AsyncFileSlice {
266 self.0.clone()
267 }
268 }
269
270 impl Encoder for SliceWrapper {
271 fn write_size(&self, _version: Version) -> usize {
272 self.len() + 4 }
274
275 fn encode<T>(&self, src: &mut T, version: Version) -> Result<(), IoError>
276 where
277 T: BufMut,
278 {
279 if self.len() == 0 {
281 let len: u32 = 0;
282 len.encode(src, version)
283 } else {
284 Err(IoError::new(
285 ErrorKind::InvalidInput,
286 format!("len {} is not zeo", self.len()),
287 ))
288 }
289 }
290 }
291
292 impl Decoder for SliceWrapper {
293 fn decode<T>(&mut self, _src: &mut T, _version: Version) -> Result<(), IoError>
294 where
295 T: Buf,
296 {
297 unimplemented!("file slice cannot be decoded in the ButMut")
298 }
299 }
300
301 impl FileWrite for SliceWrapper {
302 fn file_encode(
303 &self,
304 _dest: &mut BytesMut,
305 data: &mut Vec<StoreValue>,
306 _version: Version,
307 ) -> Result<(), IoError> {
308 data.push(StoreValue::FileSlice(self.raw_slice()));
310 Ok(())
311 }
312 }
313
314 async fn test_server(
315 addr: &str,
316 end: Arc<StickyEvent>,
317 disable_zc: bool,
318 ) -> Result<(), SocketError> {
319 let listener = TcpListener::bind(&addr).await.expect("bind");
320 debug!("server is running");
321 let mut incoming = listener.incoming();
322
323 end.notify();
324 let incoming_stream = incoming.next().await;
325 debug!("server: got connection");
326 let incoming_stream = incoming_stream.expect("next").expect("unwrap again");
327 let mut socket: FluvioSocket = incoming_stream.into();
328
329 let raw_tcp_sink = socket.get_mut_sink().get_mut_tcp_sink();
330
331 const TEXT_LEN: u16 = 5;
332
333 let mut out = vec![];
335 let len: i32 = TEXT_LEN as i32 + 2; len.encode(&mut out, 0).expect("encode"); out.put_u16(TEXT_LEN); raw_tcp_sink.get_mut().get_mut().write_all(&out).await?;
340
341 debug!("server: sending out file contents");
343 let data_file = util::open("tests/test.txt").await.expect("open file");
344 let fslice = data_file.as_slice(0, None).await.expect("slice");
345 assert_eq!(fslice.len(), 5);
346 let wrapper = SliceWrapper(fslice);
347
348 let (mut sink, _stream) = socket.split();
349 if disable_zc {
351 sink.disable_zerocopy();
352 }
353 sink.encode_file_slices(&wrapper, 0).await.expect("encode");
354
355 debug!("server: hanging on client to test");
356 sleep(Duration::from_millis(500)).await;
358 debug!("server: finish");
359 Ok(())
360 }
361
362 async fn setup_client(addr: &str, end: Arc<StickyEvent>) -> Result<(), SocketError> {
363 debug!("waiting for server to start");
364 while !end.is_set() {
365 end.listen().await;
366 }
367 debug!("client: trying to connect");
368 let mut socket = FluvioSocket::connect(addr).await.expect("connect");
369 debug!("client: connect to test server and waiting for server to send out");
370 let stream = socket.get_mut_stream();
371 debug!("client: waiting for bytes");
372 let next_value = stream.get_mut_tcp_stream().next().await;
373 debug!("client: got bytes");
374 let bytes = next_value.expect("next").expect("bytes");
375 assert_eq!(bytes.len(), 7);
376 debug!("decoding values");
377 let mut src = Cursor::new(&bytes);
378 let mut msg1 = String::new();
379 msg1.decode(&mut src, 0).expect("decode should work");
380 assert_eq!(msg1, "hello");
381
382 Ok(())
383 }
384
385 #[fluvio_future::test]
386 async fn test_sink_zero_copy() {
387 let port = portpicker::pick_unused_port().expect("No free ports left");
388 let addr = format!("127.0.0.1:{port}");
389
390 let send_event = StickyEvent::shared();
391 let _r = join(
392 setup_client(&addr, send_event.clone()),
393 test_server(&addr, send_event, false),
394 )
395 .await;
396 }
397
398 #[fluvio_future::test]
399 async fn test_sink_buffer_copy() {
400 let port = portpicker::pick_unused_port().expect("No free ports left");
401 let addr = format!("127.0.0.1:{port}");
402
403 let send_event = StickyEvent::shared();
404 let _r = join(
405 setup_client(&addr, send_event.clone()),
406 test_server(&addr, send_event, true),
407 )
408 .await;
409 }
410 }
411}
412
413pub struct ExclusiveFlvSink {
415 inner: Arc<Mutex<FluvioSink>>,
416 fd: ConnectionFd,
417}
418
419impl ExclusiveFlvSink {
420 pub fn new(sink: FluvioSink) -> Self {
421 let fd = sink.id();
422 ExclusiveFlvSink {
423 inner: Arc::new(Mutex::new(sink)),
424 fd,
425 }
426 }
427}
428
429impl ExclusiveFlvSink {
430 pub async fn lock(&self) -> MutexGuard<'_, FluvioSink> {
431 self.inner.lock().await
432 }
433
434 pub async fn send_request<R>(&self, req_msg: &RequestMessage<R>) -> Result<(), SocketError>
435 where
436 RequestMessage<R>: FlvEncoder + Debug,
437 {
438 let mut inner_sink = self.inner.lock().await;
439 inner_sink.send_request(req_msg).await
440 }
441
442 pub async fn send_response<P>(
444 &mut self,
445 resp_msg: &ResponseMessage<P>,
446 version: Version,
447 ) -> Result<(), SocketError>
448 where
449 ResponseMessage<P>: FlvEncoder + Debug,
450 {
451 let mut inner_sink = self.inner.lock().await;
452 inner_sink.send_response(resp_msg, version).await
453 }
454
455 pub fn id(&self) -> ConnectionFd {
456 #[allow(clippy::clone_on_copy)]
457 self.fd.clone()
458 }
459}
460
461impl Clone for ExclusiveFlvSink {
462 fn clone(&self) -> Self {
463 #[allow(clippy::clone_on_copy)]
464 Self {
465 inner: self.inner.clone(),
466 fd: self.fd.clone(),
467 }
468 }
469}