datafusion_datasource/write/
orchestration.rs1use std::sync::Arc;
23
24use super::demux::DemuxedStreamReceiver;
25use super::{BatchSerializer, ObjectWriterBuilder};
26use crate::file_compression_type::FileCompressionType;
27use datafusion_common::error::Result;
28
29use arrow::array::RecordBatch;
30use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError};
31use datafusion_common_runtime::{JoinSet, SpawnedTask};
32use datafusion_execution::TaskContext;
33
34use bytes::Bytes;
35use futures::join;
36use object_store::ObjectStore;
37use tokio::io::{AsyncWrite, AsyncWriteExt};
38use tokio::sync::mpsc::{self, Receiver};
39
40type WriterType = Box<dyn AsyncWrite + Send + Unpin>;
41type SerializerType = Arc<dyn BatchSerializer>;
42
43pub(crate) enum SerializedRecordBatchResult {
45 Success {
46 writer: WriterType,
48
49 row_count: usize,
51 },
52 Failure {
53 writer: Option<WriterType>,
57
58 err: DataFusionError,
60 },
61}
62
63impl SerializedRecordBatchResult {
64 pub fn success(writer: WriterType, row_count: usize) -> Self {
66 Self::Success { writer, row_count }
67 }
68
69 pub fn failure(writer: Option<WriterType>, err: DataFusionError) -> Self {
70 Self::Failure { writer, err }
71 }
72}
73
74pub(crate) async fn serialize_rb_stream_to_object_store(
84 mut data_rx: Receiver<RecordBatch>,
85 serializer: Arc<dyn BatchSerializer>,
86 mut writer: WriterType,
87) -> SerializedRecordBatchResult {
88 let (tx, mut rx) =
89 mpsc::channel::<SpawnedTask<Result<(usize, Bytes), DataFusionError>>>(100);
90 let serialize_task = SpawnedTask::spawn(async move {
91 let mut initial = true;
94 while let Some(batch) = data_rx.recv().await {
95 let serializer_clone = Arc::clone(&serializer);
96 let task = SpawnedTask::spawn(async move {
97 let num_rows = batch.num_rows();
98 let bytes = serializer_clone.serialize(batch, initial)?;
99 Ok((num_rows, bytes))
100 });
101 if initial {
102 initial = false;
103 }
104 tx.send(task).await.map_err(|_| {
105 internal_datafusion_err!("Unknown error writing to object store")
106 })?;
107 }
108 Ok(())
109 });
110
111 let mut row_count = 0;
112 while let Some(task) = rx.recv().await {
113 match task.join().await {
114 Ok(Ok((cnt, bytes))) => {
115 match writer.write_all(&bytes).await {
116 Ok(_) => (),
117 Err(e) => {
118 return SerializedRecordBatchResult::failure(
119 None,
120 DataFusionError::Execution(format!(
121 "Error writing to object store: {e}"
122 )),
123 )
124 }
125 };
126 row_count += cnt;
127 }
128 Ok(Err(e)) => {
129 return SerializedRecordBatchResult::failure(Some(writer), e);
131 }
132 Err(e) => {
133 return SerializedRecordBatchResult::failure(
135 Some(writer),
136 DataFusionError::Execution(format!(
137 "Serialization task panicked or was cancelled: {e}"
138 )),
139 );
140 }
141 }
142 }
143
144 match serialize_task.join().await {
145 Ok(Ok(_)) => (),
146 Ok(Err(e)) => return SerializedRecordBatchResult::failure(Some(writer), e),
147 Err(_) => {
148 return SerializedRecordBatchResult::failure(
149 Some(writer),
150 internal_datafusion_err!("Unknown error writing to object store"),
151 )
152 }
153 }
154 SerializedRecordBatchResult::success(writer, row_count)
155}
156
157type FileWriteBundle = (Receiver<RecordBatch>, SerializerType, WriterType);
158pub(crate) async fn stateless_serialize_and_write_files(
164 mut rx: Receiver<FileWriteBundle>,
165 tx: tokio::sync::oneshot::Sender<u64>,
166) -> Result<()> {
167 let mut row_count = 0;
168 let mut any_errors = false;
170 let mut triggering_error = None;
172 let mut any_abort_errors = false;
175 let mut join_set = JoinSet::new();
176 while let Some((data_rx, serializer, writer)) = rx.recv().await {
177 join_set.spawn(async move {
178 serialize_rb_stream_to_object_store(data_rx, serializer, writer).await
179 });
180 }
181 let mut finished_writers = Vec::new();
182 while let Some(result) = join_set.join_next().await {
183 match result {
184 Ok(res) => match res {
185 SerializedRecordBatchResult::Success {
186 writer,
187 row_count: cnt,
188 } => {
189 finished_writers.push(writer);
190 row_count += cnt;
191 }
192 SerializedRecordBatchResult::Failure { writer, err } => {
193 finished_writers.extend(writer);
194 any_errors = true;
195 triggering_error = Some(err);
196 }
197 },
198 Err(e) => {
199 any_errors = true;
203 any_abort_errors = true;
204 triggering_error = Some(internal_datafusion_err!(
205 "Unexpected join error while serializing file {e}"
206 ));
207 }
208 }
209 }
210
211 for mut writer in finished_writers.into_iter() {
213 writer.shutdown()
214 .await
215 .map_err(|_| internal_datafusion_err!("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!"))?;
216 }
217
218 if any_errors {
219 match any_abort_errors{
220 true => return internal_err!("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written."),
221 false => match triggering_error {
222 Some(e) => return Err(e),
223 None => return internal_err!("Unknown Error encountered during writing to ObjectStore. All writers successfully aborted.")
224 }
225 }
226 }
227
228 tx.send(row_count as u64).map_err(|_| {
229 internal_datafusion_err!(
230 "Error encountered while sending row count back to file sink!"
231 )
232 })?;
233 Ok(())
234}
235
236pub async fn spawn_writer_tasks_and_join(
240 context: &Arc<TaskContext>,
241 serializer: Arc<dyn BatchSerializer>,
242 compression: FileCompressionType,
243 object_store: Arc<dyn ObjectStore>,
244 demux_task: SpawnedTask<Result<()>>,
245 mut file_stream_rx: DemuxedStreamReceiver,
246) -> Result<u64> {
247 let rb_buffer_size = &context
248 .session_config()
249 .options()
250 .execution
251 .max_buffered_batches_per_output_file;
252
253 let (tx_file_bundle, rx_file_bundle) = mpsc::channel(rb_buffer_size / 2);
254 let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel();
255 let write_coordinator_task = SpawnedTask::spawn(async move {
256 stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt).await
257 });
258 while let Some((location, rb_stream)) = file_stream_rx.recv().await {
259 let writer =
260 ObjectWriterBuilder::new(compression, &location, Arc::clone(&object_store))
261 .with_buffer_size(Some(
262 context
263 .session_config()
264 .options()
265 .execution
266 .objectstore_writer_buffer_size,
267 ))
268 .build()?;
269
270 if tx_file_bundle
271 .send((rb_stream, Arc::clone(&serializer), writer))
272 .await
273 .is_err()
274 {
275 internal_datafusion_err!(
276 "Writer receive file bundle channel closed unexpectedly!"
277 );
278 }
279 }
280
281 drop(tx_file_bundle);
283
284 let (r1, r2) = join!(
285 write_coordinator_task.join_unwind(),
286 demux_task.join_unwind()
287 );
288 r1.map_err(DataFusionError::ExecutionJoin)??;
289 r2.map_err(DataFusionError::ExecutionJoin)??;
290
291 rx_row_cnt.await.map_err(|_| {
293 internal_datafusion_err!("Did not receive row count from write coordinator")
294 })
295}