use swiftide_core::document::Document;
use swiftide_integrations::treesitter::metadata_qa_code;
use temp_dir::TempDir;
use anyhow::{Result, anyhow};
use sqlx::{prelude::FromRow, types::Uuid};
use swiftide::{
indexing::{
self, EmbeddedField, Pipeline, loaders,
transformers::{
self, ChunkCode, MetadataQACode, metadata_qa_code::NAME as METADATA_QA_CODE_NAME,
},
},
integrations::{
self,
pgvector::{FieldConfig, PgVector, PgVectorBuilder, VectorConfig},
},
query::{self, Query, answers, query_transformers, response_transformers, states},
};
use swiftide_test_utils::{mock_chat_completions, openai_client};
use wiremock::MockServer;
#[allow(dead_code)]
#[derive(Debug, Clone, FromRow)]
struct VectorSearchResult {
id: Uuid,
chunk: String,
}
#[test_log::test(tokio::test)]
async fn test_pgvector_indexing() {
let tempdir = TempDir::new().unwrap();
let codefile = tempdir.child("main.rs");
let code = "fn main() { println!(\"Hello, World!\"); }";
std::fs::write(&codefile, code).unwrap();
let (_pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await;
let mock_server = MockServer::start().await;
mock_chat_completions(&mock_server).await;
let pgv_storage = PgVector::builder()
.db_url(pgv_db_url)
.vector_size(384)
.with_vector(EmbeddedField::Combined)
.table_name("swiftide_test")
.build()
.unwrap();
println!("Dropping existing test table & index if it exists");
let drop_table_sql = "DROP TABLE IF EXISTS swiftide_test";
let drop_index_sql = "DROP INDEX IF EXISTS swiftide_test_embedding_idx";
if let Ok(pool) = pgv_storage.get_pool().await {
sqlx::query(drop_table_sql)
.execute(pool)
.await
.expect("Failed to execute SQL query for dropping the table");
sqlx::query(drop_index_sql)
.execute(pool)
.await
.expect("Failed to execute SQL query for dropping the index");
} else {
panic!("Unable to acquire database connection pool");
}
let result =
Pipeline::from_loader(loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"]))
.then_chunk(ChunkCode::try_for_language("rust").unwrap())
.then(|mut node: indexing::TextNode| {
node.with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]);
Ok(node)
})
.then_store_with(pgv_storage.clone())
.run()
.await;
result.expect("PgVector Named vectors test indexing pipeline failed");
let pool = pgv_storage
.get_pool()
.await
.expect("Unable to acquire database connection pool");
let sql_vector_query =
"SELECT id, chunk FROM swiftide_test ORDER BY vector_combined <=> $1::VECTOR LIMIT $2";
println!("Running retrieve with SQL: {sql_vector_query}");
let top_k: i32 = 10;
let embedding = vec![1.0; 384];
let data: Vec<VectorSearchResult> = sqlx::query_as(sql_vector_query)
.bind(embedding)
.bind(top_k)
.fetch_all(pool)
.await
.expect("Sql named vector query failed");
let docs: Vec<_> = data.into_iter().map(|r| r.chunk).collect();
println!("Retrieved documents for debugging: {docs:#?}");
assert_eq!(docs[0], "fn main() { println!(\"Hello, World!\"); }");
}
#[test_log::test(tokio::test)]
async fn test_pgvector_retrieve() {
let tempdir = TempDir::new().unwrap();
let codefile = tempdir.child("main.rs");
let code = "fn main() { println!(\"Hello, World!\"); }";
std::fs::write(&codefile, code).unwrap();
let (_pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await;
let mock_server = MockServer::start().await;
mock_chat_completions(&mock_server).await;
let openai_client = openai_client(&mock_server.uri(), "text-embedding-3-small", "gpt-4o");
let fastembed =
integrations::fastembed::FastEmbed::try_default().expect("Could not create FastEmbed");
let pgv_storage = PgVector::builder()
.db_url(pgv_db_url)
.vector_size(384)
.with_vector(EmbeddedField::Combined)
.with_metadata(METADATA_QA_CODE_NAME)
.with_metadata("filter")
.table_name("swiftide_test")
.build()
.unwrap();
println!("Dropping existing test table & index if it exists");
let drop_table_sql = "DROP TABLE IF EXISTS swiftide_test";
let drop_index_sql = "DROP INDEX IF EXISTS swiftide_test_embedding_idx";
if let Ok(pool) = pgv_storage.get_pool().await {
sqlx::query(drop_table_sql)
.execute(pool)
.await
.expect("Failed to execute SQL query for dropping the table");
sqlx::query(drop_index_sql)
.execute(pool)
.await
.expect("Failed to execute SQL query for dropping the index");
} else {
panic!("Unable to acquire database connection pool");
}
Pipeline::from_loader(loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"]))
.then_chunk(ChunkCode::try_for_language("rust").unwrap())
.then(MetadataQACode::new(openai_client.clone()))
.then(|mut node: indexing::TextNode| {
node.metadata
.insert("filter".to_string(), "true".to_string());
Ok(node)
})
.then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(20))
.log_nodes()
.then_store_with(pgv_storage.clone())
.run()
.await
.unwrap();
let strategy = query::search_strategies::SimilaritySingleEmbedding::from_filter(
"filter = \"true\"".to_string(),
);
let query_pipeline = query::Pipeline::from_search_strategy(strategy)
.then_transform_query(query_transformers::GenerateSubquestions::from_client(
openai_client.clone(),
))
.then_transform_query(query_transformers::Embed::from_client(fastembed.clone()))
.then_retrieve(pgv_storage.clone())
.then_transform_response(response_transformers::Summary::from_client(
openai_client.clone(),
))
.then_answer(answers::Simple::from_client(openai_client.clone()));
let result: Query<states::Answered> = query_pipeline.query("What is swiftide?").await.unwrap();
assert_eq!(
result.answer(),
"\n\nHello there, how may I assist you today?"
);
let first_document = result.documents().first().unwrap();
let expected = Document::builder()
.content("fn main() { println!(\"Hello, World!\"); }")
.metadata([
(
metadata_qa_code::NAME,
"\n\nHello there, how may I assist you today?",
),
("filter", "true"),
])
.build()
.unwrap();
assert_eq!(first_document, &expected);
}
#[test_log::test(tokio::test)]
async fn test_pgvector_retrieve_dynamic_search() {
let tempdir = TempDir::new().unwrap();
let codefile = tempdir.child("main.rs");
let code = "fn main() { println!(\"Hello, World!\"); }";
std::fs::write(&codefile, code).unwrap();
let (_pgv_db_container, pgv_db_url) = swiftide_test_utils::start_postgres().await;
let mock_server = MockServer::start().await;
mock_chat_completions(&mock_server).await;
let openai_client = openai_client(&mock_server.uri(), "text-embedding-3-small", "gpt-4o");
let fastembed =
integrations::fastembed::FastEmbed::try_default().expect("Could not create FastEmbed");
let pgv_storage = PgVector::builder()
.db_url(pgv_db_url)
.vector_size(384)
.with_vector(EmbeddedField::Combined)
.with_metadata(METADATA_QA_CODE_NAME)
.with_metadata("filter")
.table_name("swiftide_test")
.build()
.unwrap();
println!("Dropping existing test table & index if it exists");
let drop_table_sql = "DROP TABLE IF EXISTS swiftide_test";
let drop_index_sql = "DROP INDEX IF EXISTS swiftide_test_embedding_idx";
if let Ok(pool) = pgv_storage.get_pool().await {
sqlx::query(drop_table_sql)
.execute(pool)
.await
.expect("Failed to execute SQL query for dropping the table");
sqlx::query(drop_index_sql)
.execute(pool)
.await
.expect("Failed to execute SQL query for dropping the index");
} else {
panic!("Unable to acquire database connection pool");
}
Pipeline::from_loader(loaders::FileLoader::new(tempdir.path()).with_extensions(&["rs"]))
.then_chunk(ChunkCode::try_for_language("rust").unwrap())
.then(MetadataQACode::new(openai_client.clone()))
.then(|mut node: indexing::TextNode| {
node.metadata
.insert("filter".to_string(), "true".to_string());
Ok(node)
})
.then_in_batch(transformers::Embed::new(fastembed.clone()).with_batch_size(20))
.log_nodes()
.then_store_with(pgv_storage.clone())
.run()
.await
.unwrap();
let pgv_storage_for_closure = pgv_storage.clone();
let custom_strategy = query::search_strategies::CustomStrategy::from_query(
move |query_node| -> Result<sqlx::QueryBuilder<'static, sqlx::Postgres>> {
const CUSTOM_STRATEGY_MAX_RESULTS: i64 = 5;
let mut builder = sqlx::QueryBuilder::new("");
let table: &str = pgv_storage_for_closure.get_table_name();
let default_fields: Vec<_> = PgVectorBuilder::default_fields();
let default_columns: Vec<&str> =
default_fields.iter().map(FieldConfig::field_name).collect();
builder.push("SELECT ");
builder.push(default_columns.join(", "));
builder.push(" FROM ");
builder.push(table);
builder.push(" WHERE meta_");
builder.push(PgVector::normalize_field_name("filter"));
builder.push(" @> ");
builder.push("'{\"filter\": \"true\"}'::jsonb");
let vector_field = VectorConfig::from(EmbeddedField::Combined).field;
builder.push(" ORDER BY ");
builder.push(vector_field);
builder.push(" <=> ");
builder.push_bind(
query_node
.embedding
.as_ref()
.ok_or_else(|| anyhow!("Missing embedding in query state"))?
.clone(),
);
builder.push("::vector");
builder.push(" LIMIT ");
builder.push_bind(CUSTOM_STRATEGY_MAX_RESULTS);
Ok(builder)
},
);
let query_pipeline = query::Pipeline::from_search_strategy(custom_strategy)
.then_transform_query(query_transformers::GenerateSubquestions::from_client(
openai_client.clone(),
))
.then_transform_query(query_transformers::Embed::from_client(fastembed.clone()))
.then_retrieve(pgv_storage.clone())
.then_transform_response(response_transformers::Summary::from_client(
openai_client.clone(),
))
.then_answer(answers::Simple::from_client(openai_client.clone()));
let result: Query<states::Answered> = query_pipeline.query("What is swiftide?").await.unwrap();
assert_eq!(
result.answer(),
"\n\nHello there, how may I assist you today?"
);
let first_document = result.documents().first().unwrap();
let expected = Document::builder()
.content("fn main() { println!(\"Hello, World!\"); }")
.build()
.unwrap();
assert_eq!(first_document, &expected);
}