use core::fmt;
use core::time::Duration;
use std::sync::Arc;
use crate::error::{Error, Result};
use crate::semantic::Embedding;
use crate::semantic::provider::EmbeddingProvider;
pub const DEFAULT_MODEL: &str = "text-embedding-3-small";
fn dimension_for(model: &str) -> usize {
match model {
"text-embedding-3-small" => 1536,
"text-embedding-3-large" => 3072,
"text-embedding-ada-002" => 1536,
_ => 1536, }
}
#[derive(Clone)]
pub struct OpenAiProvider {
inner: Arc<Inner>,
}
struct Inner {
api_key: String,
base_url: String,
model: String,
client: reqwest::blocking::Client,
}
impl OpenAiProvider {
pub fn new(api_key: impl Into<String>) -> Result<Self> {
let client = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(|e| Error::Http(e.to_string()))?;
Ok(Self {
inner: Arc::new(Inner {
api_key: api_key.into(),
base_url: "https://api.openai.com/v1".to_string(),
model: DEFAULT_MODEL.to_string(),
client,
}),
})
}
#[must_use]
pub fn with_model(mut self, model: impl Into<String>) -> Self {
let prev = Arc::try_unwrap(self.inner).unwrap_or_else(|arc| {
Inner {
api_key: arc.api_key.clone(),
base_url: arc.base_url.clone(),
model: arc.model.clone(),
client: arc.client.clone(),
}
});
self.inner = Arc::new(Inner {
model: model.into(),
..prev
});
self
}
#[must_use]
pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
let prev = Arc::try_unwrap(self.inner).unwrap_or_else(|arc| Inner {
api_key: arc.api_key.clone(),
base_url: arc.base_url.clone(),
model: arc.model.clone(),
client: arc.client.clone(),
});
self.inner = Arc::new(Inner {
base_url: url.into(),
..prev
});
self
}
pub fn embed_batch(&self, inputs: &[&str]) -> Result<alloc::vec::Vec<Embedding>> {
if inputs.is_empty() {
return Ok(alloc::vec::Vec::new());
}
let url = alloc::format!("{}/embeddings", self.inner.base_url);
let body = serde_json::json!({
"model": self.inner.model,
"input": inputs,
});
let inner = self.inner.clone();
let url_owned = url;
let body_owned = body;
let resp = super::retry::send_with_retry(
&inner.client,
|| {
inner
.client
.post(&url_owned)
.bearer_auth(&inner.api_key)
.json(&body_owned)
},
"OpenAI",
)?;
let json: serde_json::Value = resp.json().map_err(|e| Error::Http(e.to_string()))?;
let data = json
.get("data")
.and_then(|v| v.as_array())
.ok_or_else(|| Error::Http("missing `data` array in response".into()))?;
if data.is_empty() {
return Err(Error::EmptyEmbedding);
}
let mut out = alloc::vec::Vec::with_capacity(data.len());
for item in data {
let vec_field = item
.get("embedding")
.and_then(|v| v.as_array())
.ok_or_else(|| Error::Http("missing `embedding` field".into()))?;
let vector: alloc::vec::Vec<f32> = vec_field
.iter()
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
.collect();
out.push(Embedding::with_model(
vector,
Some(self.inner.model.clone()),
)?);
}
Ok(out)
}
}
impl fmt::Debug for OpenAiProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OpenAiProvider")
.field("model", &self.inner.model)
.field("base_url", &self.inner.base_url)
.field("api_key", &"<redacted>")
.finish()
}
}
impl EmbeddingProvider for OpenAiProvider {
type Input = str;
fn embed(&self, input: &str) -> Result<Embedding> {
let mut batch = self.embed_batch(&[input])?;
batch.pop().ok_or(Error::EmptyEmbedding)
}
fn model_id(&self) -> &str {
&self.inner.model
}
fn dimension(&self) -> usize {
dimension_for(&self.inner.model)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn debug_redacts_api_key() {
let p = OpenAiProvider::new("sk-secret-do-not-leak").unwrap();
let s = alloc::format!("{p:?}");
assert!(!s.contains("sk-secret"));
assert!(s.contains("<redacted>"));
}
#[test]
fn dimension_table_lookups() {
assert_eq!(dimension_for("text-embedding-3-small"), 1536);
assert_eq!(dimension_for("text-embedding-3-large"), 3072);
assert_eq!(dimension_for("unknown-model"), 1536);
}
#[test]
fn with_model_changes_model_id() {
let p = OpenAiProvider::new("sk-test")
.unwrap()
.with_model("text-embedding-3-large");
assert_eq!(p.model_id(), "text-embedding-3-large");
assert_eq!(p.dimension(), 3072);
}
}