use std::{
borrow::Cow,
collections::{HashMap, HashSet},
sync::Arc,
};
use arrow::buffer::NullBuffer;
use arrow_array::{
Array, FixedSizeListArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use futures::StreamExt;
use lancedb::{
Error, Result, connect,
embeddings::{EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry},
query::ExecutableQuery,
};
#[tokio::test]
async fn test_custom_func() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let db = connect(tempdir).execute().await?;
let embed_fun = MockEmbed::new("embed_fun".to_string(), 1);
db.embedding_registry()
.register("embed_fun", Arc::new(embed_fun.clone()))?;
let tbl = db
.create_table("test", create_some_records()?)
.add_embedding(EmbeddingDefinition::new(
"text",
&embed_fun.name,
Some("embeddings"),
))?
.execute()
.await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("embeddings");
assert!(embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
}
tbl.add(create_some_records()?).execute().await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("embeddings");
assert!(embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
}
Ok(())
}
#[tokio::test]
async fn test_custom_registry() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let db = connect(tempdir)
.embedding_registry(Arc::new(MyRegistry::default()))
.execute()
.await?;
let tbl = db
.create_table("test", create_some_records()?)
.add_embedding(EmbeddingDefinition::new(
"text",
"func_1",
Some("embeddings"),
))?
.execute()
.await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("embeddings");
assert!(embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(
embeddings.data_type(),
MockEmbed::new("func_1".to_string(), 1)
.dest_type()?
.as_ref()
);
}
Ok(())
}
#[tokio::test]
async fn test_multiple_embeddings() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let db = connect(tempdir).execute().await?;
let func_1 = MockEmbed::new("func_1".to_string(), 1);
let func_2 = MockEmbed::new("func_2".to_string(), 10);
db.embedding_registry()
.register(&func_1.name, Arc::new(func_1.clone()))?;
db.embedding_registry()
.register(&func_2.name, Arc::new(func_2.clone()))?;
let tbl = db
.create_table("test", create_some_records()?)
.add_embedding(EmbeddingDefinition::new(
"text",
&func_1.name,
Some("first_embeddings"),
))?
.add_embedding(EmbeddingDefinition::new(
"text",
&func_2.name,
Some("second_embeddings"),
))?
.execute()
.await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("first_embeddings");
assert!(embeddings.is_some());
let second_embeddings = batch.column_by_name("second_embeddings");
assert!(second_embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(embeddings.data_type(), func_1.dest_type()?.as_ref());
let second_embeddings = second_embeddings.unwrap();
assert_eq!(second_embeddings.data_type(), func_2.dest_type()?.as_ref());
}
tbl.add(create_some_records()?).execute().await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("first_embeddings");
assert!(embeddings.is_some());
let second_embeddings = batch.column_by_name("second_embeddings");
assert!(second_embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(embeddings.data_type(), func_1.dest_type()?.as_ref());
let second_embeddings = second_embeddings.unwrap();
assert_eq!(second_embeddings.data_type(), func_2.dest_type()?.as_ref());
}
Ok(())
}
#[tokio::test]
async fn test_open_table_embeddings() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let db = connect(tempdir).execute().await?;
let embed_fun = MockEmbed::new("embed_fun".to_string(), 1);
db.embedding_registry()
.register("embed_fun", Arc::new(embed_fun.clone()))?;
db.create_table("test", create_some_records()?)
.add_embedding(EmbeddingDefinition::new(
"text",
&embed_fun.name,
Some("embeddings"),
))?
.execute()
.await?;
let tbl = db.open_table("test").execute().await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("embeddings");
assert!(embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
}
tbl.add(create_some_records()?).execute().await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("embeddings");
assert!(embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
}
Ok(())
}
#[tokio::test]
async fn test_no_func_in_registry() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let db = connect(tempdir).execute().await?;
let res = db
.create_table("test", create_some_records()?)
.add_embedding(EmbeddingDefinition::new(
"text",
"some_func",
Some("first_embeddings"),
));
assert!(res.is_err());
assert!(matches!(
res.err().unwrap(),
Error::EmbeddingFunctionNotFound { .. }
));
Ok(())
}
#[tokio::test]
async fn test_no_func_in_registry_on_add() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let db = connect(tempdir).execute().await?;
db.embedding_registry().register(
"some_func",
Arc::new(MockEmbed::new("some_func".to_string(), 1)),
)?;
db.create_table("test", create_some_records()?)
.add_embedding(EmbeddingDefinition::new(
"text",
"some_func",
Some("first_embeddings"),
))?
.execute()
.await?;
let db = connect(tempdir).execute().await?;
let tbl = db.open_table("test").execute().await?;
let res = tbl.add(create_some_records()?).execute().await;
assert!(res.is_err());
assert!(matches!(
res.unwrap_err(),
crate::Error::EmbeddingFunctionNotFound { .. }
));
Ok(())
}
fn create_some_records() -> Result<Box<dyn arrow_array::RecordBatchReader + Send>> {
const TOTAL: usize = 2;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("text", DataType::Utf8, true),
]));
let batches = RecordBatchIterator::new(
vec![
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
Arc::new(StringArray::from_iter(std::iter::repeat_n(
Some("hello world".to_string()),
TOTAL,
))),
],
)
.unwrap(),
]
.into_iter()
.map(Ok),
schema.clone(),
);
Ok(Box::new(batches))
}
#[derive(Debug)]
struct MyRegistry {
functions: HashMap<String, Arc<dyn EmbeddingFunction>>,
}
impl Default for MyRegistry {
fn default() -> Self {
let funcs: Vec<Arc<dyn EmbeddingFunction>> = vec![
Arc::new(MockEmbed::new("func_1".to_string(), 1)),
Arc::new(MockEmbed::new("func_2".to_string(), 10)),
];
Self {
functions: funcs
.into_iter()
.map(|f| (f.name().to_string(), f))
.collect(),
}
}
}
impl EmbeddingRegistry for MyRegistry {
fn functions(&self) -> HashSet<String> {
self.functions.keys().cloned().collect()
}
fn register(&self, _name: &str, _function: Arc<dyn EmbeddingFunction>) -> Result<()> {
Err(Error::Other {
message: "MyRegistry is read-only".to_string(),
source: None,
})
}
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>> {
self.functions.get(name).cloned()
}
}
#[derive(Debug, Clone)]
struct MockEmbed {
source_type: DataType,
dest_type: DataType,
name: String,
dim: usize,
}
impl MockEmbed {
pub fn new(name: String, dim: usize) -> Self {
Self {
source_type: DataType::Utf8,
dest_type: DataType::new_fixed_size_list(DataType::Float32, dim as _, true),
name,
dim,
}
}
}
impl EmbeddingFunction for MockEmbed {
fn name(&self) -> &str {
&self.name
}
fn source_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Borrowed(&self.source_type))
}
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Borrowed(&self.dest_type))
}
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
let len = source.len();
let inner = Arc::new(Float32Array::from(vec![Some(1.0); len * self.dim]));
let field = Field::new("item", inner.data_type().clone(), false);
let arr = FixedSizeListArray::new(
Arc::new(field),
self.dim as _,
inner,
Some(NullBuffer::new_valid(len)),
);
Ok(Arc::new(arr))
}
#[allow(unused_variables)]
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
}