1use std::fmt::Debug;
26
27use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
28use crate::client::{
29 self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
30 ProviderClient,
31};
32use crate::completion::GetTokenUsage;
33use crate::http_client::multipart::Part;
34use crate::http_client::{self, HttpClientExt, MultipartForm, bearer_auth_header};
35use crate::streaming::StreamingCompletionResponse;
36use crate::transcription::TranscriptionError;
37use crate::{
38 completion::{self, CompletionError, CompletionRequest},
39 embeddings::{self, EmbeddingError},
40 json_utils,
41 providers::openai,
42 telemetry::SpanCombinator,
43 transcription::{self},
44};
45use bytes::Bytes;
46use serde::{Deserialize, Serialize};
47use serde_json::json;
48const DEFAULT_API_VERSION: &str = "2024-10-21";
53
54#[derive(Debug, Clone)]
55pub struct AzureExt {
56 endpoint: String,
57 api_version: String,
58}
59
60impl DebugExt for AzureExt {
61 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
62 [
63 ("endpoint", (&self.endpoint as &dyn Debug)),
64 ("api_version", (&self.api_version as &dyn Debug)),
65 ]
66 .into_iter()
67 }
68}
69
70#[derive(Debug, Clone)]
75pub struct AzureExtBuilder {
76 endpoint: Option<String>,
77 api_version: String,
78}
79
80impl Default for AzureExtBuilder {
81 fn default() -> Self {
82 Self {
83 endpoint: None,
84 api_version: DEFAULT_API_VERSION.into(),
85 }
86 }
87}
88
89pub type Client<H = reqwest::Client> = client::Client<AzureExt, H>;
90pub type ClientBuilder<H = crate::markers::Missing> =
91 client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H>;
92
93impl Provider for AzureExt {
94 type Builder = AzureExtBuilder;
95
96 const VERIFY_PATH: &'static str = "";
98}
99
100impl<H> Capabilities<H> for AzureExt {
101 type Completion = Capable<CompletionModel<H>>;
102 type Embeddings = Capable<EmbeddingModel<H>>;
103 type Transcription = Capable<TranscriptionModel<H>>;
104 type ModelListing = Nothing;
105 #[cfg(feature = "image")]
106 type ImageGeneration = Nothing;
107 #[cfg(feature = "audio")]
108 type AudioGeneration = Capable<AudioGenerationModel<H>>;
109 type Rerank = Nothing;
110}
111
112impl ProviderBuilder for AzureExtBuilder {
113 type Extension<H>
114 = AzureExt
115 where
116 H: HttpClientExt;
117 type ApiKey = AzureOpenAIAuth;
118
119 const BASE_URL: &'static str = "";
120
121 fn build<H>(
122 builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
123 ) -> http_client::Result<Self::Extension<H>>
124 where
125 H: HttpClientExt,
126 {
127 let AzureExtBuilder {
128 endpoint,
129 api_version,
130 ..
131 } = builder.ext().clone();
132
133 match endpoint {
134 Some(endpoint) => Ok(AzureExt {
135 endpoint,
136 api_version,
137 }),
138 None => Err(http_client::Error::Instance(
139 "Azure client must be provided an endpoint prior to building".into(),
140 )),
141 }
142 }
143
144 fn finish<H>(
145 &self,
146 mut builder: client::ClientBuilder<Self, Self::ApiKey, H>,
147 ) -> http_client::Result<client::ClientBuilder<Self, Self::ApiKey, H>> {
148 use AzureOpenAIAuth::*;
149
150 let auth = builder.get_api_key().clone();
151
152 match auth {
153 Token(token) => bearer_auth_header(builder.headers_mut(), token.as_str())?,
154 ApiKey(key) => {
155 let k = http::HeaderName::from_static("api-key");
156 let v = http::HeaderValue::from_str(key.as_str())?;
157
158 builder.headers_mut().insert(k, v);
159 }
160 }
161
162 Ok(builder)
163 }
164}
165
166impl<H> ClientBuilder<H> {
167 pub fn api_version(mut self, api_version: &str) -> Self {
169 self.ext_mut().api_version = api_version.into();
170
171 self
172 }
173}
174
175impl<H> client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H> {
176 pub fn azure_endpoint(self, endpoint: String) -> ClientBuilder<H> {
178 self.over_ext(|AzureExtBuilder { api_version, .. }| AzureExtBuilder {
179 endpoint: Some(endpoint),
180 api_version,
181 })
182 }
183}
184
185#[derive(Clone)]
188pub enum AzureOpenAIAuth {
189 ApiKey(String),
190 Token(String),
191}
192
193impl ApiKey for AzureOpenAIAuth {}
194
195impl std::fmt::Debug for AzureOpenAIAuth {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 match self {
198 Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
199 Self::Token(_) => write!(f, "Token <REDACTED>"),
200 }
201 }
202}
203
204impl<S> From<S> for AzureOpenAIAuth
205where
206 S: Into<String>,
207{
208 fn from(token: S) -> Self {
209 AzureOpenAIAuth::Token(token.into())
210 }
211}
212
213impl<T> Client<T>
214where
215 T: HttpClientExt,
216{
217 fn endpoint(&self) -> &str {
218 &self.ext().endpoint
219 }
220
221 fn api_version(&self) -> &str {
222 &self.ext().api_version
223 }
224
225 fn post_embedding(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
226 let url = format!(
227 "{}/openai/deployments/{}/embeddings?api-version={}",
228 self.endpoint(),
229 deployment_id.trim_start_matches('/'),
230 self.api_version()
231 );
232
233 self.post(&url)
234 }
235
236 #[cfg(feature = "audio")]
237 fn post_audio_generation(
238 &self,
239 deployment_id: &str,
240 ) -> http_client::Result<http_client::Builder> {
241 let url = format!(
242 "{}/openai/deployments/{}/audio/speech?api-version={}",
243 self.endpoint(),
244 deployment_id.trim_start_matches('/'),
245 self.api_version()
246 );
247
248 self.post(url)
249 }
250
251 fn post_chat_completion(
252 &self,
253 deployment_id: &str,
254 ) -> http_client::Result<http_client::Builder> {
255 let url = format!(
256 "{}/openai/deployments/{}/chat/completions?api-version={}",
257 self.endpoint(),
258 deployment_id.trim_start_matches('/'),
259 self.api_version()
260 );
261
262 self.post(&url)
263 }
264
265 fn post_transcription(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
266 let url = format!(
267 "{}/openai/deployments/{}/audio/translations?api-version={}",
268 self.endpoint(),
269 deployment_id.trim_start_matches('/'),
270 self.api_version()
271 );
272
273 self.post(&url)
274 }
275
276 #[cfg(feature = "image")]
277 fn post_image_generation(
278 &self,
279 deployment_id: &str,
280 ) -> http_client::Result<http_client::Builder> {
281 let url = format!(
282 "{}/openai/deployments/{}/images/generations?api-version={}",
283 self.endpoint(),
284 deployment_id.trim_start_matches('/'),
285 self.api_version()
286 );
287
288 self.post(&url)
289 }
290}
291
292pub struct AzureOpenAIClientParams {
293 api_key: String,
294 version: String,
295 header: String,
296}
297
298impl ProviderClient for Client {
299 type Input = AzureOpenAIClientParams;
300 type Error = crate::client::ProviderClientError;
301
302 fn from_env() -> Result<Self, Self::Error> {
304 let auth = if let Some(api_key) = crate::client::optional_env_var("AZURE_API_KEY")? {
305 AzureOpenAIAuth::ApiKey(api_key)
306 } else if let Some(token) = crate::client::optional_env_var("AZURE_TOKEN")? {
307 AzureOpenAIAuth::Token(token)
308 } else {
309 return Err(crate::client::ProviderClientError::InvalidConfiguration(
310 "either `AZURE_API_KEY` or `AZURE_TOKEN` must be set",
311 ));
312 };
313
314 let api_version = crate::client::required_env_var("AZURE_API_VERSION")?;
315 let azure_endpoint = crate::client::required_env_var("AZURE_ENDPOINT")?;
316
317 Self::builder()
318 .api_key(auth)
319 .azure_endpoint(azure_endpoint)
320 .api_version(&api_version)
321 .build()
322 .map_err(Into::into)
323 }
324
325 fn from_val(
326 AzureOpenAIClientParams {
327 api_key,
328 version,
329 header,
330 }: Self::Input,
331 ) -> Result<Self, Self::Error> {
332 let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
333
334 Self::builder()
335 .api_key(auth)
336 .azure_endpoint(header)
337 .api_version(&version)
338 .build()
339 .map_err(Into::into)
340 }
341}
342
343#[derive(Debug, Deserialize)]
344struct ApiErrorResponse {
345 message: String,
346}
347
348#[derive(Debug, Deserialize)]
349#[serde(untagged)]
350enum ApiResponse<T> {
351 Ok(T),
352 Err(ApiErrorResponse),
353}
354
355pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
361pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
363pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
365
366fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
367 match identifier {
368 TEXT_EMBEDDING_3_LARGE => Some(3_072),
369 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
370 _ => None,
371 }
372}
373
374#[derive(Debug, Deserialize)]
375pub struct EmbeddingResponse {
376 pub object: String,
377 pub data: Vec<EmbeddingData>,
378 pub model: String,
379 pub usage: Usage,
380}
381
382impl From<ApiErrorResponse> for EmbeddingError {
383 fn from(err: ApiErrorResponse) -> Self {
384 EmbeddingError::ProviderError(err.message)
385 }
386}
387
388impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
389 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
390 match value {
391 ApiResponse::Ok(response) => Ok(response),
392 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
393 }
394 }
395}
396
397#[derive(Debug, Deserialize)]
398pub struct EmbeddingData {
399 pub object: String,
400 pub embedding: Vec<f64>,
401 pub index: usize,
402}
403
404#[derive(Clone, Debug, Deserialize)]
405pub struct Usage {
406 pub prompt_tokens: usize,
407 pub total_tokens: usize,
408}
409
410impl GetTokenUsage for Usage {
411 fn token_usage(&self) -> crate::completion::Usage {
412 let mut usage = crate::completion::Usage::new();
413
414 usage.input_tokens = self.prompt_tokens as u64;
415 usage.total_tokens = self.total_tokens as u64;
416 usage.output_tokens = usage.total_tokens - usage.input_tokens;
417
418 usage
419 }
420}
421
422impl std::fmt::Display for Usage {
423 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424 write!(
425 f,
426 "Prompt tokens: {} Total tokens: {}",
427 self.prompt_tokens, self.total_tokens
428 )
429 }
430}
431
432#[derive(Clone)]
433pub struct EmbeddingModel<T = reqwest::Client> {
434 client: Client<T>,
435 pub model: String,
436 ndims: usize,
437}
438
439impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
440where
441 T: HttpClientExt + Default + Clone + 'static,
442{
443 const MAX_DOCUMENTS: usize = 1024;
444
445 type Client = Client<T>;
446
447 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
448 Self::new(client.clone(), model, dims)
449 }
450
451 fn ndims(&self) -> usize {
452 self.ndims
453 }
454
455 async fn embed_texts(
456 &self,
457 documents: impl IntoIterator<Item = String>,
458 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
459 let documents = documents.into_iter().collect::<Vec<_>>();
460
461 let mut body = json!({
462 "input": documents,
463 });
464
465 let body_object = body.as_object_mut().ok_or_else(|| {
466 EmbeddingError::ResponseError("embedding request body must be a JSON object".into())
467 })?;
468
469 if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
470 body_object.insert("dimensions".to_owned(), json!(self.ndims));
471 }
472
473 let body = serde_json::to_vec(&body)?;
474
475 let req = self
476 .client
477 .post_embedding(self.model.as_str())?
478 .body(body)
479 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
480
481 let response = self.client.send(req).await?;
482
483 if response.status().is_success() {
484 let body: Vec<u8> = response.into_body().await?;
485 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
486
487 match body {
488 ApiResponse::Ok(response) => {
489 tracing::info!(target: "rig",
490 "Azure embedding token usage: {}",
491 response.usage
492 );
493
494 if response.data.len() != documents.len() {
495 return Err(EmbeddingError::ResponseError(
496 "Response data length does not match input length".into(),
497 ));
498 }
499
500 Ok(response
501 .data
502 .into_iter()
503 .zip(documents.into_iter())
504 .map(|(embedding, document)| embeddings::Embedding {
505 document,
506 vec: embedding.embedding,
507 })
508 .collect())
509 }
510 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
511 }
512 } else {
513 let text = http_client::text(response).await?;
514 Err(EmbeddingError::ProviderError(text))
515 }
516 }
517}
518
519impl<T> EmbeddingModel<T> {
520 pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
521 let model = model.into();
522 let ndims = ndims
523 .or(model_dimensions_from_identifier(&model))
524 .unwrap_or_default();
525
526 Self {
527 client,
528 model,
529 ndims,
530 }
531 }
532
533 pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
534 let ndims = ndims.unwrap_or_default();
535
536 Self {
537 client,
538 model: model.into(),
539 ndims,
540 }
541 }
542}
543
544pub const O1: &str = "o1";
550pub const O1_PREVIEW: &str = "o1-preview";
552pub const O1_MINI: &str = "o1-mini";
554pub const GPT_4O: &str = "gpt-4o";
556pub const GPT_4O_MINI: &str = "gpt-4o-mini";
558pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
560pub const GPT_4_TURBO: &str = "gpt-4";
562pub const GPT_4: &str = "gpt-4";
564pub const GPT_4_32K: &str = "gpt-4-32k";
566pub const GPT_4_32K_0613: &str = "gpt-4-32k";
568pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
570pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
572pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
574
575#[derive(Debug, Serialize, Deserialize)]
576pub(super) struct AzureOpenAICompletionRequest {
577 model: String,
578 pub messages: Vec<openai::Message>,
579 #[serde(skip_serializing_if = "Option::is_none")]
580 temperature: Option<f64>,
581 #[serde(skip_serializing_if = "Vec::is_empty")]
582 tools: Vec<openai::ToolDefinition>,
583 #[serde(skip_serializing_if = "Option::is_none")]
584 tool_choice: Option<crate::providers::openai::ToolChoice>,
585 #[serde(flatten, skip_serializing_if = "Option::is_none")]
586 pub additional_params: Option<serde_json::Value>,
587}
588
589impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest {
590 type Error = CompletionError;
591
592 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
593 let chat_history = req.chat_history_with_documents();
594 let model = req.model.clone().unwrap_or_else(|| model.to_string());
595 if req.tool_choice.is_some() {
596 tracing::warn!("Tool choice is currently not supported in Azure OpenAI.");
597 }
598
599 let mut full_history: Vec<openai::Message> = match &req.preamble {
600 Some(preamble) => vec![openai::Message::system(preamble)],
601 None => vec![],
602 };
603
604 let chat_history: Vec<openai::Message> = chat_history
605 .into_iter()
606 .map(|message| message.try_into())
607 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
608 .into_iter()
609 .flatten()
610 .collect();
611
612 full_history.extend(chat_history);
613
614 let tool_choice = req
615 .tool_choice
616 .clone()
617 .map(crate::providers::openai::ToolChoice::try_from)
618 .transpose()?;
619
620 let additional_params = if let Some(schema) = req.output_schema {
621 let name = schema
622 .as_object()
623 .and_then(|o| o.get("title"))
624 .and_then(|v| v.as_str())
625 .unwrap_or("response_schema")
626 .to_string();
627 let mut schema_value = schema.to_value();
628 openai::sanitize_schema(&mut schema_value);
629 let response_format = serde_json::json!({
630 "response_format": {
631 "type": "json_schema",
632 "json_schema": {
633 "name": name,
634 "strict": true,
635 "schema": schema_value
636 }
637 }
638 });
639 Some(match req.additional_params {
640 Some(existing) => json_utils::merge(existing, response_format),
641 None => response_format,
642 })
643 } else {
644 req.additional_params
645 };
646
647 Ok(Self {
648 model: model.to_string(),
649 messages: full_history,
650 temperature: req.temperature,
651 tools: req
652 .tools
653 .clone()
654 .into_iter()
655 .map(openai::ToolDefinition::from)
656 .collect::<Vec<_>>(),
657 tool_choice,
658 additional_params,
659 })
660 }
661}
662
663#[derive(Clone)]
664pub struct CompletionModel<T = reqwest::Client> {
665 client: Client<T>,
666 pub model: String,
668}
669
670impl<T> CompletionModel<T> {
671 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
672 Self {
673 client,
674 model: model.into(),
675 }
676 }
677}
678
679impl<T> completion::CompletionModel for CompletionModel<T>
680where
681 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
682{
683 type Response = openai::CompletionResponse;
684 type StreamingResponse = openai::StreamingCompletionResponse;
685 type Client = Client<T>;
686
687 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
688 Self::new(client.clone(), model.into())
689 }
690
691 async fn completion(
692 &self,
693 completion_request: CompletionRequest,
694 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
695 let span = if tracing::Span::current().is_disabled() {
696 info_span!(
697 target: "rig::completions",
698 "chat",
699 gen_ai.operation.name = "chat",
700 gen_ai.provider.name = "azure.openai",
701 gen_ai.request.model = self.model,
702 gen_ai.system_instructions = &completion_request.preamble,
703 gen_ai.response.id = tracing::field::Empty,
704 gen_ai.response.model = tracing::field::Empty,
705 gen_ai.usage.output_tokens = tracing::field::Empty,
706 gen_ai.usage.input_tokens = tracing::field::Empty,
707 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
708 )
709 } else {
710 tracing::Span::current()
711 };
712
713 let request =
714 AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
715
716 if enabled!(Level::TRACE) {
717 tracing::trace!(target: "rig::completions",
718 "Azure OpenAI completion request: {}",
719 serde_json::to_string_pretty(&request)?
720 );
721 }
722
723 let body = serde_json::to_vec(&request)?;
724
725 let req = self
726 .client
727 .post_chat_completion(&self.model)?
728 .body(body)
729 .map_err(http_client::Error::from)?;
730
731 async move {
732 let response = self.client.send::<_, Bytes>(req).await?;
733
734 let status = response.status();
735 let response_body = response.into_body().into_future().await?.to_vec();
736
737 if status.is_success() {
738 match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
739 &response_body,
740 )? {
741 ApiResponse::Ok(response) => {
742 let span = tracing::Span::current();
743 span.record_response_metadata(&response);
744 span.record_token_usage(&response.usage);
745 if enabled!(Level::TRACE) {
746 tracing::trace!(target: "rig::completions",
747 "Azure OpenAI completion response: {}",
748 serde_json::to_string_pretty(&response)?
749 );
750 }
751 response.try_into()
752 }
753 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
754 }
755 } else {
756 Err(CompletionError::ProviderError(
757 String::from_utf8_lossy(&response_body).to_string(),
758 ))
759 }
760 }
761 .instrument(span)
762 .await
763 }
764
765 async fn stream(
766 &self,
767 completion_request: CompletionRequest,
768 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
769 let preamble = completion_request.preamble.clone();
770 let mut request =
771 AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
772
773 let params = json_utils::merge(
774 request.additional_params.unwrap_or(serde_json::json!({})),
775 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
776 );
777
778 request.additional_params = Some(params);
779
780 if enabled!(Level::TRACE) {
781 tracing::trace!(target: "rig::completions",
782 "Azure OpenAI completion request: {}",
783 serde_json::to_string_pretty(&request)?
784 );
785 }
786
787 let body = serde_json::to_vec(&request)?;
788
789 let req = self
790 .client
791 .post_chat_completion(&self.model)?
792 .body(body)
793 .map_err(http_client::Error::from)?;
794
795 let span = if tracing::Span::current().is_disabled() {
796 info_span!(
797 target: "rig::completions",
798 "chat_streaming",
799 gen_ai.operation.name = "chat_streaming",
800 gen_ai.provider.name = "azure.openai",
801 gen_ai.request.model = self.model,
802 gen_ai.system_instructions = &preamble,
803 gen_ai.response.id = tracing::field::Empty,
804 gen_ai.response.model = tracing::field::Empty,
805 gen_ai.usage.output_tokens = tracing::field::Empty,
806 gen_ai.usage.input_tokens = tracing::field::Empty,
807 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
808 )
809 } else {
810 tracing::Span::current()
811 };
812
813 tracing_futures::Instrument::instrument(
814 send_compatible_streaming_request(self.client.clone(), req),
815 span,
816 )
817 .await
818 }
819}
820
821#[derive(Clone)]
826pub struct TranscriptionModel<T = reqwest::Client> {
827 client: Client<T>,
828 pub model: String,
830}
831
832impl<T> TranscriptionModel<T> {
833 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
834 Self {
835 client,
836 model: model.into(),
837 }
838 }
839}
840
841impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
842where
843 T: HttpClientExt + Clone + 'static,
844{
845 type Response = TranscriptionResponse;
846 type Client = Client<T>;
847
848 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
849 Self::new(client.clone(), model)
850 }
851
852 async fn transcription(
853 &self,
854 request: transcription::TranscriptionRequest,
855 ) -> Result<
856 transcription::TranscriptionResponse<Self::Response>,
857 transcription::TranscriptionError,
858 > {
859 let data = request.data;
860
861 let mut body =
862 MultipartForm::new().part(Part::bytes("file", data).filename(request.filename.clone()));
863
864 if let Some(prompt) = request.prompt {
865 body = body.text("prompt", prompt.clone());
866 }
867
868 if let Some(ref temperature) = request.temperature {
869 body = body.text("temperature", temperature.to_string());
870 }
871
872 if let Some(ref additional_params) = request.additional_params {
873 let params = additional_params.as_object().ok_or_else(|| {
874 TranscriptionError::RequestError(Box::new(std::io::Error::new(
875 std::io::ErrorKind::InvalidInput,
876 "additional transcription parameters must be a JSON object",
877 )))
878 })?;
879
880 for (key, value) in params {
881 body = body.text(key.to_owned(), value.to_string());
882 }
883 }
884
885 let req = self
886 .client
887 .post_transcription(&self.model)?
888 .body(body)
889 .map_err(|e| TranscriptionError::HttpError(e.into()))?;
890
891 let response = self.client.send_multipart::<Bytes>(req).await?;
892 let status = response.status();
893 let response_body = response.into_body().into_future().await?.to_vec();
894
895 if status.is_success() {
896 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
897 ApiResponse::Ok(response) => response.try_into(),
898 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
899 api_error_response.message,
900 )),
901 }
902 } else {
903 Err(TranscriptionError::ProviderError(
904 String::from_utf8_lossy(&response_body).to_string(),
905 ))
906 }
907 }
908}
909
910#[cfg(feature = "image")]
914pub use image_generation::*;
915use tracing::{Instrument, Level, enabled, info_span};
916#[cfg(feature = "image")]
917#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
918mod image_generation {
919 use crate::http_client::HttpClientExt;
920 use crate::image_generation;
921 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
922 use crate::providers::azure::{ApiResponse, Client};
923 use crate::providers::openai::ImageGenerationResponse;
924 use bytes::Bytes;
925 use serde_json::json;
926
927 #[derive(Clone)]
928 pub struct ImageGenerationModel<T = reqwest::Client> {
929 client: Client<T>,
930 pub model: String,
931 }
932
933 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
934 where
935 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
936 {
937 type Response = ImageGenerationResponse;
938
939 type Client = Client<T>;
940
941 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
942 Self {
943 client: client.clone(),
944 model: model.into(),
945 }
946 }
947
948 async fn image_generation(
949 &self,
950 generation_request: ImageGenerationRequest,
951 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
952 {
953 let request = json!({
954 "model": self.model,
955 "prompt": generation_request.prompt,
956 "size": format!("{}x{}", generation_request.width, generation_request.height),
957 "response_format": "b64_json"
958 });
959
960 let body = serde_json::to_vec(&request)?;
961
962 let req = self
963 .client
964 .post_image_generation(&self.model)?
965 .body(body)
966 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
967
968 let response = self.client.send::<_, Bytes>(req).await?;
969 let status = response.status();
970 let response_body = response.into_body().into_future().await?.to_vec();
971
972 if !status.is_success() {
973 return Err(ImageGenerationError::ProviderError(format!(
974 "{status}: {}",
975 String::from_utf8_lossy(&response_body)
976 )));
977 }
978
979 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
980 ApiResponse::Ok(response) => response.try_into(),
981 ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
982 }
983 }
984 }
985}
986#[cfg(feature = "audio")]
991pub use audio_generation::*;
992
993#[cfg(feature = "audio")]
994#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
995mod audio_generation {
996 use super::Client;
997 use crate::audio_generation::{
998 self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
999 };
1000 use crate::http_client::HttpClientExt;
1001 use bytes::Bytes;
1002 use serde_json::json;
1003
1004 #[derive(Clone)]
1005 pub struct AudioGenerationModel<T = reqwest::Client> {
1006 client: Client<T>,
1007 model: String,
1008 }
1009
1010 impl<T> AudioGenerationModel<T> {
1011 pub fn new(client: Client<T>, deployment_name: impl Into<String>) -> Self {
1012 Self {
1013 client,
1014 model: deployment_name.into(),
1015 }
1016 }
1017 }
1018
1019 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
1020 where
1021 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
1022 {
1023 type Response = Bytes;
1024 type Client = Client<T>;
1025
1026 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1027 Self::new(client.clone(), model)
1028 }
1029
1030 async fn audio_generation(
1031 &self,
1032 request: AudioGenerationRequest,
1033 ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
1034 let request = json!({
1035 "model": self.model,
1036 "input": request.text,
1037 "voice": request.voice,
1038 "speed": request.speed,
1039 });
1040
1041 let body = serde_json::to_vec(&request)?;
1042
1043 let req = self
1044 .client
1045 .post_audio_generation("/audio/speech")?
1046 .header("Content-Type", "application/json")
1047 .body(body)
1048 .map_err(|e| AudioGenerationError::HttpError(e.into()))?;
1049
1050 let response = self.client.send::<_, Bytes>(req).await?;
1051 let status = response.status();
1052 let response_body = response.into_body().into_future().await?;
1053
1054 if !status.is_success() {
1055 return Err(AudioGenerationError::ProviderError(format!(
1056 "{status}: {}",
1057 String::from_utf8_lossy(&response_body)
1058 )));
1059 }
1060
1061 Ok(AudioGenerationResponse {
1062 audio: response_body.to_vec(),
1063 response: response_body,
1064 })
1065 }
1066 }
1067}
1068
1069#[cfg(test)]
1070mod azure_tests {
1071 use schemars::JsonSchema;
1072
1073 use super::*;
1074
1075 use crate::OneOrMany;
1076 use crate::client::{completion::CompletionClient, embeddings::EmbeddingsClient};
1077 use crate::completion::CompletionModel;
1078 use crate::embeddings::EmbeddingModel;
1079 use crate::prelude::TypedPrompt;
1080 use crate::providers::openai::GPT_5_MINI;
1081
1082 #[tokio::test]
1083 #[ignore]
1084 async fn test_azure_embedding() -> anyhow::Result<()> {
1085 let _ = tracing_subscriber::fmt::try_init();
1086
1087 let client = Client::from_env()?;
1088 let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1089 let embeddings = model.embed_texts(vec!["Hello, world!".to_string()]).await?;
1090
1091 tracing::info!("Azure embedding: {:?}", embeddings);
1092 Ok(())
1093 }
1094
1095 #[tokio::test]
1096 #[ignore]
1097 async fn test_azure_embedding_dimensions() -> anyhow::Result<()> {
1098 let _ = tracing_subscriber::fmt::try_init();
1099
1100 let ndims = 256;
1101 let client = Client::from_env()?;
1102 let model = client.embedding_model_with_ndims(TEXT_EMBEDDING_3_SMALL, ndims);
1103 let embedding = model.embed_text("Hello, world!").await?;
1104
1105 anyhow::ensure!(
1106 embedding.vec.len() == ndims,
1107 "expected embedding dimensions {ndims}, got {}",
1108 embedding.vec.len()
1109 );
1110
1111 tracing::info!("Azure dimensions embedding: {:?}", embedding);
1112 Ok(())
1113 }
1114
1115 #[tokio::test]
1116 #[ignore]
1117 async fn test_azure_completion() -> anyhow::Result<()> {
1118 let _ = tracing_subscriber::fmt::try_init();
1119
1120 let client = Client::from_env()?;
1121 let model = client.completion_model(GPT_4O_MINI);
1122 let completion = model
1123 .completion(CompletionRequest {
1124 model: None,
1125 preamble: Some("You are a helpful assistant.".to_string()),
1126 chat_history: OneOrMany::one("Hello!".into()),
1127 documents: vec![],
1128 max_tokens: Some(100),
1129 temperature: Some(0.0),
1130 tools: vec![],
1131 tool_choice: None,
1132 additional_params: None,
1133 output_schema: None,
1134 })
1135 .await?;
1136
1137 tracing::info!("Azure completion: {:?}", completion);
1138 Ok(())
1139 }
1140
1141 #[tokio::test]
1142 #[ignore]
1143 async fn test_azure_structured_output() -> anyhow::Result<()> {
1144 let _ = tracing_subscriber::fmt::try_init();
1145
1146 #[derive(Debug, Deserialize, JsonSchema)]
1147 struct Person {
1148 name: String,
1149 age: u32,
1150 }
1151
1152 let client = Client::from_env()?;
1153 let agent = client
1154 .agent(GPT_5_MINI)
1155 .preamble("You are a helpful assistant that extracts personal details.")
1156 .max_tokens(100)
1157 .output_schema::<Person>()
1158 .build();
1159
1160 let result: Person = agent
1161 .prompt_typed("Hello! My name is John Doe and I'm 54 years old.")
1162 .await?;
1163
1164 anyhow::ensure!(
1165 result.name == "John Doe",
1166 "expected name John Doe, got {}",
1167 result.name
1168 );
1169 anyhow::ensure!(result.age == 54, "expected age 54, got {}", result.age);
1170
1171 tracing::info!("Extracted person: {:?}", result);
1172 Ok(())
1173 }
1174
1175 #[tokio::test]
1176 async fn test_client_initialization() {
1177 let _client = crate::providers::azure::Client::builder()
1178 .api_key("test")
1179 .azure_endpoint("test".to_string()) .build()
1181 .expect("Client::builder() failed");
1182 }
1183}