use rusqlite::{Connection, OptionalExtension, params};
use solo_core::{Error, Result};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EmbedderIdentity {
pub name: String,
pub version: String,
pub dim: u32,
pub dtype: String,
}
impl EmbedderIdentity {
pub fn from_embedder(e: &dyn solo_core::Embedder) -> Self {
Self {
name: e.name().to_string(),
version: e.version().to_string(),
dim: e.dim() as u32,
dtype: dtype_to_str(e.dtype()).to_string(),
}
}
}
fn dtype_to_str(d: solo_core::EmbeddingDtype) -> &'static str {
match d {
solo_core::EmbeddingDtype::F32 => "f32",
solo_core::EmbeddingDtype::F16 => "f16",
solo_core::EmbeddingDtype::I8 => "i8",
solo_core::EmbeddingDtype::Binary => "binary",
}
}
pub fn get_or_insert_embedder_id(
conn: &Connection,
identity: &EmbedderIdentity,
) -> Result<i64> {
let existing: Option<(i64, u32, String)> = conn
.query_row(
"SELECT embedder_id, dim, dtype
FROM embedders
WHERE name = ? AND version = ?",
params![&identity.name, &identity.version],
|r| Ok((r.get::<_, i64>(0)?, r.get::<_, u32>(1)?, r.get::<_, String>(2)?)),
)
.optional()
.map_err(|e| Error::storage(format!("lookup embedder_id: {e}")))?;
if let Some((id, dim, dtype)) = existing {
if dim != identity.dim || dtype != identity.dtype {
return Err(Error::conflict(format!(
"embedder ({}, {}) already registered with dim={dim}/dtype={dtype}; \
caller provided dim={}/dtype={}. Bump the embedder version + run \
`solo reembed` to regenerate vectors.",
identity.name, identity.version, identity.dim, identity.dtype
)));
}
return Ok(id);
}
let now_ms = chrono::Utc::now().timestamp_millis();
conn.execute(
"INSERT INTO embedders (name, version, dim, dtype, first_seen_ms)
VALUES (?, ?, ?, ?, ?)",
params![
&identity.name,
&identity.version,
identity.dim,
&identity.dtype,
now_ms,
],
)
.map_err(|e| Error::storage(format!("INSERT embedders row: {e}")))?;
let id = conn.last_insert_rowid();
tracing::info!(
embedder_id = id,
name = %identity.name,
version = %identity.version,
dim = identity.dim,
dtype = %identity.dtype,
"registered embedder in `embedders` table"
);
Ok(id)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::migration::run_migrations;
use rusqlite::Connection;
fn fresh_conn() -> Connection {
let mut c = Connection::open_in_memory().unwrap();
run_migrations(&mut c).unwrap();
c
}
fn id() -> EmbedderIdentity {
EmbedderIdentity {
name: "test-embedder".into(),
version: "v1".into(),
dim: 1024,
dtype: "f32".into(),
}
}
#[test]
fn first_call_inserts_row_returns_id() {
let conn = fresh_conn();
let id1 = get_or_insert_embedder_id(&conn, &id()).unwrap();
assert!(id1 > 0);
let n: i64 = conn
.query_row("SELECT COUNT(*) FROM embedders", [], |r| r.get(0))
.unwrap();
assert_eq!(n, 1);
}
#[test]
fn second_call_with_same_identity_returns_same_id() {
let conn = fresh_conn();
let id1 = get_or_insert_embedder_id(&conn, &id()).unwrap();
let id2 = get_or_insert_embedder_id(&conn, &id()).unwrap();
assert_eq!(id1, id2);
let n: i64 = conn
.query_row("SELECT COUNT(*) FROM embedders", [], |r| r.get(0))
.unwrap();
assert_eq!(n, 1, "must NOT insert a duplicate row");
}
#[test]
fn different_version_inserts_new_row() {
let conn = fresh_conn();
let id_v1 = get_or_insert_embedder_id(&conn, &id()).unwrap();
let mut id_v2 = id();
id_v2.version = "v2".into();
let id_v2 = get_or_insert_embedder_id(&conn, &id_v2).unwrap();
assert_ne!(id_v1, id_v2);
}
#[test]
fn dim_mismatch_for_same_identity_rejected() {
let conn = fresh_conn();
let _ = get_or_insert_embedder_id(&conn, &id()).unwrap();
let mut bad = id();
bad.dim = 2048;
let err = get_or_insert_embedder_id(&conn, &bad).unwrap_err();
assert!(matches!(err, Error::Conflict(_)), "got: {err:?}");
}
#[test]
fn dtype_mismatch_for_same_identity_rejected() {
let conn = fresh_conn();
let _ = get_or_insert_embedder_id(&conn, &id()).unwrap();
let mut bad = id();
bad.dtype = "f16".into();
let err = get_or_insert_embedder_id(&conn, &bad).unwrap_err();
assert!(matches!(err, Error::Conflict(_)), "got: {err:?}");
}
}