1use std::mem;
2
3use crate::notifications::{FileMessage, Message};
4use crate::transformer::{FileContext, ReadWriter, Sink, Transformer, TransformerType};
5use crate::transformers::writer_sink::WriterSink;
6use anyhow::{bail, Result};
7use async_channel::{Receiver, Sender};
8use bytes::BytesMut;
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufReader, BufWriter};
10use tracing::{debug, error};
11
12pub struct ArunaReadWriter<'a, R: AsyncRead + Unpin> {
13 reader: BufReader<R>,
14 transformers: Vec<(TransformerType, Box<dyn Transformer + Send + Sync + 'a>)>,
15 sink: Box<dyn Transformer + Send + Sync + 'a>,
16 receiver: Receiver<Message>,
17 sender: Sender<Message>,
18 size_counter: usize,
19 current_file_context: Option<(FileContext, bool)>,
20 file_ctx_rx: Option<Receiver<(FileContext, bool)>>,
21}
22
23impl<'a, R: AsyncRead + Unpin> ArunaReadWriter<'a, R> {
24 #[tracing::instrument(level = "trace", skip(reader, writer))]
25 pub fn new_with_writer<W: AsyncWrite + Unpin + Send + Sync + 'a>(
26 reader: R,
27 writer: W,
28 ) -> ArunaReadWriter<'a, R> {
29 let (sx, rx) = async_channel::unbounded();
30 ArunaReadWriter {
31 reader: BufReader::new(reader),
32 sink: Box::new(WriterSink::new(BufWriter::new(writer))),
33 transformers: Vec::new(),
34 sender: sx,
35 receiver: rx,
36 size_counter: 0,
37 current_file_context: None,
38 file_ctx_rx: None,
39 }
40 }
41
42 #[tracing::instrument(level = "trace", skip(reader, transformer))]
43 pub fn new_with_sink<T: Transformer + Sink + Send + Sync + 'a>(
44 reader: R,
45 transformer: T,
46 ) -> ArunaReadWriter<'a, R> {
47 let (sx, rx) = async_channel::unbounded();
48
49 ArunaReadWriter {
50 reader: BufReader::new(reader),
51 sink: Box::new(transformer),
52 transformers: Vec::new(),
53 sender: sx,
54 receiver: rx,
55 size_counter: 0,
56 current_file_context: None,
57 file_ctx_rx: None,
58 }
59 }
60
61 #[tracing::instrument(level = "trace", skip(self, transformer))]
62 pub fn add_transformer<T: Transformer + Send + Sync + 'a>(
63 mut self,
64 mut transformer: T,
65 ) -> Self {
66 transformer.add_sender(self.sender.clone());
67
68 self.transformers
69 .push((transformer.get_type(), Box::new(transformer)));
70 self
71 }
72}
73
74#[async_trait::async_trait]
75impl<'a, R: AsyncRead + Unpin + Send + Sync> ReadWriter for ArunaReadWriter<'a, R> {
76 #[tracing::instrument(err, level = "trace", skip(self))]
77 async fn process(&mut self) -> Result<()> {
78 let mut read_buf = BytesMut::with_capacity(65_536 * 2);
80 let mut hold_buffer = BytesMut::with_capacity(65536);
81 let mut finished;
82 let mut maybe_msg: Option<Message> = None;
83 let mut read_bytes: usize = 0;
84 let mut next_file: bool = false;
85
86 if let Some(rx) = &self.file_ctx_rx {
87 let (context, is_last) = rx.try_recv()?;
88 debug!(?context, ?is_last, "received file context");
89 self.current_file_context = Some((context.clone(), is_last));
90 self.announce_all(Message {
91 target: TransformerType::All,
92 data: crate::notifications::MessageData::NextFile(FileMessage { context, is_last }),
93 })
94 .await?;
95 }
96
97 loop {
98 if hold_buffer.is_empty() {
99 read_bytes = self.reader.read_buf(&mut read_buf).await?;
100 } else if read_buf.is_empty() {
101 mem::swap(&mut hold_buffer, &mut read_buf);
102 }
103
104 if let Some((context, is_last)) = &self.current_file_context {
105 self.size_counter += read_bytes;
106 if self.size_counter > context.input_size as usize {
107 let mut diff = read_bytes - (self.size_counter - context.input_size as usize);
108 if diff >= context.input_size as usize {
109 diff = context.input_size as usize
110 }
111 hold_buffer = read_buf.split_to(diff);
112 mem::swap(&mut read_buf, &mut hold_buffer);
113 self.size_counter -= context.input_size as usize;
114 next_file = !is_last;
115 }
116 finished = read_buf.is_empty() && read_bytes == 0 && *is_last;
117 } else {
118 finished = read_buf.is_empty() && read_bytes == 0;
119 }
120 for (ttype, trans) in self.transformers.iter_mut() {
121 if let Some(m) = &maybe_msg {
122 if m.target == *ttype {
123 trans.notify(m).await?;
124 }
125 } else {
126 maybe_msg = self.receiver.try_recv().ok();
127 }
128
129 match trans.process_bytes(&mut read_buf, finished, false).await? {
130 true => {}
131 false => finished = false,
132 };
133 }
134
135 match self
136 .sink
137 .process_bytes(&mut read_buf, finished, false)
138 .await?
139 {
140 true => {}
141 false => finished = false,
142 };
143
144 if next_file {
146 if let Some(rx) = &self.file_ctx_rx {
147 for (_, trans) in self.transformers.iter_mut() {
149 trans.process_bytes(&mut read_buf, finished, true).await?;
150 }
151 self.sink
152 .process_bytes(&mut read_buf, finished, true)
153 .await?;
154 let (context, is_last) = rx.recv().await?;
155 self.current_file_context = Some((context.clone(), is_last));
156 self.announce_all(Message {
157 target: TransformerType::All,
158 data: crate::notifications::MessageData::NextFile(FileMessage {
159 context,
160 is_last,
161 }),
162 })
163 .await?;
164 next_file = false;
165 }
166 }
167
168 if read_buf.is_empty() && finished {
169 break;
170 }
171 read_bytes = 0;
172 }
173 Ok(())
174 }
175
176 #[tracing::instrument(level = "trace", skip(self, message))]
177 async fn announce_all(&mut self, mut message: Message) -> Result<()> {
178 message.target = TransformerType::All;
179 for (_, trans) in self.transformers.iter_mut() {
180 trans.notify(&message).await?;
181 }
182 Ok(())
183 }
184
185 #[tracing::instrument(level = "trace", skip(self, rx))]
186 async fn add_file_context_receiver(&mut self, rx: Receiver<(FileContext, bool)>) -> Result<()> {
187 if self.file_ctx_rx.is_none() {
188 self.file_ctx_rx = Some(rx);
189 Ok(())
190 } else {
191 error!("Overwriting existing receivers is not allowed!");
192 bail!("[READ_WRITER] Overwriting existing receivers is not allowed!")
193 }
194 }
195}