use super::request::{RequestBuilder, RequestSpec};
use crate::Config;
use crate::error::{OpenAIError, ProcessingError};
use crate::service::executor::HttpExecutor;
use eventsource_stream::{Event, EventStreamError, Eventsource};
use futures::StreamExt;
use std::any::type_name;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_stream::wrappers::ReceiverStream;
enum SseEventResult<T>
where
T: serde::de::DeserializeOwned,
{
Skip,
Data(T),
Done,
Error(OpenAIError),
}
pub struct HttpTransport {
executor: HttpExecutor,
}
impl HttpTransport {
pub fn new(config: Config) -> HttpTransport {
HttpTransport {
executor: HttpExecutor::new(config),
}
}
pub(crate) fn config(&self) -> Arc<RwLock<Config>> {
self.executor.config()
}
pub async fn post_json<U, F, T>(&self, params: RequestSpec<U, F>) -> Result<T, OpenAIError>
where
U: FnOnce(&Config) -> String,
F: FnOnce(&Config, &mut RequestBuilder),
T: serde::de::DeserializeOwned,
{
let res = self.executor.post(params).await?;
let raw = res.text().await.map_err(ProcessingError::TextRead)?;
serde_json::from_str(&raw).map_err(|_| {
ProcessingError::Conversion {
raw,
target_type: type_name::<T>().to_string(),
}
.into()
})
}
pub async fn get_json<U, F, T>(&self, params: RequestSpec<U, F>) -> Result<T, OpenAIError>
where
U: FnOnce(&Config) -> String,
F: FnOnce(&Config, &mut RequestBuilder),
T: serde::de::DeserializeOwned,
{
let res = self.executor.get(params).await?;
let raw = res.text().await.map_err(ProcessingError::TextRead)?;
serde_json::from_str(&raw).map_err(|_| {
ProcessingError::Conversion {
raw,
target_type: type_name::<T>().to_string(),
}
.into()
})
}
pub async fn post_json_stream<U, F, T>(
&self,
params: RequestSpec<U, F>,
) -> Result<tokio_stream::wrappers::ReceiverStream<Result<T, OpenAIError>>, OpenAIError>
where
U: FnOnce(&Config) -> String,
F: FnOnce(&Config, &mut RequestBuilder),
T: serde::de::DeserializeOwned + Send + 'static,
{
let res = self.executor.post(params).await?;
let mut event_stream = res.bytes_stream().eventsource();
let (tx, rx) = tokio::sync::mpsc::channel(32);
tokio::spawn(async move {
while let Some(event_result) = event_stream.next().await {
let process_result = Self::process_stream_event(event_result).await;
match process_result {
SseEventResult::Skip => continue,
SseEventResult::Data(chunk) => {
if tx.send(Ok(chunk)).await.is_err() {
break;
}
}
SseEventResult::Done => break,
SseEventResult::Error(error) => {
if tx.send(Err(error)).await.is_err() {
break;
}
}
}
}
drop(tx);
});
Ok(ReceiverStream::new(rx))
}
async fn process_stream_event<T>(
event_result: Result<Event, EventStreamError<reqwest::Error>>,
) -> SseEventResult<T>
where
T: serde::de::DeserializeOwned + Send + 'static,
{
match event_result {
Ok(event) => {
if event.data.is_empty() {
return SseEventResult::Skip;
}
if event.data == "[DONE]" {
SseEventResult::Done
} else {
match serde_json::from_str::<T>(&event.data) {
Ok(chunk) => SseEventResult::Data(chunk),
Err(_) => SseEventResult::Error(
ProcessingError::Conversion {
raw: event.data,
target_type: type_name::<T>().to_string(),
}
.into(),
),
}
}
}
Err(e) => SseEventResult::Error(OpenAIError::from_eventsource_stream_error(e)),
}
}
pub async fn refresh_client(&self) {
self.executor.rebuild_reqwest_client().await;
}
}