use std::pin::Pin;
use std::str::FromStr;
use crate::chat_completions::{
ChatCompletion, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse,
};
use crate::embeddings::{Embeddings, EmbeddingsRequest, EmbeddingsResponse};
use crate::utils::uri::ensure_no_trailing_slash;
use crate::{Error, Result};
use async_stream::stream;
use async_trait::async_trait;
use derive_builder::Builder;
use futures::{Stream, StreamExt};
use reqwest::header::HeaderName;
use secrecy::{ExposeSecret, SecretString};
pub const AZURE_OPENAI_API_KEY_ENV_VAR: &str = "AZURE_OPENAI_API_KEY";
#[derive(Debug, Clone, Builder)]
#[builder(derive(Debug))]
#[builder(setter(into))]
pub struct Client {
#[builder(default)]
http_client: reqwest::Client,
api_version: String,
base_url: String,
auth: Auth,
}
#[derive(Debug, Clone)]
pub enum Auth {
BearerToken(SecretString),
ApiKey(SecretString),
}
impl Client {
pub fn new(auth: Auth, base_url: &str, api_version: &str) -> Result<Self> {
Ok(Self {
http_client: reqwest::Client::default(),
api_version: api_version.to_string(),
base_url: ensure_no_trailing_slash(base_url),
auth,
})
}
}
impl Client {
fn get_headers(&self) -> Result<reqwest::header::HeaderMap> {
let mut headers = reqwest::header::HeaderMap::new();
let (auth_header_key, auth_header_value) = match &self.auth {
Auth::BearerToken(secret_box) => (
reqwest::header::AUTHORIZATION,
reqwest::header::HeaderValue::from_str(&format!(
"Bearer {}",
secret_box.expose_secret()
))
.map_err(|e| {
Error::InvalidHeaderValue(reqwest::header::AUTHORIZATION.to_string(), e)
})?,
),
Auth::ApiKey(secret_box) => (
HeaderName::from_str("api-key")
.map_err(|e| Error::InvalidHeaderName("api-key".to_owned(), e))?,
reqwest::header::HeaderValue::from_str(&secret_box.expose_secret())
.map_err(|e| Error::InvalidHeaderValue("api-key".to_string(), e))?,
),
};
headers.insert(auth_header_key, auth_header_value);
headers.insert(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_static("application/json"),
);
Ok(headers)
}
fn get_url(&self, model: &str) -> String {
format!(
"{}/openai/deployments/{}/chat/completions?api-version={}",
self.base_url, model, self.api_version
)
}
fn get_embeddings_url(&self, model: &str) -> String {
format!(
"{}/openai/deployments/{}/embeddings?api-version={}",
self.base_url, model, self.api_version
)
}
}
#[async_trait]
impl ChatCompletion for Client {
async fn chat_completions(
&self,
request: &ChatCompletionRequest,
) -> Result<ChatCompletionResponse> {
if let Some(stream) = request.stream {
if stream {
return Err(Error::StreamingNotSupported(
"Streaming is not supported when using chat_completions() api".to_string(),
));
}
}
if let Some(token) = &request.cancellation_token {
if token.is_cancelled() {
return Err(Error::Cancelled);
}
}
let headers = self.get_headers()?;
let url = self.get_url(&request.model);
let (abort_handle, abort_registration) = futures::future::AbortHandle::new_pair();
if let Some(token) = &request.cancellation_token {
let token = token.clone();
tokio::spawn(async move {
token.cancelled().await;
abort_handle.abort();
});
}
let request_future = self
.http_client
.post(url)
.headers(headers)
.json(request)
.send();
let response =
match futures::future::Abortable::new(request_future, abort_registration).await {
Ok(response) => response?,
Err(futures::future::Aborted) => {
return Err(Error::Cancelled);
}
};
if !response.status().is_success() {
return Err(Error::UnknownError(response.text().await?));
}
let chat_completion_response = response.json::<ChatCompletionResponse>().await?;
Ok(chat_completion_response)
}
async fn stream_chat_completions(
&self,
request: &ChatCompletionRequest,
) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>> {
if let Some(stream) = request.stream {
if !stream {
return Err(Error::StreamingNotSupported(
"Streaming required when using stream_chat_completions() api".to_string(),
));
}
}
if let Some(token) = &request.cancellation_token {
if token.is_cancelled() {
return Ok(Box::pin(futures::stream::empty()));
}
}
let mut json = serde_json::to_value(request)?;
json["stream"] = serde_json::Value::Bool(true);
let (abort_handle, abort_registration) = futures::future::AbortHandle::new_pair();
if let Some(token) = &request.cancellation_token {
let token = token.clone();
tokio::spawn(async move {
token.cancelled().await;
abort_handle.abort();
});
}
let request_future = self
.http_client
.post(self.get_url(&request.model))
.headers(self.get_headers()?)
.body(json.to_string())
.send();
let response =
match futures::future::Abortable::new(request_future, abort_registration).await {
Ok(response) => response?,
Err(futures::future::Aborted) => {
return Ok(Box::pin(futures::stream::empty()));
}
};
if !response.status().is_success() {
return Err(Error::UnknownError(response.text().await?));
}
let byte_stream = response.bytes_stream();
let cancellation_token = request.cancellation_token.clone();
let result_stream = stream! {
let mut stream = byte_stream;
let mut buffer = String::new();
while let Some(chunk_result) = stream.next().await {
if let Some(token) = &cancellation_token {
if token.is_cancelled() {
break;
}
}
match chunk_result {
Ok(chunk) => {
let chunk_str = String::from_utf8_lossy(&chunk);
buffer.push_str(&chunk_str);
while let Some(pos) = buffer.find("\n\n") {
let message = buffer[..pos].to_string();
buffer = buffer[pos + 2..].to_string();
if message.starts_with("data: ") {
let data = &message["data: ".len()..];
if data == "[DONE]" {
break;
}
match serde_json::from_str::<ChatCompletionChunk>(data) {
Ok(v) => yield Ok(v),
Err(e) => yield Err(Error::SerdeJsonError(e)),
}
}
}
},
Err(e) => yield Err(Error::UnknownError(format!("Failed to read response: {}", e))),
}
}
};
Ok(Box::pin(result_stream))
}
}
#[async_trait]
impl Embeddings for Client {
async fn create_embeddings(&self, request: &EmbeddingsRequest) -> Result<EmbeddingsResponse> {
if let Some(token) = &request.cancellation_token {
if token.is_cancelled() {
return Err(Error::Cancelled);
}
}
let headers = self.get_headers()?;
let url = self.get_embeddings_url(&request.model);
let (abort_handle, abort_registration) = futures::future::AbortHandle::new_pair();
if let Some(token) = &request.cancellation_token {
let token = token.clone();
tokio::spawn(async move {
token.cancelled().await;
abort_handle.abort();
});
}
let request_future = self
.http_client
.post(url)
.headers(headers)
.json(request)
.send();
let response =
match futures::future::Abortable::new(request_future, abort_registration).await {
Ok(response) => response?,
Err(futures::future::Aborted) => {
return Err(Error::Cancelled);
}
};
if !response.status().is_success() {
return Err(Error::UnknownError(response.text().await?));
}
let embeddings_response = response.json::<EmbeddingsResponse>().await?;
Ok(embeddings_response)
}
}
impl super::Client for Client {}