1use std::fmt::Debug;
23
24use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
25use crate::client::{
26 self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
27 ProviderClient,
28};
29use crate::completion::GetTokenUsage;
30use crate::http_client::multipart::Part;
31use crate::http_client::{self, HttpClientExt, MultipartForm, bearer_auth_header};
32use crate::streaming::StreamingCompletionResponse;
33use crate::transcription::TranscriptionError;
34use crate::{
35 completion::{self, CompletionError, CompletionRequest},
36 embeddings::{self, EmbeddingError},
37 json_utils,
38 providers::openai,
39 telemetry::SpanCombinator,
40 transcription::{self},
41};
42use bytes::Bytes;
43use serde::{Deserialize, Serialize};
44use serde_json::json;
45const DEFAULT_API_VERSION: &str = "2024-10-21";
50
51#[derive(Debug, Clone)]
52pub struct AzureExt {
53 endpoint: String,
54 api_version: String,
55}
56
57impl DebugExt for AzureExt {
58 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
59 [
60 ("endpoint", (&self.endpoint as &dyn Debug)),
61 ("api_version", (&self.api_version as &dyn Debug)),
62 ]
63 .into_iter()
64 }
65}
66
67#[derive(Debug, Clone)]
72pub struct AzureExtBuilder {
73 endpoint: Option<String>,
74 api_version: String,
75}
76
77impl Default for AzureExtBuilder {
78 fn default() -> Self {
79 Self {
80 endpoint: None,
81 api_version: DEFAULT_API_VERSION.into(),
82 }
83 }
84}
85
86pub type Client<H = reqwest::Client> = client::Client<AzureExt, H>;
87pub type ClientBuilder<H = reqwest::Client> =
88 client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H>;
89
90impl Provider for AzureExt {
91 type Builder = AzureExtBuilder;
92
93 const VERIFY_PATH: &'static str = "";
95
96 fn build<H>(
97 builder: &client::ClientBuilder<
98 Self::Builder,
99 <Self::Builder as ProviderBuilder>::ApiKey,
100 H,
101 >,
102 ) -> http_client::Result<Self> {
103 let AzureExtBuilder {
104 endpoint,
105 api_version,
106 ..
107 } = builder.ext().clone();
108
109 match endpoint {
110 Some(endpoint) => Ok(Self {
111 endpoint,
112 api_version,
113 }),
114 None => Err(http_client::Error::Instance(
115 "Azure client must be provided an endpoint prior to building".into(),
116 )),
117 }
118 }
119}
120
121impl<H> Capabilities<H> for AzureExt {
122 type Completion = Capable<CompletionModel<H>>;
123 type Embeddings = Capable<EmbeddingModel<H>>;
124 type Transcription = Capable<TranscriptionModel<H>>;
125 type ModelListing = Nothing;
126 #[cfg(feature = "image")]
127 type ImageGeneration = Nothing;
128 #[cfg(feature = "audio")]
129 type AudioGeneration = Capable<AudioGenerationModel<H>>;
130}
131
132impl ProviderBuilder for AzureExtBuilder {
133 type Output = AzureExt;
134 type ApiKey = AzureOpenAIAuth;
135
136 const BASE_URL: &'static str = "";
137
138 fn finish<H>(
139 &self,
140 mut builder: client::ClientBuilder<Self, Self::ApiKey, H>,
141 ) -> http_client::Result<client::ClientBuilder<Self, Self::ApiKey, H>> {
142 use AzureOpenAIAuth::*;
143
144 let auth = builder.get_api_key().clone();
145
146 match auth {
147 Token(token) => bearer_auth_header(builder.headers_mut(), token.as_str())?,
148 ApiKey(key) => {
149 let k = http::HeaderName::from_static("api-key");
150 let v = http::HeaderValue::from_str(key.as_str())?;
151
152 builder.headers_mut().insert(k, v);
153 }
154 }
155
156 Ok(builder)
157 }
158}
159
160impl<H> ClientBuilder<H> {
161 pub fn api_version(mut self, api_version: &str) -> Self {
163 self.ext_mut().api_version = api_version.into();
164
165 self
166 }
167}
168
169impl<H> client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H> {
170 pub fn azure_endpoint(self, endpoint: String) -> ClientBuilder<H> {
172 self.over_ext(|AzureExtBuilder { api_version, .. }| AzureExtBuilder {
173 endpoint: Some(endpoint),
174 api_version,
175 })
176 }
177}
178
179#[derive(Clone)]
182pub enum AzureOpenAIAuth {
183 ApiKey(String),
184 Token(String),
185}
186
187impl ApiKey for AzureOpenAIAuth {}
188
189impl std::fmt::Debug for AzureOpenAIAuth {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 match self {
192 Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
193 Self::Token(_) => write!(f, "Token <REDACTED>"),
194 }
195 }
196}
197
198impl<S> From<S> for AzureOpenAIAuth
199where
200 S: Into<String>,
201{
202 fn from(token: S) -> Self {
203 AzureOpenAIAuth::Token(token.into())
204 }
205}
206
207impl<T> Client<T>
208where
209 T: HttpClientExt,
210{
211 fn endpoint(&self) -> &str {
212 &self.ext().endpoint
213 }
214
215 fn api_version(&self) -> &str {
216 &self.ext().api_version
217 }
218
219 fn post_embedding(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
220 let url = format!(
221 "{}/openai/deployments/{}/embeddings?api-version={}",
222 self.endpoint(),
223 deployment_id.trim_start_matches('/'),
224 self.api_version()
225 );
226
227 self.post(&url)
228 }
229
230 #[cfg(feature = "audio")]
231 fn post_audio_generation(
232 &self,
233 deployment_id: &str,
234 ) -> http_client::Result<http_client::Builder> {
235 let url = format!(
236 "{}/openai/deployments/{}/audio/speech?api-version={}",
237 self.endpoint(),
238 deployment_id.trim_start_matches('/'),
239 self.api_version()
240 );
241
242 self.post(url)
243 }
244
245 fn post_chat_completion(
246 &self,
247 deployment_id: &str,
248 ) -> http_client::Result<http_client::Builder> {
249 let url = format!(
250 "{}/openai/deployments/{}/chat/completions?api-version={}",
251 self.endpoint(),
252 deployment_id.trim_start_matches('/'),
253 self.api_version()
254 );
255
256 self.post(&url)
257 }
258
259 fn post_transcription(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
260 let url = format!(
261 "{}/openai/deployments/{}/audio/translations?api-version={}",
262 self.endpoint(),
263 deployment_id.trim_start_matches('/'),
264 self.api_version()
265 );
266
267 self.post(&url)
268 }
269
270 #[cfg(feature = "image")]
271 fn post_image_generation(
272 &self,
273 deployment_id: &str,
274 ) -> http_client::Result<http_client::Builder> {
275 let url = format!(
276 "{}/openai/deployments/{}/images/generations?api-version={}",
277 self.endpoint(),
278 deployment_id.trim_start_matches('/'),
279 self.api_version()
280 );
281
282 self.post(&url)
283 }
284}
285
286pub struct AzureOpenAIClientParams {
287 api_key: String,
288 version: String,
289 header: String,
290}
291
292impl ProviderClient for Client {
293 type Input = AzureOpenAIClientParams;
294
295 fn from_env() -> Self {
297 let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
298 AzureOpenAIAuth::ApiKey(api_key)
299 } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
300 AzureOpenAIAuth::Token(token)
301 } else {
302 panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
303 };
304
305 let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
306 let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
307
308 Self::builder()
309 .api_key(auth)
310 .azure_endpoint(azure_endpoint)
311 .api_version(&api_version)
312 .build()
313 .unwrap()
314 }
315
316 fn from_val(
317 AzureOpenAIClientParams {
318 api_key,
319 version,
320 header,
321 }: Self::Input,
322 ) -> Self {
323 let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
324
325 Self::builder()
326 .api_key(auth)
327 .azure_endpoint(header)
328 .api_version(&version)
329 .build()
330 .unwrap()
331 }
332}
333
334#[derive(Debug, Deserialize)]
335struct ApiErrorResponse {
336 message: String,
337}
338
339#[derive(Debug, Deserialize)]
340#[serde(untagged)]
341enum ApiResponse<T> {
342 Ok(T),
343 Err(ApiErrorResponse),
344}
345
346pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
352pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
354pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
356
357fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
358 match identifier {
359 TEXT_EMBEDDING_3_LARGE => Some(3_072),
360 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
361 _ => None,
362 }
363}
364
365#[derive(Debug, Deserialize)]
366pub struct EmbeddingResponse {
367 pub object: String,
368 pub data: Vec<EmbeddingData>,
369 pub model: String,
370 pub usage: Usage,
371}
372
373impl From<ApiErrorResponse> for EmbeddingError {
374 fn from(err: ApiErrorResponse) -> Self {
375 EmbeddingError::ProviderError(err.message)
376 }
377}
378
379impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
380 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
381 match value {
382 ApiResponse::Ok(response) => Ok(response),
383 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
384 }
385 }
386}
387
388#[derive(Debug, Deserialize)]
389pub struct EmbeddingData {
390 pub object: String,
391 pub embedding: Vec<f64>,
392 pub index: usize,
393}
394
395#[derive(Clone, Debug, Deserialize)]
396pub struct Usage {
397 pub prompt_tokens: usize,
398 pub total_tokens: usize,
399}
400
401impl GetTokenUsage for Usage {
402 fn token_usage(&self) -> Option<crate::completion::Usage> {
403 let mut usage = crate::completion::Usage::new();
404
405 usage.input_tokens = self.prompt_tokens as u64;
406 usage.total_tokens = self.total_tokens as u64;
407 usage.output_tokens = usage.total_tokens - usage.input_tokens;
408
409 Some(usage)
410 }
411}
412
413impl std::fmt::Display for Usage {
414 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415 write!(
416 f,
417 "Prompt tokens: {} Total tokens: {}",
418 self.prompt_tokens, self.total_tokens
419 )
420 }
421}
422
423#[derive(Clone)]
424pub struct EmbeddingModel<T = reqwest::Client> {
425 client: Client<T>,
426 pub model: String,
427 ndims: usize,
428}
429
430impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
431where
432 T: HttpClientExt + Default + Clone + 'static,
433{
434 const MAX_DOCUMENTS: usize = 1024;
435
436 type Client = Client<T>;
437
438 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
439 Self::new(client.clone(), model, dims)
440 }
441
442 fn ndims(&self) -> usize {
443 self.ndims
444 }
445
446 async fn embed_texts(
447 &self,
448 documents: impl IntoIterator<Item = String>,
449 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
450 let documents = documents.into_iter().collect::<Vec<_>>();
451
452 let mut body = json!({
453 "input": documents,
454 });
455
456 if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
457 body["dimensions"] = json!(self.ndims);
458 }
459
460 let body = serde_json::to_vec(&body)?;
461
462 let req = self
463 .client
464 .post_embedding(self.model.as_str())?
465 .body(body)
466 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
467
468 let response = self.client.send(req).await?;
469
470 if response.status().is_success() {
471 let body: Vec<u8> = response.into_body().await?;
472 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
473
474 match body {
475 ApiResponse::Ok(response) => {
476 tracing::info!(target: "rig",
477 "Azure embedding token usage: {}",
478 response.usage
479 );
480
481 if response.data.len() != documents.len() {
482 return Err(EmbeddingError::ResponseError(
483 "Response data length does not match input length".into(),
484 ));
485 }
486
487 Ok(response
488 .data
489 .into_iter()
490 .zip(documents.into_iter())
491 .map(|(embedding, document)| embeddings::Embedding {
492 document,
493 vec: embedding.embedding,
494 })
495 .collect())
496 }
497 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
498 }
499 } else {
500 let text = http_client::text(response).await?;
501 Err(EmbeddingError::ProviderError(text))
502 }
503 }
504}
505
506impl<T> EmbeddingModel<T> {
507 pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
508 let model = model.into();
509 let ndims = ndims
510 .or(model_dimensions_from_identifier(&model))
511 .unwrap_or_default();
512
513 Self {
514 client,
515 model,
516 ndims,
517 }
518 }
519
520 pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
521 let ndims = ndims.unwrap_or_default();
522
523 Self {
524 client,
525 model: model.into(),
526 ndims,
527 }
528 }
529}
530
531pub const O1: &str = "o1";
537pub const O1_PREVIEW: &str = "o1-preview";
539pub const O1_MINI: &str = "o1-mini";
541pub const GPT_4O: &str = "gpt-4o";
543pub const GPT_4O_MINI: &str = "gpt-4o-mini";
545pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
547pub const GPT_4_TURBO: &str = "gpt-4";
549pub const GPT_4: &str = "gpt-4";
551pub const GPT_4_32K: &str = "gpt-4-32k";
553pub const GPT_4_32K_0613: &str = "gpt-4-32k";
555pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
557pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
559pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
561
562#[derive(Debug, Serialize, Deserialize)]
563pub(super) struct AzureOpenAICompletionRequest {
564 model: String,
565 pub messages: Vec<openai::Message>,
566 #[serde(skip_serializing_if = "Option::is_none")]
567 temperature: Option<f64>,
568 #[serde(skip_serializing_if = "Vec::is_empty")]
569 tools: Vec<openai::ToolDefinition>,
570 #[serde(skip_serializing_if = "Option::is_none")]
571 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
572 #[serde(flatten, skip_serializing_if = "Option::is_none")]
573 pub additional_params: Option<serde_json::Value>,
574}
575
576impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest {
577 type Error = CompletionError;
578
579 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
580 let model = req.model.clone().unwrap_or_else(|| model.to_string());
581 if req.tool_choice.is_some() {
583 tracing::warn!(
584 "Tool choice is currently not supported in Azure OpenAI. This should be fixed by Rig 0.25."
585 );
586 }
587
588 let mut full_history: Vec<openai::Message> = match &req.preamble {
589 Some(preamble) => vec![openai::Message::system(preamble)],
590 None => vec![],
591 };
592
593 if let Some(docs) = req.normalized_documents() {
594 let docs: Vec<openai::Message> = docs.try_into()?;
595 full_history.extend(docs);
596 }
597
598 let chat_history: Vec<openai::Message> = req
599 .chat_history
600 .clone()
601 .into_iter()
602 .map(|message| message.try_into())
603 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
604 .into_iter()
605 .flatten()
606 .collect();
607
608 full_history.extend(chat_history);
609
610 let tool_choice = req
611 .tool_choice
612 .clone()
613 .map(crate::providers::openrouter::ToolChoice::try_from)
614 .transpose()?;
615
616 Ok(Self {
617 model: model.to_string(),
618 messages: full_history,
619 temperature: req.temperature,
620 tools: req
621 .tools
622 .clone()
623 .into_iter()
624 .map(openai::ToolDefinition::from)
625 .collect::<Vec<_>>(),
626 tool_choice,
627 additional_params: req.additional_params,
628 })
629 }
630}
631
632#[derive(Clone)]
633pub struct CompletionModel<T = reqwest::Client> {
634 client: Client<T>,
635 pub model: String,
637}
638
639impl<T> CompletionModel<T> {
640 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
641 Self {
642 client,
643 model: model.into(),
644 }
645 }
646}
647
648impl<T> completion::CompletionModel for CompletionModel<T>
649where
650 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
651{
652 type Response = openai::CompletionResponse;
653 type StreamingResponse = openai::StreamingCompletionResponse;
654 type Client = Client<T>;
655
656 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
657 Self::new(client.clone(), model.into())
658 }
659
660 async fn completion(
661 &self,
662 completion_request: CompletionRequest,
663 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
664 let span = if tracing::Span::current().is_disabled() {
665 info_span!(
666 target: "rig::completions",
667 "chat",
668 gen_ai.operation.name = "chat",
669 gen_ai.provider.name = "azure.openai",
670 gen_ai.request.model = self.model,
671 gen_ai.system_instructions = &completion_request.preamble,
672 gen_ai.response.id = tracing::field::Empty,
673 gen_ai.response.model = tracing::field::Empty,
674 gen_ai.usage.output_tokens = tracing::field::Empty,
675 gen_ai.usage.input_tokens = tracing::field::Empty,
676 )
677 } else {
678 tracing::Span::current()
679 };
680
681 let request =
682 AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
683
684 if enabled!(Level::TRACE) {
685 tracing::trace!(target: "rig::completions",
686 "Azure OpenAI completion request: {}",
687 serde_json::to_string_pretty(&request)?
688 );
689 }
690
691 let body = serde_json::to_vec(&request)?;
692
693 let req = self
694 .client
695 .post_chat_completion(&self.model)?
696 .body(body)
697 .map_err(http_client::Error::from)?;
698
699 async move {
700 let response = self.client.send::<_, Bytes>(req).await?;
701
702 let status = response.status();
703 let response_body = response.into_body().into_future().await?.to_vec();
704
705 if status.is_success() {
706 match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
707 &response_body,
708 )? {
709 ApiResponse::Ok(response) => {
710 let span = tracing::Span::current();
711 span.record_response_metadata(&response);
712 span.record_token_usage(&response.usage);
713 if enabled!(Level::TRACE) {
714 tracing::trace!(target: "rig::completions",
715 "Azure OpenAI completion response: {}",
716 serde_json::to_string_pretty(&response)?
717 );
718 }
719 response.try_into()
720 }
721 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
722 }
723 } else {
724 Err(CompletionError::ProviderError(
725 String::from_utf8_lossy(&response_body).to_string(),
726 ))
727 }
728 }
729 .instrument(span)
730 .await
731 }
732
733 async fn stream(
734 &self,
735 completion_request: CompletionRequest,
736 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
737 let preamble = completion_request.preamble.clone();
738 let mut request =
739 AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
740
741 let params = json_utils::merge(
742 request.additional_params.unwrap_or(serde_json::json!({})),
743 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
744 );
745
746 request.additional_params = Some(params);
747
748 if enabled!(Level::TRACE) {
749 tracing::trace!(target: "rig::completions",
750 "Azure OpenAI completion request: {}",
751 serde_json::to_string_pretty(&request)?
752 );
753 }
754
755 let body = serde_json::to_vec(&request)?;
756
757 let req = self
758 .client
759 .post_chat_completion(&self.model)?
760 .body(body)
761 .map_err(http_client::Error::from)?;
762
763 let span = if tracing::Span::current().is_disabled() {
764 info_span!(
765 target: "rig::completions",
766 "chat_streaming",
767 gen_ai.operation.name = "chat_streaming",
768 gen_ai.provider.name = "azure.openai",
769 gen_ai.request.model = self.model,
770 gen_ai.system_instructions = &preamble,
771 gen_ai.response.id = tracing::field::Empty,
772 gen_ai.response.model = tracing::field::Empty,
773 gen_ai.usage.output_tokens = tracing::field::Empty,
774 gen_ai.usage.input_tokens = tracing::field::Empty,
775 )
776 } else {
777 tracing::Span::current()
778 };
779
780 tracing_futures::Instrument::instrument(
781 send_compatible_streaming_request(self.client.clone(), req),
782 span,
783 )
784 .await
785 }
786}
787
788#[derive(Clone)]
793pub struct TranscriptionModel<T = reqwest::Client> {
794 client: Client<T>,
795 pub model: String,
797}
798
799impl<T> TranscriptionModel<T> {
800 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
801 Self {
802 client,
803 model: model.into(),
804 }
805 }
806}
807
808impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
809where
810 T: HttpClientExt + Clone + 'static,
811{
812 type Response = TranscriptionResponse;
813 type Client = Client<T>;
814
815 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
816 Self::new(client.clone(), model)
817 }
818
819 async fn transcription(
820 &self,
821 request: transcription::TranscriptionRequest,
822 ) -> Result<
823 transcription::TranscriptionResponse<Self::Response>,
824 transcription::TranscriptionError,
825 > {
826 let data = request.data;
827
828 let mut body =
829 MultipartForm::new().part(Part::bytes("file", data).filename(request.filename.clone()));
830
831 if let Some(prompt) = request.prompt {
832 body = body.text("prompt", prompt.clone());
833 }
834
835 if let Some(ref temperature) = request.temperature {
836 body = body.text("temperature", temperature.to_string());
837 }
838
839 if let Some(ref additional_params) = request.additional_params {
840 for (key, value) in additional_params
841 .as_object()
842 .expect("Additional Parameters to OpenAI Transcription should be a map")
843 {
844 body = body.text(key.to_owned(), value.to_string());
845 }
846 }
847
848 let req = self
849 .client
850 .post_transcription(&self.model)?
851 .body(body)
852 .map_err(|e| TranscriptionError::HttpError(e.into()))?;
853
854 let response = self.client.send_multipart::<Bytes>(req).await?;
855 let status = response.status();
856 let response_body = response.into_body().into_future().await?.to_vec();
857
858 if status.is_success() {
859 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
860 ApiResponse::Ok(response) => response.try_into(),
861 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
862 api_error_response.message,
863 )),
864 }
865 } else {
866 Err(TranscriptionError::ProviderError(
867 String::from_utf8_lossy(&response_body).to_string(),
868 ))
869 }
870 }
871}
872
873#[cfg(feature = "image")]
877pub use image_generation::*;
878use tracing::{Instrument, Level, enabled, info_span};
879#[cfg(feature = "image")]
880#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
881mod image_generation {
882 use crate::http_client::HttpClientExt;
883 use crate::image_generation;
884 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
885 use crate::providers::azure::{ApiResponse, Client};
886 use crate::providers::openai::ImageGenerationResponse;
887 use bytes::Bytes;
888 use serde_json::json;
889
890 #[derive(Clone)]
891 pub struct ImageGenerationModel<T = reqwest::Client> {
892 client: Client<T>,
893 pub model: String,
894 }
895
896 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
897 where
898 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
899 {
900 type Response = ImageGenerationResponse;
901
902 type Client = Client<T>;
903
904 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
905 Self {
906 client: client.clone(),
907 model: model.into(),
908 }
909 }
910
911 async fn image_generation(
912 &self,
913 generation_request: ImageGenerationRequest,
914 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
915 {
916 let request = json!({
917 "model": self.model,
918 "prompt": generation_request.prompt,
919 "size": format!("{}x{}", generation_request.width, generation_request.height),
920 "response_format": "b64_json"
921 });
922
923 let body = serde_json::to_vec(&request)?;
924
925 let req = self
926 .client
927 .post_image_generation(&self.model)?
928 .body(body)
929 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
930
931 let response = self.client.send::<_, Bytes>(req).await?;
932 let status = response.status();
933 let response_body = response.into_body().into_future().await?.to_vec();
934
935 if !status.is_success() {
936 return Err(ImageGenerationError::ProviderError(format!(
937 "{status}: {}",
938 String::from_utf8_lossy(&response_body)
939 )));
940 }
941
942 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
943 ApiResponse::Ok(response) => response.try_into(),
944 ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
945 }
946 }
947 }
948}
949#[cfg(feature = "audio")]
954pub use audio_generation::*;
955
956#[cfg(feature = "audio")]
957#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
958mod audio_generation {
959 use super::Client;
960 use crate::audio_generation::{
961 self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
962 };
963 use crate::http_client::HttpClientExt;
964 use bytes::Bytes;
965 use serde_json::json;
966
967 #[derive(Clone)]
968 pub struct AudioGenerationModel<T = reqwest::Client> {
969 client: Client<T>,
970 model: String,
971 }
972
973 impl<T> AudioGenerationModel<T> {
974 pub fn new(client: Client<T>, deployment_name: impl Into<String>) -> Self {
975 Self {
976 client,
977 model: deployment_name.into(),
978 }
979 }
980 }
981
982 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
983 where
984 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
985 {
986 type Response = Bytes;
987 type Client = Client<T>;
988
989 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
990 Self::new(client.clone(), model)
991 }
992
993 async fn audio_generation(
994 &self,
995 request: AudioGenerationRequest,
996 ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
997 let request = json!({
998 "model": self.model,
999 "input": request.text,
1000 "voice": request.voice,
1001 "speed": request.speed,
1002 });
1003
1004 let body = serde_json::to_vec(&request)?;
1005
1006 let req = self
1007 .client
1008 .post_audio_generation("/audio/speech")?
1009 .header("Content-Type", "application/json")
1010 .body(body)
1011 .map_err(|e| AudioGenerationError::HttpError(e.into()))?;
1012
1013 let response = self.client.send::<_, Bytes>(req).await?;
1014 let status = response.status();
1015 let response_body = response.into_body().into_future().await?;
1016
1017 if !status.is_success() {
1018 return Err(AudioGenerationError::ProviderError(format!(
1019 "{status}: {}",
1020 String::from_utf8_lossy(&response_body)
1021 )));
1022 }
1023
1024 Ok(AudioGenerationResponse {
1025 audio: response_body.to_vec(),
1026 response: response_body,
1027 })
1028 }
1029 }
1030}
1031
1032#[cfg(test)]
1033mod azure_tests {
1034 use super::*;
1035
1036 use crate::OneOrMany;
1037 use crate::client::{completion::CompletionClient, embeddings::EmbeddingsClient};
1038 use crate::completion::CompletionModel;
1039 use crate::embeddings::EmbeddingModel;
1040
1041 #[tokio::test]
1042 #[ignore]
1043 async fn test_azure_embedding() {
1044 let _ = tracing_subscriber::fmt::try_init();
1045
1046 let client = Client::<reqwest::Client>::from_env();
1047 let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1048 let embeddings = model
1049 .embed_texts(vec!["Hello, world!".to_string()])
1050 .await
1051 .unwrap();
1052
1053 tracing::info!("Azure embedding: {:?}", embeddings);
1054 }
1055
1056 #[tokio::test]
1057 #[ignore]
1058 async fn test_azure_embedding_dimensions() {
1059 let _ = tracing_subscriber::fmt::try_init();
1060
1061 let ndims = 256;
1062 let client = Client::<reqwest::Client>::from_env();
1063 let model = client.embedding_model_with_ndims(TEXT_EMBEDDING_3_SMALL, ndims);
1064 let embedding = model.embed_text("Hello, world!").await.unwrap();
1065
1066 assert!(embedding.vec.len() == ndims);
1067
1068 tracing::info!("Azure dimensions embedding: {:?}", embedding);
1069 }
1070
1071 #[tokio::test]
1072 #[ignore]
1073 async fn test_azure_completion() {
1074 let _ = tracing_subscriber::fmt::try_init();
1075
1076 let client = Client::<reqwest::Client>::from_env();
1077 let model = client.completion_model(GPT_4O_MINI);
1078 let completion = model
1079 .completion(CompletionRequest {
1080 model: None,
1081 preamble: Some("You are a helpful assistant.".to_string()),
1082 chat_history: OneOrMany::one("Hello!".into()),
1083 documents: vec![],
1084 max_tokens: Some(100),
1085 temperature: Some(0.0),
1086 tools: vec![],
1087 tool_choice: None,
1088 additional_params: None,
1089 output_schema: None,
1090 })
1091 .await
1092 .unwrap();
1093
1094 tracing::info!("Azure completion: {:?}", completion);
1095 }
1096 #[tokio::test]
1097 async fn test_client_initialization() {
1098 let _client: crate::providers::azure::Client<reqwest::Client> =
1099 crate::providers::azure::Client::builder()
1100 .api_key("test")
1101 .azure_endpoint("test".to_string()) .build()
1103 .expect("Client::builder() failed");
1104 }
1105}