1use super::openai::{TranscriptionResponse, send_compatible_streaming_request};
13
14use crate::completion::GetTokenUsage;
15use crate::http_client::{self, HttpClientExt};
16use crate::json_utils::merge;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19 completion::{self, CompletionError, CompletionRequest},
20 embeddings::{self, EmbeddingError},
21 json_utils,
22 providers::openai,
23 telemetry::SpanCombinator,
24 transcription::{self, TranscriptionError},
25};
26use bytes::Bytes;
27use reqwest::header::AUTHORIZATION;
28use reqwest::multipart::Part;
29use serde::Deserialize;
30use serde_json::json;
31const DEFAULT_API_VERSION: &str = "2024-10-21";
36
37pub struct ClientBuilder<'a, T = reqwest::Client> {
38 auth: AzureOpenAIAuth,
39 api_version: Option<&'a str>,
40 azure_endpoint: &'a str,
41 http_client: T,
42}
43
44impl<'a, T> ClientBuilder<'a, T>
45where
46 T: Default,
47{
48 pub fn new(auth: impl Into<AzureOpenAIAuth>, endpoint: &'a str) -> Self {
49 Self {
50 auth: auth.into(),
51 api_version: None,
52 azure_endpoint: endpoint,
53 http_client: Default::default(),
54 }
55 }
56}
57
58impl<'a, T> ClientBuilder<'a, T> {
59 pub fn new_with_client(
60 auth: impl Into<AzureOpenAIAuth>,
61 endpoint: &'a str,
62 http_client: T,
63 ) -> Self {
64 Self {
65 auth: auth.into(),
66 api_version: None,
67 azure_endpoint: endpoint,
68 http_client,
69 }
70 }
71
72 pub fn api_version(mut self, api_version: &'a str) -> Self {
74 self.api_version = Some(api_version);
75 self
76 }
77
78 pub fn azure_endpoint(mut self, azure_endpoint: &'a str) -> Self {
80 self.azure_endpoint = azure_endpoint;
81 self
82 }
83
84 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
85 ClientBuilder {
86 auth: self.auth,
87 api_version: self.api_version,
88 azure_endpoint: self.azure_endpoint,
89 http_client,
90 }
91 }
92
93 pub fn build(self) -> Client<T> {
94 let api_version = self.api_version.unwrap_or(DEFAULT_API_VERSION);
95
96 Client {
97 api_version: api_version.to_string(),
98 azure_endpoint: self.azure_endpoint.to_string(),
99 auth: self.auth,
100 http_client: self.http_client,
101 }
102 }
103}
104
105#[derive(Clone)]
106pub struct Client<T = reqwest::Client> {
107 api_version: String,
108 azure_endpoint: String,
109 auth: AzureOpenAIAuth,
110 http_client: T,
111}
112
113impl<T> std::fmt::Debug for Client<T>
114where
115 T: std::fmt::Debug,
116{
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 f.debug_struct("Client")
119 .field("azure_endpoint", &self.azure_endpoint)
120 .field("http_client", &self.http_client)
121 .field("auth", &"<REDACTED>")
122 .field("api_version", &self.api_version)
123 .finish()
124 }
125}
126
127#[derive(Clone)]
128pub enum AzureOpenAIAuth {
129 ApiKey(String),
130 Token(String),
131}
132
133impl std::fmt::Debug for AzureOpenAIAuth {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 match self {
136 Self::ApiKey(_) => write!(f, "API key <REDACTED>"),
137 Self::Token(_) => write!(f, "Token <REDACTED>"),
138 }
139 }
140}
141
142impl From<String> for AzureOpenAIAuth {
143 fn from(token: String) -> Self {
144 AzureOpenAIAuth::Token(token)
145 }
146}
147
148impl AzureOpenAIAuth {
149 fn as_header(&self) -> (reqwest::header::HeaderName, reqwest::header::HeaderValue) {
150 match self {
151 AzureOpenAIAuth::ApiKey(api_key) => (
152 "api-key".parse().expect("Header value should parse"),
153 api_key.parse().expect("API key should parse"),
154 ),
155 AzureOpenAIAuth::Token(token) => (
156 AUTHORIZATION,
157 format!("Bearer {token}")
158 .parse()
159 .expect("Token should parse"),
160 ),
161 }
162 }
163}
164
165impl Client<reqwest::Client> {
166 pub fn builder(
177 auth: impl Into<AzureOpenAIAuth>,
178 endpoint: &str,
179 ) -> ClientBuilder<'_, reqwest::Client> {
180 ClientBuilder::new(auth, endpoint)
181 }
182
183 pub fn new(auth: impl Into<AzureOpenAIAuth>, endpoint: &str) -> Self {
185 Self::builder(auth, endpoint).build()
186 }
187
188 pub fn from_env() -> Self {
189 <Self as ProviderClient>::from_env()
190 }
191}
192
193impl<T> Client<T>
194where
195 T: HttpClientExt,
196{
197 fn post(&self, url: String) -> http_client::Builder {
198 let (key, value) = self.auth.as_header();
199
200 http_client::Request::post(url).header(key, value)
201 }
202
203 fn post_embedding(&self, deployment_id: &str) -> http_client::Builder {
204 let url = format!(
205 "{}/openai/deployments/{}/embeddings?api-version={}",
206 self.azure_endpoint,
207 deployment_id.trim_start_matches('/'),
208 self.api_version
209 );
210
211 self.post(url)
212 }
213
214 #[cfg(feature = "audio")]
215 fn post_audio_generation(&self, deployment_id: &str) -> http_client::Builder {
216 let url = format!(
217 "{}/openai/deployments/{}/audio/speech?api-version={}",
218 self.azure_endpoint,
219 deployment_id.trim_start_matches('/'),
220 self.api_version
221 );
222
223 self.post(url)
224 }
225
226 fn post_chat_completion(&self, deployment_id: &str) -> http_client::Builder {
227 let url = format!(
228 "{}/openai/deployments/{}/chat/completions?api-version={}",
229 self.azure_endpoint,
230 deployment_id.trim_start_matches('/'),
231 self.api_version
232 );
233
234 self.post(url)
235 }
236
237 fn post_transcription(&self, deployment_id: &str) -> http_client::Builder {
238 let url = format!(
239 "{}/openai/deployments/{}/audio/translations?api-version={}",
240 self.azure_endpoint,
241 deployment_id.trim_start_matches('/'),
242 self.api_version
243 );
244
245 self.post(url)
246 }
247
248 #[cfg(feature = "image")]
249 fn post_image_generation(&self, deployment_id: &str) -> http_client::Builder {
250 let url = format!(
251 "{}/openai/deployments/{}/images/generations?api-version={}",
252 self.azure_endpoint,
253 deployment_id.trim_start_matches('/'),
254 self.api_version
255 );
256
257 self.post(url)
258 }
259
260 async fn send<U, R>(
261 &self,
262 req: http_client::Request<U>,
263 ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
264 where
265 U: Into<Bytes> + Send,
266 R: From<Bytes> + Send + 'static,
267 {
268 self.http_client.send(req).await
269 }
270}
271
272impl<T> ProviderClient for Client<T>
273where
274 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
275{
276 fn from_env() -> Self {
278 let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
279 AzureOpenAIAuth::ApiKey(api_key)
280 } else if let Ok(token) = std::env::var("AZURE_TOKEN") {
281 AzureOpenAIAuth::Token(token)
282 } else {
283 panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set");
284 };
285
286 let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set");
287 let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set");
288
289 ClientBuilder::<T>::new(auth, &azure_endpoint)
290 .api_version(&api_version)
291 .build()
292 }
293
294 fn from_val(input: crate::client::ProviderValue) -> Self {
295 let crate::client::ProviderValue::ApiKeyWithVersionAndHeader(api_key, version, header) =
296 input
297 else {
298 panic!("Incorrect provider value type")
299 };
300 let auth = AzureOpenAIAuth::ApiKey(api_key.to_string());
301 ClientBuilder::<T>::new(auth, &header)
302 .api_version(&version)
303 .build()
304 }
305}
306
307impl<T> CompletionClient for Client<T>
308where
309 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
310{
311 type CompletionModel = CompletionModel<T>;
312
313 fn completion_model(&self, model: &str) -> Self::CompletionModel {
325 CompletionModel::new(self.clone(), model)
326 }
327}
328
329impl<T> EmbeddingsClient for Client<T>
330where
331 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
332{
333 type EmbeddingModel = EmbeddingModel<T>;
334
335 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
349 let ndims = match model {
350 TEXT_EMBEDDING_3_LARGE => 3072,
351 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
352 _ => 0,
353 };
354 EmbeddingModel::new(self.clone(), model, ndims)
355 }
356
357 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
369 EmbeddingModel::new(self.clone(), model, ndims)
370 }
371}
372
373impl<T> TranscriptionClient for Client<T>
374where
375 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
376{
377 type TranscriptionModel = TranscriptionModel<T>;
378
379 fn transcription_model(&self, model: &str) -> Self::TranscriptionModel {
391 TranscriptionModel::new(self.clone(), model)
392 }
393}
394
395#[derive(Debug, Deserialize)]
396struct ApiErrorResponse {
397 message: String,
398}
399
400#[derive(Debug, Deserialize)]
401#[serde(untagged)]
402enum ApiResponse<T> {
403 Ok(T),
404 Err(ApiErrorResponse),
405}
406
407pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
412pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
414pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
416
417#[derive(Debug, Deserialize)]
418pub struct EmbeddingResponse {
419 pub object: String,
420 pub data: Vec<EmbeddingData>,
421 pub model: String,
422 pub usage: Usage,
423}
424
425impl From<ApiErrorResponse> for EmbeddingError {
426 fn from(err: ApiErrorResponse) -> Self {
427 EmbeddingError::ProviderError(err.message)
428 }
429}
430
431impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
432 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
433 match value {
434 ApiResponse::Ok(response) => Ok(response),
435 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
436 }
437 }
438}
439
440#[derive(Debug, Deserialize)]
441pub struct EmbeddingData {
442 pub object: String,
443 pub embedding: Vec<f64>,
444 pub index: usize,
445}
446
447#[derive(Clone, Debug, Deserialize)]
448pub struct Usage {
449 pub prompt_tokens: usize,
450 pub total_tokens: usize,
451}
452
453impl GetTokenUsage for Usage {
454 fn token_usage(&self) -> Option<crate::completion::Usage> {
455 let mut usage = crate::completion::Usage::new();
456
457 usage.input_tokens = self.prompt_tokens as u64;
458 usage.total_tokens = self.total_tokens as u64;
459 usage.output_tokens = usage.total_tokens - usage.input_tokens;
460
461 Some(usage)
462 }
463}
464
465impl std::fmt::Display for Usage {
466 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
467 write!(
468 f,
469 "Prompt tokens: {} Total tokens: {}",
470 self.prompt_tokens, self.total_tokens
471 )
472 }
473}
474
475#[derive(Clone)]
476pub struct EmbeddingModel<T = reqwest::Client> {
477 client: Client<T>,
478 pub model: String,
479 ndims: usize,
480}
481
482impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
483where
484 T: HttpClientExt + Default + Clone,
485{
486 const MAX_DOCUMENTS: usize = 1024;
487
488 fn ndims(&self) -> usize {
489 self.ndims
490 }
491
492 #[cfg_attr(feature = "worker", worker::send)]
493 async fn embed_texts(
494 &self,
495 documents: impl IntoIterator<Item = String>,
496 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
497 let documents = documents.into_iter().collect::<Vec<_>>();
498
499 let body = serde_json::to_vec(&json!({
500 "input": documents,
501 }))?;
502
503 let req = self
504 .client
505 .post_embedding(&self.model)
506 .header("Content-Type", "application/json")
507 .body(body)
508 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
509
510 let response = self.client.send(req).await?;
511
512 if response.status().is_success() {
513 let body: Vec<u8> = response.into_body().await?;
514 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
515
516 match body {
517 ApiResponse::Ok(response) => {
518 tracing::info!(target: "rig",
519 "Azure embedding token usage: {}",
520 response.usage
521 );
522
523 if response.data.len() != documents.len() {
524 return Err(EmbeddingError::ResponseError(
525 "Response data length does not match input length".into(),
526 ));
527 }
528
529 Ok(response
530 .data
531 .into_iter()
532 .zip(documents.into_iter())
533 .map(|(embedding, document)| embeddings::Embedding {
534 document,
535 vec: embedding.embedding,
536 })
537 .collect())
538 }
539 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
540 }
541 } else {
542 let text = http_client::text(response).await?;
543 Err(EmbeddingError::ProviderError(text))
544 }
545 }
546}
547
548impl<T> EmbeddingModel<T> {
549 pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
550 Self {
551 client,
552 model: model.to_string(),
553 ndims,
554 }
555 }
556}
557
558pub const O1: &str = "o1";
563pub const O1_PREVIEW: &str = "o1-preview";
565pub const O1_MINI: &str = "o1-mini";
567pub const GPT_4O: &str = "gpt-4o";
569pub const GPT_4O_MINI: &str = "gpt-4o-mini";
571pub const GPT_4O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview";
573pub const GPT_4_TURBO: &str = "gpt-4";
575pub const GPT_4: &str = "gpt-4";
577pub const GPT_4_32K: &str = "gpt-4-32k";
579pub const GPT_4_32K_0613: &str = "gpt-4-32k";
581pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
583pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
585pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";
587
588#[derive(Clone)]
589pub struct CompletionModel<T = reqwest::Client> {
590 client: Client<T>,
591 pub model: String,
593}
594
595impl<T> CompletionModel<T> {
596 pub fn new(client: Client<T>, model: &str) -> Self {
597 Self {
598 client,
599 model: model.to_string(),
600 }
601 }
602
603 fn create_completion_request(
604 &self,
605 completion_request: CompletionRequest,
606 ) -> Result<serde_json::Value, CompletionError> {
607 let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
608 Some(preamble) => vec![openai::Message::system(preamble)],
609 None => vec![],
610 };
611 if let Some(docs) = completion_request.normalized_documents() {
612 let docs: Vec<openai::Message> = docs.try_into()?;
613 full_history.extend(docs);
614 }
615 let chat_history: Vec<openai::Message> = completion_request
616 .chat_history
617 .into_iter()
618 .map(|message| message.try_into())
619 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
620 .into_iter()
621 .flatten()
622 .collect();
623
624 full_history.extend(chat_history);
625
626 let request = if completion_request.tools.is_empty() {
627 json!({
628 "model": self.model,
629 "messages": full_history,
630 "temperature": completion_request.temperature,
631 })
632 } else {
633 json!({
634 "model": self.model,
635 "messages": full_history,
636 "temperature": completion_request.temperature,
637 "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
638 "tool_choice": "auto",
639 })
640 };
641
642 let request = if let Some(params) = completion_request.additional_params {
643 json_utils::merge(request, params)
644 } else {
645 request
646 };
647
648 Ok(request)
649 }
650}
651
652impl<T> completion::CompletionModel for CompletionModel<T>
653where
654 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
655{
656 type Response = openai::CompletionResponse;
657 type StreamingResponse = openai::StreamingCompletionResponse;
658
659 #[cfg_attr(feature = "worker", worker::send)]
660 async fn completion(
661 &self,
662 completion_request: CompletionRequest,
663 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
664 let span = if tracing::Span::current().is_disabled() {
665 info_span!(
666 target: "rig::completions",
667 "chat",
668 gen_ai.operation.name = "chat",
669 gen_ai.provider.name = "azure.openai",
670 gen_ai.request.model = self.model,
671 gen_ai.system_instructions = &completion_request.preamble,
672 gen_ai.response.id = tracing::field::Empty,
673 gen_ai.response.model = tracing::field::Empty,
674 gen_ai.usage.output_tokens = tracing::field::Empty,
675 gen_ai.usage.input_tokens = tracing::field::Empty,
676 gen_ai.input.messages = tracing::field::Empty,
677 gen_ai.output.messages = tracing::field::Empty,
678 )
679 } else {
680 tracing::Span::current()
681 };
682 let request = self.create_completion_request(completion_request)?;
683 span.record_model_input(
684 &request
685 .get("messages")
686 .expect("Converting JSON should not fail"),
687 );
688 let body = serde_json::to_vec(&request)?;
689
690 let req = self
691 .client
692 .post_chat_completion(&self.model)
693 .header("Content-Type", "application/json")
694 .body(body)
695 .map_err(http_client::Error::from)?;
696
697 async move {
698 let response = self.client.http_client.send::<_, Bytes>(req).await.unwrap();
699
700 let status = response.status();
701 let response_body = response.into_body().into_future().await?.to_vec();
702
703 if status.is_success() {
704 match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(&response_body)? {
705 ApiResponse::Ok(response) => {
706 let span = tracing::Span::current();
707 span.record_model_output(&response.choices);
708 span.record_response_metadata(&response);
709 span.record_token_usage(&response.usage);
710 tracing::debug!(target: "rig", "Azure completion output: {}", serde_json::to_string_pretty(&response)?);
711 response.try_into()
712 }
713 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
714 }
715 } else {
716 Err(CompletionError::ProviderError(
717 String::from_utf8_lossy(&response_body).to_string()
718 ))
719 }
720 }
721 .instrument(span)
722 .await
723 }
724
725 #[cfg_attr(feature = "worker", worker::send)]
726 async fn stream(
727 &self,
728 request: CompletionRequest,
729 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
730 let preamble = request.preamble.clone();
731 let mut request = self.create_completion_request(request)?;
732
733 request = merge(
734 request,
735 json!({"stream": true, "stream_options": {"include_usage": true}}),
736 );
737
738 let body = serde_json::to_vec(&request)?;
739
740 let req = self
741 .client
742 .post_chat_completion(&self.model)
743 .header("Content-Type", "application/json")
744 .body(body)
745 .map_err(http_client::Error::from)?;
746
747 let span = if tracing::Span::current().is_disabled() {
748 info_span!(
749 target: "rig::completions",
750 "chat_streaming",
751 gen_ai.operation.name = "chat_streaming",
752 gen_ai.provider.name = "azure.openai",
753 gen_ai.request.model = self.model,
754 gen_ai.system_instructions = &preamble,
755 gen_ai.response.id = tracing::field::Empty,
756 gen_ai.response.model = tracing::field::Empty,
757 gen_ai.usage.output_tokens = tracing::field::Empty,
758 gen_ai.usage.input_tokens = tracing::field::Empty,
759 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
760 gen_ai.output.messages = tracing::field::Empty,
761 )
762 } else {
763 tracing::Span::current()
764 };
765
766 tracing_futures::Instrument::instrument(
767 send_compatible_streaming_request(self.client.http_client.clone(), req),
768 span,
769 )
770 .await
771 }
772}
773
774#[derive(Clone)]
779pub struct TranscriptionModel<T = reqwest::Client> {
780 client: Client<T>,
781 pub model: String,
783}
784
785impl<T> TranscriptionModel<T> {
786 pub fn new(client: Client<T>, model: &str) -> Self {
787 Self {
788 client,
789 model: model.to_string(),
790 }
791 }
792}
793
794impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
795where
796 T: HttpClientExt + Clone + 'static,
797{
798 type Response = TranscriptionResponse;
799
800 #[cfg_attr(feature = "worker", worker::send)]
801 async fn transcription(
802 &self,
803 request: transcription::TranscriptionRequest,
804 ) -> Result<
805 transcription::TranscriptionResponse<Self::Response>,
806 transcription::TranscriptionError,
807 > {
808 let data = request.data;
809
810 let mut body = reqwest::multipart::Form::new().part(
811 "file",
812 Part::bytes(data).file_name(request.filename.clone()),
813 );
814
815 if let Some(prompt) = request.prompt {
816 body = body.text("prompt", prompt.clone());
817 }
818
819 if let Some(ref temperature) = request.temperature {
820 body = body.text("temperature", temperature.to_string());
821 }
822
823 if let Some(ref additional_params) = request.additional_params {
824 for (key, value) in additional_params
825 .as_object()
826 .expect("Additional Parameters to OpenAI Transcription should be a map")
827 {
828 body = body.text(key.to_owned(), value.to_string());
829 }
830 }
831
832 let req = self
833 .client
834 .post_transcription(&self.model)
835 .body(body)
836 .map_err(|e| TranscriptionError::HttpError(e.into()))?;
837
838 let response = self.client.http_client.send_multipart::<Bytes>(req).await?;
839 let status = response.status();
840 let response_body = response.into_body().into_future().await?.to_vec();
841
842 if status.is_success() {
843 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
844 ApiResponse::Ok(response) => response.try_into(),
845 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
846 api_error_response.message,
847 )),
848 }
849 } else {
850 Err(TranscriptionError::ProviderError(
851 String::from_utf8_lossy(&response_body).to_string(),
852 ))
853 }
854 }
855}
856
857#[cfg(feature = "image")]
861pub use image_generation::*;
862use tracing::{Instrument, info_span};
863#[cfg(feature = "image")]
864#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
865mod image_generation {
866 use crate::client::ImageGenerationClient;
867 use crate::http_client::HttpClientExt;
868 use crate::image_generation;
869 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
870 use crate::providers::azure::{ApiResponse, Client};
871 use crate::providers::openai::ImageGenerationResponse;
872 use bytes::Bytes;
873 use serde_json::json;
874
875 #[derive(Clone)]
876 pub struct ImageGenerationModel<T = reqwest::Client> {
877 client: Client<T>,
878 pub model: String,
879 }
880
881 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
882 where
883 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
884 {
885 type Response = ImageGenerationResponse;
886
887 #[cfg_attr(feature = "worker", worker::send)]
888 async fn image_generation(
889 &self,
890 generation_request: ImageGenerationRequest,
891 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
892 {
893 let request = json!({
894 "model": self.model,
895 "prompt": generation_request.prompt,
896 "size": format!("{}x{}", generation_request.width, generation_request.height),
897 "response_format": "b64_json"
898 });
899
900 let body = serde_json::to_vec(&request)?;
901
902 let req = self
903 .client
904 .post_image_generation(&self.model)
905 .header("Content-Type", "application/json")
906 .body(body)
907 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
908
909 let response = self.client.http_client.send::<_, Bytes>(req).await?;
910 let status = response.status();
911 let response_body = response.into_body().into_future().await?.to_vec();
912
913 if !status.is_success() {
914 return Err(ImageGenerationError::ProviderError(format!(
915 "{status}: {}",
916 String::from_utf8_lossy(&response_body)
917 )));
918 }
919
920 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
921 ApiResponse::Ok(response) => response.try_into(),
922 ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
923 }
924 }
925 }
926
927 impl<T> ImageGenerationClient for Client<T>
928 where
929 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
930 {
931 type ImageGenerationModel = ImageGenerationModel<T>;
932
933 fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
934 ImageGenerationModel {
935 client: self.clone(),
936 model: model.to_string(),
937 }
938 }
939 }
940}
941use crate::client::{
946 CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient, VerifyClient,
947 VerifyError,
948};
949#[cfg(feature = "audio")]
950pub use audio_generation::*;
951
952#[cfg(feature = "audio")]
953#[cfg_attr(docsrs, doc(cfg(feature = "audio")))]
954mod audio_generation {
955 use super::Client;
956 use crate::audio_generation::{
957 self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
958 };
959 use crate::client::AudioGenerationClient;
960 use crate::http_client::HttpClientExt;
961 use bytes::Bytes;
962 use serde_json::json;
963
964 #[derive(Clone)]
965 pub struct AudioGenerationModel<T = reqwest::Client> {
966 client: Client<T>,
967 model: String,
968 }
969
970 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
971 where
972 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
973 {
974 type Response = Bytes;
975
976 #[cfg_attr(feature = "worker", worker::send)]
977 async fn audio_generation(
978 &self,
979 request: AudioGenerationRequest,
980 ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
981 let request = json!({
982 "model": self.model,
983 "input": request.text,
984 "voice": request.voice,
985 "speed": request.speed,
986 });
987
988 let body = serde_json::to_vec(&request)?;
989
990 let req = self
991 .client
992 .post_audio_generation("/audio/speech")
993 .header("Content-Type", "application/json")
994 .body(body)
995 .map_err(|e| AudioGenerationError::HttpError(e.into()))?;
996
997 let response = self.client.http_client.send::<_, Bytes>(req).await?;
998 let status = response.status();
999 let response_body = response.into_body().into_future().await?;
1000
1001 if !status.is_success() {
1002 return Err(AudioGenerationError::ProviderError(format!(
1003 "{status}: {}",
1004 String::from_utf8_lossy(&response_body)
1005 )));
1006 }
1007
1008 Ok(AudioGenerationResponse {
1009 audio: response_body.to_vec(),
1010 response: response_body,
1011 })
1012 }
1013 }
1014
1015 impl<T> AudioGenerationClient for Client<T>
1016 where
1017 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
1018 {
1019 type AudioGenerationModel = AudioGenerationModel<T>;
1020
1021 fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
1022 AudioGenerationModel {
1023 client: self.clone(),
1024 model: model.to_string(),
1025 }
1026 }
1027 }
1028}
1029
1030impl<T> VerifyClient for Client<T>
1031where
1032 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
1033{
1034 #[cfg_attr(feature = "worker", worker::send)]
1035 async fn verify(&self) -> Result<(), VerifyError> {
1036 Ok(())
1039 }
1040}
1041
1042#[cfg(test)]
1043mod azure_tests {
1044 use super::*;
1045
1046 use crate::OneOrMany;
1047 use crate::completion::CompletionModel;
1048 use crate::embeddings::EmbeddingModel;
1049
1050 #[tokio::test]
1051 #[ignore]
1052 async fn test_azure_embedding() {
1053 let _ = tracing_subscriber::fmt::try_init();
1054
1055 let client = Client::from_env();
1056 let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL);
1057 let embeddings = model
1058 .embed_texts(vec!["Hello, world!".to_string()])
1059 .await
1060 .unwrap();
1061
1062 tracing::info!("Azure embedding: {:?}", embeddings);
1063 }
1064
1065 #[tokio::test]
1066 #[ignore]
1067 async fn test_azure_completion() {
1068 let _ = tracing_subscriber::fmt::try_init();
1069
1070 let client = Client::from_env();
1071 let model = client.completion_model(GPT_4O_MINI);
1072 let completion = model
1073 .completion(CompletionRequest {
1074 preamble: Some("You are a helpful assistant.".to_string()),
1075 chat_history: OneOrMany::one("Hello!".into()),
1076 documents: vec![],
1077 max_tokens: Some(100),
1078 temperature: Some(0.0),
1079 tools: vec![],
1080 tool_choice: None,
1081 additional_params: None,
1082 })
1083 .await
1084 .unwrap();
1085
1086 tracing::info!("Azure completion: {:?}", completion);
1087 }
1088}