aruna_file/
readwrite.rs

1use std::mem;
2
3use crate::notifications::{FileMessage, Message};
4use crate::transformer::{FileContext, ReadWriter, Sink, Transformer, TransformerType};
5use crate::transformers::writer_sink::WriterSink;
6use anyhow::{bail, Result};
7use async_channel::{Receiver, Sender};
8use bytes::BytesMut;
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufReader, BufWriter};
10use tracing::{debug, error};
11
12pub struct ArunaReadWriter<'a, R: AsyncRead + Unpin> {
13    reader: BufReader<R>,
14    transformers: Vec<(TransformerType, Box<dyn Transformer + Send + Sync + 'a>)>,
15    sink: Box<dyn Transformer + Send + Sync + 'a>,
16    receiver: Receiver<Message>,
17    sender: Sender<Message>,
18    size_counter: usize,
19    current_file_context: Option<(FileContext, bool)>,
20    file_ctx_rx: Option<Receiver<(FileContext, bool)>>,
21}
22
23impl<'a, R: AsyncRead + Unpin> ArunaReadWriter<'a, R> {
24    #[tracing::instrument(level = "trace", skip(reader, writer))]
25    pub fn new_with_writer<W: AsyncWrite + Unpin + Send + Sync + 'a>(
26        reader: R,
27        writer: W,
28    ) -> ArunaReadWriter<'a, R> {
29        let (sx, rx) = async_channel::unbounded();
30        ArunaReadWriter {
31            reader: BufReader::new(reader),
32            sink: Box::new(WriterSink::new(BufWriter::new(writer))),
33            transformers: Vec::new(),
34            sender: sx,
35            receiver: rx,
36            size_counter: 0,
37            current_file_context: None,
38            file_ctx_rx: None,
39        }
40    }
41
42    #[tracing::instrument(level = "trace", skip(reader, transformer))]
43    pub fn new_with_sink<T: Transformer + Sink + Send + Sync + 'a>(
44        reader: R,
45        transformer: T,
46    ) -> ArunaReadWriter<'a, R> {
47        let (sx, rx) = async_channel::unbounded();
48
49        ArunaReadWriter {
50            reader: BufReader::new(reader),
51            sink: Box::new(transformer),
52            transformers: Vec::new(),
53            sender: sx,
54            receiver: rx,
55            size_counter: 0,
56            current_file_context: None,
57            file_ctx_rx: None,
58        }
59    }
60
61    #[tracing::instrument(level = "trace", skip(self, transformer))]
62    pub fn add_transformer<T: Transformer + Send + Sync + 'a>(
63        mut self,
64        mut transformer: T,
65    ) -> Self {
66        transformer.add_sender(self.sender.clone());
67
68        self.transformers
69            .push((transformer.get_type(), Box::new(transformer)));
70        self
71    }
72}
73
74#[async_trait::async_trait]
75impl<'a, R: AsyncRead + Unpin + Send + Sync> ReadWriter for ArunaReadWriter<'a, R> {
76    #[tracing::instrument(err, level = "trace", skip(self))]
77    async fn process(&mut self) -> Result<()> {
78        // The buffer that accumulates the "actual" data
79        let mut read_buf = BytesMut::with_capacity(65_536 * 2);
80        let mut hold_buffer = BytesMut::with_capacity(65536);
81        let mut finished;
82        let mut maybe_msg: Option<Message> = None;
83        let mut read_bytes: usize = 0;
84        let mut next_file: bool = false;
85
86        if let Some(rx) = &self.file_ctx_rx {
87            let (context, is_last) = rx.try_recv()?;
88            debug!(?context, ?is_last, "received file context");
89            self.current_file_context = Some((context.clone(), is_last));
90            self.announce_all(Message {
91                target: TransformerType::All,
92                data: crate::notifications::MessageData::NextFile(FileMessage { context, is_last }),
93            })
94            .await?;
95        }
96
97        loop {
98            if hold_buffer.is_empty() {
99                read_bytes = self.reader.read_buf(&mut read_buf).await?;
100            } else if read_buf.is_empty() {
101                mem::swap(&mut hold_buffer, &mut read_buf);
102            }
103
104            if let Some((context, is_last)) = &self.current_file_context {
105                self.size_counter += read_bytes;
106                if self.size_counter > context.input_size as usize {
107                    let mut diff = read_bytes - (self.size_counter - context.input_size as usize);
108                    if diff >= context.input_size as usize {
109                        diff = context.input_size as usize
110                    }
111                    hold_buffer = read_buf.split_to(diff);
112                    mem::swap(&mut read_buf, &mut hold_buffer);
113                    self.size_counter -= context.input_size as usize;
114                    next_file = !is_last;
115                }
116                finished = read_buf.is_empty() && read_bytes == 0 && *is_last;
117            } else {
118                finished = read_buf.is_empty() && read_bytes == 0;
119            }
120            for (ttype, trans) in self.transformers.iter_mut() {
121                if let Some(m) = &maybe_msg {
122                    if m.target == *ttype {
123                        trans.notify(m).await?;
124                    }
125                } else {
126                    maybe_msg = self.receiver.try_recv().ok();
127                }
128
129                match trans.process_bytes(&mut read_buf, finished, false).await? {
130                    true => {}
131                    false => finished = false,
132                };
133            }
134
135            match self
136                .sink
137                .process_bytes(&mut read_buf, finished, false)
138                .await?
139            {
140                true => {}
141                false => finished = false,
142            };
143
144            // Anounce next file
145            if next_file {
146                if let Some(rx) = &self.file_ctx_rx {
147                    // Perform a flush through all transformers!
148                    for (_, trans) in self.transformers.iter_mut() {
149                        trans.process_bytes(&mut read_buf, finished, true).await?;
150                    }
151                    self.sink
152                        .process_bytes(&mut read_buf, finished, true)
153                        .await?;
154                    let (context, is_last) = rx.recv().await?;
155                    self.current_file_context = Some((context.clone(), is_last));
156                    self.announce_all(Message {
157                        target: TransformerType::All,
158                        data: crate::notifications::MessageData::NextFile(FileMessage {
159                            context,
160                            is_last,
161                        }),
162                    })
163                    .await?;
164                    next_file = false;
165                }
166            }
167
168            if read_buf.is_empty() && finished {
169                break;
170            }
171            read_bytes = 0;
172        }
173        Ok(())
174    }
175
176    #[tracing::instrument(level = "trace", skip(self, message))]
177    async fn announce_all(&mut self, mut message: Message) -> Result<()> {
178        message.target = TransformerType::All;
179        for (_, trans) in self.transformers.iter_mut() {
180            trans.notify(&message).await?;
181        }
182        Ok(())
183    }
184
185    #[tracing::instrument(level = "trace", skip(self, rx))]
186    async fn add_file_context_receiver(&mut self, rx: Receiver<(FileContext, bool)>) -> Result<()> {
187        if self.file_ctx_rx.is_none() {
188            self.file_ctx_rx = Some(rx);
189            Ok(())
190        } else {
191            error!("Overwriting existing receivers is not allowed!");
192            bail!("[READ_WRITER] Overwriting existing receivers is not allowed!")
193        }
194    }
195}