1use std::fmt::Debug;
13
14use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
15#[cfg(feature = "image")]
16use crate::client::Nothing;
17use crate::client::{
18 self, ApiKey, Capabilities, Capable, DebugExt, Provider, ProviderBuilder, ProviderClient,
19};
20use crate::completion::GetTokenUsage;
21use crate::http_client::{self, HttpClientExt, bearer_auth_header};
22use crate::streaming::StreamingCompletionResponse;
23use crate::transcription::TranscriptionError;
24use crate::{
25 completion::{self, CompletionError, CompletionRequest},
26 embeddings::{self, EmbeddingError},
27 json_utils,
28 providers::openai,
29 telemetry::SpanCombinator,
30 transcription::{self},
31};
32use bytes::Bytes;
33use reqwest::multipart::Part;
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36const DEFAULT_API_VERSION: &str = "2024-10-21";
41
42#[derive(Debug, Clone)]
43pub struct AzureExt {
44 endpoint: String,
45 api_version: String,
46}
47
48impl DebugExt for AzureExt {
49 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
50 [
51 ("endpoint", (&self.endpoint as &dyn Debug)),
52 ("api_version", (&self.api_version as &dyn Debug)),
53 ]
54 .into_iter()
55 }
56}
57
58#[derive(Debug, Clone)]
63pub struct AzureExtBuilder {
64 endpoint: Option<String>,
65 api_version: String,
66}
67
68impl Default for AzureExtBuilder {
69 fn default() -> Self {
70 Self {
71 endpoint: None,
72 api_version: DEFAULT_API_VERSION.into(),
73 }
74 }
75}
76
77pub type Client<H = reqwest::Client> = client::Client<AzureExt, H>;
78pub type ClientBuilder<H = reqwest::Client> =
79 client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H>;
80
81impl Provider for AzureExt {
82 type Builder = AzureExtBuilder;
83
84 const VERIFY_PATH: &'static str = "";
86
87 fn build<H>(
88 builder: &client::ClientBuilder<
89 Self::Builder,
90 <Self::Builder as ProviderBuilder>::ApiKey,
91 H,
92 >,
93 ) -> http_client::Result<Self> {
94 let AzureExtBuilder {
95 endpoint,
96 api_version,
97 ..
98 } = builder.ext().clone();
99
100 match endpoint {
101 Some(endpoint) => Ok(Self {
102 endpoint,
103 api_version,
104 }),
105 None => Err(http_client::Error::Instance(
106 "Azure client must be provided an endpoint prior to building".into(),
107 )),
108 }
109 }
110}
111
112impl<H> Capabilities<H> for AzureExt {
113 type Completion = Capable<CompletionModel<H>>;
114 type Embeddings = Capable<EmbeddingModel<H>>;
115 type Transcription = Capable<TranscriptionModel<H>>;
116 #[cfg(feature = "image")]
117 type ImageGeneration = Nothing;
118 #[cfg(feature = "audio")]
119 type AudioGeneration = Capable<AudioGenerationModel<H>>;
120}
121
122impl ProviderBuilder for AzureExtBuilder {
123 type Output = AzureExt;
124 type ApiKey = AzureOpenAIAuth;
125
126 const BASE_URL: &'static str = "";
127
128 fn finish<H>(
129 &self,
130 mut builder: client::ClientBuilder<Self, Self::ApiKey, H>,
131 ) -> http_client::Result<client::ClientBuilder<Self, Self::ApiKey, H>> {
132 use AzureOpenAIAuth::*;
133
134 let auth = builder.get_api_key().clone();
135
136 match auth {
137 Token(token) => bearer_auth_header(builder.headers_mut(), token.as_str())?,
138 ApiKey(key) => {
139 let k = http::HeaderName::from_static("api-key");
140 let v = http::HeaderValue::from_str(key.as_str())?;
141
142 builder.headers_mut().insert(k, v);
143 }
144 }
145
146 Ok(builder)
147 }
148}
149
150impl<H> ClientBuilder<H> {
151 pub fn api_version(mut self, api_version: &str) -> Self {
153 self.ext_mut().api_version = api_version.into();
154
155 self
156 }
157}
158
159impl<H> client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H> {
160 pub fn azure_endpoint(self, endpoint: String) -> ClientBuilder<H> {
162 self.over_ext(|AzureExtBuilder { api_version, .. }| AzureExtBuilder {
163 endpoint: Some(endpoint),
164 api_version,
165 })
166 }
167}
168
169#[derive(Clone)]
170pub enum AzureOpenAIAuth {
171 ApiKey(String),
172 Token(String),
173}
174
175impl ApiKey for AzureOpenAIAuth {}
176
177impl std::fmt::Debug for AzureOpenAIAuth {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 match self {
180 Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
181 Self::Token(_) => write!(f, "Token <REDACTED>"),
182 }
183 }
184}
185
186impl<S> From<S> for AzureOpenAIAuth
187where
188 S: Into<String>,
189{
190 fn from(token: S) -> Self {
191 AzureOpenAIAuth::Token(token.into())
192 }
193}
194
195impl<T> Client<T>
196where
197 T: HttpClientExt,
198{
199 fn endpoint(&self) -> &str {
200 &self.ext().endpoint
201 }
202
203 fn api_version(&self) -> &str {
204 &self.ext().api_version
205 }
206
207 fn post_embedding(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
208 let url = format!(
209 "{}/openai/deployments/{}/embeddings?api-version={}",
210 self.endpoint(),
211 deployment_id.trim_start_matches('/'),
212 self.api_version()
213 );
214
215 self.post(&url)
216 }
217
218 #[cfg(feature = "audio")]
219 fn post_audio_generation(
220 &self,
221 deployment_id: &str,
222 ) -> http_client::Result<http_client::Builder> {
223 let url = format!(
224 "{}/openai/deployments/{}/audio/speech?api-version={}",
225 self.endpoint(),
226 deployment_id.trim_start_matches('/'),
227 self.api_version()
228 );
229
230 self.post(url)
231 }
232
233 fn post_chat_completion(
234 &self,
235 deployment_id: &str,
236 ) -> http_client::Result<http_client::Builder> {
237 let url = format!(
238 "{}/openai/deployments/{}/chat/completions?api-version={}",
239 self.endpoint(),
240 deployment_id.trim_start_matches('/'),
241 self.api_version()
242 );
243
244 self.post(&url)
245 }
246
247 fn post_transcription(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
248 let url = format!(
249 "{}/openai/deployments/{}/audio/translations?api-version={}",
250 self.endpoint(),
251 deployment_id.trim_start_matches('/'),
252 self.api_version()
253 );
254
255 self.post(&url)
256 }
257
258 #[cfg(feature = "image")]
259 fn post_image_generation(
260 &self,
261 deployment_id: &str,
262 ) -> http_client::Result<http_client::Builder> {
263 let url = format!(
264 "{}/openai/deployments/{}/images/generations?api-version={}",
265 self.endpoint(),
266 deployment_id.trim_start_matches('/'),
267 self.api_version()
268 );
269
270 self.post(&url)
271 }
272}
273
274pub struct AzureOpenAIClientParams {
275 api_key: String,
276 version: String,
277 header: String,
278}
279
280impl ProviderClient for Client {
281 type Input = AzureOpenAIClientParams;
282
283 fn from_env() -> Self {
285 let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
286 AzureOpenAIAuth::ApiKey(api_key)
287 } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
288 AzureOpenAIAuth::Token(token)
289 } else {
290 panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
291 };
292
293 let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
294 let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
295
296 Self::builder()
297 .api_key(auth)
298 .azure_endpoint(azure_endpoint)
299 .api_version(&api_version)
300 .build()
301 .unwrap()
302 }
303
304 fn from_val(
305 AzureOpenAIClientParams {
306 api_key,
307 version,
308 header,
309 }: Self::Input,
310 ) -> Self {
311 let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
312
313 Self::builder()
314 .api_key(auth)
315 .azure_endpoint(header)
316 .api_version(&version)
317 .build()
318 .unwrap()
319 }
320}
321
322#[derive(Debug, Deserialize)]
323struct ApiErrorResponse {
324 message: String,
325}
326
327#[derive(Debug, Deserialize)]
328#[serde(untagged)]
329enum ApiResponse<T> {
330 Ok(T),
331 Err(ApiErrorResponse),
332}
333
334pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
340pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
342pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
344
345fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
346 match identifier {
347 TEXT_EMBEDDING_3_LARGE => Some(3_072),
348 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
349 _ => None,
350 }
351}
352
353#[derive(Debug, Deserialize)]
354pub struct EmbeddingResponse {
355 pub object: String,
356 pub data: Vec<EmbeddingData>,
357 pub model: String,
358 pub usage: Usage,
359}
360
361impl From<ApiErrorResponse> for EmbeddingError {
362 fn from(err: ApiErrorResponse) -> Self {
363 EmbeddingError::ProviderError(err.message)
364 }
365}
366
367impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
368 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
369 match value {
370 ApiResponse::Ok(response) => Ok(response),
371 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
372 }
373 }
374}
375
376#[derive(Debug, Deserialize)]
377pub struct EmbeddingData {
378 pub object: String,
379 pub embedding: Vec<f64>,
380 pub index: usize,
381}
382
383#[derive(Clone, Debug, Deserialize)]
384pub struct Usage {
385 pub prompt_tokens: usize,
386 pub total_tokens: usize,
387}
388
389impl GetTokenUsage for Usage {
390 fn token_usage(&self) -> Option<crate::completion::Usage> {
391 let mut usage = crate::completion::Usage::new();
392
393 usage.input_tokens = self.prompt_tokens as u64;
394 usage.total_tokens = self.total_tokens as u64;
395 usage.output_tokens = usage.total_tokens - usage.input_tokens;
396
397 Some(usage)
398 }
399}
400
401impl std::fmt::Display for Usage {
402 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403 write!(
404 f,
405 "Prompt tokens: {} Total tokens: {}",
406 self.prompt_tokens, self.total_tokens
407 )
408 }
409}
410
411#[derive(Clone)]
412pub struct EmbeddingModel<T = reqwest::Client> {
413 client: Client<T>,
414 pub model: String,
415 ndims: usize,
416}
417
418impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
419where
420 T: HttpClientExt + Default + Clone + 'static,
421{
422 const MAX_DOCUMENTS: usize = 1024;
423
424 type Client = Client<T>;
425
426 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
427 Self::new(client.clone(), model, dims)
428 }
429
430 fn ndims(&self) -> usize {
431 self.ndims
432 }
433
434 #[cfg_attr(feature = "worker", worker::send)]
435 async fn embed_texts(
436 &self,
437 documents: impl IntoIterator<Item = String>,
438 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
439 let documents = documents.into_iter().collect::<Vec<_>>();
440
441 let body = serde_json::to_vec(&json!({
442 "input": documents,
443 }))?;
444
445 let req = self
446 .client
447 .post_embedding(self.model.as_str())?
448 .body(body)
449 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
450
451 let response = self.client.send(req).await?;
452
453 if response.status().is_success() {
454 let body: Vec<u8> = response.into_body().await?;
455 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
456
457 match body {
458 ApiResponse::Ok(response) => {
459 tracing::info!(target: "rig",
460 "Azure embedding token usage: {}",
461 response.usage
462 );
463
464 if response.data.len() != documents.len() {
465 return Err(EmbeddingError::ResponseError(
466 "Response data length does not match input length".into(),
467 ));
468 }
469
470 Ok(response
471 .data
472 .into_iter()
473 .zip(documents.into_iter())
474 .map(|(embedding, document)| embeddings::Embedding {
475 document,
476 vec: embedding.embedding,
477 })
478 .collect())
479 }
480 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
481 }
482 } else {
483 let text = http_client::text(response).await?;
484 Err(EmbeddingError::ProviderError(text))
485 }
486 }
487}
488
489impl<T> EmbeddingModel<T> {
490 pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
491 let model = model.into();
492 let ndims = ndims
493 .or(model_dimensions_from_identifier(&model))
494 .unwrap_or_default();
495
496 Self {
497 client,
498 model,
499 ndims,
500 }
501 }
502
503 pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
504 let ndims = ndims.unwrap_or_default();
505
506 Self {
507 client,
508 model: model.into(),
509 ndims,
510 }
511 }
512}
513
514pub const O1: &str = "o1";
520pub const O1_PREVIEW: &str = "o1-preview";
522pub const O1_MINI: &str = "o1-mini";
524pub const GPT_4O: &str = "gpt-4o";
526pub const GPT_4O_MINI: &str = "gpt-4o-mini";
528pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
530pub const GPT_4_TURBO: &str = "gpt-4";
532pub const GPT_4: &str = "gpt-4";
534pub const GPT_4_32K: &str = "gpt-4-32k";
536pub const GPT_4_32K_0613: &str = "gpt-4-32k";
538pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
540pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
542pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
544
545#[derive(Debug, Serialize, Deserialize)]
546pub(super) struct AzureOpenAICompletionRequest {
547 model: String,
548 pub messages: Vec<openai::Message>,
549 #[serde(flatten, skip_serializing_if = "Option::is_none")]
550 temperature: Option<f64>,
551 #[serde(skip_serializing_if = "Vec::is_empty")]
552 tools: Vec<openai::ToolDefinition>,
553 #[serde(flatten, skip_serializing_if = "Option::is_none")]
554 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
555 #[serde(flatten, skip_serializing_if = "Option::is_none")]
556 pub additional_params: Option<serde_json::Value>,
557}
558
559impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest {
560 type Error = CompletionError;
561
562 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
563 if req.tool_choice.is_some() {
565 tracing::warn!(
566 "Tool choice is currently not supported in Azure OpenAI. This should be fixed by Rig 0.25."
567 );
568 }
569
570 let mut full_history: Vec<openai::Message> = match &req.preamble {
571 Some(preamble) => vec![openai::Message::system(preamble)],
572 None => vec![],
573 };
574
575 if let Some(docs) = req.normalized_documents() {
576 let docs: Vec<openai::Message> = docs.try_into()?;
577 full_history.extend(docs);
578 }
579
580 let chat_history: Vec<openai::Message> = req
581 .chat_history
582 .clone()
583 .into_iter()
584 .map(|message| message.try_into())
585 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
586 .into_iter()
587 .flatten()
588 .collect();
589
590 full_history.extend(chat_history);
591
592 let tool_choice = req
593 .tool_choice
594 .clone()
595 .map(crate::providers::openrouter::ToolChoice::try_from)
596 .transpose()?;
597
598 Ok(Self {
599 model: model.to_string(),
600 messages: full_history,
601 temperature: req.temperature,
602 tools: req
603 .tools
604 .clone()
605 .into_iter()
606 .map(openai::ToolDefinition::from)
607 .collect::<Vec<_>>(),
608 tool_choice,
609 additional_params: req.additional_params,
610 })
611 }
612}
613
614#[derive(Clone)]
615pub struct CompletionModel<T = reqwest::Client> {
616 client: Client<T>,
617 pub model: String,
619}
620
621impl<T> CompletionModel<T> {
622 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
623 Self {
624 client,
625 model: model.into(),
626 }
627 }
628}
629
630impl<T> completion::CompletionModel for CompletionModel<T>
631where
632 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
633{
634 type Response = openai::CompletionResponse;
635 type StreamingResponse = openai::StreamingCompletionResponse;
636 type Client = Client<T>;
637
638 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
639 Self::new(client.clone(), model.into())
640 }
641
642 #[cfg_attr(feature = "worker", worker::send)]
643 async fn completion(
644 &self,
645 completion_request: CompletionRequest,
646 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
647 let span = if tracing::Span::current().is_disabled() {
648 info_span!(
649 target: "rig::completions",
650 "chat",
651 gen_ai.operation.name = "chat",
652 gen_ai.provider.name = "azure.openai",
653 gen_ai.request.model = self.model,
654 gen_ai.system_instructions = &completion_request.preamble,
655 gen_ai.response.id = tracing::field::Empty,
656 gen_ai.response.model = tracing::field::Empty,
657 gen_ai.usage.output_tokens = tracing::field::Empty,
658 gen_ai.usage.input_tokens = tracing::field::Empty,
659 )
660 } else {
661 tracing::Span::current()
662 };
663
664 let request =
665 AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
666
667 if enabled!(Level::TRACE) {
668 tracing::trace!(target: "rig::completions",
669 "Azure OpenAI completion request: {}",
670 serde_json::to_string_pretty(&request)?
671 );
672 }
673
674 let body = serde_json::to_vec(&request)?;
675
676 let req = self
677 .client
678 .post_chat_completion(&self.model)?
679 .body(body)
680 .map_err(http_client::Error::from)?;
681
682 async move {
683 let response = self.client.send::<_, Bytes>(req).await.unwrap();
684
685 let status = response.status();
686 let response_body = response.into_body().into_future().await?.to_vec();
687
688 if status.is_success() {
689 match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
690 &response_body,
691 )? {
692 ApiResponse::Ok(response) => {
693 let span = tracing::Span::current();
694 span.record_response_metadata(&response);
695 span.record_token_usage(&response.usage);
696 if enabled!(Level::TRACE) {
697 tracing::trace!(target: "rig::completions",
698 "Azure OpenAI completion response: {}",
699 serde_json::to_string_pretty(&response)?
700 );
701 }
702 response.try_into()
703 }
704 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
705 }
706 } else {
707 Err(CompletionError::ProviderError(
708 String::from_utf8_lossy(&response_body).to_string(),
709 ))
710 }
711 }
712 .instrument(span)
713 .await
714 }
715
716 #[cfg_attr(feature = "worker", worker::send)]
717 async fn stream(
718 &self,
719 completion_request: CompletionRequest,
720 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
721 let preamble = completion_request.preamble.clone();
722 let mut request =
723 AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
724
725 let params = json_utils::merge(
726 request.additional_params.unwrap_or(serde_json::json!({})),
727 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
728 );
729
730 request.additional_params = Some(params);
731
732 if enabled!(Level::TRACE) {
733 tracing::trace!(target: "rig::completions",
734 "Azure OpenAI completion request: {}",
735 serde_json::to_string_pretty(&request)?
736 );
737 }
738
739 let body = serde_json::to_vec(&request)?;
740
741 let req = self
742 .client
743 .post_chat_completion(&self.model)?
744 .body(body)
745 .map_err(http_client::Error::from)?;
746
747 let span = if tracing::Span::current().is_disabled() {
748 info_span!(
749 target: "rig::completions",
750 "chat_streaming",
751 gen_ai.operation.name = "chat_streaming",
752 gen_ai.provider.name = "azure.openai",
753 gen_ai.request.model = self.model,
754 gen_ai.system_instructions = &preamble,
755 gen_ai.response.id = tracing::field::Empty,
756 gen_ai.response.model = tracing::field::Empty,
757 gen_ai.usage.output_tokens = tracing::field::Empty,
758 gen_ai.usage.input_tokens = tracing::field::Empty,
759 )
760 } else {
761 tracing::Span::current()
762 };
763
764 tracing_futures::Instrument::instrument(
765 send_compatible_streaming_request(self.client.clone(), req),
766 span,
767 )
768 .await
769 }
770}
771
772#[derive(Clone)]
777pub struct TranscriptionModel<T = reqwest::Client> {
778 client: Client<T>,
779 pub model: String,
781}
782
783impl<T> TranscriptionModel<T> {
784 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
785 Self {
786 client,
787 model: model.into(),
788 }
789 }
790}
791
792impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
793where
794 T: HttpClientExt + Clone + 'static,
795{
796 type Response = TranscriptionResponse;
797 type Client = Client<T>;
798
799 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
800 Self::new(client.clone(), model)
801 }
802
803 #[cfg_attr(feature = "worker", worker::send)]
804 async fn transcription(
805 &self,
806 request: transcription::TranscriptionRequest,
807 ) -> Result<
808 transcription::TranscriptionResponse<Self::Response>,
809 transcription::TranscriptionError,
810 > {
811 let data = request.data;
812
813 let mut body = reqwest::multipart::Form::new().part(
814 "file",
815 Part::bytes(data).file_name(request.filename.clone()),
816 );
817
818 if let Some(prompt) = request.prompt {
819 body = body.text("prompt", prompt.clone());
820 }
821
822 if let Some(ref temperature) = request.temperature {
823 body = body.text("temperature", temperature.to_string());
824 }
825
826 if let Some(ref additional_params) = request.additional_params {
827 for (key, value) in additional_params
828 .as_object()
829 .expect("Additional Parameters to OpenAI Transcription should be a map")
830 {
831 body = body.text(key.to_owned(), value.to_string());
832 }
833 }
834
835 let req = self
836 .client
837 .post_transcription(&self.model)?
838 .body(body)
839 .map_err(|e| TranscriptionError::HttpError(e.into()))?;
840
841 let response = self.client.send_multipart::<Bytes>(req).await?;
842 let status = response.status();
843 let response_body = response.into_body().into_future().await?.to_vec();
844
845 if status.is_success() {
846 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
847 ApiResponse::Ok(response) => response.try_into(),
848 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
849 api_error_response.message,
850 )),
851 }
852 } else {
853 Err(TranscriptionError::ProviderError(
854 String::from_utf8_lossy(&response_body).to_string(),
855 ))
856 }
857 }
858}
859
860#[cfg(feature = "image")]
864pub use image_generation::*;
865use tracing::{Instrument, Level, enabled, info_span};
866#[cfg(feature = "image")]
867#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
868mod image_generation {
869 use crate::http_client::HttpClientExt;
870 use crate::image_generation;
871 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
872 use crate::providers::azure::{ApiResponse, Client};
873 use crate::providers::openai::ImageGenerationResponse;
874 use bytes::Bytes;
875 use serde_json::json;
876
877 #[derive(Clone)]
878 pub struct ImageGenerationModel<T = reqwest::Client> {
879 client: Client<T>,
880 pub model: String,
881 }
882
883 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
884 where
885 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
886 {
887 type Response = ImageGenerationResponse;
888
889 type Client = Client<T>;
890
891 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
892 Self {
893 client: client.clone(),
894 model: model.into(),
895 }
896 }
897
898 #[cfg_attr(feature = "worker", worker::send)]
899 async fn image_generation(
900 &self,
901 generation_request: ImageGenerationRequest,
902 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
903 {
904 let request = json!({
905 "model": self.model,
906 "prompt": generation_request.prompt,
907 "size": format!("{}x{}", generation_request.width, generation_request.height),
908 "response_format": "b64_json"
909 });
910
911 let body = serde_json::to_vec(&request)?;
912
913 let req = self
914 .client
915 .post_image_generation(&self.model)?
916 .body(body)
917 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
918
919 let response = self.client.send::<_, Bytes>(req).await?;
920 let status = response.status();
921 let response_body = response.into_body().into_future().await?.to_vec();
922
923 if !status.is_success() {
924 return Err(ImageGenerationError::ProviderError(format!(
925 "{status}: {}",
926 String::from_utf8_lossy(&response_body)
927 )));
928 }
929
930 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
931 ApiResponse::Ok(response) => response.try_into(),
932 ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
933 }
934 }
935 }
936}
937#[cfg(feature = "audio")]
942pub use audio_generation::*;
943
944#[cfg(feature = "audio")]
945#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
946mod audio_generation {
947 use super::Client;
948 use crate::audio_generation::{
949 self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
950 };
951 use crate::http_client::HttpClientExt;
952 use bytes::Bytes;
953 use serde_json::json;
954
955 #[derive(Clone)]
956 pub struct AudioGenerationModel<T = reqwest::Client> {
957 client: Client<T>,
958 model: String,
959 }
960
961 impl<T> AudioGenerationModel<T> {
962 pub fn new(client: Client<T>, deployment_name: impl Into<String>) -> Self {
963 Self {
964 client,
965 model: deployment_name.into(),
966 }
967 }
968 }
969
970 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
971 where
972 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
973 {
974 type Response = Bytes;
975 type Client = Client<T>;
976
977 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
978 Self::new(client.clone(), model)
979 }
980
981 async fn audio_generation(
982 &self,
983 request: AudioGenerationRequest,
984 ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
985 let request = json!({
986 "model": self.model,
987 "input": request.text,
988 "voice": request.voice,
989 "speed": request.speed,
990 });
991
992 let body = serde_json::to_vec(&request)?;
993
994 let req = self
995 .client
996 .post_audio_generation("/audio/speech")?
997 .header("Content-Type", "application/json")
998 .body(body)
999 .map_err(|e| AudioGenerationError::HttpError(e.into()))?;
1000
1001 let response = self.client.send::<_, Bytes>(req).await?;
1002 let status = response.status();
1003 let response_body = response.into_body().into_future().await?;
1004
1005 if !status.is_success() {
1006 return Err(AudioGenerationError::ProviderError(format!(
1007 "{status}: {}",
1008 String::from_utf8_lossy(&response_body)
1009 )));
1010 }
1011
1012 Ok(AudioGenerationResponse {
1013 audio: response_body.to_vec(),
1014 response: response_body,
1015 })
1016 }
1017 }
1018}
1019
1020#[cfg(test)]
1021mod azure_tests {
1022 use super::*;
1023
1024 use crate::OneOrMany;
1025 use crate::client::{completion::CompletionClient, embeddings::EmbeddingsClient};
1026 use crate::completion::CompletionModel;
1027 use crate::embeddings::EmbeddingModel;
1028
1029 #[tokio::test]
1030 #[ignore]
1031 async fn test_azure_embedding() {
1032 let _ = tracing_subscriber::fmt::try_init();
1033
1034 let client = Client::<reqwest::Client>::from_env();
1035 let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1036 let embeddings = model
1037 .embed_texts(vec!["Hello, world!".to_string()])
1038 .await
1039 .unwrap();
1040
1041 tracing::info!("Azure embedding: {:?}", embeddings);
1042 }
1043
1044 #[tokio::test]
1045 #[ignore]
1046 async fn test_azure_completion() {
1047 let _ = tracing_subscriber::fmt::try_init();
1048
1049 let client = Client::<reqwest::Client>::from_env();
1050 let model = client.completion_model(GPT_4O_MINI);
1051 let completion = model
1052 .completion(CompletionRequest {
1053 preamble: Some("You are a helpful assistant.".to_string()),
1054 chat_history: OneOrMany::one("Hello!".into()),
1055 documents: vec![],
1056 max_tokens: Some(100),
1057 temperature: Some(0.0),
1058 tools: vec![],
1059 tool_choice: None,
1060 additional_params: None,
1061 })
1062 .await
1063 .unwrap();
1064
1065 tracing::info!("Azure completion: {:?}", completion);
1066 }
1067}