redis_vl/vectorizers/
azure_openai.rs1use async_trait::async_trait;
6
7use super::{AsyncVectorizer, EmbeddingRequest, EmbeddingResponse, Vectorizer};
8use crate::error::Result;
9
10#[derive(Debug, Clone)]
12pub struct AzureOpenAIConfig {
13 pub azure_endpoint: url::Url,
15 pub api_key: String,
17 pub deployment: String,
19 pub api_version: String,
21}
22
23impl AzureOpenAIConfig {
24 pub fn new(
26 azure_endpoint: impl AsRef<str>,
27 api_key: impl Into<String>,
28 deployment: impl Into<String>,
29 api_version: impl Into<String>,
30 ) -> Result<Self> {
31 Ok(Self {
32 azure_endpoint: url::Url::parse(azure_endpoint.as_ref())?,
33 api_key: api_key.into(),
34 deployment: deployment.into(),
35 api_version: api_version.into(),
36 })
37 }
38
39 pub fn from_env(deployment: impl Into<String>) -> Result<Self> {
42 let endpoint = std::env::var("AZURE_OPENAI_ENDPOINT").map_err(|_| {
43 crate::error::Error::InvalidInput("AZURE_OPENAI_ENDPOINT not set".into())
44 })?;
45 let api_key = std::env::var("AZURE_OPENAI_API_KEY").map_err(|_| {
46 crate::error::Error::InvalidInput("AZURE_OPENAI_API_KEY not set".into())
47 })?;
48 let api_version =
49 std::env::var("OPENAI_API_VERSION").unwrap_or_else(|_| "2024-02-01".to_string());
50 Self::new(endpoint, api_key, deployment, api_version)
51 }
52
53 fn embeddings_url(&self) -> Result<url::Url> {
54 let path = format!(
55 "openai/deployments/{}/embeddings?api-version={}",
56 self.deployment, self.api_version
57 );
58 Ok(self.azure_endpoint.join(&path)?)
59 }
60}
61
62#[derive(Debug, Clone)]
66pub struct AzureOpenAITextVectorizer {
67 config: AzureOpenAIConfig,
68 client: reqwest::Client,
69 blocking_client: reqwest::blocking::Client,
70}
71
72impl AzureOpenAITextVectorizer {
73 pub fn new(config: AzureOpenAIConfig) -> Self {
75 Self {
76 config,
77 client: reqwest::Client::new(),
78 blocking_client: reqwest::blocking::Client::new(),
79 }
80 }
81
82 async fn embed_many_inner(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
83 let response: EmbeddingResponse = self
84 .client
85 .post(self.config.embeddings_url()?)
86 .header("api-key", &self.config.api_key)
87 .json(&EmbeddingRequest {
88 model: &self.config.deployment,
89 input: texts.to_vec(),
90 })
91 .send()
92 .await?
93 .error_for_status()?
94 .json()
95 .await?;
96 Ok(response.data.into_iter().map(|d| d.embedding).collect())
97 }
98}
99
100impl Vectorizer for AzureOpenAITextVectorizer {
101 fn embed(&self, text: &str) -> Result<Vec<f32>> {
102 let response: EmbeddingResponse = self
103 .blocking_client
104 .post(self.config.embeddings_url()?)
105 .header("api-key", &self.config.api_key)
106 .json(&EmbeddingRequest {
107 model: &self.config.deployment,
108 input: vec![text],
109 })
110 .send()?
111 .error_for_status()?
112 .json()?;
113 Ok(response
114 .data
115 .into_iter()
116 .next()
117 .map_or_else(Vec::new, |d| d.embedding))
118 }
119}
120
121#[async_trait]
122impl AsyncVectorizer for AzureOpenAITextVectorizer {
123 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
124 let mut v = self.embed_many_inner(&[text]).await?;
125 Ok(v.pop().unwrap_or_default())
126 }
127
128 async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
129 self.embed_many_inner(texts).await
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn azure_config_builds_embeddings_url() {
139 let cfg = AzureOpenAIConfig::new(
140 "https://myresource.openai.azure.com/",
141 "test-key",
142 "my-deployment",
143 "2024-02-01",
144 )
145 .unwrap();
146 let url = cfg.embeddings_url().unwrap();
147 assert!(
148 url.as_str()
149 .contains("openai/deployments/my-deployment/embeddings"),
150 "URL was: {url}"
151 );
152 assert!(
153 url.as_str().contains("api-version=2024-02-01"),
154 "URL was: {url}"
155 );
156 }
157
158 #[test]
159 fn azure_config_rejects_bad_url() {
160 let result = AzureOpenAIConfig::new("not a url", "key", "dep", "v1");
161 assert!(result.is_err());
162 }
163
164 #[test]
165 fn azure_vectorizer_is_send_sync() {
166 fn assert_send_sync<T: Send + Sync>() {}
167 assert_send_sync::<AzureOpenAITextVectorizer>();
168 }
169}