use anyhow::{anyhow, Result};
use async_trait::async_trait;
use smooth_operator::embedding::{Embedder, InputType};
pub const OPENAI_SMALL_EMBEDDING_DIM: usize = 1536;
#[derive(Clone)]
pub struct GatewayEmbedder {
client: reqwest::Client,
base_url: String,
api_key: String,
model: String,
dim: usize,
}
impl GatewayEmbedder {
#[must_use]
pub fn new(
base_url: impl Into<String>,
api_key: impl Into<String>,
model: impl Into<String>,
dim: usize,
) -> Self {
Self {
client: reqwest::Client::new(),
base_url: base_url.into(),
api_key: api_key.into(),
model: model.into(),
dim,
}
}
pub fn from_env() -> Result<Self> {
let base_url = std::env::var("SMOOAI_GATEWAY_URL")
.map_err(|_| anyhow!("SMOOAI_GATEWAY_URL is not set"))?;
let api_key = std::env::var("SMOOAI_GATEWAY_KEY")
.map_err(|_| anyhow!("SMOOAI_GATEWAY_KEY is not set"))?;
Ok(Self::new(
base_url,
api_key,
"text-embedding-3-small",
OPENAI_SMALL_EMBEDDING_DIM,
))
}
}
#[async_trait]
impl Embedder for GatewayEmbedder {
fn dim(&self) -> usize {
self.dim
}
async fn embed(&self, texts: &[String], _input_type: InputType) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/'));
let body = serde_json::json!({ "model": self.model, "input": texts });
let resp = self
.client
.post(&url)
.bearer_auth(&self.api_key)
.json(&body)
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
return Err(anyhow!("embeddings request failed ({status}): {text}"));
}
#[derive(serde::Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
index: usize,
}
#[derive(serde::Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
let mut parsed: EmbeddingResponse = resp.json().await?;
parsed.data.sort_by_key(|d| d.index);
let out: Vec<Vec<f32>> = parsed.data.into_iter().map(|d| d.embedding).collect();
if out.len() != texts.len() {
return Err(anyhow!(
"embeddings count mismatch: got {} for {} inputs",
out.len(),
texts.len()
));
}
for (i, v) in out.iter().enumerate() {
if v.len() != self.dim {
return Err(anyhow!(
"embedding {i} has dim {} but adapter expects {}",
v.len(),
self.dim
));
}
}
Ok(out)
}
}