#![allow(clippy::doc_markdown)]
use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::config::OpenAiConfig;
use crate::embedder::Embedder;
use crate::error::EmbedError;
use crate::http::{classify_ureq_error, decode_json};
use crate::manifest::EmbedderManifest;
const MAX_BATCH: usize = 96;
const KNOWN_MODELS: &[(&str, u32)] = &[
("text-embedding-3-small", 1536),
("text-embedding-3-large", 3072),
("text-embedding-ada-002", 1536),
("BAAI/bge-large-en-v1.5", 1024),
("BAAI/bge-base-en-v1.5", 768),
("BAAI/bge-small-en-v1.5", 384),
("BAAI/bge-m3", 1024),
("Qwen/Qwen3-Embedding-0.6B", 1024),
("Qwen/Qwen3-Embedding-4B", 2560),
("Qwen/Qwen3-Embedding-8B", 4096),
("mixedbread-ai/mxbai-embed-large-v1", 1024),
];
fn known_dim(model: &str) -> Option<u32> {
KNOWN_MODELS
.iter()
.find_map(|(m, d)| (*m == model).then_some(*d))
}
#[derive(Debug)]
pub struct OpenAiEmbedder {
model_bare: String,
model_fq: String,
dim: u32,
api_key: String,
endpoint: String,
agent: ureq::Agent,
}
impl OpenAiEmbedder {
pub fn from_config(config: &OpenAiConfig) -> Result<Self, EmbedError> {
let api_key =
std::env::var(&config.api_key_env).map_err(|_| EmbedError::MissingApiKey {
var: config.api_key_env.clone(),
})?;
let dim = match config.dim_override {
Some(d) => d,
None => known_dim(&config.model).ok_or_else(|| {
EmbedError::Config(format!(
"unknown OpenAI embedding model '{}'; expected one of {:?}, \
or set `dim_override` in the config to pass through unknown models",
config.model,
KNOWN_MODELS.iter().map(|(m, _)| *m).collect::<Vec<_>>(),
))
})?,
};
let endpoint = format!("{}/v1/embeddings", config.base_url.trim_end_matches('/'));
let agent = ureq::AgentBuilder::new()
.timeout(Duration::from_secs(config.timeout_secs))
.build();
Ok(Self {
model_bare: config.model.clone(),
model_fq: format!("openai:{}", config.model),
dim,
api_key,
endpoint,
agent,
})
}
fn post_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbedError> {
#[derive(Serialize)]
struct Req<'a> {
model: &'a str,
input: &'a [&'a str],
}
#[derive(Deserialize)]
struct Resp {
data: Vec<Datum>,
}
#[derive(Deserialize)]
struct Datum {
embedding: Vec<f32>,
index: usize,
}
let body = Req {
model: &self.model_bare,
input: texts,
};
let resp = self
.agent
.post(&self.endpoint)
.set("Authorization", &format!("Bearer {}", self.api_key))
.set("Content-Type", "application/json")
.send_json(&body)
.map_err(classify_ureq_error)?;
let parsed: Resp = decode_json(resp)?;
if parsed.data.len() != texts.len() {
return Err(EmbedError::Decode(format!(
"OpenAI returned {} embeddings for {} inputs",
parsed.data.len(),
texts.len(),
)));
}
let mut data = parsed.data;
data.sort_by_key(|d| d.index);
let mut out = Vec::with_capacity(data.len());
for d in data {
if d.embedding.len() as u32 != self.dim {
return Err(EmbedError::DimMismatch {
expected: self.dim,
got: d.embedding.len() as u32,
});
}
out.push(d.embedding);
}
Ok(out)
}
}
impl Embedder for OpenAiEmbedder {
fn model(&self) -> &str {
&self.model_fq
}
fn dim(&self) -> u32 {
self.dim
}
fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedError> {
let mut v = self.post_batch(&[text])?;
Ok(v.pop().expect("post_batch returned 1 for 1 input"))
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbedError> {
let mut out = Vec::with_capacity(texts.len());
for chunk in texts.chunks(MAX_BATCH) {
let part = self.post_batch(chunk)?;
out.extend(part);
}
Ok(out)
}
fn manifest(&self) -> EmbedderManifest {
EmbedderManifest::new(self.model_fq.clone(), self.dim, 0.27)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn missing_api_key_is_surfaced() {
let var = "MNEM_TEST_OPENAI_KEY_NEVER_SET_bc6a78c1fd3e4a9f";
let cfg = OpenAiConfig {
model: "text-embedding-3-small".into(),
api_key_env: var.into(),
..Default::default()
};
let e = OpenAiEmbedder::from_config(&cfg).unwrap_err();
match e {
EmbedError::MissingApiKey { var: got } => assert_eq!(got, var),
other => panic!("expected MissingApiKey, got {other:?}"),
}
}
#[test]
fn known_dim_maps_shipped_models() {
assert_eq!(known_dim("text-embedding-3-small"), Some(1536));
assert_eq!(known_dim("text-embedding-3-large"), Some(3072));
assert_eq!(known_dim("text-embedding-ada-002"), Some(1536));
assert_eq!(known_dim("text-embedding-3-superhuge"), None);
}
}