datafusion_datasource/write/
orchestration.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Module containing helper methods/traits related to
19//! orchestrating file serialization, streaming to object store,
20//! parallelization, and abort handling
21
22use std::sync::Arc;
23
24use super::demux::DemuxedStreamReceiver;
25use super::{BatchSerializer, ObjectWriterBuilder};
26use crate::file_compression_type::FileCompressionType;
27use datafusion_common::error::Result;
28
29use arrow::array::RecordBatch;
30use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError};
31use datafusion_common_runtime::{JoinSet, SpawnedTask};
32use datafusion_execution::TaskContext;
33
34use bytes::Bytes;
35use futures::join;
36use object_store::ObjectStore;
37use tokio::io::{AsyncWrite, AsyncWriteExt};
38use tokio::sync::mpsc::{self, Receiver};
39
40type WriterType = Box<dyn AsyncWrite + Send + Unpin>;
41type SerializerType = Arc<dyn BatchSerializer>;
42
43/// Result of calling [`serialize_rb_stream_to_object_store`]
44pub(crate) enum SerializedRecordBatchResult {
45    Success {
46        /// the writer
47        writer: WriterType,
48
49        /// the number of rows successfully written
50        row_count: usize,
51    },
52    Failure {
53        /// As explained in [`serialize_rb_stream_to_object_store`]:
54        /// - If an IO error occurred that involved the ObjectStore writer, then the writer will not be returned to the caller
55        /// - Otherwise, the writer is returned to the caller
56        writer: Option<WriterType>,
57
58        /// the actual error that occurred
59        err: DataFusionError,
60    },
61}
62
63impl SerializedRecordBatchResult {
64    /// Create the success variant
65    pub fn success(writer: WriterType, row_count: usize) -> Self {
66        Self::Success { writer, row_count }
67    }
68
69    pub fn failure(writer: Option<WriterType>, err: DataFusionError) -> Self {
70        Self::Failure { writer, err }
71    }
72}
73
74/// Serializes a single data stream in parallel and writes to an ObjectStore concurrently.
75/// Data order is preserved.
76///
77/// In the event of a non-IO error which does not involve the ObjectStore writer,
78/// the writer returned to the caller in addition to the error,
79/// so that failed writes may be aborted.
80///
81/// In the event of an IO error involving the ObjectStore writer,
82/// the writer is dropped to avoid calling further methods on it which might panic.
83pub(crate) async fn serialize_rb_stream_to_object_store(
84    mut data_rx: Receiver<RecordBatch>,
85    serializer: Arc<dyn BatchSerializer>,
86    mut writer: WriterType,
87) -> SerializedRecordBatchResult {
88    let (tx, mut rx) =
89        mpsc::channel::<SpawnedTask<Result<(usize, Bytes), DataFusionError>>>(100);
90    let serialize_task = SpawnedTask::spawn(async move {
91        // Some serializers (like CSV) handle the first batch differently than
92        // subsequent batches, so we track that here.
93        let mut initial = true;
94        while let Some(batch) = data_rx.recv().await {
95            let serializer_clone = Arc::clone(&serializer);
96            let task = SpawnedTask::spawn(async move {
97                let num_rows = batch.num_rows();
98                let bytes = serializer_clone.serialize(batch, initial)?;
99                Ok((num_rows, bytes))
100            });
101            if initial {
102                initial = false;
103            }
104            tx.send(task).await.map_err(|_| {
105                internal_datafusion_err!("Unknown error writing to object store")
106            })?;
107        }
108        Ok(())
109    });
110
111    let mut row_count = 0;
112    while let Some(task) = rx.recv().await {
113        match task.join().await {
114            Ok(Ok((cnt, bytes))) => {
115                match writer.write_all(&bytes).await {
116                    Ok(_) => (),
117                    Err(e) => {
118                        return SerializedRecordBatchResult::failure(
119                            None,
120                            DataFusionError::Execution(format!(
121                                "Error writing to object store: {e}"
122                            )),
123                        )
124                    }
125                };
126                row_count += cnt;
127            }
128            Ok(Err(e)) => {
129                // Return the writer along with the error
130                return SerializedRecordBatchResult::failure(Some(writer), e);
131            }
132            Err(e) => {
133                // Handle task panic or cancellation
134                return SerializedRecordBatchResult::failure(
135                    Some(writer),
136                    DataFusionError::Execution(format!(
137                        "Serialization task panicked or was cancelled: {e}"
138                    )),
139                );
140            }
141        }
142    }
143
144    match serialize_task.join().await {
145        Ok(Ok(_)) => (),
146        Ok(Err(e)) => return SerializedRecordBatchResult::failure(Some(writer), e),
147        Err(_) => {
148            return SerializedRecordBatchResult::failure(
149                Some(writer),
150                internal_datafusion_err!("Unknown error writing to object store"),
151            )
152        }
153    }
154    SerializedRecordBatchResult::success(writer, row_count)
155}
156
157type FileWriteBundle = (Receiver<RecordBatch>, SerializerType, WriterType);
158/// Contains the common logic for serializing RecordBatches and
159/// writing the resulting bytes to an ObjectStore.
160/// Serialization is assumed to be stateless, i.e.
161/// each RecordBatch can be serialized without any
162/// dependency on the RecordBatches before or after.
163pub(crate) async fn stateless_serialize_and_write_files(
164    mut rx: Receiver<FileWriteBundle>,
165    tx: tokio::sync::oneshot::Sender<u64>,
166) -> Result<()> {
167    let mut row_count = 0;
168    // tracks if any writers encountered an error triggering the need to abort
169    let mut any_errors = false;
170    // tracks the specific error triggering abort
171    let mut triggering_error = None;
172    // tracks if any errors were encountered in the process of aborting writers.
173    // if true, we may not have a guarantee that all written data was cleaned up.
174    let mut any_abort_errors = false;
175    let mut join_set = JoinSet::new();
176    while let Some((data_rx, serializer, writer)) = rx.recv().await {
177        join_set.spawn(async move {
178            serialize_rb_stream_to_object_store(data_rx, serializer, writer).await
179        });
180    }
181    let mut finished_writers = Vec::new();
182    while let Some(result) = join_set.join_next().await {
183        match result {
184            Ok(res) => match res {
185                SerializedRecordBatchResult::Success {
186                    writer,
187                    row_count: cnt,
188                } => {
189                    finished_writers.push(writer);
190                    row_count += cnt;
191                }
192                SerializedRecordBatchResult::Failure { writer, err } => {
193                    finished_writers.extend(writer);
194                    any_errors = true;
195                    triggering_error = Some(err);
196                }
197            },
198            Err(e) => {
199                // Don't panic, instead try to clean up as many writers as possible.
200                // If we hit this code, ownership of a writer was not joined back to
201                // this thread, so we cannot clean it up (hence any_abort_errors is true)
202                any_errors = true;
203                any_abort_errors = true;
204                triggering_error = Some(internal_datafusion_err!(
205                    "Unexpected join error while serializing file {e}"
206                ));
207            }
208        }
209    }
210
211    // Finalize or abort writers as appropriate
212    for mut writer in finished_writers.into_iter() {
213        writer.shutdown()
214                    .await
215                    .map_err(|_| internal_datafusion_err!("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!"))?;
216    }
217
218    if any_errors {
219        match any_abort_errors{
220            true => return internal_err!("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written."),
221            false => match triggering_error {
222                Some(e) => return Err(e),
223                None => return internal_err!("Unknown Error encountered during writing to ObjectStore. All writers successfully aborted.")
224            }
225        }
226    }
227
228    tx.send(row_count as u64).map_err(|_| {
229        internal_datafusion_err!(
230            "Error encountered while sending row count back to file sink!"
231        )
232    })?;
233    Ok(())
234}
235
236/// Orchestrates multipart put of a dynamic number of output files from a single input stream
237/// for any statelessly serialized file type. That is, any file type for which each [RecordBatch]
238/// can be serialized independently of all other [RecordBatch]s.
239pub async fn spawn_writer_tasks_and_join(
240    context: &Arc<TaskContext>,
241    serializer: Arc<dyn BatchSerializer>,
242    compression: FileCompressionType,
243    object_store: Arc<dyn ObjectStore>,
244    demux_task: SpawnedTask<Result<()>>,
245    mut file_stream_rx: DemuxedStreamReceiver,
246) -> Result<u64> {
247    let rb_buffer_size = &context
248        .session_config()
249        .options()
250        .execution
251        .max_buffered_batches_per_output_file;
252
253    let (tx_file_bundle, rx_file_bundle) = mpsc::channel(rb_buffer_size / 2);
254    let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel();
255    let write_coordinator_task = SpawnedTask::spawn(async move {
256        stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt).await
257    });
258    while let Some((location, rb_stream)) = file_stream_rx.recv().await {
259        let writer =
260            ObjectWriterBuilder::new(compression, &location, Arc::clone(&object_store))
261                .with_buffer_size(Some(
262                    context
263                        .session_config()
264                        .options()
265                        .execution
266                        .objectstore_writer_buffer_size,
267                ))
268                .build()?;
269
270        if tx_file_bundle
271            .send((rb_stream, Arc::clone(&serializer), writer))
272            .await
273            .is_err()
274        {
275            internal_datafusion_err!(
276                "Writer receive file bundle channel closed unexpectedly!"
277            );
278        }
279    }
280
281    // Signal to the write coordinator that no more files are coming
282    drop(tx_file_bundle);
283
284    let (r1, r2) = join!(
285        write_coordinator_task.join_unwind(),
286        demux_task.join_unwind()
287    );
288    r1.map_err(DataFusionError::ExecutionJoin)??;
289    r2.map_err(DataFusionError::ExecutionJoin)??;
290
291    // Return total row count:
292    rx_row_cnt.await.map_err(|_| {
293        internal_datafusion_err!("Did not receive row count from write coordinator")
294    })
295}