use async_trait::async_trait;
use super::{AsyncVectorizer, EmbeddingRequest, EmbeddingResponse, Vectorizer};
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct AzureOpenAIConfig {
pub azure_endpoint: url::Url,
pub api_key: String,
pub deployment: String,
pub api_version: String,
}
impl AzureOpenAIConfig {
pub fn new(
azure_endpoint: impl AsRef<str>,
api_key: impl Into<String>,
deployment: impl Into<String>,
api_version: impl Into<String>,
) -> Result<Self> {
Ok(Self {
azure_endpoint: url::Url::parse(azure_endpoint.as_ref())?,
api_key: api_key.into(),
deployment: deployment.into(),
api_version: api_version.into(),
})
}
pub fn from_env(deployment: impl Into<String>) -> Result<Self> {
let endpoint = std::env::var("AZURE_OPENAI_ENDPOINT").map_err(|_| {
crate::error::Error::InvalidInput("AZURE_OPENAI_ENDPOINT not set".into())
})?;
let api_key = std::env::var("AZURE_OPENAI_API_KEY").map_err(|_| {
crate::error::Error::InvalidInput("AZURE_OPENAI_API_KEY not set".into())
})?;
let api_version =
std::env::var("OPENAI_API_VERSION").unwrap_or_else(|_| "2024-02-01".to_string());
Self::new(endpoint, api_key, deployment, api_version)
}
fn embeddings_url(&self) -> Result<url::Url> {
let path = format!(
"openai/deployments/{}/embeddings?api-version={}",
self.deployment, self.api_version
);
Ok(self.azure_endpoint.join(&path)?)
}
}
#[derive(Debug, Clone)]
pub struct AzureOpenAITextVectorizer {
config: AzureOpenAIConfig,
client: reqwest::Client,
blocking_client: reqwest::blocking::Client,
}
impl AzureOpenAITextVectorizer {
pub fn new(config: AzureOpenAIConfig) -> Self {
Self {
config,
client: reqwest::Client::new(),
blocking_client: reqwest::blocking::Client::new(),
}
}
async fn embed_many_inner(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let response: EmbeddingResponse = self
.client
.post(self.config.embeddings_url()?)
.header("api-key", &self.config.api_key)
.json(&EmbeddingRequest {
model: &self.config.deployment,
input: texts.to_vec(),
})
.send()
.await?
.error_for_status()?
.json()
.await?;
Ok(response.data.into_iter().map(|d| d.embedding).collect())
}
}
impl Vectorizer for AzureOpenAITextVectorizer {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let response: EmbeddingResponse = self
.blocking_client
.post(self.config.embeddings_url()?)
.header("api-key", &self.config.api_key)
.json(&EmbeddingRequest {
model: &self.config.deployment,
input: vec![text],
})
.send()?
.error_for_status()?
.json()?;
Ok(response
.data
.into_iter()
.next()
.map_or_else(Vec::new, |d| d.embedding))
}
}
#[async_trait]
impl AsyncVectorizer for AzureOpenAITextVectorizer {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
let mut v = self.embed_many_inner(&[text]).await?;
Ok(v.pop().unwrap_or_default())
}
async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
self.embed_many_inner(texts).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn azure_config_builds_embeddings_url() {
let cfg = AzureOpenAIConfig::new(
"https://myresource.openai.azure.com/",
"test-key",
"my-deployment",
"2024-02-01",
)
.unwrap();
let url = cfg.embeddings_url().unwrap();
assert!(
url.as_str()
.contains("openai/deployments/my-deployment/embeddings"),
"URL was: {url}"
);
assert!(
url.as_str().contains("api-version=2024-02-01"),
"URL was: {url}"
);
}
#[test]
fn azure_config_rejects_bad_url() {
let result = AzureOpenAIConfig::new("not a url", "key", "dep", "v1");
assert!(result.is_err());
}
#[test]
fn azure_vectorizer_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<AzureOpenAITextVectorizer>();
}
}