Skip to main content

redis_vl/vectorizers/
azure_openai.rs

1//! Azure OpenAI embedding adapter.
2//!
3//! Enabled by the `azure-openai` feature flag.
4
5use async_trait::async_trait;
6
7use super::{AsyncVectorizer, EmbeddingRequest, EmbeddingResponse, Vectorizer};
8use crate::error::Result;
9
10/// Configuration for connecting to an Azure OpenAI deployment.
11#[derive(Debug, Clone)]
12pub struct AzureOpenAIConfig {
13    /// Azure OpenAI resource endpoint (e.g. `https://myresource.openai.azure.com/`).
14    pub azure_endpoint: url::Url,
15    /// API key for authentication.
16    pub api_key: String,
17    /// Deployment name (not the model name).
18    pub deployment: String,
19    /// API version, e.g. `"2024-02-01"`.
20    pub api_version: String,
21}
22
23impl AzureOpenAIConfig {
24    /// Creates a new Azure OpenAI configuration.
25    pub fn new(
26        azure_endpoint: impl AsRef<str>,
27        api_key: impl Into<String>,
28        deployment: impl Into<String>,
29        api_version: impl Into<String>,
30    ) -> Result<Self> {
31        Ok(Self {
32            azure_endpoint: url::Url::parse(azure_endpoint.as_ref())?,
33            api_key: api_key.into(),
34            deployment: deployment.into(),
35            api_version: api_version.into(),
36        })
37    }
38
39    /// Constructs from environment variables:
40    /// `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_API_KEY`, `OPENAI_API_VERSION`.
41    pub fn from_env(deployment: impl Into<String>) -> Result<Self> {
42        let endpoint = std::env::var("AZURE_OPENAI_ENDPOINT").map_err(|_| {
43            crate::error::Error::InvalidInput("AZURE_OPENAI_ENDPOINT not set".into())
44        })?;
45        let api_key = std::env::var("AZURE_OPENAI_API_KEY").map_err(|_| {
46            crate::error::Error::InvalidInput("AZURE_OPENAI_API_KEY not set".into())
47        })?;
48        let api_version =
49            std::env::var("OPENAI_API_VERSION").unwrap_or_else(|_| "2024-02-01".to_string());
50        Self::new(endpoint, api_key, deployment, api_version)
51    }
52
53    fn embeddings_url(&self) -> Result<url::Url> {
54        let path = format!(
55            "openai/deployments/{}/embeddings?api-version={}",
56            self.deployment, self.api_version
57        );
58        Ok(self.azure_endpoint.join(&path)?)
59    }
60}
61
62/// Azure OpenAI embedding adapter.
63///
64/// Uses the Azure-specific endpoint format with `api-key` header authentication.
65#[derive(Debug, Clone)]
66pub struct AzureOpenAITextVectorizer {
67    config: AzureOpenAIConfig,
68    client: reqwest::Client,
69    blocking_client: reqwest::blocking::Client,
70}
71
72impl AzureOpenAITextVectorizer {
73    /// Creates a new Azure OpenAI adapter.
74    pub fn new(config: AzureOpenAIConfig) -> Self {
75        Self {
76            config,
77            client: reqwest::Client::new(),
78            blocking_client: reqwest::blocking::Client::new(),
79        }
80    }
81
82    async fn embed_many_inner(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
83        let response: EmbeddingResponse = self
84            .client
85            .post(self.config.embeddings_url()?)
86            .header("api-key", &self.config.api_key)
87            .json(&EmbeddingRequest {
88                model: &self.config.deployment,
89                input: texts.to_vec(),
90            })
91            .send()
92            .await?
93            .error_for_status()?
94            .json()
95            .await?;
96        Ok(response.data.into_iter().map(|d| d.embedding).collect())
97    }
98}
99
100impl Vectorizer for AzureOpenAITextVectorizer {
101    fn embed(&self, text: &str) -> Result<Vec<f32>> {
102        let response: EmbeddingResponse = self
103            .blocking_client
104            .post(self.config.embeddings_url()?)
105            .header("api-key", &self.config.api_key)
106            .json(&EmbeddingRequest {
107                model: &self.config.deployment,
108                input: vec![text],
109            })
110            .send()?
111            .error_for_status()?
112            .json()?;
113        Ok(response
114            .data
115            .into_iter()
116            .next()
117            .map_or_else(Vec::new, |d| d.embedding))
118    }
119}
120
121#[async_trait]
122impl AsyncVectorizer for AzureOpenAITextVectorizer {
123    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
124        let mut v = self.embed_many_inner(&[text]).await?;
125        Ok(v.pop().unwrap_or_default())
126    }
127
128    async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
129        self.embed_many_inner(texts).await
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn azure_config_builds_embeddings_url() {
139        let cfg = AzureOpenAIConfig::new(
140            "https://myresource.openai.azure.com/",
141            "test-key",
142            "my-deployment",
143            "2024-02-01",
144        )
145        .unwrap();
146        let url = cfg.embeddings_url().unwrap();
147        assert!(
148            url.as_str()
149                .contains("openai/deployments/my-deployment/embeddings"),
150            "URL was: {url}"
151        );
152        assert!(
153            url.as_str().contains("api-version=2024-02-01"),
154            "URL was: {url}"
155        );
156    }
157
158    #[test]
159    fn azure_config_rejects_bad_url() {
160        let result = AzureOpenAIConfig::new("not a url", "key", "dep", "v1");
161        assert!(result.is_err());
162    }
163
164    #[test]
165    fn azure_vectorizer_is_send_sync() {
166        fn assert_send_sync<T: Send + Sync>() {}
167        assert_send_sync::<AzureOpenAITextVectorizer>();
168    }
169}