use super::Embedder;
use crate::vcr::{VcrClient, VcrMode};
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct OpenAIEmbedder {
pub model: String,
pub api_key: String,
pub client: reqwest::Client,
vcr: Option<Arc<Mutex<VcrClient>>>,
}
impl OpenAIEmbedder {
pub fn new(model: String, api_key: String) -> Self {
Self {
model,
api_key,
client: reqwest::Client::new(),
vcr: None,
}
}
pub fn with_vcr(model: String, api_key: String, vcr: Arc<Mutex<VcrClient>>) -> Self {
Self {
model,
api_key,
client: reqwest::Client::new(),
vcr: Some(vcr),
}
}
pub fn from_env(model: String, api_key: String) -> Self {
let vcr_mode = VcrMode::from_env();
if vcr_mode != VcrMode::Off {
let vcr = VcrClient::from_env();
Self::with_vcr(model, api_key, Arc::new(Mutex::new(vcr)))
} else {
Self::new(model, api_key)
}
}
}
#[async_trait]
impl Embedder for OpenAIEmbedder {
async fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
let url = "https://api.openai.com/v1/embeddings";
let body = json!({
"input": text,
"model": self.model,
"encoding_format": "float"
});
let json: serde_json::Value = if let Some(vcr) = &self.vcr {
let mut vcr_guard = vcr.lock().await;
let auth = format!("Bearer {}", self.api_key);
let resp = vcr_guard.post_json(url, &body, Some(&auth)).await?;
if !resp.is_success() {
anyhow::bail!(
"OpenAI embeddings API error (status {}): {}",
resp.status,
resp.body
);
}
resp.body
} else {
crate::providers::network::check_outbound(url)?;
let resp = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await?;
if !resp.status().is_success() {
let error_text = resp.text().await.unwrap_or_default();
anyhow::bail!("OpenAI embeddings API error: {}", error_text);
}
resp.json().await?
};
let vec = json
.pointer("/data/0/embedding")
.and_then(|v| v.as_array())
.ok_or_else(|| anyhow::anyhow!("OpenAI API response missing embedding field"))?;
let floats: Vec<f32> = vec
.iter()
.map(|x| x.as_f64().unwrap_or(0.0) as f32)
.collect();
Ok(floats)
}
fn name(&self) -> &'static str {
"openai"
}
fn model_id(&self) -> String {
self.model.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn openai_embedder_respects_network_deny_policy() {
let _serial = crate::providers::network::lock_test_serial_async().await;
let _guard = crate::providers::network::NetworkPolicyGuard::deny("unit test");
let embedder = OpenAIEmbedder::new("text-embedding-3-small".to_string(), "test-key".into());
let err = embedder
.embed("hello")
.await
.expect_err("network deny policy should block outbound call");
let msg = err.to_string();
assert!(msg.contains("outbound network blocked by policy"));
assert!(msg.contains("api.openai.com"));
}
}