datafusion_datasource/write/
orchestration.rs1use std::sync::Arc;
23
24use super::demux::DemuxedStreamReceiver;
25use super::{BatchSerializer, ObjectWriterBuilder};
26use crate::file_compression_type::FileCompressionType;
27use datafusion_common::error::Result;
28
29use arrow::array::RecordBatch;
30use datafusion_common::{
31 DataFusionError, exec_datafusion_err, internal_datafusion_err, internal_err,
32};
33use datafusion_common_runtime::{JoinSet, SpawnedTask};
34use datafusion_execution::TaskContext;
35
36use bytes::Bytes;
37use futures::join;
38use object_store::ObjectStore;
39use tokio::io::{AsyncWrite, AsyncWriteExt};
40use tokio::sync::mpsc::{self, Receiver};
41
42type WriterType = Box<dyn AsyncWrite + Send + Unpin>;
43type SerializerType = Arc<dyn BatchSerializer>;
44
45pub(crate) enum SerializedRecordBatchResult {
47 Success {
48 writer: WriterType,
50
51 row_count: usize,
53 },
54 Failure {
55 writer: Option<WriterType>,
59
60 err: DataFusionError,
62 },
63}
64
65impl SerializedRecordBatchResult {
66 pub fn success(writer: WriterType, row_count: usize) -> Self {
68 Self::Success { writer, row_count }
69 }
70
71 pub fn failure(writer: Option<WriterType>, err: DataFusionError) -> Self {
72 Self::Failure { writer, err }
73 }
74}
75
76pub(crate) async fn serialize_rb_stream_to_object_store(
86 mut data_rx: Receiver<RecordBatch>,
87 serializer: Arc<dyn BatchSerializer>,
88 mut writer: WriterType,
89) -> SerializedRecordBatchResult {
90 let (tx, mut rx) =
91 mpsc::channel::<SpawnedTask<Result<(usize, Bytes), DataFusionError>>>(100);
92 let serialize_task = SpawnedTask::spawn(async move {
93 let mut initial = true;
96 while let Some(batch) = data_rx.recv().await {
97 let serializer_clone = Arc::clone(&serializer);
98 let task = SpawnedTask::spawn(async move {
99 let num_rows = batch.num_rows();
100 let bytes = serializer_clone.serialize(batch, initial)?;
101 Ok((num_rows, bytes))
102 });
103 if initial {
104 initial = false;
105 }
106 tx.send(task).await.map_err(|_| {
107 internal_datafusion_err!("Unknown error writing to object store")
108 })?;
109 }
110 Ok(())
111 });
112
113 let mut row_count = 0;
114 while let Some(task) = rx.recv().await {
115 match task.join().await {
116 Ok(Ok((cnt, bytes))) => {
117 match writer.write_all(&bytes).await {
118 Ok(_) => (),
119 Err(e) => {
120 return SerializedRecordBatchResult::failure(
121 None,
122 exec_datafusion_err!("Error writing to object store: {e}"),
123 );
124 }
125 };
126 row_count += cnt;
127 }
128 Ok(Err(e)) => {
129 return SerializedRecordBatchResult::failure(Some(writer), e);
131 }
132 Err(e) => {
133 return SerializedRecordBatchResult::failure(
135 Some(writer),
136 exec_datafusion_err!(
137 "Serialization task panicked or was cancelled: {e}"
138 ),
139 );
140 }
141 }
142 }
143
144 match serialize_task.join().await {
145 Ok(Ok(_)) => (),
146 Ok(Err(e)) => return SerializedRecordBatchResult::failure(Some(writer), e),
147 Err(_) => {
148 return SerializedRecordBatchResult::failure(
149 Some(writer),
150 internal_datafusion_err!("Unknown error writing to object store"),
151 );
152 }
153 }
154 SerializedRecordBatchResult::success(writer, row_count)
155}
156
157type FileWriteBundle = (Receiver<RecordBatch>, SerializerType, WriterType);
158pub(crate) async fn stateless_serialize_and_write_files(
164 mut rx: Receiver<FileWriteBundle>,
165 tx: tokio::sync::oneshot::Sender<u64>,
166) -> Result<()> {
167 let mut row_count = 0;
168 let mut any_errors = false;
170 let mut triggering_error = None;
172 let mut any_abort_errors = false;
175 let mut join_set = JoinSet::new();
176 while let Some((data_rx, serializer, writer)) = rx.recv().await {
177 join_set.spawn(async move {
178 serialize_rb_stream_to_object_store(data_rx, serializer, writer).await
179 });
180 }
181 let mut finished_writers = Vec::new();
182 while let Some(result) = join_set.join_next().await {
183 match result {
184 Ok(res) => match res {
185 SerializedRecordBatchResult::Success {
186 writer,
187 row_count: cnt,
188 } => {
189 finished_writers.push(writer);
190 row_count += cnt;
191 }
192 SerializedRecordBatchResult::Failure { writer, err } => {
193 finished_writers.extend(writer);
194 any_errors = true;
195 triggering_error = Some(err);
196 }
197 },
198 Err(e) => {
199 any_errors = true;
203 any_abort_errors = true;
204 triggering_error = Some(internal_datafusion_err!(
205 "Unexpected join error while serializing file {e}"
206 ));
207 }
208 }
209 }
210
211 for mut writer in finished_writers.into_iter() {
213 writer.shutdown()
214 .await
215 .map_err(|_| internal_datafusion_err!("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!"))?;
216 }
217
218 if any_errors {
219 match any_abort_errors {
220 true => {
221 return internal_err!(
222 "Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written."
223 );
224 }
225 false => match triggering_error {
226 Some(e) => return Err(e),
227 None => {
228 return internal_err!(
229 "Unknown Error encountered during writing to ObjectStore. All writers successfully aborted."
230 );
231 }
232 },
233 }
234 }
235
236 tx.send(row_count as u64).map_err(|_| {
237 internal_datafusion_err!(
238 "Error encountered while sending row count back to file sink!"
239 )
240 })?;
241 Ok(())
242}
243
244pub async fn spawn_writer_tasks_and_join(
248 context: &Arc<TaskContext>,
249 serializer: Arc<dyn BatchSerializer>,
250 compression: FileCompressionType,
251 compression_level: Option<u32>,
252 object_store: Arc<dyn ObjectStore>,
253 demux_task: SpawnedTask<Result<()>>,
254 mut file_stream_rx: DemuxedStreamReceiver,
255) -> Result<u64> {
256 let rb_buffer_size = &context
257 .session_config()
258 .options()
259 .execution
260 .max_buffered_batches_per_output_file;
261
262 let (tx_file_bundle, rx_file_bundle) = mpsc::channel(rb_buffer_size / 2);
263 let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel();
264 let write_coordinator_task = SpawnedTask::spawn(async move {
265 stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt).await
266 });
267 while let Some((location, rb_stream)) = file_stream_rx.recv().await {
268 let writer =
269 ObjectWriterBuilder::new(compression, &location, Arc::clone(&object_store))
270 .with_buffer_size(Some(
271 context
272 .session_config()
273 .options()
274 .execution
275 .objectstore_writer_buffer_size,
276 ))
277 .with_compression_level(compression_level)
278 .build()?;
279
280 if tx_file_bundle
281 .send((rb_stream, Arc::clone(&serializer), writer))
282 .await
283 .is_err()
284 {
285 internal_datafusion_err!(
286 "Writer receive file bundle channel closed unexpectedly!"
287 );
288 }
289 }
290
291 drop(tx_file_bundle);
293
294 let (r1, r2) = join!(
295 write_coordinator_task.join_unwind(),
296 demux_task.join_unwind()
297 );
298 r1.map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??;
299 r2.map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??;
300
301 rx_row_cnt.await.map_err(|_| {
303 internal_datafusion_err!("Did not receive row count from write coordinator")
304 })
305}