swiftide_integrations/lancedb/
persist.rs

1use std::sync::Arc;
2
3use anyhow::Context as _;
4use anyhow::Result;
5use arrow_array::Array;
6use arrow_array::FixedSizeListArray;
7use arrow_array::GenericByteArray;
8use arrow_array::RecordBatch;
9use arrow_array::RecordBatchIterator;
10use arrow_array::types::Float32Type;
11use arrow_array::types::UInt8Type;
12use arrow_array::types::Utf8Type;
13use async_trait::async_trait;
14use swiftide_core::Persist;
15use swiftide_core::indexing::IndexingStream;
16use swiftide_core::indexing::TextNode;
17
18use super::FieldConfig;
19use super::LanceDB;
20
21#[async_trait]
22impl Persist for LanceDB {
23    type Input = String;
24    type Output = String;
25
26    #[tracing::instrument(skip_all)]
27    async fn setup(&self) -> Result<()> {
28        let conn = self.get_connection().await?;
29        let schema = self.schema.clone();
30
31        if let Err(err) = conn.open_table(&self.table_name).execute().await {
32            if matches!(err, lancedb::Error::TableNotFound { .. }) {
33                conn.create_empty_table(&self.table_name, schema)
34                    .execute()
35                    .await
36                    .map(|_| ())
37                    .map_err(anyhow::Error::from)?;
38            } else {
39                return Err(err.into());
40            }
41        }
42
43        Ok(())
44    }
45
46    #[tracing::instrument(skip_all)]
47    async fn store(&self, node: TextNode) -> Result<TextNode> {
48        let mut nodes = vec![node; 1];
49        self.store_nodes(&nodes).await?;
50
51        let node = nodes.swap_remove(0);
52
53        Ok(node)
54    }
55
56    #[tracing::instrument(skip_all)]
57    async fn batch_store(&self, nodes: Vec<TextNode>) -> IndexingStream<String> {
58        self.store_nodes(&nodes).await.map(|()| nodes).into()
59    }
60
61    fn batch_size(&self) -> Option<usize> {
62        Some(self.batch_size)
63    }
64}
65
66impl LanceDB {
67    async fn store_nodes(&self, nodes: &[TextNode]) -> Result<()> {
68        let schema = self.schema.clone();
69
70        let batches = self.extract_arrow_batches_from_nodes(nodes)?;
71
72        let data = RecordBatchIterator::new(
73            vec![
74                RecordBatch::try_new(schema.clone(), batches)
75                    .context("Could not create batches")?,
76            ]
77            .into_iter()
78            .map(Ok),
79            schema.clone(),
80        );
81
82        let conn = self.get_connection().await?;
83        let table = conn.open_table(&self.table_name).execute().await?;
84        let mut merge_insert = table.merge_insert(&["id"]);
85
86        merge_insert
87            .when_matched_update_all(None)
88            .when_not_matched_insert_all();
89
90        merge_insert.execute(Box::new(data)).await?;
91
92        Ok(())
93    }
94
95    fn extract_arrow_batches_from_nodes(
96        &self,
97        nodes: &[TextNode],
98    ) -> core::result::Result<Vec<Arc<dyn Array>>, anyhow::Error> {
99        let fields = self.fields.as_slice();
100        let mut batches: Vec<Arc<dyn Array>> = Vec::with_capacity(fields.len());
101
102        for field in fields {
103            match field {
104                FieldConfig::Vector(config) => {
105                    let mut row = Vec::with_capacity(nodes.len());
106                    let vector_size = config
107                        .vector_size
108                        .or(self.vector_size)
109                        .context("Expected vector size to be set for field")?;
110
111                    for node in nodes {
112                        let data = node
113                            .vectors
114                            .as_ref()
115                            // TODO: verify compiler optimizes the double loops away
116                            .and_then(|v| v.get(&config.embedded_field))
117                            .map(|v| v.iter().map(|f| Some(*f)));
118
119                        row.push(data);
120                    }
121                    batches.push(Arc::new(FixedSizeListArray::from_iter_primitive::<
122                        Float32Type,
123                        _,
124                        _,
125                    >(row, vector_size)));
126                }
127                FieldConfig::Metadata(config) => {
128                    let mut row = Vec::with_capacity(nodes.len());
129
130                    for node in nodes {
131                        let data = node
132                            .metadata
133                            .get(&config.original_field)
134                            // TODO: Verify this gives the correct data
135                            .and_then(|v| v.as_str());
136
137                        row.push(data);
138                    }
139                    batches.push(Arc::new(GenericByteArray::<Utf8Type>::from_iter(row)));
140                }
141                FieldConfig::Chunk => {
142                    let mut row = Vec::with_capacity(nodes.len());
143
144                    for node in nodes {
145                        let data = Some(node.chunk.as_str());
146                        row.push(data);
147                    }
148                    batches.push(Arc::new(GenericByteArray::<Utf8Type>::from_iter(row)));
149                }
150                FieldConfig::ID => {
151                    let mut row = Vec::with_capacity(nodes.len());
152                    for node in nodes {
153                        let data = Some(node.id().as_bytes().map(Some));
154                        row.push(data);
155                    }
156                    batches.push(Arc::new(FixedSizeListArray::from_iter_primitive::<
157                        UInt8Type,
158                        _,
159                        _,
160                    >(row, 16)));
161                }
162            }
163        }
164        Ok(batches)
165    }
166}
167
168#[cfg(test)]
169mod test {
170    use swiftide_core::{Persist as _, indexing::EmbeddedField};
171    use temp_dir::TempDir;
172
173    use super::*;
174
175    async fn setup() -> (TempDir, LanceDB) {
176        let tempdir = TempDir::new().unwrap();
177        let lancedb = LanceDB::builder()
178            .uri(tempdir.child("lancedb").to_str().unwrap())
179            .vector_size(384)
180            .with_metadata("filter")
181            .with_vector(EmbeddedField::Combined)
182            .table_name("swiftide_test")
183            .build()
184            .unwrap();
185        lancedb.setup().await.unwrap();
186
187        (tempdir, lancedb)
188    }
189
190    #[tokio::test]
191    async fn test_no_error_when_table_exists() {
192        let (_guard, lancedb) = setup().await;
193
194        lancedb
195            .setup()
196            .await
197            .expect("Should not error if table exists");
198    }
199}