1use std::fmt::Debug;
23
24use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
25use crate::client::{
26 self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
27 ProviderClient,
28};
29use crate::completion::GetTokenUsage;
30use crate::http_client::multipart::Part;
31use crate::http_client::{self, HttpClientExt, MultipartForm, bearer_auth_header};
32use crate::streaming::StreamingCompletionResponse;
33use crate::transcription::TranscriptionError;
34use crate::{
35 completion::{self, CompletionError, CompletionRequest},
36 embeddings::{self, EmbeddingError},
37 json_utils,
38 providers::openai,
39 telemetry::SpanCombinator,
40 transcription::{self},
41};
42use bytes::Bytes;
43use serde::{Deserialize, Serialize};
44use serde_json::json;
45const DEFAULT_API_VERSION: &str = "2024-10-21";
50
51#[derive(Debug, Clone)]
52pub struct AzureExt {
53 endpoint: String,
54 api_version: String,
55}
56
57impl DebugExt for AzureExt {
58 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn std::fmt::Debug)> {
59 [
60 ("endpoint", (&self.endpoint as &dyn Debug)),
61 ("api_version", (&self.api_version as &dyn Debug)),
62 ]
63 .into_iter()
64 }
65}
66
67#[derive(Debug, Clone)]
72pub struct AzureExtBuilder {
73 endpoint: Option<String>,
74 api_version: String,
75}
76
77impl Default for AzureExtBuilder {
78 fn default() -> Self {
79 Self {
80 endpoint: None,
81 api_version: DEFAULT_API_VERSION.into(),
82 }
83 }
84}
85
86pub type Client<H = reqwest::Client> = client::Client<AzureExt, H>;
87pub type ClientBuilder<H = reqwest::Client> =
88 client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H>;
89
90impl Provider for AzureExt {
91 type Builder = AzureExtBuilder;
92
93 const VERIFY_PATH: &'static str = "";
95}
96
97impl<H> Capabilities<H> for AzureExt {
98 type Completion = Capable<CompletionModel<H>>;
99 type Embeddings = Capable<EmbeddingModel<H>>;
100 type Transcription = Capable<TranscriptionModel<H>>;
101 type ModelListing = Nothing;
102 #[cfg(feature = "image")]
103 type ImageGeneration = Nothing;
104 #[cfg(feature = "audio")]
105 type AudioGeneration = Capable<AudioGenerationModel<H>>;
106}
107
108impl ProviderBuilder for AzureExtBuilder {
109 type Extension<H>
110 = AzureExt
111 where
112 H: HttpClientExt;
113 type ApiKey = AzureOpenAIAuth;
114
115 const BASE_URL: &'static str = "";
116
117 fn build<H>(
118 builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
119 ) -> http_client::Result<Self::Extension<H>>
120 where
121 H: HttpClientExt,
122 {
123 let AzureExtBuilder {
124 endpoint,
125 api_version,
126 ..
127 } = builder.ext().clone();
128
129 match endpoint {
130 Some(endpoint) => Ok(AzureExt {
131 endpoint,
132 api_version,
133 }),
134 None => Err(http_client::Error::Instance(
135 "Azure client must be provided an endpoint prior to building".into(),
136 )),
137 }
138 }
139
140 fn finish<H>(
141 &self,
142 mut builder: client::ClientBuilder<Self, Self::ApiKey, H>,
143 ) -> http_client::Result<client::ClientBuilder<Self, Self::ApiKey, H>> {
144 use AzureOpenAIAuth::*;
145
146 let auth = builder.get_api_key().clone();
147
148 match auth {
149 Token(token) => bearer_auth_header(builder.headers_mut(), token.as_str())?,
150 ApiKey(key) => {
151 let k = http::HeaderName::from_static("api-key");
152 let v = http::HeaderValue::from_str(key.as_str())?;
153
154 builder.headers_mut().insert(k, v);
155 }
156 }
157
158 Ok(builder)
159 }
160}
161
162impl<H> ClientBuilder<H> {
163 pub fn api_version(mut self, api_version: &str) -> Self {
165 self.ext_mut().api_version = api_version.into();
166
167 self
168 }
169}
170
171impl<H> client::ClientBuilder<AzureExtBuilder, AzureOpenAIAuth, H> {
172 pub fn azure_endpoint(self, endpoint: String) -> ClientBuilder<H> {
174 self.over_ext(|AzureExtBuilder { api_version, .. }| AzureExtBuilder {
175 endpoint: Some(endpoint),
176 api_version,
177 })
178 }
179}
180
181#[derive(Clone)]
184pub enum AzureOpenAIAuth {
185 ApiKey(String),
186 Token(String),
187}
188
189impl ApiKey for AzureOpenAIAuth {}
190
191impl std::fmt::Debug for AzureOpenAIAuth {
192 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193 match self {
194 Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
195 Self::Token(_) => write!(f, "Token <REDACTED>"),
196 }
197 }
198}
199
200impl<S> From<S> for AzureOpenAIAuth
201where
202 S: Into<String>,
203{
204 fn from(token: S) -> Self {
205 AzureOpenAIAuth::Token(token.into())
206 }
207}
208
209impl<T> Client<T>
210where
211 T: HttpClientExt,
212{
213 fn endpoint(&self) -> &str {
214 &self.ext().endpoint
215 }
216
217 fn api_version(&self) -> &str {
218 &self.ext().api_version
219 }
220
221 fn post_embedding(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
222 let url = format!(
223 "{}/openai/deployments/{}/embeddings?api-version={}",
224 self.endpoint(),
225 deployment_id.trim_start_matches('/'),
226 self.api_version()
227 );
228
229 self.post(&url)
230 }
231
232 #[cfg(feature = "audio")]
233 fn post_audio_generation(
234 &self,
235 deployment_id: &str,
236 ) -> http_client::Result<http_client::Builder> {
237 let url = format!(
238 "{}/openai/deployments/{}/audio/speech?api-version={}",
239 self.endpoint(),
240 deployment_id.trim_start_matches('/'),
241 self.api_version()
242 );
243
244 self.post(url)
245 }
246
247 fn post_chat_completion(
248 &self,
249 deployment_id: &str,
250 ) -> http_client::Result<http_client::Builder> {
251 let url = format!(
252 "{}/openai/deployments/{}/chat/completions?api-version={}",
253 self.endpoint(),
254 deployment_id.trim_start_matches('/'),
255 self.api_version()
256 );
257
258 self.post(&url)
259 }
260
261 fn post_transcription(&self, deployment_id: &str) -> http_client::Result<http_client::Builder> {
262 let url = format!(
263 "{}/openai/deployments/{}/audio/translations?api-version={}",
264 self.endpoint(),
265 deployment_id.trim_start_matches('/'),
266 self.api_version()
267 );
268
269 self.post(&url)
270 }
271
272 #[cfg(feature = "image")]
273 fn post_image_generation(
274 &self,
275 deployment_id: &str,
276 ) -> http_client::Result<http_client::Builder> {
277 let url = format!(
278 "{}/openai/deployments/{}/images/generations?api-version={}",
279 self.endpoint(),
280 deployment_id.trim_start_matches('/'),
281 self.api_version()
282 );
283
284 self.post(&url)
285 }
286}
287
288pub struct AzureOpenAIClientParams {
289 api_key: String,
290 version: String,
291 header: String,
292}
293
294impl ProviderClient for Client {
295 type Input = AzureOpenAIClientParams;
296 type Error = crate::client::ProviderClientError;
297
298 fn from_env() -> Result<Self, Self::Error> {
300 let auth = if let Some(api_key) = crate::client::optional_env_var("AZURE_API_KEY")? {
301 AzureOpenAIAuth::ApiKey(api_key)
302 } else if let Some(token) = crate::client::optional_env_var("AZURE_TOKEN")? {
303 AzureOpenAIAuth::Token(token)
304 } else {
305 return Err(crate::client::ProviderClientError::InvalidConfiguration(
306 "either `AZURE_API_KEY` or `AZURE_TOKEN` must be set",
307 ));
308 };
309
310 let api_version = crate::client::required_env_var("AZURE_API_VERSION")?;
311 let azure_endpoint = crate::client::required_env_var("AZURE_ENDPOINT")?;
312
313 Self::builder()
314 .api_key(auth)
315 .azure_endpoint(azure_endpoint)
316 .api_version(&api_version)
317 .build()
318 .map_err(Into::into)
319 }
320
321 fn from_val(
322 AzureOpenAIClientParams {
323 api_key,
324 version,
325 header,
326 }: Self::Input,
327 ) -> Result<Self, Self::Error> {
328 let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
329
330 Self::builder()
331 .api_key(auth)
332 .azure_endpoint(header)
333 .api_version(&version)
334 .build()
335 .map_err(Into::into)
336 }
337}
338
339#[derive(Debug, Deserialize)]
340struct ApiErrorResponse {
341 message: String,
342}
343
344#[derive(Debug, Deserialize)]
345#[serde(untagged)]
346enum ApiResponse<T> {
347 Ok(T),
348 Err(ApiErrorResponse),
349}
350
351pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
357pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
359pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
361
362fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
363 match identifier {
364 TEXT_EMBEDDING_3_LARGE => Some(3_072),
365 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
366 _ => None,
367 }
368}
369
370#[derive(Debug, Deserialize)]
371pub struct EmbeddingResponse {
372 pub object: String,
373 pub data: Vec<EmbeddingData>,
374 pub model: String,
375 pub usage: Usage,
376}
377
378impl From<ApiErrorResponse> for EmbeddingError {
379 fn from(err: ApiErrorResponse) -> Self {
380 EmbeddingError::ProviderError(err.message)
381 }
382}
383
384impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
385 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
386 match value {
387 ApiResponse::Ok(response) => Ok(response),
388 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
389 }
390 }
391}
392
393#[derive(Debug, Deserialize)]
394pub struct EmbeddingData {
395 pub object: String,
396 pub embedding: Vec<f64>,
397 pub index: usize,
398}
399
400#[derive(Clone, Debug, Deserialize)]
401pub struct Usage {
402 pub prompt_tokens: usize,
403 pub total_tokens: usize,
404}
405
406impl GetTokenUsage for Usage {
407 fn token_usage(&self) -> Option<crate::completion::Usage> {
408 let mut usage = crate::completion::Usage::new();
409
410 usage.input_tokens = self.prompt_tokens as u64;
411 usage.total_tokens = self.total_tokens as u64;
412 usage.output_tokens = usage.total_tokens - usage.input_tokens;
413
414 Some(usage)
415 }
416}
417
418impl std::fmt::Display for Usage {
419 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
420 write!(
421 f,
422 "Prompt tokens: {} Total tokens: {}",
423 self.prompt_tokens, self.total_tokens
424 )
425 }
426}
427
428#[derive(Clone)]
429pub struct EmbeddingModel<T = reqwest::Client> {
430 client: Client<T>,
431 pub model: String,
432 ndims: usize,
433}
434
435impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
436where
437 T: HttpClientExt + Default + Clone + 'static,
438{
439 const MAX_DOCUMENTS: usize = 1024;
440
441 type Client = Client<T>;
442
443 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
444 Self::new(client.clone(), model, dims)
445 }
446
447 fn ndims(&self) -> usize {
448 self.ndims
449 }
450
451 async fn embed_texts(
452 &self,
453 documents: impl IntoIterator<Item = String>,
454 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
455 let documents = documents.into_iter().collect::<Vec<_>>();
456
457 let mut body = json!({
458 "input": documents,
459 });
460
461 let body_object = body.as_object_mut().ok_or_else(|| {
462 EmbeddingError::ResponseError("embedding request body must be a JSON object".into())
463 })?;
464
465 if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
466 body_object.insert("dimensions".to_owned(), json!(self.ndims));
467 }
468
469 let body = serde_json::to_vec(&body)?;
470
471 let req = self
472 .client
473 .post_embedding(self.model.as_str())?
474 .body(body)
475 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
476
477 let response = self.client.send(req).await?;
478
479 if response.status().is_success() {
480 let body: Vec<u8> = response.into_body().await?;
481 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
482
483 match body {
484 ApiResponse::Ok(response) => {
485 tracing::info!(target: "rig",
486 "Azure embedding token usage: {}",
487 response.usage
488 );
489
490 if response.data.len() != documents.len() {
491 return Err(EmbeddingError::ResponseError(
492 "Response data length does not match input length".into(),
493 ));
494 }
495
496 Ok(response
497 .data
498 .into_iter()
499 .zip(documents.into_iter())
500 .map(|(embedding, document)| embeddings::Embedding {
501 document,
502 vec: embedding.embedding,
503 })
504 .collect())
505 }
506 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
507 }
508 } else {
509 let text = http_client::text(response).await?;
510 Err(EmbeddingError::ProviderError(text))
511 }
512 }
513}
514
515impl<T> EmbeddingModel<T> {
516 pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
517 let model = model.into();
518 let ndims = ndims
519 .or(model_dimensions_from_identifier(&model))
520 .unwrap_or_default();
521
522 Self {
523 client,
524 model,
525 ndims,
526 }
527 }
528
529 pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
530 let ndims = ndims.unwrap_or_default();
531
532 Self {
533 client,
534 model: model.into(),
535 ndims,
536 }
537 }
538}
539
540pub const O1: &str = "o1";
546pub const O1_PREVIEW: &str = "o1-preview";
548pub const O1_MINI: &str = "o1-mini";
550pub const GPT_4O: &str = "gpt-4o";
552pub const GPT_4O_MINI: &str = "gpt-4o-mini";
554pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
556pub const GPT_4_TURBO: &str = "gpt-4";
558pub const GPT_4: &str = "gpt-4";
560pub const GPT_4_32K: &str = "gpt-4-32k";
562pub const GPT_4_32K_0613: &str = "gpt-4-32k";
564pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
566pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
568pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
570
571#[derive(Debug, Serialize, Deserialize)]
572pub(super) struct AzureOpenAICompletionRequest {
573 model: String,
574 pub messages: Vec<openai::Message>,
575 #[serde(skip_serializing_if = "Option::is_none")]
576 temperature: Option<f64>,
577 #[serde(skip_serializing_if = "Vec::is_empty")]
578 tools: Vec<openai::ToolDefinition>,
579 #[serde(skip_serializing_if = "Option::is_none")]
580 tool_choice: Option<crate::providers::openai::ToolChoice>,
581 #[serde(flatten, skip_serializing_if = "Option::is_none")]
582 pub additional_params: Option<serde_json::Value>,
583}
584
585impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest {
586 type Error = CompletionError;
587
588 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
589 let model = req.model.clone().unwrap_or_else(|| model.to_string());
590 if req.tool_choice.is_some() {
592 tracing::warn!(
593 "Tool choice is currently not supported in Azure OpenAI. This should be fixed by Rig 0.25."
594 );
595 }
596
597 let mut full_history: Vec<openai::Message> = match &req.preamble {
598 Some(preamble) => vec![openai::Message::system(preamble)],
599 None => vec![],
600 };
601
602 if let Some(docs) = req.normalized_documents() {
603 let docs: Vec<openai::Message> = docs.try_into()?;
604 full_history.extend(docs);
605 }
606
607 let chat_history: Vec<openai::Message> = req
608 .chat_history
609 .clone()
610 .into_iter()
611 .map(|message| message.try_into())
612 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
613 .into_iter()
614 .flatten()
615 .collect();
616
617 full_history.extend(chat_history);
618
619 let tool_choice = req
620 .tool_choice
621 .clone()
622 .map(crate::providers::openai::ToolChoice::try_from)
623 .transpose()?;
624
625 let additional_params = if let Some(schema) = req.output_schema {
626 let name = schema
627 .as_object()
628 .and_then(|o| o.get("title"))
629 .and_then(|v| v.as_str())
630 .unwrap_or("response_schema")
631 .to_string();
632 let mut schema_value = schema.to_value();
633 openai::sanitize_schema(&mut schema_value);
634 let response_format = serde_json::json!({
635 "response_format": {
636 "type": "json_schema",
637 "json_schema": {
638 "name": name,
639 "strict": true,
640 "schema": schema_value
641 }
642 }
643 });
644 Some(match req.additional_params {
645 Some(existing) => json_utils::merge(existing, response_format),
646 None => response_format,
647 })
648 } else {
649 req.additional_params
650 };
651
652 Ok(Self {
653 model: model.to_string(),
654 messages: full_history,
655 temperature: req.temperature,
656 tools: req
657 .tools
658 .clone()
659 .into_iter()
660 .map(openai::ToolDefinition::from)
661 .collect::<Vec<_>>(),
662 tool_choice,
663 additional_params,
664 })
665 }
666}
667
668#[derive(Clone)]
669pub struct CompletionModel<T = reqwest::Client> {
670 client: Client<T>,
671 pub model: String,
673}
674
675impl<T> CompletionModel<T> {
676 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
677 Self {
678 client,
679 model: model.into(),
680 }
681 }
682}
683
684impl<T> completion::CompletionModel for CompletionModel<T>
685where
686 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
687{
688 type Response = openai::CompletionResponse;
689 type StreamingResponse = openai::StreamingCompletionResponse;
690 type Client = Client<T>;
691
692 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
693 Self::new(client.clone(), model.into())
694 }
695
696 async fn completion(
697 &self,
698 completion_request: CompletionRequest,
699 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
700 let span = if tracing::Span::current().is_disabled() {
701 info_span!(
702 target: "rig::completions",
703 "chat",
704 gen_ai.operation.name = "chat",
705 gen_ai.provider.name = "azure.openai",
706 gen_ai.request.model = self.model,
707 gen_ai.system_instructions = &completion_request.preamble,
708 gen_ai.response.id = tracing::field::Empty,
709 gen_ai.response.model = tracing::field::Empty,
710 gen_ai.usage.output_tokens = tracing::field::Empty,
711 gen_ai.usage.input_tokens = tracing::field::Empty,
712 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
713 )
714 } else {
715 tracing::Span::current()
716 };
717
718 let request =
719 AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
720
721 if enabled!(Level::TRACE) {
722 tracing::trace!(target: "rig::completions",
723 "Azure OpenAI completion request: {}",
724 serde_json::to_string_pretty(&request)?
725 );
726 }
727
728 let body = serde_json::to_vec(&request)?;
729
730 let req = self
731 .client
732 .post_chat_completion(&self.model)?
733 .body(body)
734 .map_err(http_client::Error::from)?;
735
736 async move {
737 let response = self.client.send::<_, Bytes>(req).await?;
738
739 let status = response.status();
740 let response_body = response.into_body().into_future().await?.to_vec();
741
742 if status.is_success() {
743 match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
744 &response_body,
745 )? {
746 ApiResponse::Ok(response) => {
747 let span = tracing::Span::current();
748 span.record_response_metadata(&response);
749 span.record_token_usage(&response.usage);
750 if enabled!(Level::TRACE) {
751 tracing::trace!(target: "rig::completions",
752 "Azure OpenAI completion response: {}",
753 serde_json::to_string_pretty(&response)?
754 );
755 }
756 response.try_into()
757 }
758 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
759 }
760 } else {
761 Err(CompletionError::ProviderError(
762 String::from_utf8_lossy(&response_body).to_string(),
763 ))
764 }
765 }
766 .instrument(span)
767 .await
768 }
769
770 async fn stream(
771 &self,
772 completion_request: CompletionRequest,
773 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
774 let preamble = completion_request.preamble.clone();
775 let mut request =
776 AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;
777
778 let params = json_utils::merge(
779 request.additional_params.unwrap_or(serde_json::json!({})),
780 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
781 );
782
783 request.additional_params = Some(params);
784
785 if enabled!(Level::TRACE) {
786 tracing::trace!(target: "rig::completions",
787 "Azure OpenAI completion request: {}",
788 serde_json::to_string_pretty(&request)?
789 );
790 }
791
792 let body = serde_json::to_vec(&request)?;
793
794 let req = self
795 .client
796 .post_chat_completion(&self.model)?
797 .body(body)
798 .map_err(http_client::Error::from)?;
799
800 let span = if tracing::Span::current().is_disabled() {
801 info_span!(
802 target: "rig::completions",
803 "chat_streaming",
804 gen_ai.operation.name = "chat_streaming",
805 gen_ai.provider.name = "azure.openai",
806 gen_ai.request.model = self.model,
807 gen_ai.system_instructions = &preamble,
808 gen_ai.response.id = tracing::field::Empty,
809 gen_ai.response.model = tracing::field::Empty,
810 gen_ai.usage.output_tokens = tracing::field::Empty,
811 gen_ai.usage.input_tokens = tracing::field::Empty,
812 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
813 )
814 } else {
815 tracing::Span::current()
816 };
817
818 tracing_futures::Instrument::instrument(
819 send_compatible_streaming_request(self.client.clone(), req),
820 span,
821 )
822 .await
823 }
824}
825
826#[derive(Clone)]
831pub struct TranscriptionModel<T = reqwest::Client> {
832 client: Client<T>,
833 pub model: String,
835}
836
837impl<T> TranscriptionModel<T> {
838 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
839 Self {
840 client,
841 model: model.into(),
842 }
843 }
844}
845
846impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
847where
848 T: HttpClientExt + Clone + 'static,
849{
850 type Response = TranscriptionResponse;
851 type Client = Client<T>;
852
853 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
854 Self::new(client.clone(), model)
855 }
856
857 async fn transcription(
858 &self,
859 request: transcription::TranscriptionRequest,
860 ) -> Result<
861 transcription::TranscriptionResponse<Self::Response>,
862 transcription::TranscriptionError,
863 > {
864 let data = request.data;
865
866 let mut body =
867 MultipartForm::new().part(Part::bytes("file", data).filename(request.filename.clone()));
868
869 if let Some(prompt) = request.prompt {
870 body = body.text("prompt", prompt.clone());
871 }
872
873 if let Some(ref temperature) = request.temperature {
874 body = body.text("temperature", temperature.to_string());
875 }
876
877 if let Some(ref additional_params) = request.additional_params {
878 let params = additional_params.as_object().ok_or_else(|| {
879 TranscriptionError::RequestError(Box::new(std::io::Error::new(
880 std::io::ErrorKind::InvalidInput,
881 "additional transcription parameters must be a JSON object",
882 )))
883 })?;
884
885 for (key, value) in params {
886 body = body.text(key.to_owned(), value.to_string());
887 }
888 }
889
890 let req = self
891 .client
892 .post_transcription(&self.model)?
893 .body(body)
894 .map_err(|e| TranscriptionError::HttpError(e.into()))?;
895
896 let response = self.client.send_multipart::<Bytes>(req).await?;
897 let status = response.status();
898 let response_body = response.into_body().into_future().await?.to_vec();
899
900 if status.is_success() {
901 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
902 ApiResponse::Ok(response) => response.try_into(),
903 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
904 api_error_response.message,
905 )),
906 }
907 } else {
908 Err(TranscriptionError::ProviderError(
909 String::from_utf8_lossy(&response_body).to_string(),
910 ))
911 }
912 }
913}
914
915#[cfg(feature = "image")]
919pub use image_generation::*;
920use tracing::{Instrument, Level, enabled, info_span};
921#[cfg(feature = "image")]
922#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
923mod image_generation {
924 use crate::http_client::HttpClientExt;
925 use crate::image_generation;
926 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
927 use crate::providers::azure::{ApiResponse, Client};
928 use crate::providers::openai::ImageGenerationResponse;
929 use bytes::Bytes;
930 use serde_json::json;
931
932 #[derive(Clone)]
933 pub struct ImageGenerationModel<T = reqwest::Client> {
934 client: Client<T>,
935 pub model: String,
936 }
937
938 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
939 where
940 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
941 {
942 type Response = ImageGenerationResponse;
943
944 type Client = Client<T>;
945
946 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
947 Self {
948 client: client.clone(),
949 model: model.into(),
950 }
951 }
952
953 async fn image_generation(
954 &self,
955 generation_request: ImageGenerationRequest,
956 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
957 {
958 let request = json!({
959 "model": self.model,
960 "prompt": generation_request.prompt,
961 "size": format!("{}x{}", generation_request.width, generation_request.height),
962 "response_format": "b64_json"
963 });
964
965 let body = serde_json::to_vec(&request)?;
966
967 let req = self
968 .client
969 .post_image_generation(&self.model)?
970 .body(body)
971 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
972
973 let response = self.client.send::<_, Bytes>(req).await?;
974 let status = response.status();
975 let response_body = response.into_body().into_future().await?.to_vec();
976
977 if !status.is_success() {
978 return Err(ImageGenerationError::ProviderError(format!(
979 "{status}: {}",
980 String::from_utf8_lossy(&response_body)
981 )));
982 }
983
984 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
985 ApiResponse::Ok(response) => response.try_into(),
986 ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
987 }
988 }
989 }
990}
991#[cfg(feature = "audio")]
996pub use audio_generation::*;
997
998#[cfg(feature = "audio")]
999#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
1000mod audio_generation {
1001 use super::Client;
1002 use crate::audio_generation::{
1003 self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
1004 };
1005 use crate::http_client::HttpClientExt;
1006 use bytes::Bytes;
1007 use serde_json::json;
1008
1009 #[derive(Clone)]
1010 pub struct AudioGenerationModel<T = reqwest::Client> {
1011 client: Client<T>,
1012 model: String,
1013 }
1014
1015 impl<T> AudioGenerationModel<T> {
1016 pub fn new(client: Client<T>, deployment_name: impl Into<String>) -> Self {
1017 Self {
1018 client,
1019 model: deployment_name.into(),
1020 }
1021 }
1022 }
1023
1024 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
1025 where
1026 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
1027 {
1028 type Response = Bytes;
1029 type Client = Client<T>;
1030
1031 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1032 Self::new(client.clone(), model)
1033 }
1034
1035 async fn audio_generation(
1036 &self,
1037 request: AudioGenerationRequest,
1038 ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
1039 let request = json!({
1040 "model": self.model,
1041 "input": request.text,
1042 "voice": request.voice,
1043 "speed": request.speed,
1044 });
1045
1046 let body = serde_json::to_vec(&request)?;
1047
1048 let req = self
1049 .client
1050 .post_audio_generation("/audio/speech")?
1051 .header("Content-Type", "application/json")
1052 .body(body)
1053 .map_err(|e| AudioGenerationError::HttpError(e.into()))?;
1054
1055 let response = self.client.send::<_, Bytes>(req).await?;
1056 let status = response.status();
1057 let response_body = response.into_body().into_future().await?;
1058
1059 if !status.is_success() {
1060 return Err(AudioGenerationError::ProviderError(format!(
1061 "{status}: {}",
1062 String::from_utf8_lossy(&response_body)
1063 )));
1064 }
1065
1066 Ok(AudioGenerationResponse {
1067 audio: response_body.to_vec(),
1068 response: response_body,
1069 })
1070 }
1071 }
1072}
1073
1074#[cfg(test)]
1075mod azure_tests {
1076 use schemars::JsonSchema;
1077
1078 use super::*;
1079
1080 use crate::OneOrMany;
1081 use crate::client::{completion::CompletionClient, embeddings::EmbeddingsClient};
1082 use crate::completion::CompletionModel;
1083 use crate::embeddings::EmbeddingModel;
1084 use crate::prelude::TypedPrompt;
1085 use crate::providers::openai::GPT_5_MINI;
1086
1087 #[tokio::test]
1088 #[ignore]
1089 async fn test_azure_embedding() -> anyhow::Result<()> {
1090 let _ = tracing_subscriber::fmt::try_init();
1091
1092 let client = Client::from_env()?;
1093 let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1094 let embeddings = model.embed_texts(vec!["Hello, world!".to_string()]).await?;
1095
1096 tracing::info!("Azure embedding: {:?}", embeddings);
1097 Ok(())
1098 }
1099
1100 #[tokio::test]
1101 #[ignore]
1102 async fn test_azure_embedding_dimensions() -> anyhow::Result<()> {
1103 let _ = tracing_subscriber::fmt::try_init();
1104
1105 let ndims = 256;
1106 let client = Client::from_env()?;
1107 let model = client.embedding_model_with_ndims(TEXT_EMBEDDING_3_SMALL, ndims);
1108 let embedding = model.embed_text("Hello, world!").await?;
1109
1110 anyhow::ensure!(
1111 embedding.vec.len() == ndims,
1112 "expected embedding dimensions {ndims}, got {}",
1113 embedding.vec.len()
1114 );
1115
1116 tracing::info!("Azure dimensions embedding: {:?}", embedding);
1117 Ok(())
1118 }
1119
1120 #[tokio::test]
1121 #[ignore]
1122 async fn test_azure_completion() -> anyhow::Result<()> {
1123 let _ = tracing_subscriber::fmt::try_init();
1124
1125 let client = Client::from_env()?;
1126 let model = client.completion_model(GPT_4O_MINI);
1127 let completion = model
1128 .completion(CompletionRequest {
1129 model: None,
1130 preamble: Some("You are a helpful assistant.".to_string()),
1131 chat_history: OneOrMany::one("Hello!".into()),
1132 documents: vec![],
1133 max_tokens: Some(100),
1134 temperature: Some(0.0),
1135 tools: vec![],
1136 tool_choice: None,
1137 additional_params: None,
1138 output_schema: None,
1139 })
1140 .await?;
1141
1142 tracing::info!("Azure completion: {:?}", completion);
1143 Ok(())
1144 }
1145
1146 #[tokio::test]
1147 #[ignore]
1148 async fn test_azure_structured_output() -> anyhow::Result<()> {
1149 let _ = tracing_subscriber::fmt::try_init();
1150
1151 #[derive(Debug, Deserialize, JsonSchema)]
1152 struct Person {
1153 name: String,
1154 age: u32,
1155 }
1156
1157 let client = Client::from_env()?;
1158 let agent = client
1159 .agent(GPT_5_MINI)
1160 .preamble("You are a helpful assistant that extracts personal details.")
1161 .max_tokens(100)
1162 .output_schema::<Person>()
1163 .build();
1164
1165 let result: Person = agent
1166 .prompt_typed("Hello! My name is John Doe and I'm 54 years old.")
1167 .await?;
1168
1169 anyhow::ensure!(
1170 result.name == "John Doe",
1171 "expected name John Doe, got {}",
1172 result.name
1173 );
1174 anyhow::ensure!(result.age == 54, "expected age 54, got {}", result.age);
1175
1176 tracing::info!("Extracted person: {:?}", result);
1177 Ok(())
1178 }
1179
1180 #[tokio::test]
1181 async fn test_client_initialization() {
1182 let _client = crate::providers::azure::Client::builder()
1183 .api_key("test")
1184 .azure_endpoint("test".to_string()) .build()
1186 .expect("Client::builder() failed");
1187 }
1188}