1use crate::client::{RowBytes, WriteFormat, WriteRecord, WriteResultFuture, WriterClient};
19use crate::error::Error::{IllegalArgument, UnexpectedError};
20use crate::error::Result;
21use crate::metadata::{RowType, TableInfo, TablePath};
22use crate::row::InternalRow;
23use crate::row::encode::{KeyEncoder, KeyEncoderFactory, RowEncoder, RowEncoderFactory};
24use crate::row::field_getter::FieldGetter;
25use std::sync::{Arc, Mutex};
26
27use crate::client::table::partition_getter::{PartitionGetter, get_physical_path};
28use bitvec::prelude::bitvec;
29use bytes::Bytes;
30
31#[allow(dead_code)]
32pub struct TableUpsert {
33 table_path: TablePath,
34 table_info: TableInfo,
35 writer_client: Arc<WriterClient>,
36 target_columns: Option<Arc<Vec<usize>>>,
37}
38
39#[allow(dead_code)]
40impl TableUpsert {
41 pub fn new(
42 table_path: TablePath,
43 table_info: TableInfo,
44 writer_client: Arc<WriterClient>,
45 ) -> Self {
46 Self {
47 table_path,
48 table_info,
49 writer_client,
50 target_columns: None,
51 }
52 }
53
54 pub fn partial_update(&self, target_columns: Option<Vec<usize>>) -> Result<Self> {
55 if let Some(columns) = &target_columns {
56 let num_columns = self.table_info.row_type().fields().len();
57
58 if let Some(&invalid_column) = columns.iter().find(|&&col| col >= num_columns) {
59 return Err(IllegalArgument {
60 message: format!(
61 "Invalid target column index: {invalid_column} for table {}. The table only has {num_columns} columns.",
62 self.table_path
63 ),
64 });
65 }
66 }
67
68 Ok(Self {
69 table_path: self.table_path.clone(),
70 table_info: self.table_info.clone(),
71 writer_client: self.writer_client.clone(),
72 target_columns: target_columns.map(Arc::new),
73 })
74 }
75
76 pub fn partial_update_with_column_names(&self, target_column_names: &[&str]) -> Result<Self> {
77 let row_type = self.table_info.row_type();
78 let col_indices: Vec<(&str, Option<usize>)> = target_column_names
79 .iter()
80 .map(|col_name| (*col_name, row_type.get_field_index(col_name)))
81 .collect();
82
83 if let Some((missing_name, _)) = col_indices.iter().find(|(_, ix)| ix.is_none()) {
84 return Err(IllegalArgument {
85 message: format!(
86 "Cannot find target column `{}` for table {}.",
87 missing_name, self.table_path
88 ),
89 });
90 }
91
92 let valid_col_indices: Vec<usize> = col_indices
93 .into_iter()
94 .map(|(_, index)| index.unwrap())
95 .collect();
96
97 self.partial_update(Some(valid_col_indices))
98 }
99
100 pub fn create_writer(&self) -> Result<UpsertWriter> {
101 UpsertWriterFactory::create(
102 Arc::new(self.table_path.clone()),
103 Arc::new(self.table_info.clone()),
104 self.target_columns.clone(),
105 Arc::clone(&self.writer_client),
106 )
107 }
108}
109
110pub struct UpsertWriter {
111 table_path: Arc<TablePath>,
112 writer_client: Arc<WriterClient>,
113 partition_field_getter: Option<PartitionGetter>,
114 primary_key_encoder: Mutex<Box<dyn KeyEncoder>>,
115 target_columns: Option<Arc<Vec<usize>>>,
116 bucket_key_encoder: Option<Mutex<Box<dyn KeyEncoder>>>,
118 write_format: WriteFormat,
119 row_encoder: Mutex<Box<dyn RowEncoder>>,
120 field_getters: Box<[FieldGetter]>,
121 table_info: Arc<TableInfo>,
122}
123
124struct UpsertWriterFactory;
125
126impl UpsertWriterFactory {
127 pub fn create(
128 table_path: Arc<TablePath>,
129 table_info: Arc<TableInfo>,
130 partial_update_columns: Option<Arc<Vec<usize>>>,
131 writer_client: Arc<WriterClient>,
132 ) -> Result<UpsertWriter> {
133 let data_lake_format = &table_info.table_config.get_datalake_format()?;
134 let row_type = table_info.row_type();
135 let physical_pks = table_info.get_physical_primary_keys();
136
137 let names = table_info.get_schema().auto_increment_col_names();
138
139 Self::sanity_check(
140 row_type,
141 &table_info.primary_keys,
142 names,
143 &partial_update_columns,
144 )?;
145
146 let primary_key_encoder = KeyEncoderFactory::of(row_type, physical_pks, data_lake_format)?;
147 let bucket_key_encoder = if !table_info.is_default_bucket_key() {
148 Some(KeyEncoderFactory::of(
149 row_type,
150 table_info.get_bucket_keys(),
151 data_lake_format,
152 )?)
153 } else {
154 None
156 };
157
158 let kv_format = table_info.get_table_config().get_kv_format()?;
159 let write_format = WriteFormat::from_kv_format(&kv_format)?;
160
161 let field_getters = FieldGetter::create_field_getters(row_type);
162
163 let partition_field_getter = if table_info.is_partitioned() {
164 Some(PartitionGetter::new(
165 row_type,
166 Arc::clone(table_info.get_partition_keys()),
167 )?)
168 } else {
169 None
170 };
171
172 Ok(UpsertWriter {
173 table_path,
174 partition_field_getter,
175 writer_client,
176 primary_key_encoder: Mutex::new(primary_key_encoder),
177 target_columns: partial_update_columns,
178 bucket_key_encoder: bucket_key_encoder.map(Mutex::new),
179 write_format,
180 row_encoder: Mutex::new(Box::new(RowEncoderFactory::create(
181 kv_format,
182 row_type.clone(),
183 )?)),
184 field_getters,
185 table_info: table_info.clone(),
186 })
187 }
188
189 #[allow(dead_code)]
190 fn sanity_check(
191 row_type: &RowType,
192 primary_keys: &Vec<String>,
193 auto_increment_col_names: &Vec<String>,
194 target_columns: &Option<Arc<Vec<usize>>>,
195 ) -> Result<()> {
196 if target_columns.is_none() {
197 if !auto_increment_col_names.is_empty() {
198 return Err(IllegalArgument {
199 message: format!(
200 "This table has auto increment column {}. Explicitly specifying values for an auto increment column is not allowed. Please Specify non-auto-increment columns as target columns using partialUpdate first.",
201 auto_increment_col_names.join(", ")
202 ),
203 });
204 }
205 return Ok(());
206 }
207
208 let field_count = row_type.fields().len();
209
210 let mut target_column_set = bitvec![0; field_count];
211
212 let columns = target_columns.as_ref().unwrap().as_ref();
213
214 for &target_index in columns {
215 target_column_set.set(target_index, true);
216 }
217
218 let mut pk_column_set = bitvec![0; field_count];
219
220 for primary_key in primary_keys {
222 let pk_index = row_type.get_field_index(primary_key.as_str());
223 match pk_index {
224 Some(pk_index) => {
225 if !target_column_set[pk_index] {
226 return Err(IllegalArgument {
227 message: format!(
228 "The target write columns {} must contain the primary key columns {}",
229 row_type.project(columns)?.get_field_names().join(", "),
230 primary_keys.join(", ")
231 ),
232 });
233 }
234 pk_column_set.set(pk_index, true);
235 }
236 None => {
237 return Err(IllegalArgument {
238 message: format!(
239 "The specified primary key {primary_key} is not in row type {row_type}"
240 ),
241 });
242 }
243 }
244 }
245
246 let mut auto_increment_column_set = bitvec![0; field_count];
247 for auto_increment_col_name in auto_increment_col_names {
249 let auto_increment_field_index =
250 row_type.get_field_index(auto_increment_col_name.as_str());
251
252 if let Some(index) = auto_increment_field_index {
253 if target_column_set[index] {
254 return Err(IllegalArgument {
255 message: format!(
256 "Explicitly specifying values for the auto increment column {auto_increment_col_name} is not allowed."
257 ),
258 });
259 }
260
261 auto_increment_column_set.set(index, true);
262 }
263 }
264
265 for i in 0..field_count {
267 if !pk_column_set[i] && !auto_increment_column_set[i] {
269 if !row_type.fields().get(i).unwrap().data_type.is_nullable() {
271 return Err(IllegalArgument {
272 message: format!(
273 "Partial Update requires all columns except primary key to be nullable, but column {} is NOT NULL.",
274 row_type.fields().get(i).unwrap().name()
275 ),
276 });
277 }
278 }
279 }
280
281 Ok(())
282 }
283}
284
285impl UpsertWriter {
286 fn check_field_count<R: InternalRow>(&self, row: &R) -> Result<()> {
287 let expected = self.table_info.get_row_type().fields().len();
288 if row.get_field_count() != expected {
289 return Err(IllegalArgument {
290 message: format!(
291 "The field count of the row does not match the table schema. Expected: {}, Actual: {}",
292 expected,
293 row.get_field_count()
294 ),
295 });
296 }
297 Ok(())
298 }
299
300 fn get_keys(&self, row: &dyn InternalRow) -> Result<(Bytes, Option<Bytes>)> {
301 let key = self
302 .primary_key_encoder
303 .lock()
304 .map_err(|e| UnexpectedError {
305 message: format!("primary_key_encoder lock poisoned: {e}"),
306 source: None,
307 })?
308 .encode_key(row)?;
309 let bucket_key = match &self.bucket_key_encoder {
310 Some(encoder) => Some(
311 encoder
312 .lock()
313 .map_err(|e| UnexpectedError {
314 message: format!("bucket_key_encoder lock poisoned: {e}"),
315 source: None,
316 })?
317 .encode_key(row)?,
318 ),
319 None => Some(key.clone()),
320 };
321 Ok((key, bucket_key))
322 }
323
324 fn encode_row<R: InternalRow>(&self, row: &R) -> Result<Bytes> {
325 let mut encoder = self.row_encoder.lock().map_err(|e| UnexpectedError {
326 message: format!("row_encoder lock poisoned: {e}"),
327 source: None,
328 })?;
329 encoder.start_new_row()?;
330 for (pos, field_getter) in self.field_getters.iter().enumerate() {
331 let datum = field_getter.get_field(row)?;
332 encoder.encode_field(pos, datum)?;
333 }
334 encoder.finish_row()
335 }
336
337 pub async fn flush(&self) -> Result<()> {
343 self.writer_client.flush().await
344 }
345
346 pub fn upsert<R: InternalRow>(&self, row: &R) -> Result<WriteResultFuture> {
358 self.check_field_count(row)?;
359
360 let (key, bucket_key) = self.get_keys(row)?;
361
362 let row_bytes: RowBytes<'_> = match row.as_encoded_bytes(self.write_format) {
363 Some(bytes) => RowBytes::Borrowed(bytes),
364 None => RowBytes::Owned(self.encode_row(row)?),
365 };
366
367 let write_record = WriteRecord::for_upsert(
368 Arc::clone(&self.table_info),
369 Arc::new(get_physical_path(
370 &self.table_path,
371 self.partition_field_getter.as_ref(),
372 row,
373 )?),
374 self.table_info.schema_id,
375 key,
376 bucket_key,
377 self.write_format,
378 self.target_columns.clone(),
379 Some(row_bytes),
380 );
381
382 let result_handle = self.writer_client.send(&write_record)?;
383 Ok(WriteResultFuture::new(result_handle))
384 }
385
386 pub fn delete<R: InternalRow>(&self, row: &R) -> Result<WriteResultFuture> {
399 self.check_field_count(row)?;
400
401 let (key, bucket_key) = self.get_keys(row)?;
402
403 let write_record = WriteRecord::for_upsert(
404 Arc::clone(&self.table_info),
405 Arc::new(get_physical_path(
406 &self.table_path,
407 self.partition_field_getter.as_ref(),
408 row,
409 )?),
410 self.table_info.schema_id,
411 key,
412 bucket_key,
413 self.write_format,
414 self.target_columns.clone(),
415 None,
416 );
417
418 let result_handle = self.writer_client.send(&write_record)?;
419 Ok(WriteResultFuture::new(result_handle))
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use crate::metadata::{DataField, DataTypes};
427
428 #[test]
429 fn sanity_check() {
430 let fields = vec![
432 DataField::new("id", DataTypes::int().as_non_nullable(), None),
433 DataField::new("name", DataTypes::string(), None),
434 ];
435 let row_type = RowType::new(fields);
436 let primary_keys = vec!["id".to_string()];
437 let auto_increment_col_names = vec!["id".to_string()];
438 let target_columns = None;
439
440 let result = UpsertWriterFactory::sanity_check(
441 &row_type,
442 &primary_keys,
443 &auto_increment_col_names,
444 &target_columns,
445 );
446
447 assert!(result.unwrap_err().to_string().contains(
448 "This table has auto increment column id. Explicitly specifying values for an auto increment column is not allowed. Please Specify non-auto-increment columns as target columns using partialUpdate first."
449 ));
450
451 let fields = vec![
453 DataField::new("id", DataTypes::int().as_non_nullable(), None),
454 DataField::new("name", DataTypes::string(), None),
455 DataField::new("value", DataTypes::int(), None),
456 ];
457 let row_type = RowType::new(fields);
458 let primary_keys = vec!["id".to_string()];
459 let auto_increment_col_names = vec![];
460 let target_columns = Some(Arc::new(vec![1usize]));
461
462 let result = UpsertWriterFactory::sanity_check(
463 &row_type,
464 &primary_keys,
465 &auto_increment_col_names,
466 &target_columns,
467 );
468
469 assert!(
470 result
471 .unwrap_err()
472 .to_string()
473 .contains("The target write columns name must contain the primary key columns id")
474 );
475
476 let fields = vec![
478 DataField::new("id", DataTypes::int().as_non_nullable(), None),
479 DataField::new("name", DataTypes::string(), None),
480 ];
481 let row_type = RowType::new(fields);
482 let primary_keys = vec!["nonexistent_pk".to_string()];
483 let auto_increment_col_names = vec![];
484 let target_columns = Some(Arc::new(vec![0usize, 1]));
485
486 let result = UpsertWriterFactory::sanity_check(
487 &row_type,
488 &primary_keys,
489 &auto_increment_col_names,
490 &target_columns,
491 );
492
493 assert!(
494 result
495 .unwrap_err()
496 .to_string()
497 .contains("The specified primary key nonexistent_pk is not in row type")
498 );
499
500 let fields = vec![
502 DataField::new("id", DataTypes::int().as_non_nullable(), None),
503 DataField::new("seq", DataTypes::bigint().as_non_nullable(), None),
504 DataField::new("name", DataTypes::string(), None),
505 ];
506 let row_type = RowType::new(fields);
507 let primary_keys = vec!["id".to_string()];
508 let auto_increment_col_names = vec!["seq".to_string()];
509 let target_columns = Some(Arc::new(vec![0usize, 1, 2]));
510
511 let result = UpsertWriterFactory::sanity_check(
512 &row_type,
513 &primary_keys,
514 &auto_increment_col_names,
515 &target_columns,
516 );
517
518 assert!(result.unwrap_err().to_string().contains(
519 "Explicitly specifying values for the auto increment column seq is not allowed."
520 ));
521
522 let fields = vec![
524 DataField::new("id", DataTypes::int().as_non_nullable(), None),
525 DataField::new(
526 "required_field",
527 DataTypes::string().as_non_nullable(),
528 None,
529 ),
530 DataField::new("optional_field", DataTypes::int(), None),
531 ];
532 let row_type = RowType::new(fields);
533 let primary_keys = vec!["id".to_string()];
534 let auto_increment_col_names = vec![];
535 let target_columns = Some(Arc::new(vec![0usize]));
536
537 let result = UpsertWriterFactory::sanity_check(
538 &row_type,
539 &primary_keys,
540 &auto_increment_col_names,
541 &target_columns,
542 );
543
544 assert!(result.unwrap_err().to_string().contains(
545 "Partial Update requires all columns except primary key to be nullable, but column required_field is NOT NULL."
546 ));
547 }
548}
549
550#[derive(Default)]
553#[allow(dead_code)]
554pub struct UpsertResult;
555
556#[derive(Default)]
559#[allow(dead_code)]
560pub struct DeleteResult;