1use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
13
14use crate::completion::GetTokenUsage;
15use crate::http_client::{self, HttpClientExt};
16use crate::json_utils::merge;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19 completion::{self, CompletionError, CompletionRequest},
20 embeddings::{self, EmbeddingError},
21 json_utils,
22 providers::openai,
23 telemetry::SpanCombinator,
24 transcription::{self, TranscriptionError},
25};
26use bytes::Bytes;
27use reqwest::header::AUTHORIZATION;
28use reqwest::multipart::Part;
29use serde::Deserialize;
30use serde_json::json;
31const DEFAULT_API_VERSION: &str = "2024-10-21";
36
37pub struct ClientBuilder<'a, T = reqwest::Client> {
38 auth: AzureOpenAIAuth,
39 api_version: Option<&'a str>,
40 azure_endpoint: &'a str,
41 http_client: T,
42}
43
44impl<'a, T> ClientBuilder<'a, T>
45where
46 T: Default,
47{
48 pub fn new(auth: impl Into<AzureOpenAIAuth>, endpoint: &'a str) -> Self {
49 Self {
50 auth: auth.into(),
51 api_version: None,
52 azure_endpoint: endpoint,
53 http_client: Default::default(),
54 }
55 }
56}
57
58impl<'a, T> ClientBuilder<'a, T> {
59 pub fn api_version(mut self, api_version: &'a str) -> Self {
61 self.api_version = Some(api_version);
62 self
63 }
64
65 pub fn azure_endpoint(mut self, azure_endpoint: &'a str) -> Self {
67 self.azure_endpoint = azure_endpoint;
68 self
69 }
70
71 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
72 ClientBuilder {
73 auth: self.auth,
74 api_version: self.api_version,
75 azure_endpoint: self.azure_endpoint,
76 http_client,
77 }
78 }
79
80 pub fn build(self) -> Client<T> {
81 let api_version = self.api_version.unwrap_or(DEFAULT_API_VERSION);
82
83 Client {
84 api_version: api_version.to_string(),
85 azure_endpoint: self.azure_endpoint.to_string(),
86 auth: self.auth,
87 http_client: self.http_client,
88 }
89 }
90}
91
92#[derive(Clone)]
93pub struct Client<T = reqwest::Client> {
94 api_version: String,
95 azure_endpoint: String,
96 auth: AzureOpenAIAuth,
97 http_client: T,
98}
99
100impl<T> std::fmt::Debug for Client<T>
101where
102 T: std::fmt::Debug,
103{
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 f.debug_struct("Client")
106 .field("azure_endpoint", &self.azure_endpoint)
107 .field("http_client", &self.http_client)
108 .field("auth", &"<REDACTED>")
109 .field("api_version", &self.api_version)
110 .finish()
111 }
112}
113
114#[derive(Clone)]
115pub enum AzureOpenAIAuth {
116 ApiKey(String),
117 Token(String),
118}
119
120impl std::fmt::Debug for AzureOpenAIAuth {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 match self {
123 Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
124 Self::Token(_) => write!(f, "Token <REDACTED>"),
125 }
126 }
127}
128
129impl From<String> for AzureOpenAIAuth {
130 fn from(token: String) -> Self {
131 AzureOpenAIAuth::Token(token)
132 }
133}
134
135impl AzureOpenAIAuth {
136 fn as_header(&self) -> (reqwest::header::HeaderName, reqwest::header::HeaderValue) {
137 match self {
138 AzureOpenAIAuth::ApiKey(api_key) => (
139 "api-key".parse().expect("Header value should parse"),
140 api_key.parse().expect("API key should parse"),
141 ),
142 AzureOpenAIAuth::Token(token) => (
143 AUTHORIZATION,
144 format!("Bearer {token}")
145 .parse()
146 .expect("Token should parse"),
147 ),
148 }
149 }
150}
151
152impl<T> Client<T>
153where
154 T: Default,
155{
156 pub fn builder(auth: impl Into<AzureOpenAIAuth>, endpoint: &str) -> ClientBuilder<'_, T> {
167 ClientBuilder::new(auth, endpoint)
168 }
169
170 pub fn new(auth: impl Into<AzureOpenAIAuth>, endpoint: &str) -> Self {
172 Self::builder(auth, endpoint).build()
173 }
174}
175
176impl<T> Client<T>
177where
178 T: HttpClientExt,
179{
180 fn post(&self, url: String) -> http_client::Builder {
181 let (key, value) = self.auth.as_header();
182
183 http_client::Request::post(url).header(key, value)
184 }
185
186 fn post_embedding(&self, deployment_id: &str) -> http_client::Builder {
187 let url = format!(
188 "{}/openai/deployments/{}/embeddings?api-version={}",
189 self.azure_endpoint,
190 deployment_id.trim_start_matches('/'),
191 self.api_version
192 );
193
194 self.post(url)
195 }
196
197 async fn send<U, R>(
198 &self,
199 req: http_client::Request<U>,
200 ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
201 where
202 U: Into<Bytes> + Send,
203 R: From<Bytes> + Send + 'static,
204 {
205 self.http_client.send(req).await
206 }
207}
208
209impl Client<reqwest::Client> {
210 fn reqwest_post(&self, url: String) -> reqwest::RequestBuilder {
211 let (key, val) = self.auth.as_header();
212
213 self.http_client.post(url).header(key, val)
214 }
215
216 #[cfg(feature = "audio")]
217 fn post_audio_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
218 let url = format!(
219 "{}/openai/deployments/{}/audio/speech?api-version={}",
220 self.azure_endpoint, deployment_id, self.api_version
221 )
222 .replace("//", "/");
223
224 self.reqwest_post(url)
225 }
226
227 fn post_chat_completion(&self, deployment_id: &str) -> reqwest::RequestBuilder {
228 let url = format!(
229 "{}/openai/deployments/{}/chat/completions?api-version={}",
230 self.azure_endpoint, deployment_id, self.api_version
231 )
232 .replace("//", "/");
233
234 self.reqwest_post(url)
235 }
236
237 fn post_transcription(&self, deployment_id: &str) -> reqwest::RequestBuilder {
238 let url = format!(
239 "{}/openai/deployments/{}/audio/translations?api-version={}",
240 self.azure_endpoint, deployment_id, self.api_version
241 )
242 .replace("//", "/");
243
244 self.reqwest_post(url)
245 }
246
247 #[cfg(feature = "image")]
248 fn post_image_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
249 let url = format!(
250 "{}/openai/deployments/{}/images/generations?api-version={}",
251 self.azure_endpoint, deployment_id, self.api_version
252 )
253 .replace("//", "/");
254
255 self.reqwest_post(url)
256 }
257}
258
259impl ProviderClient for Client<reqwest::Client> {
260 fn from_env() -> Self {
262 let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
263 AzureOpenAIAuth::ApiKey(api_key)
264 } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
265 AzureOpenAIAuth::Token(token)
266 } else {
267 panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
268 };
269
270 let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
271 let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
272
273 Self::builder(auth, &azure_endpoint)
274 .api_version(&api_version)
275 .build()
276 }
277
278 fn from_val(input: crate::client::ProviderValue) -> Self {
279 let crate::client::ProviderValue::ApiKeyWithVersionAndHeader(api_key, version, header) =
280 input
281 else {
282 panic!("Incorrect provider value type")
283 };
284 let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
285 Self::builder(auth, &header).api_version(&version).build()
286 }
287}
288
289impl CompletionClient for Client<reqwest::Client> {
290 type CompletionModel = CompletionModel<reqwest::Client>;
291
292 fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
304 CompletionModel::new(self.clone(), model)
305 }
306}
307
308impl EmbeddingsClient for Client<reqwest::Client> {
309 type EmbeddingModel = EmbeddingModel<reqwest::Client>;
310
311 fn embedding_model(&self, model: &str) -> EmbeddingModel<reqwest::Client> {
325 let ndims = match model {
326 TEXT_EMBEDDING_3_LARGE => 3072,
327 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
328 _ => 0,
329 };
330 EmbeddingModel::new(self.clone(), model, ndims)
331 }
332
333 fn embedding_model_with_ndims(
345 &self,
346 model: &str,
347 ndims: usize,
348 ) -> EmbeddingModel<reqwest::Client> {
349 EmbeddingModel::new(self.clone(), model, ndims)
350 }
351}
352
353impl TranscriptionClient for Client<reqwest::Client> {
354 type TranscriptionModel = TranscriptionModel<reqwest::Client>;
355
356 fn transcription_model(&self, model: &str) -> TranscriptionModel<reqwest::Client> {
368 TranscriptionModel::new(self.clone(), model)
369 }
370}
371
372#[derive(Debug, Deserialize)]
373struct ApiErrorResponse {
374 message: String,
375}
376
377#[derive(Debug, Deserialize)]
378#[serde(untagged)]
379enum ApiResponse<T> {
380 Ok(T),
381 Err(ApiErrorResponse),
382}
383
384pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
389pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
391pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
393
394#[derive(Debug, Deserialize)]
395pub struct EmbeddingResponse {
396 pub object: String,
397 pub data: Vec<EmbeddingData>,
398 pub model: String,
399 pub usage: Usage,
400}
401
402impl From<ApiErrorResponse> for EmbeddingError {
403 fn from(err: ApiErrorResponse) -> Self {
404 EmbeddingError::ProviderError(err.message)
405 }
406}
407
408impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
409 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
410 match value {
411 ApiResponse::Ok(response) => Ok(response),
412 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
413 }
414 }
415}
416
417#[derive(Debug, Deserialize)]
418pub struct EmbeddingData {
419 pub object: String,
420 pub embedding: Vec<f64>,
421 pub index: usize,
422}
423
424#[derive(Clone, Debug, Deserialize)]
425pub struct Usage {
426 pub prompt_tokens: usize,
427 pub total_tokens: usize,
428}
429
430impl GetTokenUsage for Usage {
431 fn token_usage(&self) -> Option<crate::completion::Usage> {
432 let mut usage = crate::completion::Usage::new();
433
434 usage.input_tokens = self.prompt_tokens as u64;
435 usage.total_tokens = self.total_tokens as u64;
436 usage.output_tokens = usage.total_tokens - usage.input_tokens;
437
438 Some(usage)
439 }
440}
441
442impl std::fmt::Display for Usage {
443 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
444 write!(
445 f,
446 "Prompt tokens: {} Total tokens: {}",
447 self.prompt_tokens, self.total_tokens
448 )
449 }
450}
451
452#[derive(Clone)]
453pub struct EmbeddingModel<T = reqwest::Client> {
454 client: Client<T>,
455 pub model: String,
456 ndims: usize,
457}
458
459impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
460where
461 T: HttpClientExt + Default + Clone,
462{
463 const MAX_DOCUMENTS: usize = 1024;
464
465 fn ndims(&self) -> usize {
466 self.ndims
467 }
468
469 #[cfg_attr(feature = "worker", worker::send)]
470 async fn embed_texts(
471 &self,
472 documents: impl IntoIterator<Item = String>,
473 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
474 let documents = documents.into_iter().collect::<Vec<_>>();
475
476 let body = serde_json::to_vec(&json!({
477 "input": documents,
478 }))?;
479
480 let req = self
481 .client
482 .post_embedding(&self.model)
483 .header("Content-Type", "application/json")
484 .body(body)
485 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
486
487 let response = self.client.send(req).await?;
488
489 if response.status().is_success() {
490 let body: Vec<u8> = response.into_body().await?;
491 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
492
493 match body {
494 ApiResponse::Ok(response) => {
495 tracing::info!(target: "rig",
496 "Azure embedding token usage: {}",
497 response.usage
498 );
499
500 if response.data.len() != documents.len() {
501 return Err(EmbeddingError::ResponseError(
502 "Response data length does not match input length".into(),
503 ));
504 }
505
506 Ok(response
507 .data
508 .into_iter()
509 .zip(documents.into_iter())
510 .map(|(embedding, document)| embeddings::Embedding {
511 document,
512 vec: embedding.embedding,
513 })
514 .collect())
515 }
516 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
517 }
518 } else {
519 let text = http_client::text(response).await?;
520 Err(EmbeddingError::ProviderError(text))
521 }
522 }
523}
524
525impl<T> EmbeddingModel<T> {
526 pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
527 Self {
528 client,
529 model: model.to_string(),
530 ndims,
531 }
532 }
533}
534
535pub const O1: &str = "o1";
540pub const O1_PREVIEW: &str = "o1-preview";
542pub const O1_MINI: &str = "o1-mini";
544pub const GPT_4O: &str = "gpt-4o";
546pub const GPT_4O_MINI: &str = "gpt-4o-mini";
548pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
550pub const GPT_4_TURBO: &str = "gpt-4";
552pub const GPT_4: &str = "gpt-4";
554pub const GPT_4_32K: &str = "gpt-4-32k";
556pub const GPT_4_32K_0613: &str = "gpt-4-32k";
558pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
560pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
562pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
564
565#[derive(Clone)]
566pub struct CompletionModel<T = reqwest::Client> {
567 client: Client<T>,
568 pub model: String,
570}
571
572impl<T> CompletionModel<T> {
573 pub fn new(client: Client<T>, model: &str) -> Self {
574 Self {
575 client,
576 model: model.to_string(),
577 }
578 }
579
580 fn create_completion_request(
581 &self,
582 completion_request: CompletionRequest,
583 ) -> Result<serde_json::Value, CompletionError> {
584 let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
585 Some(preamble) => vec![openai::Message::system(preamble)],
586 None => vec![],
587 };
588 if let Some(docs) = completion_request.normalized_documents() {
589 let docs: Vec<openai::Message> = docs.try_into()?;
590 full_history.extend(docs);
591 }
592 let chat_history: Vec<openai::Message> = completion_request
593 .chat_history
594 .into_iter()
595 .map(|message| message.try_into())
596 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
597 .into_iter()
598 .flatten()
599 .collect();
600
601 full_history.extend(chat_history);
602
603 let request = if completion_request.tools.is_empty() {
604 json!({
605 "model": self.model,
606 "messages": full_history,
607 "temperature": completion_request.temperature,
608 })
609 } else {
610 json!({
611 "model": self.model,
612 "messages": full_history,
613 "temperature": completion_request.temperature,
614 "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
615 "tool_choice": "auto",
616 })
617 };
618
619 let request = if let Some(params) = completion_request.additional_params {
620 json_utils::merge(request, params)
621 } else {
622 request
623 };
624
625 Ok(request)
626 }
627}
628
629impl completion::CompletionModel for CompletionModel<reqwest::Client> {
630 type Response = openai::CompletionResponse;
631 type StreamingResponse = openai::StreamingCompletionResponse;
632
633 #[cfg_attr(feature = "worker", worker::send)]
634 async fn completion(
635 &self,
636 completion_request: CompletionRequest,
637 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
638 let span = if tracing::Span::current().is_disabled() {
639 info_span!(
640 target: "rig::completions",
641 "chat",
642 gen_ai.operation.name = "chat",
643 gen_ai.provider.name = "azure.openai",
644 gen_ai.request.model = self.model,
645 gen_ai.system_instructions = &completion_request.preamble,
646 gen_ai.response.id = tracing::field::Empty,
647 gen_ai.response.model = tracing::field::Empty,
648 gen_ai.usage.output_tokens = tracing::field::Empty,
649 gen_ai.usage.input_tokens = tracing::field::Empty,
650 gen_ai.input.messages = tracing::field::Empty,
651 gen_ai.output.messages = tracing::field::Empty,
652 )
653 } else {
654 tracing::Span::current()
655 };
656 let request = self.create_completion_request(completion_request)?;
657 span.record_model_input(
658 &request
659 .get("messages")
660 .expect("Converting JSON should not fail"),
661 );
662
663 async move {
664 let response = self
665 .client
666 .post_chat_completion(&self.model)
667 .json(&request)
668 .send()
669 .await
670 .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?;
671
672 if response.status().is_success() {
673 let t = response.text().await.map_err(|e| {
674 CompletionError::HttpError(http_client::Error::Instance(e.into()))
675 })?;
676 tracing::debug!(target: "rig", "Azure completion error: {}", t);
677
678 match serde_json::from_str::<ApiResponse<openai::CompletionResponse>>(&t)? {
679 ApiResponse::Ok(response) => {
680 let span = tracing::Span::current();
681 span.record_model_output(&response.choices);
682 span.record_response_metadata(&response);
683 span.record_token_usage(&response.usage);
684 response.try_into()
685 }
686 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
687 }
688 } else {
689 Err(CompletionError::ProviderError(
690 response.text().await.map_err(|e| {
691 CompletionError::HttpError(http_client::Error::Instance(e.into()))
692 })?,
693 ))
694 }
695 }
696 .instrument(span)
697 .await
698 }
699
700 #[cfg_attr(feature = "worker", worker::send)]
701 async fn stream(
702 &self,
703 request: CompletionRequest,
704 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
705 let preamble = request.preamble.clone();
706 let mut request = self.create_completion_request(request)?;
707
708 request = merge(
709 request,
710 json!({"stream": true, "stream_options": {"include_usage": true}}),
711 );
712
713 let builder = self
714 .client
715 .post_chat_completion(self.model.as_str())
716 .header("Content-Type", "application/json")
717 .json(&request);
718
719 let span = if tracing::Span::current().is_disabled() {
720 info_span!(
721 target: "rig::completions",
722 "chat_streaming",
723 gen_ai.operation.name = "chat_streaming",
724 gen_ai.provider.name = "azure.openai",
725 gen_ai.request.model = self.model,
726 gen_ai.system_instructions = &preamble,
727 gen_ai.response.id = tracing::field::Empty,
728 gen_ai.response.model = tracing::field::Empty,
729 gen_ai.usage.output_tokens = tracing::field::Empty,
730 gen_ai.usage.input_tokens = tracing::field::Empty,
731 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
732 gen_ai.output.messages = tracing::field::Empty,
733 )
734 } else {
735 tracing::Span::current()
736 };
737
738 tracing_futures::Instrument::instrument(send_compatible_streaming_request(builder), span)
739 .await
740 }
741}
742
743#[derive(Clone)]
748pub struct TranscriptionModel<T = reqwest::Client> {
749 client: Client<T>,
750 pub model: String,
752}
753
754impl<T> TranscriptionModel<T> {
755 pub fn new(client: Client<T>, model: &str) -> Self {
756 Self {
757 client,
758 model: model.to_string(),
759 }
760 }
761}
762
763impl transcription::TranscriptionModel for TranscriptionModel<reqwest::Client> {
764 type Response = TranscriptionResponse;
765
766 #[cfg_attr(feature = "worker", worker::send)]
767 async fn transcription(
768 &self,
769 request: transcription::TranscriptionRequest,
770 ) -> Result<
771 transcription::TranscriptionResponse<Self::Response>,
772 transcription::TranscriptionError,
773 > {
774 let data = request.data;
775
776 let mut body = reqwest::multipart::Form::new().part(
777 "file",
778 Part::bytes(data).file_name(request.filename.clone()),
779 );
780
781 if let Some(prompt) = request.prompt {
782 body = body.text("prompt", prompt.clone());
783 }
784
785 if let Some(ref temperature) = request.temperature {
786 body = body.text("temperature", temperature.to_string());
787 }
788
789 if let Some(ref additional_params) = request.additional_params {
790 for (key, value) in additional_params
791 .as_object()
792 .expect("Additional Parameters to OpenAI Transcription should be a map")
793 {
794 body = body.text(key.to_owned(), value.to_string());
795 }
796 }
797
798 let response = self
799 .client
800 .post_transcription(&self.model)
801 .header("Content-Type", "application/json")
802 .multipart(body)
803 .send()
804 .await
805 .map_err(|e| TranscriptionError::HttpError(http_client::Error::Instance(e.into())))?;
806
807 if response.status().is_success() {
808 match response
809 .json::<ApiResponse<TranscriptionResponse>>()
810 .await
811 .map_err(|e| {
812 TranscriptionError::HttpError(http_client::Error::Instance(e.into()))
813 })? {
814 ApiResponse::Ok(response) => response.try_into(),
815 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
816 api_error_response.message,
817 )),
818 }
819 } else {
820 Err(TranscriptionError::ProviderError(
821 response.text().await.map_err(|e| {
822 TranscriptionError::HttpError(http_client::Error::Instance(e.into()))
823 })?,
824 ))
825 }
826 }
827}
828
829#[cfg(feature = "image")]
833pub use image_generation::*;
834use tracing::{Instrument, info_span};
835#[cfg(feature = "image")]
836#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
837mod image_generation {
838 use crate::client::ImageGenerationClient;
839 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
840 use crate::providers::azure::{ApiResponse, Client};
841 use crate::providers::openai::ImageGenerationResponse;
842 use crate::{http_client, image_generation};
843 use serde_json::json;
844
845 #[derive(Clone)]
846 pub struct ImageGenerationModel<T = reqwest::Client> {
847 client: Client<T>,
848 pub model: String,
849 }
850 impl image_generation::ImageGenerationModel for ImageGenerationModel<reqwest::Client> {
851 type Response = ImageGenerationResponse;
852
853 #[cfg_attr(feature = "worker", worker::send)]
854 async fn image_generation(
855 &self,
856 generation_request: ImageGenerationRequest,
857 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
858 {
859 let request = json!({
860 "model": self.model,
861 "prompt": generation_request.prompt,
862 "size": format!("{}x{}", generation_request.width, generation_request.height),
863 "response_format": "b64_json"
864 });
865
866 let response = self
867 .client
868 .post_image_generation(&self.model)
869 .header("Content-Type", "application/json")
870 .json(&request)
871 .send()
872 .await
873 .map_err(|e| {
874 ImageGenerationError::HttpError(http_client::Error::Instance(e.into()))
875 })?;
876
877 if !response.status().is_success() {
878 return Err(ImageGenerationError::ProviderError(format!(
879 "{}: {}",
880 response.status(),
881 response.text().await.map_err(|e| {
882 ImageGenerationError::HttpError(http_client::Error::Instance(e.into()))
883 })?
884 )));
885 }
886
887 let t = response.text().await.map_err(|e| {
888 ImageGenerationError::HttpError(http_client::Error::Instance(e.into()))
889 })?;
890
891 match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&t)? {
892 ApiResponse::Ok(response) => response.try_into(),
893 ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
894 }
895 }
896 }
897
898 impl ImageGenerationClient for Client<reqwest::Client> {
899 type ImageGenerationModel = ImageGenerationModel<reqwest::Client>;
900
901 fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
902 ImageGenerationModel {
903 client: self.clone(),
904 model: model.to_string(),
905 }
906 }
907 }
908}
909use crate::client::{
914 CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient, VerifyClient,
915 VerifyError,
916};
917#[cfg(feature = "audio")]
918pub use audio_generation::*;
919
920#[cfg(feature = "audio")]
921#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
922mod audio_generation {
923 use super::Client;
924 use crate::audio_generation::{
925 AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
926 };
927 use crate::client::AudioGenerationClient;
928 use crate::{audio_generation, http_client};
929 use bytes::Bytes;
930 use serde_json::json;
931
932 #[derive(Clone)]
933 pub struct AudioGenerationModel<T = reqwest::Client> {
934 client: Client<T>,
935 model: String,
936 }
937
938 impl audio_generation::AudioGenerationModel for AudioGenerationModel<reqwest::Client> {
939 type Response = Bytes;
940
941 #[cfg_attr(feature = "worker", worker::send)]
942 async fn audio_generation(
943 &self,
944 request: AudioGenerationRequest,
945 ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
946 let request = json!({
947 "model": self.model,
948 "input": request.text,
949 "voice": request.voice,
950 "speed": request.speed,
951 });
952
953 let response = self
954 .client
955 .post_audio_generation("/audio/speech")
956 .header("Content-Type", "application/json")
957 .json(&request)
958 .send()
959 .await
960 .map_err(|e| {
961 AudioGenerationError::HttpError(http_client::Error::Instance(e.into()))
962 })?;
963
964 if !response.status().is_success() {
965 return Err(AudioGenerationError::ProviderError(format!(
966 "{}: {}",
967 response.status(),
968 response.text().await.map_err(|e| {
969 AudioGenerationError::HttpError(http_client::Error::Instance(e.into()))
970 })?
971 )));
972 }
973
974 let bytes = response.bytes().await.map_err(|e| {
975 AudioGenerationError::HttpError(http_client::Error::Instance(e.into()))
976 })?;
977
978 Ok(AudioGenerationResponse {
979 audio: bytes.to_vec(),
980 response: bytes,
981 })
982 }
983 }
984
985 impl AudioGenerationClient for Client<reqwest::Client> {
986 type AudioGenerationModel = AudioGenerationModel<reqwest::Client>;
987
988 fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
989 AudioGenerationModel {
990 client: self.clone(),
991 model: model.to_string(),
992 }
993 }
994 }
995}
996
997impl VerifyClient for Client<reqwest::Client> {
998 #[cfg_attr(feature = "worker", worker::send)]
999 async fn verify(&self) -> Result<(), VerifyError> {
1000 Ok(())
1003 }
1004}
1005
1006#[cfg(test)]
1007mod azure_tests {
1008 use super::*;
1009
1010 use crate::OneOrMany;
1011 use crate::completion::CompletionModel;
1012 use crate::embeddings::EmbeddingModel;
1013
1014 #[tokio::test]
1015 #[ignore]
1016 async fn test_azure_embedding() {
1017 let _ = tracing_subscriber::fmt::try_init();
1018
1019 let client = Client::from_env();
1020 let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1021 let embeddings = model
1022 .embed_texts(vec!["Hello, world!".to_string()])
1023 .await
1024 .unwrap();
1025
1026 tracing::info!("Azure embedding: {:?}", embeddings);
1027 }
1028
1029 #[tokio::test]
1030 #[ignore]
1031 async fn test_azure_completion() {
1032 let _ = tracing_subscriber::fmt::try_init();
1033
1034 let client = Client::from_env();
1035 let model = client.completion_model(GPT_4O_MINI);
1036 let completion = model
1037 .completion(CompletionRequest {
1038 preamble: Some("You are a helpful assistant.".to_string()),
1039 chat_history: OneOrMany::one("Hello!".into()),
1040 documents: vec![],
1041 max_tokens: Some(100),
1042 temperature: Some(0.0),
1043 tools: vec![],
1044 tool_choice: None,
1045 additional_params: None,
1046 })
1047 .await
1048 .unwrap();
1049
1050 tracing::info!("Azure completion: {:?}", completion);
1051 }
1052}