1use std::fmt::Debug;
13
14use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
15#[cfg(feature = "image")]
16use crate::client::Nothing;
17use crate::client::{
18 self, ApiKey, Capabilities, Capable, DebugExt, Provider, ProviderBuilder, ProviderClient,
19};
20use crate::completion::GetTokenUsage;
21use crate::http_client::multipart::Part;
22use crate::http_client::{self, HttpClientExt, MultipartForm, bearer_auth_header};
23use crate::streaming::StreamingCompletionResponse;
24use crate::transcription::TranscriptionError;
25use crate::{
26 completion::{self, CompletionError, CompletionRequest},
27 embeddings::{self, EmbeddingError},
28 json_utils,
29 providers::openai,
30 telemetry::SpanCombinator,
31 transcription::{self},
32};
33use bytes::Bytes;
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36const DEFAULT_API_VERSION: &str = "2024-10-21";
41
42#[derive(Debug, Clone)]
43pub struct AzureExt {
44 endpoint: String,
45 api_version: String,
46}
47
48impl DebugExt for AzureExt {
49 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
50 [
51 ("endpoint", (&self.endpoint as &dyn Debug)),
52 ("api_version", (&self.api_version as &dyn Debug)),
53 ]
54 .into_iter()
55 }
56}
57
58#[derive(Debug, Clone)]
63pub struct AzureExtBuilder {
64 endpoint: Option<String>,
65 api_version: String,
66}
67
68impl Default for AzureExtBuilder {
69 fn default() -> Self {
70 Self {
71 endpoint: None,
72 api_version: DEFAULT_API_VERSION.into(),
73 }
74 }
75}
76
77pub type Client<H = reqwest::Client> = client::Client<AzureExt, H>;
78pub type ClientBuilder<H = reqwest::Client> =
79 client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H>;
80
81impl Provider for AzureExt {
82 type Builder = AzureExtBuilder;
83
84 const VERIFY_PATH: &'static str = "";
86
87 fn build<H>(
88 builder: &client::ClientBuilder<
89 Self::Builder,
90 <Self::Builder as ProviderBuilder>::ApiKey,
91 H,
92 >,
93 ) -> http_client::Result<Self> {
94 let AzureExtBuilder {
95 endpoint,
96 api_version,
97 ..
98 } = builder.ext().clone();
99
100 match endpoint {
101 Some(endpoint) => Ok(Self {
102 endpoint,
103 api_version,
104 }),
105 None => Err(http_client::Error::Instance(
106 "Azure client must be provided an endpoint prior to building".into(),
107 )),
108 }
109 }
110}
111
112impl<H> Capabilities<H> for AzureExt {
113 type Completion = Capable<CompletionModel<H>>;
114 type Embeddings = Capable<EmbeddingModel<H>>;
115 type Transcription = Capable<TranscriptionModel<H>>;
116 #[cfg(feature = "image")]
117 type ImageGeneration = Nothing;
118 #[cfg(feature = "audio")]
119 type AudioGeneration = Capable<AudioGenerationModel<H>>;
120}
121
122impl ProviderBuilder for AzureExtBuilder {
123 type Output = AzureExt;
124 type ApiKey = AzureOpenAIAuth;
125
126 const BASE_URL: &'static str = "";
127
128 fn finish<H>(
129 &self,
130 mut builder: client::ClientBuilder<Self, Self::ApiKey, H>,
131 ) -> http_client::Result<client::ClientBuilder<Self, Self::ApiKey, H>> {
132 use AzureOpenAIAuth::*;
133
134 let auth = builder.get_api_key().clone();
135
136 match auth {
137 Token(token) => bearer_auth_header(builder.headers_mut(), token.as_str())?,
138 ApiKey(key) => {
139 let k = http::HeaderName::from_static("api-key");
140 let v = http::HeaderValue::from_str(key.as_str())?;
141
142 builder.headers_mut().insert(k, v);
143 }
144 }
145
146 Ok(builder)
147 }
148}
149
150impl<H> ClientBuilder<H> {
151 pub fn api_version(mut self, api_version: &str) -> Self {
153 self.ext_mut().api_version = api_version.into();
154
155 self
156 }
157}
158
159impl<H> client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H> {
160 pub fn azure_endpoint(self, endpoint: String) -> ClientBuilder<H> {
162 self.over_ext(|AzureExtBuilder { api_version, .. }| AzureExtBuilder {
163 endpoint: Some(endpoint),
164 api_version,
165 })
166 }
167}
168
169#[derive(Clone)]
170pub enum AzureOpenAIAuth {
171 ApiKey(String),
172 Token(String),
173}
174
175impl ApiKey for AzureOpenAIAuth {}
176
177impl std::fmt::Debug for AzureOpenAIAuth {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 match self {
180 Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
181 Self::Token(_) => write!(f, "Token <REDACTED>"),
182 }
183 }
184}
185
186impl<S> From<S> for AzureOpenAIAuth
187where
188 S: Into<String>,
189{
190 fn from(token: S) -> Self {
191 AzureOpenAIAuth::Token(token.into())
192 }
193}
194
195impl<T> Client<T>
196where
197 T: HttpClientExt,
198{
199 fn endpoint(&self) -> &str {
200 &self.ext().endpoint
201 }
202
203 fn api_version(&self) -> &str {
204 &self.ext().api_version
205 }
206
207 fn post_embedding(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
208 let url = format!(
209 "{}/openai/deployments/{}/embeddings?api-version={}",
210 self.endpoint(),
211 deployment_id.trim_start_matches('/'),
212 self.api_version()
213 );
214
215 self.post(&url)
216 }
217
218 #[cfg(feature = "audio")]
219 fn post_audio_generation(
220 &self,
221 deployment_id: &str,
222 ) -> http_client::Result<http_client::Builder> {
223 let url = format!(
224 "{}/openai/deployments/{}/audio/speech?api-version={}",
225 self.endpoint(),
226 deployment_id.trim_start_matches('/'),
227 self.api_version()
228 );
229
230 self.post(url)
231 }
232
233 fn post_chat_completion(
234 &self,
235 deployment_id: &str,
236 ) -> http_client::Result<http_client::Builder> {
237 let url = format!(
238 "{}/openai/deployments/{}/chat/completions?api-version={}",
239 self.endpoint(),
240 deployment_id.trim_start_matches('/'),
241 self.api_version()
242 );
243
244 self.post(&url)
245 }
246
247 fn post_transcription(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
248 let url = format!(
249 "{}/openai/deployments/{}/audio/translations?api-version={}",
250 self.endpoint(),
251 deployment_id.trim_start_matches('/'),
252 self.api_version()
253 );
254
255 self.post(&url)
256 }
257
258 #[cfg(feature = "image")]
259 fn post_image_generation(
260 &self,
261 deployment_id: &str,
262 ) -> http_client::Result<http_client::Builder> {
263 let url = format!(
264 "{}/openai/deployments/{}/images/generations?api-version={}",
265 self.endpoint(),
266 deployment_id.trim_start_matches('/'),
267 self.api_version()
268 );
269
270 self.post(&url)
271 }
272}
273
274pub struct AzureOpenAIClientParams {
275 api_key: String,
276 version: String,
277 header: String,
278}
279
280impl ProviderClient for Client {
281 type Input = AzureOpenAIClientParams;
282
283 fn from_env() -> Self {
285 let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
286 AzureOpenAIAuth::ApiKey(api_key)
287 } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
288 AzureOpenAIAuth::Token(token)
289 } else {
290 panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
291 };
292
293 let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
294 let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
295
296 Self::builder()
297 .api_key(auth)
298 .azure_endpoint(azure_endpoint)
299 .api_version(&api_version)
300 .build()
301 .unwrap()
302 }
303
304 fn from_val(
305 AzureOpenAIClientParams {
306 api_key,
307 version,
308 header,
309 }: Self::Input,
310 ) -> Self {
311 let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
312
313 Self::builder()
314 .api_key(auth)
315 .azure_endpoint(header)
316 .api_version(&version)
317 .build()
318 .unwrap()
319 }
320}
321
322#[derive(Debug, Deserialize)]
323struct ApiErrorResponse {
324 message: String,
325}
326
327#[derive(Debug, Deserialize)]
328#[serde(untagged)]
329enum ApiResponse<T> {
330 Ok(T),
331 Err(ApiErrorResponse),
332}
333
334pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
340pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
342pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
344
345fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
346 match identifier {
347 TEXT_EMBEDDING_3_LARGE => Some(3_072),
348 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
349 _ => None,
350 }
351}
352
353#[derive(Debug, Deserialize)]
354pub struct EmbeddingResponse {
355 pub object: String,
356 pub data: Vec<EmbeddingData>,
357 pub model: String,
358 pub usage: Usage,
359}
360
361impl From<ApiErrorResponse> for EmbeddingError {
362 fn from(err: ApiErrorResponse) -> Self {
363 EmbeddingError::ProviderError(err.message)
364 }
365}
366
367impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
368 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
369 match value {
370 ApiResponse::Ok(response) => Ok(response),
371 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
372 }
373 }
374}
375
376#[derive(Debug, Deserialize)]
377pub struct EmbeddingData {
378 pub object: String,
379 pub embedding: Vec<f64>,
380 pub index: usize,
381}
382
383#[derive(Clone, Debug, Deserialize)]
384pub struct Usage {
385 pub prompt_tokens: usize,
386 pub total_tokens: usize,
387}
388
389impl GetTokenUsage for Usage {
390 fn token_usage(&self) -> Option<crate::completion::Usage> {
391 let mut usage = crate::completion::Usage::new();
392
393 usage.input_tokens = self.prompt_tokens as u64;
394 usage.total_tokens = self.total_tokens as u64;
395 usage.output_tokens = usage.total_tokens - usage.input_tokens;
396
397 Some(usage)
398 }
399}
400
401impl std::fmt::Display for Usage {
402 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403 write!(
404 f,
405 "Prompt tokens: {} Total tokens: {}",
406 self.prompt_tokens, self.total_tokens
407 )
408 }
409}
410
411#[derive(Clone)]
412pub struct EmbeddingModel<T = reqwest::Client> {
413 client: Client<T>,
414 pub model: String,
415 ndims: usize,
416}
417
418impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
419where
420 T: HttpClientExt + Default + Clone + 'static,
421{
422 const MAX_DOCUMENTS: usize = 1024;
423
424 type Client = Client<T>;
425
426 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
427 Self::new(client.clone(), model, dims)
428 }
429
430 fn ndims(&self) -> usize {
431 self.ndims
432 }
433
434 async fn embed_texts(
435 &self,
436 documents: impl IntoIterator<Item = String>,
437 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
438 let documents = documents.into_iter().collect::<Vec<_>>();
439
440 let body = serde_json::to_vec(&json!({
441 "input": documents,
442 }))?;
443
444 let req = self
445 .client
446 .post_embedding(self.model.as_str())?
447 .body(body)
448 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
449
450 let response = self.client.send(req).await?;
451
452 if response.status().is_success() {
453 let body: Vec<u8> = response.into_body().await?;
454 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
455
456 match body {
457 ApiResponse::Ok(response) => {
458 tracing::info!(target: "rig",
459 "Azure embedding token usage: {}",
460 response.usage
461 );
462
463 if response.data.len() != documents.len() {
464 return Err(EmbeddingError::ResponseError(
465 "Response data length does not match input length".into(),
466 ));
467 }
468
469 Ok(response
470 .data
471 .into_iter()
472 .zip(documents.into_iter())
473 .map(|(embedding, document)| embeddings::Embedding {
474 document,
475 vec: embedding.embedding,
476 })
477 .collect())
478 }
479 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
480 }
481 } else {
482 let text = http_client::text(response).await?;
483 Err(EmbeddingError::ProviderError(text))
484 }
485 }
486}
487
488impl<T> EmbeddingModel<T> {
489 pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
490 let model = model.into();
491 let ndims = ndims
492 .or(model_dimensions_from_identifier(&model))
493 .unwrap_or_default();
494
495 Self {
496 client,
497 model,
498 ndims,
499 }
500 }
501
502 pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
503 let ndims = ndims.unwrap_or_default();
504
505 Self {
506 client,
507 model: model.into(),
508 ndims,
509 }
510 }
511}
512
513pub const O1: &str = "o1";
519pub const O1_PREVIEW: &str = "o1-preview";
521pub const O1_MINI: &str = "o1-mini";
523pub const GPT_4O: &str = "gpt-4o";
525pub const GPT_4O_MINI: &str = "gpt-4o-mini";
527pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
529pub const GPT_4_TURBO: &str = "gpt-4";
531pub const GPT_4: &str = "gpt-4";
533pub const GPT_4_32K: &str = "gpt-4-32k";
535pub const GPT_4_32K_0613: &str = "gpt-4-32k";
537pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
539pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
541pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
543
544#[derive(Debug, Serialize, Deserialize)]
545pub(super) struct AzureOpenAICompletionRequest {
546 model: String,
547 pub messages: Vec<openai::Message>,
548 #[serde(skip_serializing_if = "Option::is_none")]
549 temperature: Option<f64>,
550 #[serde(skip_serializing_if = "Vec::is_empty")]
551 tools: Vec<openai::ToolDefinition>,
552 #[serde(skip_serializing_if = "Option::is_none")]
553 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
554 #[serde(flatten, skip_serializing_if = "Option::is_none")]
555 pub additional_params: Option<serde_json::Value>,
556}
557
558impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest {
559 type Error = CompletionError;
560
561 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
562 if req.tool_choice.is_some() {
564 tracing::warn!(
565 "Tool choice is currently not supported in Azure OpenAI. This should be fixed by Rig 0.25."
566 );
567 }
568
569 let mut full_history: Vec<openai::Message> = match &req.preamble {
570 Some(preamble) => vec![openai::Message::system(preamble)],
571 None => vec![],
572 };
573
574 if let Some(docs) = req.normalized_documents() {
575 let docs: Vec<openai::Message> = docs.try_into()?;
576 full_history.extend(docs);
577 }
578
579 let chat_history: Vec<openai::Message> = req
580 .chat_history
581 .clone()
582 .into_iter()
583 .map(|message| message.try_into())
584 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
585 .into_iter()
586 .flatten()
587 .collect();
588
589 full_history.extend(chat_history);
590
591 let tool_choice = req
592 .tool_choice
593 .clone()
594 .map(crate::providers::openrouter::ToolChoice::try_from)
595 .transpose()?;
596
597 Ok(Self {
598 model: model.to_string(),
599 messages: full_history,
600 temperature: req.temperature,
601 tools: req
602 .tools
603 .clone()
604 .into_iter()
605 .map(openai::ToolDefinition::from)
606 .collect::<Vec<_>>(),
607 tool_choice,
608 additional_params: req.additional_params,
609 })
610 }
611}
612
613#[derive(Clone)]
614pub struct CompletionModel<T = reqwest::Client> {
615 client: Client<T>,
616 pub model: String,
618}
619
620impl<T> CompletionModel<T> {
621 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
622 Self {
623 client,
624 model: model.into(),
625 }
626 }
627}
628
629impl<T> completion::CompletionModel for CompletionModel<T>
630where
631 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
632{
633 type Response = openai::CompletionResponse;
634 type StreamingResponse = openai::StreamingCompletionResponse;
635 type Client = Client<T>;
636
637 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
638 Self::new(client.clone(), model.into())
639 }
640
641 async fn completion(
642 &self,
643 completion_request: CompletionRequest,
644 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
645 let span = if tracing::Span::current().is_disabled() {
646 info_span!(
647 target: "rig::completions",
648 "chat",
649 gen_ai.operation.name = "chat",
650 gen_ai.provider.name = "azure.openai",
651 gen_ai.request.model = self.model,
652 gen_ai.system_instructions = &completion_request.preamble,
653 gen_ai.response.id = tracing::field::Empty,
654 gen_ai.response.model = tracing::field::Empty,
655 gen_ai.usage.output_tokens = tracing::field::Empty,
656 gen_ai.usage.input_tokens = tracing::field::Empty,
657 )
658 } else {
659 tracing::Span::current()
660 };
661
662 let request =
663 AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
664
665 if enabled!(Level::TRACE) {
666 tracing::trace!(target: "rig::completions",
667 "Azure OpenAI completion request: {}",
668 serde_json::to_string_pretty(&request)?
669 );
670 }
671
672 let body = serde_json::to_vec(&request)?;
673
674 let req = self
675 .client
676 .post_chat_completion(&self.model)?
677 .body(body)
678 .map_err(http_client::Error::from)?;
679
680 async move {
681 let response = self.client.send::<_, Bytes>(req).await.unwrap();
682
683 let status = response.status();
684 let response_body = response.into_body().into_future().await?.to_vec();
685
686 if status.is_success() {
687 match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
688 &response_body,
689 )? {
690 ApiResponse::Ok(response) => {
691 let span = tracing::Span::current();
692 span.record_response_metadata(&response);
693 span.record_token_usage(&response.usage);
694 if enabled!(Level::TRACE) {
695 tracing::trace!(target: "rig::completions",
696 "Azure OpenAI completion response: {}",
697 serde_json::to_string_pretty(&response)?
698 );
699 }
700 response.try_into()
701 }
702 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
703 }
704 } else {
705 Err(CompletionError::ProviderError(
706 String::from_utf8_lossy(&response_body).to_string(),
707 ))
708 }
709 }
710 .instrument(span)
711 .await
712 }
713
714 async fn stream(
715 &self,
716 completion_request: CompletionRequest,
717 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
718 let preamble = completion_request.preamble.clone();
719 let mut request =
720 AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
721
722 let params = json_utils::merge(
723 request.additional_params.unwrap_or(serde_json::json!({})),
724 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
725 );
726
727 request.additional_params = Some(params);
728
729 if enabled!(Level::TRACE) {
730 tracing::trace!(target: "rig::completions",
731 "Azure OpenAI completion request: {}",
732 serde_json::to_string_pretty(&request)?
733 );
734 }
735
736 let body = serde_json::to_vec(&request)?;
737
738 let req = self
739 .client
740 .post_chat_completion(&self.model)?
741 .body(body)
742 .map_err(http_client::Error::from)?;
743
744 let span = if tracing::Span::current().is_disabled() {
745 info_span!(
746 target: "rig::completions",
747 "chat_streaming",
748 gen_ai.operation.name = "chat_streaming",
749 gen_ai.provider.name = "azure.openai",
750 gen_ai.request.model = self.model,
751 gen_ai.system_instructions = &preamble,
752 gen_ai.response.id = tracing::field::Empty,
753 gen_ai.response.model = tracing::field::Empty,
754 gen_ai.usage.output_tokens = tracing::field::Empty,
755 gen_ai.usage.input_tokens = tracing::field::Empty,
756 )
757 } else {
758 tracing::Span::current()
759 };
760
761 tracing_futures::Instrument::instrument(
762 send_compatible_streaming_request(self.client.clone(), req),
763 span,
764 )
765 .await
766 }
767}
768
769#[derive(Clone)]
774pub struct TranscriptionModel<T = reqwest::Client> {
775 client: Client<T>,
776 pub model: String,
778}
779
780impl<T> TranscriptionModel<T> {
781 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
782 Self {
783 client,
784 model: model.into(),
785 }
786 }
787}
788
789impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
790where
791 T: HttpClientExt + Clone + 'static,
792{
793 type Response = TranscriptionResponse;
794 type Client = Client<T>;
795
796 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
797 Self::new(client.clone(), model)
798 }
799
800 async fn transcription(
801 &self,
802 request: transcription::TranscriptionRequest,
803 ) -> Result<
804 transcription::TranscriptionResponse<Self::Response>,
805 transcription::TranscriptionError,
806 > {
807 let data = request.data;
808
809 let mut body =
810 MultipartForm::new().part(Part::bytes("file", data).filename(request.filename.clone()));
811
812 if let Some(prompt) = request.prompt {
813 body = body.text("prompt", prompt.clone());
814 }
815
816 if let Some(ref temperature) = request.temperature {
817 body = body.text("temperature", temperature.to_string());
818 }
819
820 if let Some(ref additional_params) = request.additional_params {
821 for (key, value) in additional_params
822 .as_object()
823 .expect("Additional Parameters to OpenAI Transcription should be a map")
824 {
825 body = body.text(key.to_owned(), value.to_string());
826 }
827 }
828
829 let req = self
830 .client
831 .post_transcription(&self.model)?
832 .body(body)
833 .map_err(|e| TranscriptionError::HttpError(e.into()))?;
834
835 let response = self.client.send_multipart::<Bytes>(req).await?;
836 let status = response.status();
837 let response_body = response.into_body().into_future().await?.to_vec();
838
839 if status.is_success() {
840 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
841 ApiResponse::Ok(response) => response.try_into(),
842 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
843 api_error_response.message,
844 )),
845 }
846 } else {
847 Err(TranscriptionError::ProviderError(
848 String::from_utf8_lossy(&response_body).to_string(),
849 ))
850 }
851 }
852}
853
854#[cfg(feature = "image")]
858pub use image_generation::*;
859use tracing::{Instrument, Level, enabled, info_span};
860#[cfg(feature = "image")]
861#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
862mod image_generation {
863 use crate::http_client::HttpClientExt;
864 use crate::image_generation;
865 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
866 use crate::providers::azure::{ApiResponse, Client};
867 use crate::providers::openai::ImageGenerationResponse;
868 use bytes::Bytes;
869 use serde_json::json;
870
871 #[derive(Clone)]
872 pub struct ImageGenerationModel<T = reqwest::Client> {
873 client: Client<T>,
874 pub model: String,
875 }
876
877 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
878 where
879 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
880 {
881 type Response = ImageGenerationResponse;
882
883 type Client = Client<T>;
884
885 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
886 Self {
887 client: client.clone(),
888 model: model.into(),
889 }
890 }
891
892 async fn image_generation(
893 &self,
894 generation_request: ImageGenerationRequest,
895 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
896 {
897 let request = json!({
898 "model": self.model,
899 "prompt": generation_request.prompt,
900 "size": format!("{}x{}", generation_request.width, generation_request.height),
901 "response_format": "b64_json"
902 });
903
904 let body = serde_json::to_vec(&request)?;
905
906 let req = self
907 .client
908 .post_image_generation(&self.model)?
909 .body(body)
910 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
911
912 let response = self.client.send::<_, Bytes>(req).await?;
913 let status = response.status();
914 let response_body = response.into_body().into_future().await?.to_vec();
915
916 if !status.is_success() {
917 return Err(ImageGenerationError::ProviderError(format!(
918 "{status}: {}",
919 String::from_utf8_lossy(&response_body)
920 )));
921 }
922
923 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
924 ApiResponse::Ok(response) => response.try_into(),
925 ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
926 }
927 }
928 }
929}
930#[cfg(feature = "audio")]
935pub use audio_generation::*;
936
937#[cfg(feature = "audio")]
938#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
939mod audio_generation {
940 use super::Client;
941 use crate::audio_generation::{
942 self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
943 };
944 use crate::http_client::HttpClientExt;
945 use bytes::Bytes;
946 use serde_json::json;
947
948 #[derive(Clone)]
949 pub struct AudioGenerationModel<T = reqwest::Client> {
950 client: Client<T>,
951 model: String,
952 }
953
954 impl<T> AudioGenerationModel<T> {
955 pub fn new(client: Client<T>, deployment_name: impl Into<String>) -> Self {
956 Self {
957 client,
958 model: deployment_name.into(),
959 }
960 }
961 }
962
963 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
964 where
965 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
966 {
967 type Response = Bytes;
968 type Client = Client<T>;
969
970 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
971 Self::new(client.clone(), model)
972 }
973
974 async fn audio_generation(
975 &self,
976 request: AudioGenerationRequest,
977 ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
978 let request = json!({
979 "model": self.model,
980 "input": request.text,
981 "voice": request.voice,
982 "speed": request.speed,
983 });
984
985 let body = serde_json::to_vec(&request)?;
986
987 let req = self
988 .client
989 .post_audio_generation("/audio/speech")?
990 .header("Content-Type", "application/json")
991 .body(body)
992 .map_err(|e| AudioGenerationError::HttpError(e.into()))?;
993
994 let response = self.client.send::<_, Bytes>(req).await?;
995 let status = response.status();
996 let response_body = response.into_body().into_future().await?;
997
998 if !status.is_success() {
999 return Err(AudioGenerationError::ProviderError(format!(
1000 "{status}: {}",
1001 String::from_utf8_lossy(&response_body)
1002 )));
1003 }
1004
1005 Ok(AudioGenerationResponse {
1006 audio: response_body.to_vec(),
1007 response: response_body,
1008 })
1009 }
1010 }
1011}
1012
1013#[cfg(test)]
1014mod azure_tests {
1015 use super::*;
1016
1017 use crate::OneOrMany;
1018 use crate::client::{completion::CompletionClient, embeddings::EmbeddingsClient};
1019 use crate::completion::CompletionModel;
1020 use crate::embeddings::EmbeddingModel;
1021
1022 #[tokio::test]
1023 #[ignore]
1024 async fn test_azure_embedding() {
1025 let _ = tracing_subscriber::fmt::try_init();
1026
1027 let client = Client::<reqwest::Client>::from_env();
1028 let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1029 let embeddings = model
1030 .embed_texts(vec!["Hello, world!".to_string()])
1031 .await
1032 .unwrap();
1033
1034 tracing::info!("Azure embedding: {:?}", embeddings);
1035 }
1036
1037 #[tokio::test]
1038 #[ignore]
1039 async fn test_azure_completion() {
1040 let _ = tracing_subscriber::fmt::try_init();
1041
1042 let client = Client::<reqwest::Client>::from_env();
1043 let model = client.completion_model(GPT_4O_MINI);
1044 let completion = model
1045 .completion(CompletionRequest {
1046 preamble: Some("You are a helpful assistant.".to_string()),
1047 chat_history: OneOrMany::one("Hello!".into()),
1048 documents: vec![],
1049 max_tokens: Some(100),
1050 temperature: Some(0.0),
1051 tools: vec![],
1052 tool_choice: None,
1053 additional_params: None,
1054 })
1055 .await
1056 .unwrap();
1057
1058 tracing::info!("Azure completion: {:?}", completion);
1059 }
1060}