Skip to main content

fluss/client/table/
upsert.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::client::{RowBytes, WriteFormat, WriteRecord, WriteResultFuture, WriterClient};
19use crate::error::Error::{IllegalArgument, UnexpectedError};
20use crate::error::Result;
21use crate::metadata::{RowType, TableInfo, TablePath};
22use crate::row::InternalRow;
23use crate::row::encode::{KeyEncoder, KeyEncoderFactory, RowEncoder, RowEncoderFactory};
24use crate::row::field_getter::FieldGetter;
25use std::sync::{Arc, Mutex};
26
27use crate::client::table::partition_getter::{PartitionGetter, get_physical_path};
28use bitvec::prelude::bitvec;
29use bytes::Bytes;
30
31#[allow(dead_code)]
32pub struct TableUpsert {
33    table_path: TablePath,
34    table_info: TableInfo,
35    writer_client: Arc<WriterClient>,
36    target_columns: Option<Arc<Vec<usize>>>,
37}
38
39#[allow(dead_code)]
40impl TableUpsert {
41    pub fn new(
42        table_path: TablePath,
43        table_info: TableInfo,
44        writer_client: Arc<WriterClient>,
45    ) -> Self {
46        Self {
47            table_path,
48            table_info,
49            writer_client,
50            target_columns: None,
51        }
52    }
53
54    pub fn partial_update(&self, target_columns: Option<Vec<usize>>) -> Result<Self> {
55        if let Some(columns) = &target_columns {
56            let num_columns = self.table_info.row_type().fields().len();
57
58            if let Some(&invalid_column) = columns.iter().find(|&&col| col >= num_columns) {
59                return Err(IllegalArgument {
60                    message: format!(
61                        "Invalid target column index: {invalid_column} for table {}. The table only has {num_columns} columns.",
62                        self.table_path
63                    ),
64                });
65            }
66        }
67
68        Ok(Self {
69            table_path: self.table_path.clone(),
70            table_info: self.table_info.clone(),
71            writer_client: self.writer_client.clone(),
72            target_columns: target_columns.map(Arc::new),
73        })
74    }
75
76    pub fn partial_update_with_column_names(&self, target_column_names: &[&str]) -> Result<Self> {
77        let row_type = self.table_info.row_type();
78        let col_indices: Vec<(&str, Option<usize>)> = target_column_names
79            .iter()
80            .map(|col_name| (*col_name, row_type.get_field_index(col_name)))
81            .collect();
82
83        if let Some((missing_name, _)) = col_indices.iter().find(|(_, ix)| ix.is_none()) {
84            return Err(IllegalArgument {
85                message: format!(
86                    "Cannot find target column `{}` for table {}.",
87                    missing_name, self.table_path
88                ),
89            });
90        }
91
92        let valid_col_indices: Vec<usize> = col_indices
93            .into_iter()
94            .map(|(_, index)| index.unwrap())
95            .collect();
96
97        self.partial_update(Some(valid_col_indices))
98    }
99
100    pub fn create_writer(&self) -> Result<UpsertWriter> {
101        UpsertWriterFactory::create(
102            Arc::new(self.table_path.clone()),
103            Arc::new(self.table_info.clone()),
104            self.target_columns.clone(),
105            Arc::clone(&self.writer_client),
106        )
107    }
108}
109
110pub struct UpsertWriter {
111    table_path: Arc<TablePath>,
112    writer_client: Arc<WriterClient>,
113    partition_field_getter: Option<PartitionGetter>,
114    primary_key_encoder: Mutex<Box<dyn KeyEncoder>>,
115    target_columns: Option<Arc<Vec<usize>>>,
116    // Use primary key encoder as bucket key encoder when None
117    bucket_key_encoder: Option<Mutex<Box<dyn KeyEncoder>>>,
118    write_format: WriteFormat,
119    row_encoder: Mutex<Box<dyn RowEncoder>>,
120    field_getters: Box<[FieldGetter]>,
121    table_info: Arc<TableInfo>,
122}
123
124struct UpsertWriterFactory;
125
126impl UpsertWriterFactory {
127    pub fn create(
128        table_path: Arc<TablePath>,
129        table_info: Arc<TableInfo>,
130        partial_update_columns: Option<Arc<Vec<usize>>>,
131        writer_client: Arc<WriterClient>,
132    ) -> Result<UpsertWriter> {
133        let data_lake_format = &table_info.table_config.get_datalake_format()?;
134        let row_type = table_info.row_type();
135        let physical_pks = table_info.get_physical_primary_keys();
136
137        let names = table_info.get_schema().auto_increment_col_names();
138
139        Self::sanity_check(
140            row_type,
141            &table_info.primary_keys,
142            names,
143            &partial_update_columns,
144        )?;
145
146        let primary_key_encoder = KeyEncoderFactory::of(row_type, physical_pks, data_lake_format)?;
147        let bucket_key_encoder = if !table_info.is_default_bucket_key() {
148            Some(KeyEncoderFactory::of(
149                row_type,
150                table_info.get_bucket_keys(),
151                data_lake_format,
152            )?)
153        } else {
154            // Defaults to using primary key encoder when None for bucket key
155            None
156        };
157
158        let kv_format = table_info.get_table_config().get_kv_format()?;
159        let write_format = WriteFormat::from_kv_format(&kv_format)?;
160
161        let field_getters = FieldGetter::create_field_getters(row_type);
162
163        let partition_field_getter = if table_info.is_partitioned() {
164            Some(PartitionGetter::new(
165                row_type,
166                Arc::clone(table_info.get_partition_keys()),
167            )?)
168        } else {
169            None
170        };
171
172        Ok(UpsertWriter {
173            table_path,
174            partition_field_getter,
175            writer_client,
176            primary_key_encoder: Mutex::new(primary_key_encoder),
177            target_columns: partial_update_columns,
178            bucket_key_encoder: bucket_key_encoder.map(Mutex::new),
179            write_format,
180            row_encoder: Mutex::new(Box::new(RowEncoderFactory::create(
181                kv_format,
182                row_type.clone(),
183            )?)),
184            field_getters,
185            table_info: table_info.clone(),
186        })
187    }
188
189    #[allow(dead_code)]
190    fn sanity_check(
191        row_type: &RowType,
192        primary_keys: &Vec<String>,
193        auto_increment_col_names: &Vec<String>,
194        target_columns: &Option<Arc<Vec<usize>>>,
195    ) -> Result<()> {
196        if target_columns.is_none() {
197            if !auto_increment_col_names.is_empty() {
198                return Err(IllegalArgument {
199                    message: format!(
200                        "This table has auto increment column {}. Explicitly specifying values for an auto increment column is not allowed. Please Specify non-auto-increment columns as target columns using partialUpdate first.",
201                        auto_increment_col_names.join(", ")
202                    ),
203                });
204            }
205            return Ok(());
206        }
207
208        let field_count = row_type.fields().len();
209
210        let mut target_column_set = bitvec![0; field_count];
211
212        let columns = target_columns.as_ref().unwrap().as_ref();
213
214        for &target_index in columns {
215            target_column_set.set(target_index, true);
216        }
217
218        let mut pk_column_set = bitvec![0; field_count];
219
220        // check the target columns contains the primary key
221        for primary_key in primary_keys {
222            let pk_index = row_type.get_field_index(primary_key.as_str());
223            match pk_index {
224                Some(pk_index) => {
225                    if !target_column_set[pk_index] {
226                        return Err(IllegalArgument {
227                            message: format!(
228                                "The target write columns {} must contain the primary key columns {}",
229                                row_type.project(columns)?.get_field_names().join(", "),
230                                primary_keys.join(", ")
231                            ),
232                        });
233                    }
234                    pk_column_set.set(pk_index, true);
235                }
236                None => {
237                    return Err(IllegalArgument {
238                        message: format!(
239                            "The specified primary key {primary_key} is not in row type {row_type}"
240                        ),
241                    });
242                }
243            }
244        }
245
246        let mut auto_increment_column_set = bitvec![0; field_count];
247        // explicitly specifying values for an auto increment column is not allowed
248        for auto_increment_col_name in auto_increment_col_names {
249            let auto_increment_field_index =
250                row_type.get_field_index(auto_increment_col_name.as_str());
251
252            if let Some(index) = auto_increment_field_index {
253                if target_column_set[index] {
254                    return Err(IllegalArgument {
255                        message: format!(
256                            "Explicitly specifying values for the auto increment column {auto_increment_col_name} is not allowed."
257                        ),
258                    });
259                }
260
261                auto_increment_column_set.set(index, true);
262            }
263        }
264
265        // check the columns not in targetColumns should be nullable
266        for i in 0..field_count {
267            // column not in primary key and not in auto increment column
268            if !pk_column_set[i] && !auto_increment_column_set[i] {
269                // the column should be nullable
270                if !row_type.fields().get(i).unwrap().data_type.is_nullable() {
271                    return Err(IllegalArgument {
272                        message: format!(
273                            "Partial Update requires all columns except primary key to be nullable, but column {} is NOT NULL.",
274                            row_type.fields().get(i).unwrap().name()
275                        ),
276                    });
277                }
278            }
279        }
280
281        Ok(())
282    }
283}
284
285impl UpsertWriter {
286    fn check_field_count<R: InternalRow>(&self, row: &R) -> Result<()> {
287        let expected = self.table_info.get_row_type().fields().len();
288        if row.get_field_count() != expected {
289            return Err(IllegalArgument {
290                message: format!(
291                    "The field count of the row does not match the table schema. Expected: {}, Actual: {}",
292                    expected,
293                    row.get_field_count()
294                ),
295            });
296        }
297        Ok(())
298    }
299
300    fn get_keys(&self, row: &dyn InternalRow) -> Result<(Bytes, Option<Bytes>)> {
301        let key = self
302            .primary_key_encoder
303            .lock()
304            .map_err(|e| UnexpectedError {
305                message: format!("primary_key_encoder lock poisoned: {e}"),
306                source: None,
307            })?
308            .encode_key(row)?;
309        let bucket_key = match &self.bucket_key_encoder {
310            Some(encoder) => Some(
311                encoder
312                    .lock()
313                    .map_err(|e| UnexpectedError {
314                        message: format!("bucket_key_encoder lock poisoned: {e}"),
315                        source: None,
316                    })?
317                    .encode_key(row)?,
318            ),
319            None => Some(key.clone()),
320        };
321        Ok((key, bucket_key))
322    }
323
324    fn encode_row<R: InternalRow>(&self, row: &R) -> Result<Bytes> {
325        let mut encoder = self.row_encoder.lock().map_err(|e| UnexpectedError {
326            message: format!("row_encoder lock poisoned: {e}"),
327            source: None,
328        })?;
329        encoder.start_new_row()?;
330        for (pos, field_getter) in self.field_getters.iter().enumerate() {
331            let datum = field_getter.get_field(row)?;
332            encoder.encode_field(pos, datum)?;
333        }
334        encoder.finish_row()
335    }
336
337    /// Flush data written that have not yet been sent to the server, forcing the client to send the
338    /// requests to server and blocks on the completion of the requests associated with these
339    /// records. A request is considered completed when it is successfully acknowledged according to
340    /// the CLIENT_WRITER_ACKS configuration option you have specified or else it
341    /// results in an error.
342    pub async fn flush(&self) -> Result<()> {
343        self.writer_client.flush().await
344    }
345
346    /// Inserts row into Fluss table if they do not already exist, or updates them if they do exist.
347    ///
348    /// This method returns a [`WriteResultFuture`] immediately after queueing the write,
349    /// enabling fire-and-forget semantics for efficient batching.
350    ///
351    /// # Arguments
352    /// * row - the row to upsert.
353    ///
354    /// # Returns
355    /// A [`WriteResultFuture`] that can be awaited to wait for server acknowledgment,
356    /// or dropped for fire-and-forget behavior (use `flush()` to ensure delivery).
357    pub fn upsert<R: InternalRow>(&self, row: &R) -> Result<WriteResultFuture> {
358        self.check_field_count(row)?;
359
360        let (key, bucket_key) = self.get_keys(row)?;
361
362        let row_bytes: RowBytes<'_> = match row.as_encoded_bytes(self.write_format) {
363            Some(bytes) => RowBytes::Borrowed(bytes),
364            None => RowBytes::Owned(self.encode_row(row)?),
365        };
366
367        let write_record = WriteRecord::for_upsert(
368            Arc::clone(&self.table_info),
369            Arc::new(get_physical_path(
370                &self.table_path,
371                self.partition_field_getter.as_ref(),
372                row,
373            )?),
374            self.table_info.schema_id,
375            key,
376            bucket_key,
377            self.write_format,
378            self.target_columns.clone(),
379            Some(row_bytes),
380        );
381
382        let result_handle = self.writer_client.send(&write_record)?;
383        Ok(WriteResultFuture::new(result_handle))
384    }
385
386    /// Delete certain row by the input row in Fluss table, the input row must contain the primary
387    /// key.
388    ///
389    /// This method returns a [`WriteResultFuture`] immediately after queueing the delete,
390    /// enabling fire-and-forget semantics for efficient batching.
391    ///
392    /// # Arguments
393    /// * row - the row to delete (must contain the primary key fields).
394    ///
395    /// # Returns
396    /// A [`WriteResultFuture`] that can be awaited to wait for server acknowledgment,
397    /// or dropped for fire-and-forget behavior (use `flush()` to ensure delivery).
398    pub fn delete<R: InternalRow>(&self, row: &R) -> Result<WriteResultFuture> {
399        self.check_field_count(row)?;
400
401        let (key, bucket_key) = self.get_keys(row)?;
402
403        let write_record = WriteRecord::for_upsert(
404            Arc::clone(&self.table_info),
405            Arc::new(get_physical_path(
406                &self.table_path,
407                self.partition_field_getter.as_ref(),
408                row,
409            )?),
410            self.table_info.schema_id,
411            key,
412            bucket_key,
413            self.write_format,
414            self.target_columns.clone(),
415            None,
416        );
417
418        let result_handle = self.writer_client.send(&write_record)?;
419        Ok(WriteResultFuture::new(result_handle))
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426    use crate::metadata::{DataField, DataTypes};
427
428    #[test]
429    fn sanity_check() {
430        // No target columns specified but table has auto-increment column
431        let fields = vec![
432            DataField::new("id", DataTypes::int().as_non_nullable(), None),
433            DataField::new("name", DataTypes::string(), None),
434        ];
435        let row_type = RowType::new(fields);
436        let primary_keys = vec!["id".to_string()];
437        let auto_increment_col_names = vec!["id".to_string()];
438        let target_columns = None;
439
440        let result = UpsertWriterFactory::sanity_check(
441            &row_type,
442            &primary_keys,
443            &auto_increment_col_names,
444            &target_columns,
445        );
446
447        assert!(result.unwrap_err().to_string().contains(
448            "This table has auto increment column id. Explicitly specifying values for an auto increment column is not allowed. Please Specify non-auto-increment columns as target columns using partialUpdate first."
449        ));
450
451        // Target columns do not contain primary key
452        let fields = vec![
453            DataField::new("id", DataTypes::int().as_non_nullable(), None),
454            DataField::new("name", DataTypes::string(), None),
455            DataField::new("value", DataTypes::int(), None),
456        ];
457        let row_type = RowType::new(fields);
458        let primary_keys = vec!["id".to_string()];
459        let auto_increment_col_names = vec![];
460        let target_columns = Some(Arc::new(vec![1usize]));
461
462        let result = UpsertWriterFactory::sanity_check(
463            &row_type,
464            &primary_keys,
465            &auto_increment_col_names,
466            &target_columns,
467        );
468
469        assert!(
470            result
471                .unwrap_err()
472                .to_string()
473                .contains("The target write columns name must contain the primary key columns id")
474        );
475
476        // Primary key column not found in row type
477        let fields = vec![
478            DataField::new("id", DataTypes::int().as_non_nullable(), None),
479            DataField::new("name", DataTypes::string(), None),
480        ];
481        let row_type = RowType::new(fields);
482        let primary_keys = vec!["nonexistent_pk".to_string()];
483        let auto_increment_col_names = vec![];
484        let target_columns = Some(Arc::new(vec![0usize, 1]));
485
486        let result = UpsertWriterFactory::sanity_check(
487            &row_type,
488            &primary_keys,
489            &auto_increment_col_names,
490            &target_columns,
491        );
492
493        assert!(
494            result
495                .unwrap_err()
496                .to_string()
497                .contains("The specified primary key nonexistent_pk is not in row type")
498        );
499
500        // Target columns include auto-increment column
501        let fields = vec![
502            DataField::new("id", DataTypes::int().as_non_nullable(), None),
503            DataField::new("seq", DataTypes::bigint().as_non_nullable(), None),
504            DataField::new("name", DataTypes::string(), None),
505        ];
506        let row_type = RowType::new(fields);
507        let primary_keys = vec!["id".to_string()];
508        let auto_increment_col_names = vec!["seq".to_string()];
509        let target_columns = Some(Arc::new(vec![0usize, 1, 2]));
510
511        let result = UpsertWriterFactory::sanity_check(
512            &row_type,
513            &primary_keys,
514            &auto_increment_col_names,
515            &target_columns,
516        );
517
518        assert!(result.unwrap_err().to_string().contains(
519            "Explicitly specifying values for the auto increment column seq is not allowed."
520        ));
521
522        // Non-nullable column not in target columns (partial update requires nullable)
523        let fields = vec![
524            DataField::new("id", DataTypes::int().as_non_nullable(), None),
525            DataField::new(
526                "required_field",
527                DataTypes::string().as_non_nullable(),
528                None,
529            ),
530            DataField::new("optional_field", DataTypes::int(), None),
531        ];
532        let row_type = RowType::new(fields);
533        let primary_keys = vec!["id".to_string()];
534        let auto_increment_col_names = vec![];
535        let target_columns = Some(Arc::new(vec![0usize]));
536
537        let result = UpsertWriterFactory::sanity_check(
538            &row_type,
539            &primary_keys,
540            &auto_increment_col_names,
541            &target_columns,
542        );
543
544        assert!(result.unwrap_err().to_string().contains(
545            "Partial Update requires all columns except primary key to be nullable, but column required_field is NOT NULL."
546        ));
547    }
548}
549
550/// The result of upserting a record
551/// Currently this is an empty struct to allow for compatible evolution in the future
552#[derive(Default)]
553#[allow(dead_code)]
554pub struct UpsertResult;
555
556/// The result of deleting a record
557/// Currently this is an empty struct to allow for compatible evolution in the future
558#[derive(Default)]
559#[allow(dead_code)]
560pub struct DeleteResult;