datafusion_datasource/write/
orchestration.rs1use std::sync::Arc;
23
24use super::demux::DemuxedStreamReceiver;
25use super::{create_writer, BatchSerializer};
26use crate::file_compression_type::FileCompressionType;
27use datafusion_common::error::Result;
28
29use arrow::array::RecordBatch;
30use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError};
31use datafusion_common_runtime::SpawnedTask;
32use datafusion_execution::TaskContext;
33
34use bytes::Bytes;
35use futures::join;
36use object_store::ObjectStore;
37use tokio::io::{AsyncWrite, AsyncWriteExt};
38use tokio::sync::mpsc::{self, Receiver};
39use tokio::task::JoinSet;
40
41type WriterType = Box<dyn AsyncWrite + Send + Unpin>;
42type SerializerType = Arc<dyn BatchSerializer>;
43
44pub(crate) enum SerializedRecordBatchResult {
46 Success {
47 writer: WriterType,
49
50 row_count: usize,
52 },
53 Failure {
54 writer: Option<WriterType>,
58
59 err: DataFusionError,
61 },
62}
63
64impl SerializedRecordBatchResult {
65 pub fn success(writer: WriterType, row_count: usize) -> Self {
67 Self::Success { writer, row_count }
68 }
69
70 pub fn failure(writer: Option<WriterType>, err: DataFusionError) -> Self {
71 Self::Failure { writer, err }
72 }
73}
74
75pub(crate) async fn serialize_rb_stream_to_object_store(
85 mut data_rx: Receiver<RecordBatch>,
86 serializer: Arc<dyn BatchSerializer>,
87 mut writer: WriterType,
88) -> SerializedRecordBatchResult {
89 let (tx, mut rx) =
90 mpsc::channel::<SpawnedTask<Result<(usize, Bytes), DataFusionError>>>(100);
91 let serialize_task = SpawnedTask::spawn(async move {
92 let mut initial = true;
95 while let Some(batch) = data_rx.recv().await {
96 let serializer_clone = Arc::clone(&serializer);
97 let task = SpawnedTask::spawn(async move {
98 let num_rows = batch.num_rows();
99 let bytes = serializer_clone.serialize(batch, initial)?;
100 Ok((num_rows, bytes))
101 });
102 if initial {
103 initial = false;
104 }
105 tx.send(task).await.map_err(|_| {
106 internal_datafusion_err!("Unknown error writing to object store")
107 })?;
108 }
109 Ok(())
110 });
111
112 let mut row_count = 0;
113 while let Some(task) = rx.recv().await {
114 match task.join().await {
115 Ok(Ok((cnt, bytes))) => {
116 match writer.write_all(&bytes).await {
117 Ok(_) => (),
118 Err(e) => {
119 return SerializedRecordBatchResult::failure(
120 None,
121 DataFusionError::Execution(format!(
122 "Error writing to object store: {e}"
123 )),
124 )
125 }
126 };
127 row_count += cnt;
128 }
129 Ok(Err(e)) => {
130 return SerializedRecordBatchResult::failure(Some(writer), e);
132 }
133 Err(e) => {
134 return SerializedRecordBatchResult::failure(
136 Some(writer),
137 DataFusionError::Execution(format!(
138 "Serialization task panicked or was cancelled: {e}"
139 )),
140 );
141 }
142 }
143 }
144
145 match serialize_task.join().await {
146 Ok(Ok(_)) => (),
147 Ok(Err(e)) => return SerializedRecordBatchResult::failure(Some(writer), e),
148 Err(_) => {
149 return SerializedRecordBatchResult::failure(
150 Some(writer),
151 internal_datafusion_err!("Unknown error writing to object store"),
152 )
153 }
154 }
155 SerializedRecordBatchResult::success(writer, row_count)
156}
157
158type FileWriteBundle = (Receiver<RecordBatch>, SerializerType, WriterType);
159pub(crate) async fn stateless_serialize_and_write_files(
165 mut rx: Receiver<FileWriteBundle>,
166 tx: tokio::sync::oneshot::Sender<u64>,
167) -> Result<()> {
168 let mut row_count = 0;
169 let mut any_errors = false;
171 let mut triggering_error = None;
173 let mut any_abort_errors = false;
176 let mut join_set = JoinSet::new();
177 while let Some((data_rx, serializer, writer)) = rx.recv().await {
178 join_set.spawn(async move {
179 serialize_rb_stream_to_object_store(data_rx, serializer, writer).await
180 });
181 }
182 let mut finished_writers = Vec::new();
183 while let Some(result) = join_set.join_next().await {
184 match result {
185 Ok(res) => match res {
186 SerializedRecordBatchResult::Success {
187 writer,
188 row_count: cnt,
189 } => {
190 finished_writers.push(writer);
191 row_count += cnt;
192 }
193 SerializedRecordBatchResult::Failure { writer, err } => {
194 finished_writers.extend(writer);
195 any_errors = true;
196 triggering_error = Some(err);
197 }
198 },
199 Err(e) => {
200 any_errors = true;
204 any_abort_errors = true;
205 triggering_error = Some(internal_datafusion_err!(
206 "Unexpected join error while serializing file {e}"
207 ));
208 }
209 }
210 }
211
212 for mut writer in finished_writers.into_iter() {
214 writer.shutdown()
215 .await
216 .map_err(|_| internal_datafusion_err!("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!"))?;
217 }
218
219 if any_errors {
220 match any_abort_errors{
221 true => return internal_err!("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written."),
222 false => match triggering_error {
223 Some(e) => return Err(e),
224 None => return internal_err!("Unknown Error encountered during writing to ObjectStore. All writers successfully aborted.")
225 }
226 }
227 }
228
229 tx.send(row_count as u64).map_err(|_| {
230 internal_datafusion_err!(
231 "Error encountered while sending row count back to file sink!"
232 )
233 })?;
234 Ok(())
235}
236
237pub async fn spawn_writer_tasks_and_join(
241 context: &Arc<TaskContext>,
242 serializer: Arc<dyn BatchSerializer>,
243 compression: FileCompressionType,
244 object_store: Arc<dyn ObjectStore>,
245 demux_task: SpawnedTask<Result<()>>,
246 mut file_stream_rx: DemuxedStreamReceiver,
247) -> Result<u64> {
248 let rb_buffer_size = &context
249 .session_config()
250 .options()
251 .execution
252 .max_buffered_batches_per_output_file;
253
254 let (tx_file_bundle, rx_file_bundle) = mpsc::channel(rb_buffer_size / 2);
255 let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel();
256 let write_coordinator_task = SpawnedTask::spawn(async move {
257 stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt).await
258 });
259 while let Some((location, rb_stream)) = file_stream_rx.recv().await {
260 let writer =
261 create_writer(compression, &location, Arc::clone(&object_store)).await?;
262
263 if tx_file_bundle
264 .send((rb_stream, Arc::clone(&serializer), writer))
265 .await
266 .is_err()
267 {
268 internal_datafusion_err!(
269 "Writer receive file bundle channel closed unexpectedly!"
270 );
271 }
272 }
273
274 drop(tx_file_bundle);
276
277 let (r1, r2) = join!(
278 write_coordinator_task.join_unwind(),
279 demux_task.join_unwind()
280 );
281 r1.map_err(DataFusionError::ExecutionJoin)??;
282 r2.map_err(DataFusionError::ExecutionJoin)??;
283
284 rx_row_cnt.await.map_err(|_| {
286 internal_datafusion_err!("Did not receive row count from write coordinator")
287 })
288}