use std::path::{Path, PathBuf};
use anyhow::{Context, anyhow, bail};
use quiver_embed::{Database, Descriptor, DistanceMetric, Dtype, FieldType, FilterableField};
use quiver_import::{
ChromaSource, ParseOptions, PgvectorSource, QdrantSource, Source, fetch_chroma, fetch_pgvector,
fetch_qdrant, import_into, infer_dim, parse, plaintext_credential_warning,
};
pub(crate) struct ImportArgs {
pub source: Source,
pub input: Option<PathBuf>,
pub qdrant_url: Option<String>,
pub chroma_url: Option<String>,
pub chroma_tenant: Option<String>,
pub chroma_database: Option<String>,
pub postgres_url: Option<String>,
pub table: Option<String>,
pub api_key: Option<String>,
pub collection: String,
pub data_dir: PathBuf,
pub metric: DistanceMetric,
pub dim: Option<usize>,
pub filterable: Vec<FilterableField>,
pub id_field: Option<String>,
pub vector_field: Option<String>,
pub vector_name: Option<String>,
pub encryption_key: Option<String>,
pub insecure: bool,
}
pub(crate) fn import(args: ImportArgs) -> anyhow::Result<usize> {
let mut opts = ParseOptions::new(args.source);
if let Some(field) = args.id_field {
opts.id_field = field;
}
if let Some(field) = args.vector_field {
opts.vector_field = field;
}
opts.vector_name = args.vector_name;
let points = if let Some(url) = &args.qdrant_url {
if args.source != Source::Qdrant {
bail!("--qdrant-url is only supported with --source qdrant");
}
if let Some(warning) =
plaintext_credential_warning(Source::Qdrant, url, args.api_key.is_some())
{
eprintln!("warning: {warning}");
}
let src = QdrantSource {
api_key: args.api_key.clone(),
..QdrantSource::new(url.clone(), args.collection.clone())
};
fetch_qdrant(&src, &opts)?
} else if let Some(url) = &args.chroma_url {
if args.source != Source::Chroma {
bail!("--chroma-url is only supported with --source chroma");
}
if let Some(warning) =
plaintext_credential_warning(Source::Chroma, url, args.api_key.is_some())
{
eprintln!("warning: {warning}");
}
let mut src = ChromaSource::new(url.clone(), args.collection.clone());
if let Some(tenant) = &args.chroma_tenant {
src.tenant = tenant.clone();
}
if let Some(database) = &args.chroma_database {
src.database = database.clone();
}
src.api_key = args.api_key.clone();
fetch_chroma(&src)?
} else if let Some(url) = &args.postgres_url {
if args.source != Source::Pgvector {
bail!("--postgres-url is only supported with --source pgvector");
}
let table = args
.table
.clone()
.unwrap_or_else(|| args.collection.clone());
if let Some(warning) = plaintext_credential_warning(Source::Pgvector, url, false) {
eprintln!("warning: {warning}");
}
let src = PgvectorSource::new(url.clone(), table);
fetch_pgvector(&src, &opts)?
} else if let Some(input) = &args.input {
let text = std::fs::read_to_string(input)
.with_context(|| format!("reading export {}", input.display()))?;
parse(&opts, &text)?
} else {
bail!(
"provide an export file (--input) or a live source \
(--qdrant-url / --chroma-url / --postgres-url)"
);
};
if points.is_empty() {
bail!("no points found");
}
let dim = match args.dim {
Some(d) => d,
None => infer_dim(&points)?,
};
let descriptor =
Descriptor::new(dim as u32, Dtype::F32, args.metric).with_filterable(args.filterable);
let mut db = open_database(
&args.data_dir,
args.encryption_key.as_deref(),
args.insecure,
)?;
Ok(import_into(&mut db, &args.collection, descriptor, &points)?)
}
pub(crate) fn parse_metric(name: &str) -> anyhow::Result<DistanceMetric> {
match name.to_ascii_lowercase().as_str() {
"l2" | "euclidean" => Ok(DistanceMetric::L2),
"cosine" => Ok(DistanceMetric::Cosine),
"dot" | "ip" => Ok(DistanceMetric::Dot),
other => bail!("unknown metric '{other}' (expected l2, cosine, or dot)"),
}
}
pub(crate) fn parse_filterable(specs: &[String]) -> anyhow::Result<Vec<FilterableField>> {
specs
.iter()
.map(|spec| {
let (path, ty) = spec
.rsplit_once(':')
.ok_or_else(|| anyhow!("filterable '{spec}' must be path:type"))?;
let field_type = match ty.to_ascii_lowercase().as_str() {
"keyword" | "string" => FieldType::Keyword,
"numeric" | "number" => FieldType::Numeric,
other => bail!("filterable '{spec}': unknown type '{other}' (keyword|numeric)"),
};
Ok(FilterableField {
path: path.to_string(),
field_type,
})
})
.collect()
}
fn open_database(data_dir: &Path, key: Option<&str>, insecure: bool) -> anyhow::Result<Database> {
let db = match quiver_crypto::open_keyring(data_dir, key, insecure)? {
Some(keyring) => Database::open_with_keyring(data_dir, keyring)?,
None => Database::open(data_dir)?,
};
Ok(db)
}
#[cfg(test)]
mod tests {
use super::*;
use quiver_embed::SearchParams;
use std::io::Write;
#[test]
fn metric_and_filterable_parsing() {
assert_eq!(parse_metric("Cosine").unwrap(), DistanceMetric::Cosine);
assert_eq!(parse_metric("l2").unwrap(), DistanceMetric::L2);
assert!(parse_metric("nope").is_err());
let f = parse_filterable(&["user.age:numeric".to_string(), "city:keyword".to_string()])
.unwrap();
assert_eq!(f.len(), 2);
assert_eq!(f[0].path, "user.age");
assert_eq!(f[0].field_type, FieldType::Numeric);
assert_eq!(f[1].field_type, FieldType::Keyword);
assert!(parse_filterable(&["bad".to_string()]).is_err());
assert!(parse_filterable(&["x:weird".to_string()]).is_err());
}
#[test]
fn imports_qdrant_jsonl_into_a_local_db() {
let dir = tempfile::tempdir().unwrap();
let export = dir.path().join("qdrant.jsonl");
let mut f = std::fs::File::create(&export).unwrap();
writeln!(
f,
r#"{{"id": 1, "vector": [1.0, 0.0, 0.0], "payload": {{"city": "paris"}}}}"#
)
.unwrap();
writeln!(
f,
r#"{{"id": 2, "vector": [0.0, 1.0, 0.0], "payload": {{"city": "rome"}}}}"#
)
.unwrap();
drop(f);
let data_dir = dir.path().join("data");
let args = ImportArgs {
source: Source::Qdrant,
input: Some(export),
qdrant_url: None,
chroma_url: None,
chroma_tenant: None,
chroma_database: None,
postgres_url: None,
table: None,
api_key: None,
collection: "places".to_string(),
data_dir: data_dir.clone(),
metric: DistanceMetric::L2,
dim: None,
filterable: vec![FilterableField {
path: "city".to_string(),
field_type: FieldType::Keyword,
}],
id_field: None,
vector_field: None,
vector_name: None,
encryption_key: None,
insecure: true,
};
assert_eq!(import(args).unwrap(), 2);
let mut db = Database::open(&data_dir).unwrap();
assert_eq!(db.len("places").unwrap(), 2);
let res = db
.search("places", &[1.0, 0.0, 0.0], &SearchParams::default())
.unwrap();
assert_eq!(res[0].id, "1");
}
#[test]
fn encrypted_import_reopens_via_the_serve_path() {
let dir = tempfile::tempdir().unwrap();
let export = dir.path().join("q.jsonl");
std::fs::write(
&export,
"{\"id\": 1, \"vector\": [1.0, 2.0], \"payload\": {}}\n",
)
.unwrap();
let data_dir = dir.path().join("data");
let key = "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff";
let args = ImportArgs {
source: Source::Qdrant,
input: Some(export),
qdrant_url: None,
chroma_url: None,
chroma_tenant: None,
chroma_database: None,
postgres_url: None,
table: None,
api_key: None,
collection: "c".to_string(),
data_dir: data_dir.clone(),
metric: DistanceMetric::L2,
dim: None,
filterable: Vec::new(),
id_field: None,
vector_field: None,
vector_name: None,
encryption_key: Some(key.to_string()),
insecure: false,
};
assert_eq!(import(args).unwrap(), 1);
let db = open_database(&data_dir, Some(key), false).unwrap();
assert_eq!(db.len("c").unwrap(), 1);
}
#[test]
fn refuses_to_import_without_a_key_unless_insecure() {
let dir = tempfile::tempdir().unwrap();
let export = dir.path().join("q.jsonl");
std::fs::write(&export, "{\"id\": 1, \"vector\": [1.0], \"payload\": {}}\n").unwrap();
let args = ImportArgs {
source: Source::Qdrant,
input: Some(export),
qdrant_url: None,
chroma_url: None,
chroma_tenant: None,
chroma_database: None,
postgres_url: None,
table: None,
api_key: None,
collection: "c".to_string(),
data_dir: dir.path().join("data"),
metric: DistanceMetric::L2,
dim: None,
filterable: Vec::new(),
id_field: None,
vector_field: None,
vector_name: None,
encryption_key: None,
insecure: false,
};
assert!(import(args).is_err());
}
#[test]
fn live_qdrant_import_loads_a_running_collection() {
use std::io::Read;
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
std::thread::spawn(move || {
let (mut s, _) = listener.accept().unwrap();
let mut buf = Vec::new();
let mut chunk = [0u8; 2048];
loop {
let n = s.read(&mut chunk).unwrap();
if n == 0 {
break;
}
buf.extend_from_slice(&chunk[..n]);
if let Some(end) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
let len = String::from_utf8_lossy(&buf[..end + 4])
.lines()
.find_map(|l| {
l.to_ascii_lowercase()
.strip_prefix("content-length:")
.and_then(|v| v.trim().parse::<usize>().ok())
})
.unwrap_or(0);
if buf.len() - (end + 4) >= len {
break;
}
}
}
let body = r#"{"result":{"points":[{"id":1,"vector":[1.0,0.0],"payload":{"city":"paris"}}],"next_page_offset":null},"status":"ok"}"#;
let resp = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body
);
s.write_all(resp.as_bytes()).unwrap();
});
let dir = tempfile::tempdir().unwrap();
let args = ImportArgs {
source: Source::Qdrant,
input: None,
qdrant_url: Some(format!("http://{addr}")),
chroma_url: None,
chroma_tenant: None,
chroma_database: None,
postgres_url: None,
table: None,
api_key: None,
collection: "places".to_string(),
data_dir: dir.path().join("data"),
metric: DistanceMetric::L2,
dim: None,
filterable: Vec::new(),
id_field: None,
vector_field: None,
vector_name: None,
encryption_key: None,
insecure: true,
};
assert_eq!(import(args).unwrap(), 1);
let db = Database::open(&dir.path().join("data")).unwrap();
assert_eq!(db.len("places").unwrap(), 1);
}
#[test]
fn live_import_rejects_a_non_qdrant_source() {
let dir = tempfile::tempdir().unwrap();
let args = ImportArgs {
source: Source::Chroma,
input: None,
qdrant_url: Some("http://localhost:6333".to_string()),
chroma_url: None,
chroma_tenant: None,
chroma_database: None,
postgres_url: None,
table: None,
api_key: None,
collection: "c".to_string(),
data_dir: dir.path().join("data"),
metric: DistanceMetric::L2,
dim: None,
filterable: Vec::new(),
id_field: None,
vector_field: None,
vector_name: None,
encryption_key: None,
insecure: true,
};
assert!(import(args).is_err());
}
#[test]
fn live_chroma_import_loads_a_running_collection() {
use std::io::{Read, Write};
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
std::thread::spawn(move || {
let bodies = [
r#"[{"id":"col-1","name":"places"}]"#.to_string(),
r#"{"ids":["1"],"embeddings":[[1.0,0.0]],"metadatas":[{"city":"paris"}],"documents":[null]}"#.to_string(),
];
for body in bodies {
let (mut s, _) = listener.accept().unwrap();
let mut buf = Vec::new();
let mut chunk = [0u8; 2048];
loop {
let n = s.read(&mut chunk).unwrap();
if n == 0 {
break;
}
buf.extend_from_slice(&chunk[..n]);
if let Some(end) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
let len = String::from_utf8_lossy(&buf[..end + 4])
.lines()
.find_map(|l| {
l.to_ascii_lowercase()
.strip_prefix("content-length:")
.and_then(|v| v.trim().parse::<usize>().ok())
})
.unwrap_or(0);
if buf.len() - (end + 4) >= len {
break;
}
}
}
let resp = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body
);
s.write_all(resp.as_bytes()).unwrap();
}
});
let dir = tempfile::tempdir().unwrap();
let args = ImportArgs {
source: Source::Chroma,
input: None,
qdrant_url: None,
chroma_url: Some(format!("http://{addr}")),
chroma_tenant: None,
chroma_database: None,
postgres_url: None,
table: None,
api_key: None,
collection: "places".to_string(),
data_dir: dir.path().join("data"),
metric: DistanceMetric::Cosine,
dim: None,
filterable: Vec::new(),
id_field: None,
vector_field: None,
vector_name: None,
encryption_key: None,
insecure: true,
};
assert_eq!(import(args).unwrap(), 1);
let db = Database::open(&dir.path().join("data")).unwrap();
assert_eq!(db.len("places").unwrap(), 1);
}
#[test]
fn live_chroma_import_rejects_a_non_chroma_source() {
let dir = tempfile::tempdir().unwrap();
let args = ImportArgs {
source: Source::Qdrant,
input: None,
qdrant_url: None,
chroma_url: Some("http://localhost:8000".to_string()),
chroma_tenant: None,
chroma_database: None,
postgres_url: None,
table: None,
api_key: None,
collection: "c".to_string(),
data_dir: dir.path().join("data"),
metric: DistanceMetric::L2,
dim: None,
filterable: Vec::new(),
id_field: None,
vector_field: None,
vector_name: None,
encryption_key: None,
insecure: true,
};
assert!(import(args).is_err());
}
#[test]
fn live_postgres_import_rejects_a_non_pgvector_source() {
let dir = tempfile::tempdir().unwrap();
let args = ImportArgs {
source: Source::Chroma,
input: None,
qdrant_url: None,
chroma_url: None,
chroma_tenant: None,
chroma_database: None,
postgres_url: Some("postgresql://localhost/db".to_string()),
table: None,
api_key: None,
collection: "c".to_string(),
data_dir: dir.path().join("data"),
metric: DistanceMetric::L2,
dim: None,
filterable: Vec::new(),
id_field: None,
vector_field: None,
vector_name: None,
encryption_key: None,
insecure: true,
};
assert!(import(args).is_err());
}
}