datafusion_datasource/write/
orchestration.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Module containing helper methods/traits related to
19//! orchestrating file serialization, streaming to object store,
20//! parallelization, and abort handling
21
22use std::sync::Arc;
23
24use super::demux::DemuxedStreamReceiver;
25use super::{create_writer, BatchSerializer};
26use crate::file_compression_type::FileCompressionType;
27use datafusion_common::error::Result;
28
29use arrow::array::RecordBatch;
30use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError};
31use datafusion_common_runtime::SpawnedTask;
32use datafusion_execution::TaskContext;
33
34use bytes::Bytes;
35use futures::join;
36use object_store::ObjectStore;
37use tokio::io::{AsyncWrite, AsyncWriteExt};
38use tokio::sync::mpsc::{self, Receiver};
39use tokio::task::JoinSet;
40
41type WriterType = Box<dyn AsyncWrite + Send + Unpin>;
42type SerializerType = Arc<dyn BatchSerializer>;
43
44/// Result of calling [`serialize_rb_stream_to_object_store`]
45pub(crate) enum SerializedRecordBatchResult {
46    Success {
47        /// the writer
48        writer: WriterType,
49
50        /// the number of rows successfully written
51        row_count: usize,
52    },
53    Failure {
54        /// As explained in [`serialize_rb_stream_to_object_store`]:
55        /// - If an IO error occurred that involved the ObjectStore writer, then the writer will not be returned to the caller
56        /// - Otherwise, the writer is returned to the caller
57        writer: Option<WriterType>,
58
59        /// the actual error that occurred
60        err: DataFusionError,
61    },
62}
63
64impl SerializedRecordBatchResult {
65    /// Create the success variant
66    pub fn success(writer: WriterType, row_count: usize) -> Self {
67        Self::Success { writer, row_count }
68    }
69
70    pub fn failure(writer: Option<WriterType>, err: DataFusionError) -> Self {
71        Self::Failure { writer, err }
72    }
73}
74
75/// Serializes a single data stream in parallel and writes to an ObjectStore concurrently.
76/// Data order is preserved.
77///
78/// In the event of a non-IO error which does not involve the ObjectStore writer,
79/// the writer returned to the caller in addition to the error,
80/// so that failed writes may be aborted.
81///
82/// In the event of an IO error involving the ObjectStore writer,
83/// the writer is dropped to avoid calling further methods on it which might panic.
84pub(crate) async fn serialize_rb_stream_to_object_store(
85    mut data_rx: Receiver<RecordBatch>,
86    serializer: Arc<dyn BatchSerializer>,
87    mut writer: WriterType,
88) -> SerializedRecordBatchResult {
89    let (tx, mut rx) =
90        mpsc::channel::<SpawnedTask<Result<(usize, Bytes), DataFusionError>>>(100);
91    let serialize_task = SpawnedTask::spawn(async move {
92        // Some serializers (like CSV) handle the first batch differently than
93        // subsequent batches, so we track that here.
94        let mut initial = true;
95        while let Some(batch) = data_rx.recv().await {
96            let serializer_clone = Arc::clone(&serializer);
97            let task = SpawnedTask::spawn(async move {
98                let num_rows = batch.num_rows();
99                let bytes = serializer_clone.serialize(batch, initial)?;
100                Ok((num_rows, bytes))
101            });
102            if initial {
103                initial = false;
104            }
105            tx.send(task).await.map_err(|_| {
106                internal_datafusion_err!("Unknown error writing to object store")
107            })?;
108        }
109        Ok(())
110    });
111
112    let mut row_count = 0;
113    while let Some(task) = rx.recv().await {
114        match task.join().await {
115            Ok(Ok((cnt, bytes))) => {
116                match writer.write_all(&bytes).await {
117                    Ok(_) => (),
118                    Err(e) => {
119                        return SerializedRecordBatchResult::failure(
120                            None,
121                            DataFusionError::Execution(format!(
122                                "Error writing to object store: {e}"
123                            )),
124                        )
125                    }
126                };
127                row_count += cnt;
128            }
129            Ok(Err(e)) => {
130                // Return the writer along with the error
131                return SerializedRecordBatchResult::failure(Some(writer), e);
132            }
133            Err(e) => {
134                // Handle task panic or cancellation
135                return SerializedRecordBatchResult::failure(
136                    Some(writer),
137                    DataFusionError::Execution(format!(
138                        "Serialization task panicked or was cancelled: {e}"
139                    )),
140                );
141            }
142        }
143    }
144
145    match serialize_task.join().await {
146        Ok(Ok(_)) => (),
147        Ok(Err(e)) => return SerializedRecordBatchResult::failure(Some(writer), e),
148        Err(_) => {
149            return SerializedRecordBatchResult::failure(
150                Some(writer),
151                internal_datafusion_err!("Unknown error writing to object store"),
152            )
153        }
154    }
155    SerializedRecordBatchResult::success(writer, row_count)
156}
157
158type FileWriteBundle = (Receiver<RecordBatch>, SerializerType, WriterType);
159/// Contains the common logic for serializing RecordBatches and
160/// writing the resulting bytes to an ObjectStore.
161/// Serialization is assumed to be stateless, i.e.
162/// each RecordBatch can be serialized without any
163/// dependency on the RecordBatches before or after.
164pub(crate) async fn stateless_serialize_and_write_files(
165    mut rx: Receiver<FileWriteBundle>,
166    tx: tokio::sync::oneshot::Sender<u64>,
167) -> Result<()> {
168    let mut row_count = 0;
169    // tracks if any writers encountered an error triggering the need to abort
170    let mut any_errors = false;
171    // tracks the specific error triggering abort
172    let mut triggering_error = None;
173    // tracks if any errors were encountered in the process of aborting writers.
174    // if true, we may not have a guarantee that all written data was cleaned up.
175    let mut any_abort_errors = false;
176    let mut join_set = JoinSet::new();
177    while let Some((data_rx, serializer, writer)) = rx.recv().await {
178        join_set.spawn(async move {
179            serialize_rb_stream_to_object_store(data_rx, serializer, writer).await
180        });
181    }
182    let mut finished_writers = Vec::new();
183    while let Some(result) = join_set.join_next().await {
184        match result {
185            Ok(res) => match res {
186                SerializedRecordBatchResult::Success {
187                    writer,
188                    row_count: cnt,
189                } => {
190                    finished_writers.push(writer);
191                    row_count += cnt;
192                }
193                SerializedRecordBatchResult::Failure { writer, err } => {
194                    finished_writers.extend(writer);
195                    any_errors = true;
196                    triggering_error = Some(err);
197                }
198            },
199            Err(e) => {
200                // Don't panic, instead try to clean up as many writers as possible.
201                // If we hit this code, ownership of a writer was not joined back to
202                // this thread, so we cannot clean it up (hence any_abort_errors is true)
203                any_errors = true;
204                any_abort_errors = true;
205                triggering_error = Some(internal_datafusion_err!(
206                    "Unexpected join error while serializing file {e}"
207                ));
208            }
209        }
210    }
211
212    // Finalize or abort writers as appropriate
213    for mut writer in finished_writers.into_iter() {
214        writer.shutdown()
215                    .await
216                    .map_err(|_| internal_datafusion_err!("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!"))?;
217    }
218
219    if any_errors {
220        match any_abort_errors{
221            true => return internal_err!("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written."),
222            false => match triggering_error {
223                Some(e) => return Err(e),
224                None => return internal_err!("Unknown Error encountered during writing to ObjectStore. All writers successfully aborted.")
225            }
226        }
227    }
228
229    tx.send(row_count as u64).map_err(|_| {
230        internal_datafusion_err!(
231            "Error encountered while sending row count back to file sink!"
232        )
233    })?;
234    Ok(())
235}
236
237/// Orchestrates multipart put of a dynamic number of output files from a single input stream
238/// for any statelessly serialized file type. That is, any file type for which each [RecordBatch]
239/// can be serialized independently of all other [RecordBatch]s.
240pub async fn spawn_writer_tasks_and_join(
241    context: &Arc<TaskContext>,
242    serializer: Arc<dyn BatchSerializer>,
243    compression: FileCompressionType,
244    object_store: Arc<dyn ObjectStore>,
245    demux_task: SpawnedTask<Result<()>>,
246    mut file_stream_rx: DemuxedStreamReceiver,
247) -> Result<u64> {
248    let rb_buffer_size = &context
249        .session_config()
250        .options()
251        .execution
252        .max_buffered_batches_per_output_file;
253
254    let (tx_file_bundle, rx_file_bundle) = mpsc::channel(rb_buffer_size / 2);
255    let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel();
256    let write_coordinator_task = SpawnedTask::spawn(async move {
257        stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt).await
258    });
259    while let Some((location, rb_stream)) = file_stream_rx.recv().await {
260        let writer =
261            create_writer(compression, &location, Arc::clone(&object_store)).await?;
262
263        if tx_file_bundle
264            .send((rb_stream, Arc::clone(&serializer), writer))
265            .await
266            .is_err()
267        {
268            internal_datafusion_err!(
269                "Writer receive file bundle channel closed unexpectedly!"
270            );
271        }
272    }
273
274    // Signal to the write coordinator that no more files are coming
275    drop(tx_file_bundle);
276
277    let (r1, r2) = join!(
278        write_coordinator_task.join_unwind(),
279        demux_task.join_unwind()
280    );
281    r1.map_err(DataFusionError::ExecutionJoin)??;
282    r2.map_err(DataFusionError::ExecutionJoin)??;
283
284    // Return total row count:
285    rx_row_cnt.await.map_err(|_| {
286        internal_datafusion_err!("Did not receive row count from write coordinator")
287    })
288}