1use std::fmt::Debug;
13
14use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
15#[cfg(feature = "image")]
16use crate::client::Nothing;
17use crate::client::{
18 self, ApiKey, Capabilities, Capable, DebugExt, Provider, ProviderBuilder, ProviderClient,
19};
20use crate::completion::GetTokenUsage;
21use crate::http_client::{self, HttpClientExt, bearer_auth_header};
22use crate::streaming::StreamingCompletionResponse;
23use crate::transcription::TranscriptionError;
24use crate::{
25 completion::{self, CompletionError, CompletionRequest},
26 embeddings::{self, EmbeddingError},
27 json_utils,
28 providers::openai,
29 telemetry::SpanCombinator,
30 transcription::{self},
31};
32use bytes::Bytes;
33use reqwest::multipart::Part;
34use serde::{Deserialize, Serialize};
35use serde_json::json;
36const DEFAULT_API_VERSION: &str = "2024-10-21";
41
42#[derive(Debug, Clone)]
43pub struct AzureExt {
44 endpoint: String,
45 api_version: String,
46}
47
48impl DebugExt for AzureExt {
49 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
50 [
51 ("endpoint", (&self.endpoint as &dyn Debug)),
52 ("api_version", (&self.api_version as &dyn Debug)),
53 ]
54 .into_iter()
55 }
56}
57
58#[derive(Debug, Clone)]
63pub struct AzureExtBuilder {
64 endpoint: Option<String>,
65 api_version: String,
66}
67
68impl Default for AzureExtBuilder {
69 fn default() -> Self {
70 Self {
71 endpoint: None,
72 api_version: DEFAULT_API_VERSION.into(),
73 }
74 }
75}
76
77pub type Client<H = reqwest::Client> = client::Client<AzureExt, H>;
78pub type ClientBuilder<H = reqwest::Client> =
79 client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H>;
80
81impl Provider for AzureExt {
82 type Builder = AzureExtBuilder;
83
84 const VERIFY_PATH: &'static str = "";
86
87 fn build<H>(
88 builder: &client::ClientBuilder<
89 Self::Builder,
90 <Self::Builder as ProviderBuilder>::ApiKey,
91 H,
92 >,
93 ) -> http_client::Result<Self> {
94 let AzureExtBuilder {
95 endpoint,
96 api_version,
97 ..
98 } = builder.ext().clone();
99
100 match endpoint {
101 Some(endpoint) => Ok(Self {
102 endpoint,
103 api_version,
104 }),
105 None => Err(http_client::Error::Instance(
106 "Azure client must be provided an endpoint prior to building".into(),
107 )),
108 }
109 }
110}
111
112impl<H> Capabilities<H> for AzureExt {
113 type Completion = Capable<CompletionModel<H>>;
114 type Embeddings = Capable<EmbeddingModel<H>>;
115 type Transcription = Capable<TranscriptionModel<H>>;
116 #[cfg(feature = "image")]
117 type ImageGeneration = Nothing;
118 #[cfg(feature = "audio")]
119 type AudioGeneration = Capable<AudioGenerationModel<H>>;
120}
121
122impl ProviderBuilder for AzureExtBuilder {
123 type Output = AzureExt;
124 type ApiKey = AzureOpenAIAuth;
125
126 const BASE_URL: &'static str = "";
127
128 fn finish<H>(
129 &self,
130 mut builder: client::ClientBuilder<Self, Self::ApiKey, H>,
131 ) -> http_client::Result<client::ClientBuilder<Self, Self::ApiKey, H>> {
132 use AzureOpenAIAuth::*;
133
134 let auth = builder.get_api_key().clone();
135
136 match auth {
137 Token(token) => bearer_auth_header(builder.headers_mut(), token.as_str())?,
138 ApiKey(key) => {
139 let k = http::HeaderName::from_static("api-key");
140 let v = http::HeaderValue::from_str(key.as_str())?;
141
142 builder.headers_mut().insert(k, v);
143 }
144 }
145
146 Ok(builder)
147 }
148}
149
150impl<H> ClientBuilder<H> {
151 pub fn api_version(mut self, api_version: &str) -> Self {
153 self.ext_mut().api_version = api_version.into();
154
155 self
156 }
157}
158
159impl<H> client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H> {
160 pub fn azure_endpoint(self, endpoint: String) -> ClientBuilder<H> {
162 self.over_ext(|AzureExtBuilder { api_version, .. }| AzureExtBuilder {
163 endpoint: Some(endpoint),
164 api_version,
165 })
166 }
167}
168
169#[derive(Clone)]
170pub enum AzureOpenAIAuth {
171 ApiKey(String),
172 Token(String),
173}
174
175impl ApiKey for AzureOpenAIAuth {}
176
177impl std::fmt::Debug for AzureOpenAIAuth {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 match self {
180 Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
181 Self::Token(_) => write!(f, "Token <REDACTED>"),
182 }
183 }
184}
185
186impl<S> From<S> for AzureOpenAIAuth
187where
188 S: Into<String>,
189{
190 fn from(token: S) -> Self {
191 AzureOpenAIAuth::Token(token.into())
192 }
193}
194
195impl<T> Client<T>
196where
197 T: HttpClientExt,
198{
199 fn endpoint(&self) -> &str {
200 &self.ext().endpoint
201 }
202
203 fn api_version(&self) -> &str {
204 &self.ext().api_version
205 }
206
207 fn post_embedding(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
208 let url = format!(
209 "{}/openai/deployments/{}/embeddings?api-version={}",
210 self.endpoint(),
211 deployment_id.trim_start_matches('/'),
212 self.api_version()
213 );
214
215 self.post(&url)
216 }
217
218 #[cfg(feature = "audio")]
219 fn post_audio_generation(
220 &self,
221 deployment_id: &str,
222 ) -> http_client::Result<http_client::Builder> {
223 let url = format!(
224 "{}/openai/deployments/{}/audio/speech?api-version={}",
225 self.endpoint(),
226 deployment_id.trim_start_matches('/'),
227 self.api_version()
228 );
229
230 self.post(url)
231 }
232
233 fn post_chat_completion(
234 &self,
235 deployment_id: &str,
236 ) -> http_client::Result<http_client::Builder> {
237 let url = format!(
238 "{}/openai/deployments/{}/chat/completions?api-version={}",
239 self.endpoint(),
240 deployment_id.trim_start_matches('/'),
241 self.api_version()
242 );
243
244 self.post(&url)
245 }
246
247 fn post_transcription(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
248 let url = format!(
249 "{}/openai/deployments/{}/audio/translations?api-version={}",
250 self.endpoint(),
251 deployment_id.trim_start_matches('/'),
252 self.api_version()
253 );
254
255 self.post(&url)
256 }
257
258 #[cfg(feature = "image")]
259 fn post_image_generation(
260 &self,
261 deployment_id: &str,
262 ) -> http_client::Result<http_client::Builder> {
263 let url = format!(
264 "{}/openai/deployments/{}/images/generations?api-version={}",
265 self.endpoint(),
266 deployment_id.trim_start_matches('/'),
267 self.api_version()
268 );
269
270 self.post(&url)
271 }
272}
273
274pub struct AzureOpenAIClientParams {
275 api_key: String,
276 version: String,
277 header: String,
278}
279
280impl ProviderClient for Client {
281 type Input = AzureOpenAIClientParams;
282
283 fn from_env() -> Self {
285 let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
286 AzureOpenAIAuth::ApiKey(api_key)
287 } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
288 AzureOpenAIAuth::Token(token)
289 } else {
290 panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
291 };
292
293 let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
294 let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
295
296 Self::builder()
297 .api_key(auth)
298 .azure_endpoint(azure_endpoint)
299 .api_version(&api_version)
300 .build()
301 .unwrap()
302 }
303
304 fn from_val(
305 AzureOpenAIClientParams {
306 api_key,
307 version,
308 header,
309 }: Self::Input,
310 ) -> Self {
311 let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
312
313 Self::builder()
314 .api_key(auth)
315 .azure_endpoint(header)
316 .api_version(&version)
317 .build()
318 .unwrap()
319 }
320}
321
322#[derive(Debug, Deserialize)]
323struct ApiErrorResponse {
324 message: String,
325}
326
327#[derive(Debug, Deserialize)]
328#[serde(untagged)]
329enum ApiResponse<T> {
330 Ok(T),
331 Err(ApiErrorResponse),
332}
333
334pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
340pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
342pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
344
345fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
346 match identifier {
347 TEXT_EMBEDDING_3_LARGE => Some(3_072),
348 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
349 _ => None,
350 }
351}
352
353#[derive(Debug, Deserialize)]
354pub struct EmbeddingResponse {
355 pub object: String,
356 pub data: Vec<EmbeddingData>,
357 pub model: String,
358 pub usage: Usage,
359}
360
361impl From<ApiErrorResponse> for EmbeddingError {
362 fn from(err: ApiErrorResponse) -> Self {
363 EmbeddingError::ProviderError(err.message)
364 }
365}
366
367impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
368 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
369 match value {
370 ApiResponse::Ok(response) => Ok(response),
371 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
372 }
373 }
374}
375
376#[derive(Debug, Deserialize)]
377pub struct EmbeddingData {
378 pub object: String,
379 pub embedding: Vec<f64>,
380 pub index: usize,
381}
382
383#[derive(Clone, Debug, Deserialize)]
384pub struct Usage {
385 pub prompt_tokens: usize,
386 pub total_tokens: usize,
387}
388
389impl GetTokenUsage for Usage {
390 fn token_usage(&self) -> Option<crate::completion::Usage> {
391 let mut usage = crate::completion::Usage::new();
392
393 usage.input_tokens = self.prompt_tokens as u64;
394 usage.total_tokens = self.total_tokens as u64;
395 usage.output_tokens = usage.total_tokens - usage.input_tokens;
396
397 Some(usage)
398 }
399}
400
401impl std::fmt::Display for Usage {
402 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403 write!(
404 f,
405 "Prompt tokens: {} Total tokens: {}",
406 self.prompt_tokens, self.total_tokens
407 )
408 }
409}
410
411#[derive(Clone)]
412pub struct EmbeddingModel<T = reqwest::Client> {
413 client: Client<T>,
414 pub model: String,
415 ndims: usize,
416}
417
418impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
419where
420 T: HttpClientExt + Default + Clone,
421{
422 const MAX_DOCUMENTS: usize = 1024;
423
424 type Client = Client<T>;
425
426 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
427 Self::new(client.clone(), model, dims)
428 }
429
430 fn ndims(&self) -> usize {
431 self.ndims
432 }
433
434 #[cfg_attr(feature = "worker", worker::send)]
435 async fn embed_texts(
436 &self,
437 documents: impl IntoIterator<Item = String>,
438 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
439 let documents = documents.into_iter().collect::<Vec<_>>();
440
441 let body = serde_json::to_vec(&json!({
442 "input": documents,
443 }))?;
444
445 let req = self
446 .client
447 .post_embedding(self.model.as_str())?
448 .body(body)
449 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
450
451 let response = self.client.send(req).await?;
452
453 if response.status().is_success() {
454 let body: Vec<u8> = response.into_body().await?;
455 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
456
457 match body {
458 ApiResponse::Ok(response) => {
459 tracing::info!(target: "rig",
460 "Azure embedding token usage: {}",
461 response.usage
462 );
463
464 if response.data.len() != documents.len() {
465 return Err(EmbeddingError::ResponseError(
466 "Response data length does not match input length".into(),
467 ));
468 }
469
470 Ok(response
471 .data
472 .into_iter()
473 .zip(documents.into_iter())
474 .map(|(embedding, document)| embeddings::Embedding {
475 document,
476 vec: embedding.embedding,
477 })
478 .collect())
479 }
480 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
481 }
482 } else {
483 let text = http_client::text(response).await?;
484 Err(EmbeddingError::ProviderError(text))
485 }
486 }
487}
488
489impl<T> EmbeddingModel<T> {
490 pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
491 let model = model.into();
492 let ndims = ndims
493 .or(model_dimensions_from_identifier(&model))
494 .unwrap_or_default();
495
496 Self {
497 client,
498 model,
499 ndims,
500 }
501 }
502
503 pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
504 let ndims = ndims.unwrap_or_default();
505
506 Self {
507 client,
508 model: model.into(),
509 ndims,
510 }
511 }
512}
513
514pub const O1: &str = "o1";
520pub const O1_PREVIEW: &str = "o1-preview";
522pub const O1_MINI: &str = "o1-mini";
524pub const GPT_4O: &str = "gpt-4o";
526pub const GPT_4O_MINI: &str = "gpt-4o-mini";
528pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
530pub const GPT_4_TURBO: &str = "gpt-4";
532pub const GPT_4: &str = "gpt-4";
534pub const GPT_4_32K: &str = "gpt-4-32k";
536pub const GPT_4_32K_0613: &str = "gpt-4-32k";
538pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
540pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
542pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
544
545#[derive(Debug, Serialize, Deserialize)]
546pub(super) struct AzureOpenAICompletionRequest {
547 model: String,
548 pub messages: Vec<openai::Message>,
549 #[serde(flatten, skip_serializing_if = "Option::is_none")]
550 temperature: Option<f64>,
551 #[serde(skip_serializing_if = "Vec::is_empty")]
552 tools: Vec<openai::ToolDefinition>,
553 #[serde(flatten, skip_serializing_if = "Option::is_none")]
554 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
555 #[serde(flatten, skip_serializing_if = "Option::is_none")]
556 pub additional_params: Option<serde_json::Value>,
557}
558
559impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest {
560 type Error = CompletionError;
561
562 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
563 if req.tool_choice.is_some() {
565 tracing::warn!(
566 "Tool choice is currently not supported in Azure OpenAI. This should be fixed by Rig 0.25."
567 );
568 }
569
570 let mut full_history: Vec<openai::Message> = match &req.preamble {
571 Some(preamble) => vec![openai::Message::system(preamble)],
572 None => vec![],
573 };
574
575 if let Some(docs) = req.normalized_documents() {
576 let docs: Vec<openai::Message> = docs.try_into()?;
577 full_history.extend(docs);
578 }
579
580 let chat_history: Vec<openai::Message> = req
581 .chat_history
582 .clone()
583 .into_iter()
584 .map(|message| message.try_into())
585 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
586 .into_iter()
587 .flatten()
588 .collect();
589
590 full_history.extend(chat_history);
591
592 let tool_choice = req
593 .tool_choice
594 .clone()
595 .map(crate::providers::openrouter::ToolChoice::try_from)
596 .transpose()?;
597
598 Ok(Self {
599 model: model.to_string(),
600 messages: full_history,
601 temperature: req.temperature,
602 tools: req
603 .tools
604 .clone()
605 .into_iter()
606 .map(openai::ToolDefinition::from)
607 .collect::<Vec<_>>(),
608 tool_choice,
609 additional_params: req.additional_params,
610 })
611 }
612}
613
614#[derive(Clone)]
615pub struct CompletionModel<T = reqwest::Client> {
616 client: Client<T>,
617 pub model: String,
619}
620
621impl<T> CompletionModel<T> {
622 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
623 Self {
624 client,
625 model: model.into(),
626 }
627 }
628}
629
630impl<T> completion::CompletionModel for CompletionModel<T>
631where
632 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
633{
634 type Response = openai::CompletionResponse;
635 type StreamingResponse = openai::StreamingCompletionResponse;
636 type Client = Client<T>;
637
638 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
639 Self::new(client.clone(), model.into())
640 }
641
642 #[cfg_attr(feature = "worker", worker::send)]
643 async fn completion(
644 &self,
645 completion_request: CompletionRequest,
646 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
647 let span = if tracing::Span::current().is_disabled() {
648 info_span!(
649 target: "rig::completions",
650 "chat",
651 gen_ai.operation.name = "chat",
652 gen_ai.provider.name = "azure.openai",
653 gen_ai.request.model = self.model,
654 gen_ai.system_instructions = &completion_request.preamble,
655 gen_ai.response.id = tracing::field::Empty,
656 gen_ai.response.model = tracing::field::Empty,
657 gen_ai.usage.output_tokens = tracing::field::Empty,
658 gen_ai.usage.input_tokens = tracing::field::Empty,
659 gen_ai.input.messages = tracing::field::Empty,
660 gen_ai.output.messages = tracing::field::Empty,
661 )
662 } else {
663 tracing::Span::current()
664 };
665
666 let request =
667 AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
668
669 span.record_model_input(&request.messages);
670 let body = serde_json::to_vec(&request)?;
671
672 let req = self
673 .client
674 .post_chat_completion(&self.model)?
675 .body(body)
676 .map_err(http_client::Error::from)?;
677
678 async move {
679 let response = self.client.send::<_, Bytes>(req).await.unwrap();
680
681 let status = response.status();
682 let response_body = response.into_body().into_future().await?.to_vec();
683
684 if status.is_success() {
685 match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
686 &response_body,
687 )? {
688 ApiResponse::Ok(response) => {
689 let span = tracing::Span::current();
690 span.record_model_output(&response.choices);
691 span.record_response_metadata(&response);
692 span.record_token_usage(&response.usage);
693 tracing::trace!(
694 target: "rig::completions",
695 "Azure completion response: {}",
696 serde_json::to_string_pretty(&response)?
697 );
698 response.try_into()
699 }
700 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
701 }
702 } else {
703 Err(CompletionError::ProviderError(
704 String::from_utf8_lossy(&response_body).to_string(),
705 ))
706 }
707 }
708 .instrument(span)
709 .await
710 }
711
712 #[cfg_attr(feature = "worker", worker::send)]
713 async fn stream(
714 &self,
715 completion_request: CompletionRequest,
716 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
717 let preamble = completion_request.preamble.clone();
718 let mut request =
719 AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
720
721 let params = json_utils::merge(
722 request.additional_params.unwrap_or(serde_json::json!({})),
723 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
724 );
725
726 request.additional_params = Some(params);
727
728 let body = serde_json::to_vec(&request)?;
729
730 let req = self
731 .client
732 .post_chat_completion(&self.model)?
733 .body(body)
734 .map_err(http_client::Error::from)?;
735
736 let span = if tracing::Span::current().is_disabled() {
737 info_span!(
738 target: "rig::completions",
739 "chat_streaming",
740 gen_ai.operation.name = "chat_streaming",
741 gen_ai.provider.name = "azure.openai",
742 gen_ai.request.model = self.model,
743 gen_ai.system_instructions = &preamble,
744 gen_ai.response.id = tracing::field::Empty,
745 gen_ai.response.model = tracing::field::Empty,
746 gen_ai.usage.output_tokens = tracing::field::Empty,
747 gen_ai.usage.input_tokens = tracing::field::Empty,
748 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
749 gen_ai.output.messages = tracing::field::Empty,
750 )
751 } else {
752 tracing::Span::current()
753 };
754
755 tracing_futures::Instrument::instrument(
756 send_compatible_streaming_request(self.client.http_client().clone(), req),
757 span,
758 )
759 .await
760 }
761}
762
763#[derive(Clone)]
768pub struct TranscriptionModel<T = reqwest::Client> {
769 client: Client<T>,
770 pub model: String,
772}
773
774impl<T> TranscriptionModel<T> {
775 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
776 Self {
777 client,
778 model: model.into(),
779 }
780 }
781}
782
783impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
784where
785 T: HttpClientExt + Clone + 'static,
786{
787 type Response = TranscriptionResponse;
788 type Client = Client<T>;
789
790 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
791 Self::new(client.clone(), model)
792 }
793
794 #[cfg_attr(feature = "worker", worker::send)]
795 async fn transcription(
796 &self,
797 request: transcription::TranscriptionRequest,
798 ) -> Result<
799 transcription::TranscriptionResponse<Self::Response>,
800 transcription::TranscriptionError,
801 > {
802 let data = request.data;
803
804 let mut body = reqwest::multipart::Form::new().part(
805 "file",
806 Part::bytes(data).file_name(request.filename.clone()),
807 );
808
809 if let Some(prompt) = request.prompt {
810 body = body.text("prompt", prompt.clone());
811 }
812
813 if let Some(ref temperature) = request.temperature {
814 body = body.text("temperature", temperature.to_string());
815 }
816
817 if let Some(ref additional_params) = request.additional_params {
818 for (key, value) in additional_params
819 .as_object()
820 .expect("Additional Parameters to OpenAI Transcription should be a map")
821 {
822 body = body.text(key.to_owned(), value.to_string());
823 }
824 }
825
826 let req = self
827 .client
828 .post_transcription(&self.model)?
829 .body(body)
830 .map_err(|e| TranscriptionError::HttpError(e.into()))?;
831
832 let response = self
833 .client
834 .http_client()
835 .send_multipart::<Bytes>(req)
836 .await?;
837 let status = response.status();
838 let response_body = response.into_body().into_future().await?.to_vec();
839
840 if status.is_success() {
841 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
842 ApiResponse::Ok(response) => response.try_into(),
843 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
844 api_error_response.message,
845 )),
846 }
847 } else {
848 Err(TranscriptionError::ProviderError(
849 String::from_utf8_lossy(&response_body).to_string(),
850 ))
851 }
852 }
853}
854
855#[cfg(feature = "image")]
859pub use image_generation::*;
860use tracing::{Instrument, info_span};
861#[cfg(feature = "image")]
862#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
863mod image_generation {
864 use crate::http_client::HttpClientExt;
865 use crate::image_generation;
866 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
867 use crate::providers::azure::{ApiResponse, Client};
868 use crate::providers::openai::ImageGenerationResponse;
869 use bytes::Bytes;
870 use serde_json::json;
871
872 #[derive(Clone)]
873 pub struct ImageGenerationModel<T = reqwest::Client> {
874 client: Client<T>,
875 pub model: String,
876 }
877
878 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
879 where
880 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
881 {
882 type Response = ImageGenerationResponse;
883
884 type Client = Client<T>;
885
886 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
887 Self {
888 client: client.clone(),
889 model: model.into(),
890 }
891 }
892
893 #[cfg_attr(feature = "worker", worker::send)]
894 async fn image_generation(
895 &self,
896 generation_request: ImageGenerationRequest,
897 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
898 {
899 let request = json!({
900 "model": self.model,
901 "prompt": generation_request.prompt,
902 "size": format!("{}x{}", generation_request.width, generation_request.height),
903 "response_format": "b64_json"
904 });
905
906 let body = serde_json::to_vec(&request)?;
907
908 let req = self
909 .client
910 .post_image_generation(&self.model)?
911 .body(body)
912 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
913
914 let response = self.client.send::<_, Bytes>(req).await?;
915 let status = response.status();
916 let response_body = response.into_body().into_future().await?.to_vec();
917
918 if !status.is_success() {
919 return Err(ImageGenerationError::ProviderError(format!(
920 "{status}: {}",
921 String::from_utf8_lossy(&response_body)
922 )));
923 }
924
925 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
926 ApiResponse::Ok(response) => response.try_into(),
927 ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
928 }
929 }
930 }
931}
932#[cfg(feature = "audio")]
937pub use audio_generation::*;
938
939#[cfg(feature = "audio")]
940#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
941mod audio_generation {
942 use super::Client;
943 use crate::audio_generation::{
944 self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
945 };
946 use crate::http_client::HttpClientExt;
947 use bytes::Bytes;
948 use serde_json::json;
949
950 #[derive(Clone)]
951 pub struct AudioGenerationModel<T = reqwest::Client> {
952 client: Client<T>,
953 model: String,
954 }
955
956 impl<T> AudioGenerationModel<T> {
957 pub fn new(client: Client<T>, deployment_name: impl Into<String>) -> Self {
958 Self {
959 client,
960 model: deployment_name.into(),
961 }
962 }
963 }
964
965 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
966 where
967 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
968 {
969 type Response = Bytes;
970 type Client = Client<T>;
971
972 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
973 Self::new(client.clone(), model)
974 }
975
976 async fn audio_generation(
977 &self,
978 request: AudioGenerationRequest,
979 ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
980 let request = json!({
981 "model": self.model,
982 "input": request.text,
983 "voice": request.voice,
984 "speed": request.speed,
985 });
986
987 let body = serde_json::to_vec(&request)?;
988
989 let req = self
990 .client
991 .post_audio_generation("/audio/speech")?
992 .header("Content-Type", "application/json")
993 .body(body)
994 .map_err(|e| AudioGenerationError::HttpError(e.into()))?;
995
996 let response = self.client.send::<_, Bytes>(req).await?;
997 let status = response.status();
998 let response_body = response.into_body().into_future().await?;
999
1000 if !status.is_success() {
1001 return Err(AudioGenerationError::ProviderError(format!(
1002 "{status}: {}",
1003 String::from_utf8_lossy(&response_body)
1004 )));
1005 }
1006
1007 Ok(AudioGenerationResponse {
1008 audio: response_body.to_vec(),
1009 response: response_body,
1010 })
1011 }
1012 }
1013}
1014
1015#[cfg(test)]
1016mod azure_tests {
1017 use super::*;
1018
1019 use crate::OneOrMany;
1020 use crate::client::{completion::CompletionClient, embeddings::EmbeddingsClient};
1021 use crate::completion::CompletionModel;
1022 use crate::embeddings::EmbeddingModel;
1023
1024 #[tokio::test]
1025 #[ignore]
1026 async fn test_azure_embedding() {
1027 let _ = tracing_subscriber::fmt::try_init();
1028
1029 let client = Client::<reqwest::Client>::from_env();
1030 let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1031 let embeddings = model
1032 .embed_texts(vec!["Hello, world!".to_string()])
1033 .await
1034 .unwrap();
1035
1036 tracing::info!("Azure embedding: {:?}", embeddings);
1037 }
1038
1039 #[tokio::test]
1040 #[ignore]
1041 async fn test_azure_completion() {
1042 let _ = tracing_subscriber::fmt::try_init();
1043
1044 let client = Client::<reqwest::Client>::from_env();
1045 let model = client.completion_model(GPT_4O_MINI);
1046 let completion = model
1047 .completion(CompletionRequest {
1048 preamble: Some("You are a helpful assistant.".to_string()),
1049 chat_history: OneOrMany::one("Hello!".into()),
1050 documents: vec![],
1051 max_tokens: Some(100),
1052 temperature: Some(0.0),
1053 tools: vec![],
1054 tool_choice: None,
1055 additional_params: None,
1056 })
1057 .await
1058 .unwrap();
1059
1060 tracing::info!("Azure completion: {:?}", completion);
1061 }
1062}