1use std::fmt::Debug;
23
24use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
25use crate::client::{
26 self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
27 ProviderClient,
28};
29use crate::completion::GetTokenUsage;
30use crate::http_client::multipart::Part;
31use crate::http_client::{self, HttpClientExt, MultipartForm, bearer_auth_header};
32use crate::streaming::StreamingCompletionResponse;
33use crate::transcription::TranscriptionError;
34use crate::{
35 completion::{self, CompletionError, CompletionRequest},
36 embeddings::{self, EmbeddingError},
37 json_utils,
38 providers::openai,
39 telemetry::SpanCombinator,
40 transcription::{self},
41};
42use bytes::Bytes;
43use serde::{Deserialize, Serialize};
44use serde_json::json;
45const DEFAULT_API_VERSION: &str = "2024-10-21";
50
51#[derive(Debug, Clone)]
52pub struct AzureExt {
53 endpoint: String,
54 api_version: String,
55}
56
57impl DebugExt for AzureExt {
58 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
59 [
60 ("endpoint", (&self.endpoint as &dyn Debug)),
61 ("api_version", (&self.api_version as &dyn Debug)),
62 ]
63 .into_iter()
64 }
65}
66
67#[derive(Debug, Clone)]
72pub struct AzureExtBuilder {
73 endpoint: Option<String>,
74 api_version: String,
75}
76
77impl Default for AzureExtBuilder {
78 fn default() -> Self {
79 Self {
80 endpoint: None,
81 api_version: DEFAULT_API_VERSION.into(),
82 }
83 }
84}
85
86pub type Client<H = reqwest::Client> = client::Client<AzureExt, H>;
87pub type ClientBuilder<H = reqwest::Client> =
88 client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H>;
89
90impl Provider for AzureExt {
91 type Builder = AzureExtBuilder;
92
93 const VERIFY_PATH: &'static str = "";
95}
96
97impl<H> Capabilities<H> for AzureExt {
98 type Completion = Capable<CompletionModel<H>>;
99 type Embeddings = Capable<EmbeddingModel<H>>;
100 type Transcription = Capable<TranscriptionModel<H>>;
101 type ModelListing = Nothing;
102 #[cfg(feature = "image")]
103 type ImageGeneration = Nothing;
104 #[cfg(feature = "audio")]
105 type AudioGeneration = Capable<AudioGenerationModel<H>>;
106}
107
108impl ProviderBuilder for AzureExtBuilder {
109 type Extension<H>
110 = AzureExt
111 where
112 H: HttpClientExt;
113 type ApiKey = AzureOpenAIAuth;
114
115 const BASE_URL: &'static str = "";
116
117 fn build<H>(
118 builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
119 ) -> http_client::Result<Self::Extension<H>>
120 where
121 H: HttpClientExt,
122 {
123 let AzureExtBuilder {
124 endpoint,
125 api_version,
126 ..
127 } = builder.ext().clone();
128
129 match endpoint {
130 Some(endpoint) => Ok(AzureExt {
131 endpoint,
132 api_version,
133 }),
134 None => Err(http_client::Error::Instance(
135 "Azure client must be provided an endpoint prior to building".into(),
136 )),
137 }
138 }
139
140 fn finish<H>(
141 &self,
142 mut builder: client::ClientBuilder<Self, Self::ApiKey, H>,
143 ) -> http_client::Result<client::ClientBuilder<Self, Self::ApiKey, H>> {
144 use AzureOpenAIAuth::*;
145
146 let auth = builder.get_api_key().clone();
147
148 match auth {
149 Token(token) => bearer_auth_header(builder.headers_mut(), token.as_str())?,
150 ApiKey(key) => {
151 let k = http::HeaderName::from_static("api-key");
152 let v = http::HeaderValue::from_str(key.as_str())?;
153
154 builder.headers_mut().insert(k, v);
155 }
156 }
157
158 Ok(builder)
159 }
160}
161
162impl<H> ClientBuilder<H> {
163 pub fn api_version(mut self, api_version: &str) -> Self {
165 self.ext_mut().api_version = api_version.into();
166
167 self
168 }
169}
170
171impl<H> client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H> {
172 pub fn azure_endpoint(self, endpoint: String) -> ClientBuilder<H> {
174 self.over_ext(|AzureExtBuilder { api_version, .. }| AzureExtBuilder {
175 endpoint: Some(endpoint),
176 api_version,
177 })
178 }
179}
180
181#[derive(Clone)]
184pub enum AzureOpenAIAuth {
185 ApiKey(String),
186 Token(String),
187}
188
189impl ApiKey for AzureOpenAIAuth {}
190
191impl std::fmt::Debug for AzureOpenAIAuth {
192 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193 match self {
194 Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
195 Self::Token(_) => write!(f, "Token <REDACTED>"),
196 }
197 }
198}
199
200impl<S> From<S> for AzureOpenAIAuth
201where
202 S: Into<String>,
203{
204 fn from(token: S) -> Self {
205 AzureOpenAIAuth::Token(token.into())
206 }
207}
208
209impl<T> Client<T>
210where
211 T: HttpClientExt,
212{
213 fn endpoint(&self) -> &str {
214 &self.ext().endpoint
215 }
216
217 fn api_version(&self) -> &str {
218 &self.ext().api_version
219 }
220
221 fn post_embedding(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
222 let url = format!(
223 "{}/openai/deployments/{}/embeddings?api-version={}",
224 self.endpoint(),
225 deployment_id.trim_start_matches('/'),
226 self.api_version()
227 );
228
229 self.post(&url)
230 }
231
232 #[cfg(feature = "audio")]
233 fn post_audio_generation(
234 &self,
235 deployment_id: &str,
236 ) -> http_client::Result<http_client::Builder> {
237 let url = format!(
238 "{}/openai/deployments/{}/audio/speech?api-version={}",
239 self.endpoint(),
240 deployment_id.trim_start_matches('/'),
241 self.api_version()
242 );
243
244 self.post(url)
245 }
246
247 fn post_chat_completion(
248 &self,
249 deployment_id: &str,
250 ) -> http_client::Result<http_client::Builder> {
251 let url = format!(
252 "{}/openai/deployments/{}/chat/completions?api-version={}",
253 self.endpoint(),
254 deployment_id.trim_start_matches('/'),
255 self.api_version()
256 );
257
258 self.post(&url)
259 }
260
261 fn post_transcription(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
262 let url = format!(
263 "{}/openai/deployments/{}/audio/translations?api-version={}",
264 self.endpoint(),
265 deployment_id.trim_start_matches('/'),
266 self.api_version()
267 );
268
269 self.post(&url)
270 }
271
272 #[cfg(feature = "image")]
273 fn post_image_generation(
274 &self,
275 deployment_id: &str,
276 ) -> http_client::Result<http_client::Builder> {
277 let url = format!(
278 "{}/openai/deployments/{}/images/generations?api-version={}",
279 self.endpoint(),
280 deployment_id.trim_start_matches('/'),
281 self.api_version()
282 );
283
284 self.post(&url)
285 }
286}
287
288pub struct AzureOpenAIClientParams {
289 api_key: String,
290 version: String,
291 header: String,
292}
293
294impl ProviderClient for Client {
295 type Input = AzureOpenAIClientParams;
296
297 fn from_env() -> Self {
299 let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
300 AzureOpenAIAuth::ApiKey(api_key)
301 } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
302 AzureOpenAIAuth::Token(token)
303 } else {
304 panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
305 };
306
307 let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
308 let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
309
310 Self::builder()
311 .api_key(auth)
312 .azure_endpoint(azure_endpoint)
313 .api_version(&api_version)
314 .build()
315 .unwrap()
316 }
317
318 fn from_val(
319 AzureOpenAIClientParams {
320 api_key,
321 version,
322 header,
323 }: Self::Input,
324 ) -> Self {
325 let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
326
327 Self::builder()
328 .api_key(auth)
329 .azure_endpoint(header)
330 .api_version(&version)
331 .build()
332 .unwrap()
333 }
334}
335
336#[derive(Debug, Deserialize)]
337struct ApiErrorResponse {
338 message: String,
339}
340
341#[derive(Debug, Deserialize)]
342#[serde(untagged)]
343enum ApiResponse<T> {
344 Ok(T),
345 Err(ApiErrorResponse),
346}
347
348pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
354pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
356pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
358
359fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
360 match identifier {
361 TEXT_EMBEDDING_3_LARGE => Some(3_072),
362 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
363 _ => None,
364 }
365}
366
367#[derive(Debug, Deserialize)]
368pub struct EmbeddingResponse {
369 pub object: String,
370 pub data: Vec<EmbeddingData>,
371 pub model: String,
372 pub usage: Usage,
373}
374
375impl From<ApiErrorResponse> for EmbeddingError {
376 fn from(err: ApiErrorResponse) -> Self {
377 EmbeddingError::ProviderError(err.message)
378 }
379}
380
381impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
382 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
383 match value {
384 ApiResponse::Ok(response) => Ok(response),
385 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
386 }
387 }
388}
389
390#[derive(Debug, Deserialize)]
391pub struct EmbeddingData {
392 pub object: String,
393 pub embedding: Vec<f64>,
394 pub index: usize,
395}
396
397#[derive(Clone, Debug, Deserialize)]
398pub struct Usage {
399 pub prompt_tokens: usize,
400 pub total_tokens: usize,
401}
402
403impl GetTokenUsage for Usage {
404 fn token_usage(&self) -> Option<crate::completion::Usage> {
405 let mut usage = crate::completion::Usage::new();
406
407 usage.input_tokens = self.prompt_tokens as u64;
408 usage.total_tokens = self.total_tokens as u64;
409 usage.output_tokens = usage.total_tokens - usage.input_tokens;
410
411 Some(usage)
412 }
413}
414
415impl std::fmt::Display for Usage {
416 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417 write!(
418 f,
419 "Prompt tokens: {} Total tokens: {}",
420 self.prompt_tokens, self.total_tokens
421 )
422 }
423}
424
425#[derive(Clone)]
426pub struct EmbeddingModel<T = reqwest::Client> {
427 client: Client<T>,
428 pub model: String,
429 ndims: usize,
430}
431
432impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
433where
434 T: HttpClientExt + Default + Clone + 'static,
435{
436 const MAX_DOCUMENTS: usize = 1024;
437
438 type Client = Client<T>;
439
440 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
441 Self::new(client.clone(), model, dims)
442 }
443
444 fn ndims(&self) -> usize {
445 self.ndims
446 }
447
448 async fn embed_texts(
449 &self,
450 documents: impl IntoIterator<Item = String>,
451 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
452 let documents = documents.into_iter().collect::<Vec<_>>();
453
454 let mut body = json!({
455 "input": documents,
456 });
457
458 if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
459 body["dimensions"] = json!(self.ndims);
460 }
461
462 let body = serde_json::to_vec(&body)?;
463
464 let req = self
465 .client
466 .post_embedding(self.model.as_str())?
467 .body(body)
468 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
469
470 let response = self.client.send(req).await?;
471
472 if response.status().is_success() {
473 let body: Vec<u8> = response.into_body().await?;
474 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
475
476 match body {
477 ApiResponse::Ok(response) => {
478 tracing::info!(target: "rig",
479 "Azure embedding token usage: {}",
480 response.usage
481 );
482
483 if response.data.len() != documents.len() {
484 return Err(EmbeddingError::ResponseError(
485 "Response data length does not match input length".into(),
486 ));
487 }
488
489 Ok(response
490 .data
491 .into_iter()
492 .zip(documents.into_iter())
493 .map(|(embedding, document)| embeddings::Embedding {
494 document,
495 vec: embedding.embedding,
496 })
497 .collect())
498 }
499 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
500 }
501 } else {
502 let text = http_client::text(response).await?;
503 Err(EmbeddingError::ProviderError(text))
504 }
505 }
506}
507
508impl<T> EmbeddingModel<T> {
509 pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
510 let model = model.into();
511 let ndims = ndims
512 .or(model_dimensions_from_identifier(&model))
513 .unwrap_or_default();
514
515 Self {
516 client,
517 model,
518 ndims,
519 }
520 }
521
522 pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
523 let ndims = ndims.unwrap_or_default();
524
525 Self {
526 client,
527 model: model.into(),
528 ndims,
529 }
530 }
531}
532
533pub const O1: &str = "o1";
539pub const O1_PREVIEW: &str = "o1-preview";
541pub const O1_MINI: &str = "o1-mini";
543pub const GPT_4O: &str = "gpt-4o";
545pub const GPT_4O_MINI: &str = "gpt-4o-mini";
547pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
549pub const GPT_4_TURBO: &str = "gpt-4";
551pub const GPT_4: &str = "gpt-4";
553pub const GPT_4_32K: &str = "gpt-4-32k";
555pub const GPT_4_32K_0613: &str = "gpt-4-32k";
557pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
559pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
561pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
563
564#[derive(Debug, Serialize, Deserialize)]
565pub(super) struct AzureOpenAICompletionRequest {
566 model: String,
567 pub messages: Vec<openai::Message>,
568 #[serde(skip_serializing_if = "Option::is_none")]
569 temperature: Option<f64>,
570 #[serde(skip_serializing_if = "Vec::is_empty")]
571 tools: Vec<openai::ToolDefinition>,
572 #[serde(skip_serializing_if = "Option::is_none")]
573 tool_choice: Option<crate::providers::openai::ToolChoice>,
574 #[serde(flatten, skip_serializing_if = "Option::is_none")]
575 pub additional_params: Option<serde_json::Value>,
576}
577
578impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest {
579 type Error = CompletionError;
580
581 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
582 let model = req.model.clone().unwrap_or_else(|| model.to_string());
583 if req.tool_choice.is_some() {
585 tracing::warn!(
586 "Tool choice is currently not supported in Azure OpenAI. This should be fixed by Rig 0.25."
587 );
588 }
589
590 let mut full_history: Vec<openai::Message> = match &req.preamble {
591 Some(preamble) => vec![openai::Message::system(preamble)],
592 None => vec![],
593 };
594
595 if let Some(docs) = req.normalized_documents() {
596 let docs: Vec<openai::Message> = docs.try_into()?;
597 full_history.extend(docs);
598 }
599
600 let chat_history: Vec<openai::Message> = req
601 .chat_history
602 .clone()
603 .into_iter()
604 .map(|message| message.try_into())
605 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
606 .into_iter()
607 .flatten()
608 .collect();
609
610 full_history.extend(chat_history);
611
612 let tool_choice = req
613 .tool_choice
614 .clone()
615 .map(crate::providers::openai::ToolChoice::try_from)
616 .transpose()?;
617
618 let additional_params = if let Some(schema) = req.output_schema {
619 let name = schema
620 .as_object()
621 .and_then(|o| o.get("title"))
622 .and_then(|v| v.as_str())
623 .unwrap_or("response_schema")
624 .to_string();
625 let mut schema_value = schema.to_value();
626 openai::sanitize_schema(&mut schema_value);
627 let response_format = serde_json::json!({
628 "response_format": {
629 "type": "json_schema",
630 "json_schema": {
631 "name": name,
632 "strict": true,
633 "schema": schema_value
634 }
635 }
636 });
637 Some(match req.additional_params {
638 Some(existing) => json_utils::merge(existing, response_format),
639 None => response_format,
640 })
641 } else {
642 req.additional_params
643 };
644
645 Ok(Self {
646 model: model.to_string(),
647 messages: full_history,
648 temperature: req.temperature,
649 tools: req
650 .tools
651 .clone()
652 .into_iter()
653 .map(openai::ToolDefinition::from)
654 .collect::<Vec<_>>(),
655 tool_choice,
656 additional_params,
657 })
658 }
659}
660
661#[derive(Clone)]
662pub struct CompletionModel<T = reqwest::Client> {
663 client: Client<T>,
664 pub model: String,
666}
667
668impl<T> CompletionModel<T> {
669 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
670 Self {
671 client,
672 model: model.into(),
673 }
674 }
675}
676
677impl<T> completion::CompletionModel for CompletionModel<T>
678where
679 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
680{
681 type Response = openai::CompletionResponse;
682 type StreamingResponse = openai::StreamingCompletionResponse;
683 type Client = Client<T>;
684
685 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
686 Self::new(client.clone(), model.into())
687 }
688
689 async fn completion(
690 &self,
691 completion_request: CompletionRequest,
692 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
693 let span = if tracing::Span::current().is_disabled() {
694 info_span!(
695 target: "rig::completions",
696 "chat",
697 gen_ai.operation.name = "chat",
698 gen_ai.provider.name = "azure.openai",
699 gen_ai.request.model = self.model,
700 gen_ai.system_instructions = &completion_request.preamble,
701 gen_ai.response.id = tracing::field::Empty,
702 gen_ai.response.model = tracing::field::Empty,
703 gen_ai.usage.output_tokens = tracing::field::Empty,
704 gen_ai.usage.input_tokens = tracing::field::Empty,
705 gen_ai.usage.cached_tokens = tracing::field::Empty,
706 )
707 } else {
708 tracing::Span::current()
709 };
710
711 let request =
712 AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
713
714 if enabled!(Level::TRACE) {
715 tracing::trace!(target: "rig::completions",
716 "Azure OpenAI completion request: {}",
717 serde_json::to_string_pretty(&request)?
718 );
719 }
720
721 let body = serde_json::to_vec(&request)?;
722
723 let req = self
724 .client
725 .post_chat_completion(&self.model)?
726 .body(body)
727 .map_err(http_client::Error::from)?;
728
729 async move {
730 let response = self.client.send::<_, Bytes>(req).await?;
731
732 let status = response.status();
733 let response_body = response.into_body().into_future().await?.to_vec();
734
735 if status.is_success() {
736 match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
737 &response_body,
738 )? {
739 ApiResponse::Ok(response) => {
740 let span = tracing::Span::current();
741 span.record_response_metadata(&response);
742 span.record_token_usage(&response.usage);
743 if enabled!(Level::TRACE) {
744 tracing::trace!(target: "rig::completions",
745 "Azure OpenAI completion response: {}",
746 serde_json::to_string_pretty(&response)?
747 );
748 }
749 response.try_into()
750 }
751 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
752 }
753 } else {
754 Err(CompletionError::ProviderError(
755 String::from_utf8_lossy(&response_body).to_string(),
756 ))
757 }
758 }
759 .instrument(span)
760 .await
761 }
762
763 async fn stream(
764 &self,
765 completion_request: CompletionRequest,
766 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
767 let preamble = completion_request.preamble.clone();
768 let mut request =
769 AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
770
771 let params = json_utils::merge(
772 request.additional_params.unwrap_or(serde_json::json!({})),
773 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
774 );
775
776 request.additional_params = Some(params);
777
778 if enabled!(Level::TRACE) {
779 tracing::trace!(target: "rig::completions",
780 "Azure OpenAI completion request: {}",
781 serde_json::to_string_pretty(&request)?
782 );
783 }
784
785 let body = serde_json::to_vec(&request)?;
786
787 let req = self
788 .client
789 .post_chat_completion(&self.model)?
790 .body(body)
791 .map_err(http_client::Error::from)?;
792
793 let span = if tracing::Span::current().is_disabled() {
794 info_span!(
795 target: "rig::completions",
796 "chat_streaming",
797 gen_ai.operation.name = "chat_streaming",
798 gen_ai.provider.name = "azure.openai",
799 gen_ai.request.model = self.model,
800 gen_ai.system_instructions = &preamble,
801 gen_ai.response.id = tracing::field::Empty,
802 gen_ai.response.model = tracing::field::Empty,
803 gen_ai.usage.output_tokens = tracing::field::Empty,
804 gen_ai.usage.input_tokens = tracing::field::Empty,
805 gen_ai.usage.cached_tokens = tracing::field::Empty,
806 )
807 } else {
808 tracing::Span::current()
809 };
810
811 tracing_futures::Instrument::instrument(
812 send_compatible_streaming_request(self.client.clone(), req),
813 span,
814 )
815 .await
816 }
817}
818
819#[derive(Clone)]
824pub struct TranscriptionModel<T = reqwest::Client> {
825 client: Client<T>,
826 pub model: String,
828}
829
830impl<T> TranscriptionModel<T> {
831 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
832 Self {
833 client,
834 model: model.into(),
835 }
836 }
837}
838
839impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
840where
841 T: HttpClientExt + Clone + 'static,
842{
843 type Response = TranscriptionResponse;
844 type Client = Client<T>;
845
846 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
847 Self::new(client.clone(), model)
848 }
849
850 async fn transcription(
851 &self,
852 request: transcription::TranscriptionRequest,
853 ) -> Result<
854 transcription::TranscriptionResponse<Self::Response>,
855 transcription::TranscriptionError,
856 > {
857 let data = request.data;
858
859 let mut body =
860 MultipartForm::new().part(Part::bytes("file", data).filename(request.filename.clone()));
861
862 if let Some(prompt) = request.prompt {
863 body = body.text("prompt", prompt.clone());
864 }
865
866 if let Some(ref temperature) = request.temperature {
867 body = body.text("temperature", temperature.to_string());
868 }
869
870 if let Some(ref additional_params) = request.additional_params {
871 for (key, value) in additional_params
872 .as_object()
873 .expect("Additional Parameters to OpenAI Transcription should be a map")
874 {
875 body = body.text(key.to_owned(), value.to_string());
876 }
877 }
878
879 let req = self
880 .client
881 .post_transcription(&self.model)?
882 .body(body)
883 .map_err(|e| TranscriptionError::HttpError(e.into()))?;
884
885 let response = self.client.send_multipart::<Bytes>(req).await?;
886 let status = response.status();
887 let response_body = response.into_body().into_future().await?.to_vec();
888
889 if status.is_success() {
890 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
891 ApiResponse::Ok(response) => response.try_into(),
892 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
893 api_error_response.message,
894 )),
895 }
896 } else {
897 Err(TranscriptionError::ProviderError(
898 String::from_utf8_lossy(&response_body).to_string(),
899 ))
900 }
901 }
902}
903
904#[cfg(feature = "image")]
908pub use image_generation::*;
909use tracing::{Instrument, Level, enabled, info_span};
910#[cfg(feature = "image")]
911#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
912mod image_generation {
913 use crate::http_client::HttpClientExt;
914 use crate::image_generation;
915 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
916 use crate::providers::azure::{ApiResponse, Client};
917 use crate::providers::openai::ImageGenerationResponse;
918 use bytes::Bytes;
919 use serde_json::json;
920
921 #[derive(Clone)]
922 pub struct ImageGenerationModel<T = reqwest::Client> {
923 client: Client<T>,
924 pub model: String,
925 }
926
927 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
928 where
929 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
930 {
931 type Response = ImageGenerationResponse;
932
933 type Client = Client<T>;
934
935 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
936 Self {
937 client: client.clone(),
938 model: model.into(),
939 }
940 }
941
942 async fn image_generation(
943 &self,
944 generation_request: ImageGenerationRequest,
945 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
946 {
947 let request = json!({
948 "model": self.model,
949 "prompt": generation_request.prompt,
950 "size": format!("{}x{}", generation_request.width, generation_request.height),
951 "response_format": "b64_json"
952 });
953
954 let body = serde_json::to_vec(&request)?;
955
956 let req = self
957 .client
958 .post_image_generation(&self.model)?
959 .body(body)
960 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
961
962 let response = self.client.send::<_, Bytes>(req).await?;
963 let status = response.status();
964 let response_body = response.into_body().into_future().await?.to_vec();
965
966 if !status.is_success() {
967 return Err(ImageGenerationError::ProviderError(format!(
968 "{status}: {}",
969 String::from_utf8_lossy(&response_body)
970 )));
971 }
972
973 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
974 ApiResponse::Ok(response) => response.try_into(),
975 ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
976 }
977 }
978 }
979}
980#[cfg(feature = "audio")]
985pub use audio_generation::*;
986
987#[cfg(feature = "audio")]
988#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
989mod audio_generation {
990 use super::Client;
991 use crate::audio_generation::{
992 self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
993 };
994 use crate::http_client::HttpClientExt;
995 use bytes::Bytes;
996 use serde_json::json;
997
998 #[derive(Clone)]
999 pub struct AudioGenerationModel<T = reqwest::Client> {
1000 client: Client<T>,
1001 model: String,
1002 }
1003
1004 impl<T> AudioGenerationModel<T> {
1005 pub fn new(client: Client<T>, deployment_name: impl Into<String>) -> Self {
1006 Self {
1007 client,
1008 model: deployment_name.into(),
1009 }
1010 }
1011 }
1012
1013 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
1014 where
1015 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
1016 {
1017 type Response = Bytes;
1018 type Client = Client<T>;
1019
1020 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1021 Self::new(client.clone(), model)
1022 }
1023
1024 async fn audio_generation(
1025 &self,
1026 request: AudioGenerationRequest,
1027 ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
1028 let request = json!({
1029 "model": self.model,
1030 "input": request.text,
1031 "voice": request.voice,
1032 "speed": request.speed,
1033 });
1034
1035 let body = serde_json::to_vec(&request)?;
1036
1037 let req = self
1038 .client
1039 .post_audio_generation("/audio/speech")?
1040 .header("Content-Type", "application/json")
1041 .body(body)
1042 .map_err(|e| AudioGenerationError::HttpError(e.into()))?;
1043
1044 let response = self.client.send::<_, Bytes>(req).await?;
1045 let status = response.status();
1046 let response_body = response.into_body().into_future().await?;
1047
1048 if !status.is_success() {
1049 return Err(AudioGenerationError::ProviderError(format!(
1050 "{status}: {}",
1051 String::from_utf8_lossy(&response_body)
1052 )));
1053 }
1054
1055 Ok(AudioGenerationResponse {
1056 audio: response_body.to_vec(),
1057 response: response_body,
1058 })
1059 }
1060 }
1061}
1062
1063#[cfg(test)]
1064mod azure_tests {
1065 use schemars::JsonSchema;
1066
1067 use super::*;
1068
1069 use crate::OneOrMany;
1070 use crate::client::{completion::CompletionClient, embeddings::EmbeddingsClient};
1071 use crate::completion::CompletionModel;
1072 use crate::embeddings::EmbeddingModel;
1073 use crate::prelude::TypedPrompt;
1074 use crate::providers::openai::GPT_5_MINI;
1075
1076 #[tokio::test]
1077 #[ignore]
1078 async fn test_azure_embedding() {
1079 let _ = tracing_subscriber::fmt::try_init();
1080
1081 let client = Client::from_env();
1082 let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1083 let embeddings = model
1084 .embed_texts(vec!["Hello, world!".to_string()])
1085 .await
1086 .unwrap();
1087
1088 tracing::info!("Azure embedding: {:?}", embeddings);
1089 }
1090
1091 #[tokio::test]
1092 #[ignore]
1093 async fn test_azure_embedding_dimensions() {
1094 let _ = tracing_subscriber::fmt::try_init();
1095
1096 let ndims = 256;
1097 let client = Client::from_env();
1098 let model = client.embedding_model_with_ndims(TEXT_EMBEDDING_3_SMALL, ndims);
1099 let embedding = model.embed_text("Hello, world!").await.unwrap();
1100
1101 assert!(embedding.vec.len() == ndims);
1102
1103 tracing::info!("Azure dimensions embedding: {:?}", embedding);
1104 }
1105
1106 #[tokio::test]
1107 #[ignore]
1108 async fn test_azure_completion() {
1109 let _ = tracing_subscriber::fmt::try_init();
1110
1111 let client = Client::from_env();
1112 let model = client.completion_model(GPT_4O_MINI);
1113 let completion = model
1114 .completion(CompletionRequest {
1115 model: None,
1116 preamble: Some("You are a helpful assistant.".to_string()),
1117 chat_history: OneOrMany::one("Hello!".into()),
1118 documents: vec![],
1119 max_tokens: Some(100),
1120 temperature: Some(0.0),
1121 tools: vec![],
1122 tool_choice: None,
1123 additional_params: None,
1124 output_schema: None,
1125 })
1126 .await
1127 .unwrap();
1128
1129 tracing::info!("Azure completion: {:?}", completion);
1130 }
1131
1132 #[tokio::test]
1133 #[ignore]
1134 async fn test_azure_structured_output() {
1135 let _ = tracing_subscriber::fmt::try_init();
1136
1137 #[derive(Debug, Deserialize, JsonSchema)]
1138 struct Person {
1139 name: String,
1140 age: u32,
1141 }
1142
1143 let client = Client::from_env();
1144 let agent = client
1145 .agent(GPT_5_MINI)
1146 .preamble("You are a helpful assistant that extracts personal details.")
1147 .max_tokens(100)
1148 .output_schema::<Person>()
1149 .build();
1150
1151 let result: Person = agent
1152 .prompt_typed("Hello! My name is John Doe and I'm 54 years old.")
1153 .await
1154 .expect("failed to extract person");
1155
1156 assert!(result.name == "John Doe");
1157 assert!(result.age == 54);
1158
1159 tracing::info!("Extracted person: {:?}", result);
1160 }
1161
1162 #[tokio::test]
1163 async fn test_client_initialization() {
1164 let _client = crate::providers::azure::Client::builder()
1165 .api_key("test")
1166 .azure_endpoint("test".to_string()) .build()
1168 .expect("Client::builder() failed");
1169 }
1170}