Skip to main content

bollard/
read.rs

1use bytes::Buf;
2use bytes::BytesMut;
3use futures_core::Stream;
4use hyper::body::Body;
5use hyper::body::Bytes;
6use hyper::body::Incoming;
7use hyper::upgrade::Upgraded;
8use log::debug;
9use log::trace;
10use pin_project_lite::pin_project;
11use serde::de::DeserializeOwned;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14use std::{cmp, io, marker::PhantomData};
15
16use tokio::io::AsyncWrite;
17use tokio::io::{AsyncRead, ReadBuf};
18use tokio_util::codec::Decoder;
19
20use crate::container::LogOutput;
21
22use crate::errors::Error;
23use crate::errors::Error::JsonDataError;
24
25#[derive(Debug, Copy, Clone)]
26enum NewlineLogOutputDecoderState {
27    WaitingHeader,
28    WaitingPayload(u8, usize), // StreamType, Length
29}
30
31#[derive(Debug, Copy, Clone)]
32pub(crate) struct NewlineLogOutputDecoder {
33    state: NewlineLogOutputDecoderState,
34    is_tcp: bool,
35}
36
37impl NewlineLogOutputDecoder {
38    pub(crate) fn new(is_tcp: bool) -> NewlineLogOutputDecoder {
39        NewlineLogOutputDecoder {
40            state: NewlineLogOutputDecoderState::WaitingHeader,
41            is_tcp,
42        }
43    }
44}
45
46impl Decoder for NewlineLogOutputDecoder {
47    type Item = LogOutput;
48    type Error = io::Error;
49
50    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
51        loop {
52            match self.state {
53                NewlineLogOutputDecoderState::WaitingHeader => {
54                    // `start_exec` API on unix socket will emit values without a header
55                    if !src.is_empty() && src[0] > 2 {
56                        if self.is_tcp {
57                            return Ok(Some(LogOutput::Console {
58                                message: src.split().freeze(),
59                            }));
60                        }
61                        let nl_index = src.iter().position(|b| *b == b'\n');
62                        if let Some(pos) = nl_index {
63                            return Ok(Some(LogOutput::Console {
64                                message: src.split_to(pos + 1).freeze(),
65                            }));
66                        } else {
67                            return Ok(None);
68                        }
69                    }
70
71                    if src.len() < 8 {
72                        return Ok(None);
73                    }
74
75                    let header = src.split_to(8);
76                    let length =
77                        u32::from_be_bytes([header[4], header[5], header[6], header[7]]) as usize;
78                    self.state = NewlineLogOutputDecoderState::WaitingPayload(header[0], length);
79                }
80                NewlineLogOutputDecoderState::WaitingPayload(typ, length) => {
81                    if src.len() < length {
82                        return Ok(None);
83                    } else {
84                        trace!("NewlineLogOutputDecoder: Reading payload");
85                        let message = src.split_to(length).freeze();
86                        let item = match typ {
87                            0 => LogOutput::StdIn { message },
88                            1 => LogOutput::StdOut { message },
89                            2 => LogOutput::StdErr { message },
90                            _ => unreachable!(),
91                        };
92
93                        self.state = NewlineLogOutputDecoderState::WaitingHeader;
94                        return Ok(Some(item));
95                    }
96                }
97            }
98        }
99    }
100}
101
102pin_project! {
103    #[derive(Debug)]
104    pub(crate) struct JsonLineDecoder<T> {
105        ty: PhantomData<T>,
106    }
107}
108
109impl<T> JsonLineDecoder<T> {
110    #[inline]
111    pub(crate) fn new() -> JsonLineDecoder<T> {
112        JsonLineDecoder { ty: PhantomData }
113    }
114}
115
116fn decode_json_from_slice<T: DeserializeOwned>(slice: &[u8]) -> Result<Option<T>, Error> {
117    debug!(
118        "Decoding JSON line from stream: {}",
119        String::from_utf8_lossy(slice)
120    );
121
122    match serde_json::from_slice(slice) {
123        Ok(json) => Ok(json),
124        Err(ref e) if e.is_data() => Err(JsonDataError {
125            message: e.to_string(),
126            column: e.column(),
127            #[cfg(feature = "json_data_content")]
128            contents: String::from_utf8_lossy(slice).to_string(),
129        }),
130        Err(e) if e.is_eof() => Ok(None),
131        Err(e) => Err(e.into()),
132    }
133}
134
135impl<T> Decoder for JsonLineDecoder<T>
136where
137    T: DeserializeOwned,
138{
139    type Item = T;
140    type Error = Error;
141    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
142        let nl_index = src.iter().position(|b| *b == b'\n');
143
144        if !src.is_empty() {
145            if let Some(pos) = nl_index {
146                let remainder = src.split_off(pos + 1);
147                let slice = &src[..src.len() - 1];
148
149                match decode_json_from_slice(slice) {
150                    Ok(None) => {
151                        // Unescaped newline inside the json structure
152                        src.truncate(src.len() - 1); // Remove the newline
153                        src.unsplit(remainder);
154                        Ok(None)
155                    }
156                    Ok(json) => {
157                        // Newline delimited json
158                        src.unsplit(remainder);
159                        src.advance(pos + 1);
160                        Ok(json)
161                    }
162                    Err(e) => Err(e),
163                }
164            } else {
165                // No newline delimited json.
166                match decode_json_from_slice(src) {
167                    Ok(None) => Ok(None),
168                    Ok(json) => {
169                        src.clear();
170                        Ok(json)
171                    }
172                    Err(e) => Err(e),
173                }
174            }
175        } else {
176            Ok(None)
177        }
178    }
179}
180
181#[derive(Debug)]
182enum ReadState {
183    Ready(Bytes, usize),
184    NotReady,
185}
186
187pin_project! {
188    #[derive(Debug)]
189    pub(crate) struct StreamReader {
190        #[pin]
191        stream: Incoming,
192        state: ReadState,
193    }
194}
195
196impl StreamReader {
197    #[inline]
198    pub(crate) fn new(stream: Incoming) -> StreamReader {
199        StreamReader {
200            stream,
201            state: ReadState::NotReady,
202        }
203    }
204}
205
206impl AsyncRead for StreamReader {
207    fn poll_read(
208        mut self: Pin<&mut Self>,
209        cx: &mut Context<'_>,
210        read_buf: &mut ReadBuf<'_>,
211    ) -> Poll<io::Result<()>> {
212        loop {
213            match self.as_mut().project().state {
214                ReadState::Ready(ref mut chunk, ref mut pos) => {
215                    let chunk_start = *pos;
216                    let buf = read_buf.initialize_unfilled();
217                    let len = cmp::min(buf.len(), chunk.len() - chunk_start);
218                    let chunk_end = chunk_start + len;
219
220                    buf[..len].copy_from_slice(&chunk[chunk_start..chunk_end]);
221                    *pos += len;
222                    read_buf.advance(len);
223
224                    if *pos != chunk.len() {
225                        return Poll::Ready(Ok(()));
226                    }
227                }
228
229                ReadState::NotReady => match self.as_mut().project().stream.poll_frame(cx) {
230                    Poll::Ready(Some(Ok(frame))) if frame.is_data() => {
231                        *self.as_mut().project().state =
232                            ReadState::Ready(frame.into_data().unwrap(), 0);
233
234                        continue;
235                    }
236                    Poll::Ready(Some(Ok(_frame))) => return Poll::Ready(Ok(())),
237                    Poll::Ready(None) => return Poll::Ready(Ok(())),
238                    Poll::Pending => {
239                        return Poll::Pending;
240                    }
241                    Poll::Ready(Some(Err(e))) => {
242                        return Poll::Ready(Err(io::Error::other(e.to_string())));
243                    }
244                },
245            }
246
247            *self.as_mut().project().state = ReadState::NotReady;
248
249            return Poll::Ready(Ok(()));
250        }
251    }
252}
253
254pin_project! {
255    #[derive(Debug)]
256    pub(crate) struct AsyncUpgraded {
257        #[pin]
258        inner: Upgraded,
259    }
260}
261
262impl AsyncUpgraded {
263    pub(crate) fn new(upgraded: Upgraded) -> Self {
264        Self { inner: upgraded }
265    }
266}
267
268impl AsyncRead for AsyncUpgraded {
269    fn poll_read(
270        self: Pin<&mut Self>,
271        cx: &mut Context<'_>,
272        read_buf: &mut ReadBuf<'_>,
273    ) -> Poll<io::Result<()>> {
274        let n = {
275            let mut hbuf = hyper::rt::ReadBuf::new(read_buf.initialize_unfilled());
276            match hyper::rt::Read::poll_read(self.project().inner, cx, hbuf.unfilled()) {
277                Poll::Ready(Ok(())) => hbuf.filled().len(),
278                other => return other,
279            }
280        };
281        read_buf.advance(n);
282
283        Poll::Ready(Ok(()))
284    }
285}
286
287impl AsyncWrite for AsyncUpgraded {
288    fn poll_write(
289        self: Pin<&mut Self>,
290        cx: &mut Context<'_>,
291        buf: &[u8],
292    ) -> Poll<Result<usize, io::Error>> {
293        hyper::rt::Write::poll_write(self.project().inner, cx, buf)
294    }
295
296    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
297        hyper::rt::Write::poll_flush(self.project().inner, cx)
298    }
299
300    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
301        hyper::rt::Write::poll_shutdown(self.project().inner, cx)
302    }
303}
304
305pin_project! {
306    #[derive(Debug)]
307    pub(crate) struct IncomingStream {
308        #[pin]
309        inner: Incoming,
310    }
311}
312
313impl IncomingStream {
314    pub(crate) fn new(incoming: Incoming) -> Self {
315        Self { inner: incoming }
316    }
317}
318
319impl Stream for IncomingStream {
320    type Item = Result<Bytes, Error>;
321
322    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
323        match futures_util::ready!(self.as_mut().project().inner.poll_frame(cx)?) {
324            Some(frame) => match frame.into_data() {
325                Ok(data) => Poll::Ready(Some(Ok(data))),
326                Err(_) => Poll::Ready(None),
327            },
328            None => Poll::Ready(None),
329        }
330    }
331}
332
333#[cfg(feature = "websocket")]
334pub(crate) mod websocket {
335    use bytes::{Bytes, BytesMut};
336    use futures_core::Stream;
337    use futures_util::stream::{SplitSink, SplitStream};
338    use pin_project_lite::pin_project;
339    use std::cmp;
340    use std::io;
341    use std::pin::Pin;
342    use std::task::{Context, Poll};
343    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
344    use tokio_tungstenite::tungstenite::Message;
345    use tokio_tungstenite::WebSocketStream;
346
347    #[derive(Debug)]
348    enum ReaderState {
349        /// Ready to read from the current chunk at the given position.
350        Ready(Bytes, usize),
351        /// Waiting for the next WebSocket message.
352        Waiting,
353        /// The WebSocket stream has been closed.
354        Closed,
355    }
356
357    pin_project! {
358        /// Wraps a WebSocket read stream to implement [`AsyncRead`].
359        ///
360        /// Reads binary and text WebSocket messages and provides their payloads
361        /// as a contiguous byte stream suitable for use with [`FramedRead`](tokio_util::codec::FramedRead).
362        #[derive(Debug)]
363        pub struct WebSocketReader<S> {
364            #[pin]
365            stream: SplitStream<WebSocketStream<S>>,
366            state: ReaderState,
367        }
368    }
369
370    impl<S> WebSocketReader<S> {
371        /// Create a new `WebSocketReader` from a WebSocket split stream.
372        pub fn new(stream: SplitStream<WebSocketStream<S>>) -> Self {
373            Self {
374                stream,
375                state: ReaderState::Waiting,
376            }
377        }
378    }
379
380    impl<S> AsyncRead for WebSocketReader<S>
381    where
382        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
383    {
384        fn poll_read(
385            mut self: Pin<&mut Self>,
386            cx: &mut Context<'_>,
387            read_buf: &mut ReadBuf<'_>,
388        ) -> Poll<io::Result<()>> {
389            loop {
390                match self.as_mut().project().state {
391                    ReaderState::Ready(ref chunk, ref mut pos) => {
392                        let chunk_start = *pos;
393                        let buf = read_buf.initialize_unfilled();
394                        let len = cmp::min(buf.len(), chunk.len() - chunk_start);
395                        let chunk_end = chunk_start + len;
396
397                        buf[..len].copy_from_slice(&chunk[chunk_start..chunk_end]);
398                        *pos += len;
399                        read_buf.advance(len);
400
401                        if *pos >= chunk.len() {
402                            *self.as_mut().project().state = ReaderState::Waiting;
403                        }
404                        return Poll::Ready(Ok(()));
405                    }
406                    ReaderState::Waiting => {
407                        match self.as_mut().project().stream.poll_next(cx) {
408                            Poll::Ready(Some(Ok(msg))) => match msg {
409                                Message::Binary(data) => {
410                                    *self.as_mut().project().state = ReaderState::Ready(data, 0);
411                                    continue;
412                                }
413                                Message::Text(text) => {
414                                    *self.as_mut().project().state = ReaderState::Ready(
415                                        Bytes::copy_from_slice(text.as_bytes()),
416                                        0,
417                                    );
418                                    continue;
419                                }
420                                Message::Close(_) => {
421                                    *self.as_mut().project().state = ReaderState::Closed;
422                                    return Poll::Ready(Ok(()));
423                                }
424                                // Ping/Pong frames are handled by tungstenite automatically
425                                Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => {
426                                    continue;
427                                }
428                            },
429                            Poll::Ready(Some(Err(e))) => {
430                                return Poll::Ready(Err(io::Error::other(e.to_string())));
431                            }
432                            Poll::Ready(None) => {
433                                *self.as_mut().project().state = ReaderState::Closed;
434                                return Poll::Ready(Ok(()));
435                            }
436                            Poll::Pending => {
437                                return Poll::Pending;
438                            }
439                        }
440                    }
441                    ReaderState::Closed => {
442                        return Poll::Ready(Ok(()));
443                    }
444                }
445            }
446        }
447    }
448
449    pin_project! {
450        /// Wraps a WebSocket write sink to implement [`AsyncWrite`].
451        ///
452        /// Buffers writes and sends the accumulated data as a single binary
453        /// WebSocket message when flushed.
454        #[derive(Debug)]
455        pub struct WebSocketWriter<S> {
456            #[pin]
457            sink: SplitSink<WebSocketStream<S>, Message>,
458            buffer: BytesMut,
459        }
460    }
461
462    impl<S> WebSocketWriter<S>
463    where
464        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
465    {
466        /// Create a new `WebSocketWriter` from a WebSocket split sink.
467        pub fn new(sink: SplitSink<WebSocketStream<S>, Message>) -> Self {
468            Self {
469                sink,
470                buffer: BytesMut::new(),
471            }
472        }
473    }
474
475    impl<S> AsyncWrite for WebSocketWriter<S>
476    where
477        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
478    {
479        fn poll_write(
480            self: Pin<&mut Self>,
481            _cx: &mut Context<'_>,
482            buf: &[u8],
483        ) -> Poll<Result<usize, io::Error>> {
484            let this = self.project();
485            this.buffer.extend_from_slice(buf);
486            Poll::Ready(Ok(buf.len()))
487        }
488
489        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
490            use futures_util::Sink;
491
492            let mut this = self.project();
493
494            if !this.buffer.is_empty() {
495                match this.sink.as_mut().poll_ready(cx) {
496                    Poll::Ready(Ok(())) => {}
497                    Poll::Ready(Err(e)) => {
498                        return Poll::Ready(Err(io::Error::other(e)));
499                    }
500                    Poll::Pending => return Poll::Pending,
501                }
502
503                let data = this.buffer.split().freeze();
504                if let Err(e) = this.sink.as_mut().start_send(Message::Binary(data)) {
505                    return Poll::Ready(Err(io::Error::other(e)));
506                }
507            }
508
509            match this.sink.poll_flush(cx) {
510                Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
511                Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
512                Poll::Pending => Poll::Pending,
513            }
514        }
515
516        fn poll_shutdown(
517            self: Pin<&mut Self>,
518            cx: &mut Context<'_>,
519        ) -> Poll<Result<(), io::Error>> {
520            use futures_util::Sink;
521
522            let mut this = self.project();
523
524            // Flush any remaining buffered data
525            if !this.buffer.is_empty() {
526                match this.sink.as_mut().poll_ready(cx) {
527                    Poll::Ready(Ok(())) => {}
528                    Poll::Ready(Err(e)) => {
529                        return Poll::Ready(Err(io::Error::other(e)));
530                    }
531                    Poll::Pending => return Poll::Pending,
532                }
533
534                let data = this.buffer.split().freeze();
535                if let Err(e) = this.sink.as_mut().start_send(Message::Binary(data)) {
536                    return Poll::Ready(Err(io::Error::other(e)));
537                }
538            }
539
540            // Close the WebSocket connection
541            match this.sink.poll_close(cx) {
542                Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
543                Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
544                Poll::Pending => Poll::Pending,
545            }
546        }
547    }
548}
549
550#[cfg(test)]
551mod tests {
552    use std::collections::HashMap;
553
554    use bytes::{BufMut, BytesMut};
555    use tokio_util::codec::Decoder;
556
557    use crate::container::LogOutput;
558
559    use super::{JsonLineDecoder, NewlineLogOutputDecoder};
560
561    #[test]
562    fn json_decode_empty() {
563        let mut buf = BytesMut::from(&b""[..]);
564        let mut codec: JsonLineDecoder<()> = JsonLineDecoder::new();
565
566        assert_eq!(codec.decode(&mut buf).unwrap(), None);
567    }
568
569    #[test]
570    fn json_decode() {
571        let mut buf = BytesMut::from(&b"{}\n{}\n\n{}\n"[..]);
572        let mut codec: JsonLineDecoder<HashMap<(), ()>> = JsonLineDecoder::new();
573
574        assert_eq!(codec.decode(&mut buf).unwrap(), Some(HashMap::new()));
575        assert_eq!(codec.decode(&mut buf).unwrap(), Some(HashMap::new()));
576        assert_eq!(codec.decode(&mut buf).unwrap(), None);
577        assert_eq!(codec.decode(&mut buf).unwrap(), Some(HashMap::new()));
578        assert_eq!(codec.decode(&mut buf).unwrap(), None);
579        assert!(buf.is_empty());
580    }
581
582    #[test]
583    fn json_partial_decode() {
584        let mut buf = BytesMut::from(&b"{}\n{}\n\n{"[..]);
585        let mut codec: JsonLineDecoder<HashMap<(), ()>> = JsonLineDecoder::new();
586
587        assert_eq!(codec.decode(&mut buf).unwrap(), Some(HashMap::new()));
588        assert_eq!(buf, &b"{}\n\n{"[..]);
589        assert_eq!(codec.decode(&mut buf).unwrap(), Some(HashMap::new()));
590        assert_eq!(codec.decode(&mut buf).unwrap(), None);
591        assert_eq!(codec.decode(&mut buf).unwrap(), None);
592        assert_eq!(buf, &b"{"[..]);
593        buf.put(&b"}"[..]);
594        assert_eq!(codec.decode(&mut buf).unwrap(), Some(HashMap::new()));
595        assert!(buf.is_empty());
596    }
597
598    #[test]
599    fn json_partial_decode_no_newline() {
600        let mut buf = BytesMut::from(&b"{\"status\":\"Extracting\",\"progressDetail\":{\"current\":33980416,\"total\":102266715}"[..]);
601        let mut codec: JsonLineDecoder<crate::models::CreateImageInfo> = JsonLineDecoder::new();
602
603        let expected = crate::models::CreateImageInfo {
604            status: Some(String::from("Extracting")),
605            progress_detail: Some(crate::models::ProgressDetail {
606                current: Some(33980416),
607                total: Some(102266715),
608            }),
609            ..Default::default()
610        };
611        assert_eq!(codec.decode(&mut buf).unwrap(), None);
612        assert_eq!(buf, &b"{\"status\":\"Extracting\",\"progressDetail\":{\"current\":33980416,\"total\":102266715}"[..]);
613        buf.put(&b"}"[..]);
614        assert_eq!(codec.decode(&mut buf).unwrap(), Some(expected));
615        assert!(buf.is_empty());
616    }
617
618    #[test]
619    fn json_partial_decode_newline() {
620        let mut buf = BytesMut::from(&b"{\"status\":\"Extracting\",\"progressDetail\":{\"current\":33980416,\"total\":102266715}\n"[..]);
621        let mut codec: JsonLineDecoder<crate::models::CreateImageInfo> = JsonLineDecoder::new();
622
623        let expected = crate::models::CreateImageInfo {
624            status: Some(String::from("Extracting")),
625            progress_detail: Some(crate::models::ProgressDetail {
626                current: Some(33980416),
627                total: Some(102266715),
628            }),
629            ..Default::default()
630        };
631        assert_eq!(codec.decode(&mut buf).unwrap(), None);
632        assert_eq!(buf, &b"{\"status\":\"Extracting\",\"progressDetail\":{\"current\":33980416,\"total\":102266715}"[..]);
633        buf.put(&b"}"[..]);
634        assert_eq!(codec.decode(&mut buf).unwrap(), Some(expected));
635        assert!(buf.is_empty());
636    }
637
638    #[test]
639    fn json_decode_escaped_newline() {
640        let mut buf = BytesMut::from(&b"\"foo\\nbar\""[..]);
641        let mut codec: JsonLineDecoder<String> = JsonLineDecoder::new();
642
643        assert_eq!(
644            codec.decode(&mut buf).unwrap(),
645            Some(String::from("foo\nbar"))
646        );
647    }
648
649    #[test]
650    fn json_decode_lacking_newline() {
651        let mut buf = BytesMut::from(&b"{}"[..]);
652        let mut codec: JsonLineDecoder<HashMap<(), ()>> = JsonLineDecoder::new();
653
654        assert_eq!(codec.decode(&mut buf).unwrap(), Some(HashMap::new()));
655        assert!(buf.is_empty());
656    }
657
658    #[test]
659    fn newline_decode_no_header() {
660        let expected = &b"2023-01-14T23:17:27.496421984-05:00 [lighttpd] 2023/01/14 23"[..];
661        let mut buf = BytesMut::from(expected);
662        let mut codec: NewlineLogOutputDecoder = NewlineLogOutputDecoder::new(true);
663
664        assert_eq!(
665            codec.decode(&mut buf).unwrap(),
666            Some(LogOutput::Console {
667                message: bytes::Bytes::from(expected)
668            })
669        );
670
671        let mut buf =
672            BytesMut::from(&b"2023-01-14T23:17:27.496421984-05:00 [lighttpd] 2023/01/14 23"[..]);
673        let mut codec: NewlineLogOutputDecoder = NewlineLogOutputDecoder::new(false);
674
675        assert_eq!(codec.decode(&mut buf).unwrap(), None);
676
677        buf.put(
678            &b":17:27 2023-01-14 23:17:26: server.c.1513) server started (lighttpd/1.4.59)\r\n"[..],
679        );
680
681        let expected = &b"2023-01-14T23:17:27.496421984-05:00 [lighttpd] 2023/01/14 23:17:27 2023-01-14 23:17:26: server.c.1513) server started (lighttpd/1.4.59)\r\n"[..];
682        assert_eq!(
683            codec.decode(&mut buf).unwrap(),
684            Some(LogOutput::Console {
685                message: bytes::Bytes::from(expected)
686            })
687        );
688    }
689}