Skip to main content

duckdb/appender/
arrow.rs

1use super::{ffi, Appender, Result};
2use crate::{
3    core::DataChunkHandle,
4    error::result_from_duckdb_appender,
5    vtab::{record_batch_to_duckdb_data_chunk, to_duckdb_logical_type},
6    Error,
7};
8use arrow::record_batch::RecordBatch;
9use ffi::{duckdb_append_data_chunk, duckdb_vector_size};
10
11impl Appender<'_> {
12    /// Append one record batch
13    ///
14    /// ## Example
15    ///
16    /// ```rust,no_run
17    /// # use duckdb::{Connection, Result, params};
18    ///   use arrow::record_batch::RecordBatch;
19    /// fn insert_record_batch(conn: &Connection,record_batch:RecordBatch) -> Result<()> {
20    ///     let mut app = conn.appender("foo")?;
21    ///     app.append_record_batch(record_batch)?;
22    ///     Ok(())
23    /// }
24    /// ```
25    ///
26    /// # Failure
27    ///
28    /// Will return `Err` if append column count not the same with the table schema
29    #[inline]
30    pub fn append_record_batch(&mut self, record_batch: RecordBatch) -> Result<()> {
31        let schema = record_batch.schema();
32        let fields = schema.fields();
33        let capacity = fields.len();
34        let mut logical_types = Vec::with_capacity(capacity);
35        for field in fields.iter() {
36            logical_types.push(
37                to_duckdb_logical_type(field.data_type())
38                    .map_err(|_op| Error::ArrowTypeToDuckdbType(field.to_string(), field.data_type().clone()))?,
39            );
40        }
41
42        let vector_size = unsafe { duckdb_vector_size() } as usize;
43        let num_rows = record_batch.num_rows();
44
45        // Process record batch in chunks that fit within DuckDB's vector size
46        let mut offset = 0;
47        while offset < num_rows {
48            let slice_len = std::cmp::min(vector_size, num_rows - offset);
49            let slice = record_batch.slice(offset, slice_len);
50
51            let mut data_chunk = DataChunkHandle::new(&logical_types);
52            record_batch_to_duckdb_data_chunk(&slice, &mut data_chunk).map_err(|_op| Error::AppendError)?;
53
54            let rc = unsafe { duckdb_append_data_chunk(self.app, data_chunk.get_ptr()) };
55            result_from_duckdb_appender(rc, &mut self.app)?;
56
57            offset += slice_len;
58        }
59
60        Ok(())
61    }
62}
63
64#[cfg(test)]
65mod test {
66    use crate::{Connection, Result};
67    use arrow::{
68        array::{Int32Array, Int8Array, StringArray},
69        datatypes::{DataType, Field, Schema},
70        record_batch::RecordBatch,
71    };
72    use std::sync::Arc;
73
74    #[test]
75    fn test_append_record_batch() -> Result<()> {
76        let db = Connection::open_in_memory()?;
77        db.execute_batch("CREATE TABLE foo(id TINYINT not null,area TINYINT not null,name Varchar)")?;
78        {
79            let id_array = Int8Array::from(vec![1, 2, 3, 4, 5]);
80            let area_array = Int8Array::from(vec![11, 22, 33, 44, 55]);
81            let name_array = StringArray::from(vec![Some("11"), None, None, Some("44"), None]);
82            let schema = Schema::new(vec![
83                Field::new("id", DataType::Int8, true),
84                Field::new("area", DataType::Int8, true),
85                Field::new("name", DataType::Utf8, true),
86            ]);
87            let record_batch = RecordBatch::try_new(
88                Arc::new(schema),
89                vec![Arc::new(id_array), Arc::new(area_array), Arc::new(name_array)],
90            )
91            .unwrap();
92            let mut app = db.appender("foo")?;
93            app.append_record_batch(record_batch)?;
94        }
95        let mut stmt = db.prepare("SELECT id, area, name FROM foo")?;
96        let rbs: Vec<RecordBatch> = stmt.query_arrow([])?.collect();
97        assert_eq!(rbs.iter().map(|op| op.num_rows()).sum::<usize>(), 5);
98        Ok(())
99    }
100
101    #[test]
102    fn test_append_record_batch_large() -> Result<()> {
103        let record_count = usize::pow(2, 16) + 1;
104        let db = Connection::open_in_memory()?;
105        db.execute_batch("CREATE TABLE foo(id INT)")?;
106        {
107            let id_array = Int32Array::from((0..record_count as i32).collect::<Vec<_>>());
108            let schema = Schema::new(vec![Field::new("id", DataType::Int32, true)]);
109            let record_batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap();
110            let mut app = db.appender("foo")?;
111            app.append_record_batch(record_batch)?;
112        }
113        let count: usize = db.query_row("SELECT COUNT(*) FROM foo", [], |row| row.get(0))?;
114        assert_eq!(count, record_count);
115
116        // Verify the data is correct
117        let sum: i64 = db.query_row("SELECT SUM(id) FROM foo", [], |row| row.get(0))?;
118        let expected_sum: i64 = (0..record_count as i64).sum();
119        assert_eq!(sum, expected_sum);
120
121        Ok(())
122    }
123}