use super::{
client::Client,
http::{multipart, Method, Request},
request::{
AudioToTextRequest, Bytes, ChatMessagesRequest, CompletionMessagesRequest,
ConversationsDeleteRequest, ConversationsRenameRequest, ConversationsRequest,
FilesUploadRequest, MessagesFeedbacksRequest, MessagesRequest, MessagesSuggestedRequest,
MetaRequest, ParametersRequest, ResponseMode, StreamTaskStopRequest, TextToAudioRequest,
WorkflowsRunRequest,
},
response::{
parse_error_response, parse_response, AudioToTextResponse, ChatMessagesResponse,
CompletionMessagesResponse, ConversationsResponse, FilesUploadResponse, MessagesResponse,
MessagesSuggestedResponse, MetaResponse, ParametersResponse, ResultResponse,
SseMessageEventStream, WorkflowsRunResponse,
},
};
use anyhow::{bail, Result as AnyResult};
use eventsource_stream::Eventsource;
use futures::stream::Stream;
use std::fmt::{Display, Formatter, Result as FmtResult};
#[derive(Debug)]
pub enum ApiPath {
ChatMessages,
FilesUpload,
ChatMessagesStop,
MessagesFeedbacks,
MessagesSuggested,
Messages,
Conversations,
ConversationsDelete,
ConversationsRename,
AudioToText,
TextToAudio,
Parameters,
Meta,
WorkflowsRun,
WorkflowsStop,
CompletionMessages,
CompletionMessagesStop,
}
impl ApiPath {
pub fn as_str(&self) -> &'static str {
match self {
ApiPath::ChatMessages => "/v1/chat-messages",
ApiPath::FilesUpload => "/v1/files/upload",
ApiPath::ChatMessagesStop => "/v1/chat-messages/{task_id}/stop",
ApiPath::MessagesFeedbacks => "/v1/messages/{message_id}/feedbacks",
ApiPath::MessagesSuggested => "/v1/messages/{message_id}/suggested",
ApiPath::Messages => "/v1/messages",
ApiPath::Conversations => "/v1/conversations",
ApiPath::ConversationsDelete => "/v1/conversations/{conversation_id}",
ApiPath::ConversationsRename => "/v1/conversations/{conversation_id}/name",
ApiPath::AudioToText => "/v1/audio-to-text",
ApiPath::TextToAudio => "/v1/text-to-audio",
ApiPath::Parameters => "/v1/parameters",
ApiPath::Meta => "/v1/meta",
ApiPath::WorkflowsRun => "/v1/workflows/run",
ApiPath::WorkflowsStop => "/v1/workflows/{task_id}/stop",
ApiPath::CompletionMessages => "/v1/completion-messages",
ApiPath::CompletionMessagesStop => "/v1/completion-messages/{task_id}/stop",
}
}
}
impl Display for ApiPath {
fn fmt(&self, f: &mut Formatter) -> FmtResult {
write!(f, "{}", self.as_str())
}
}
type BeforeSend = Option<Box<dyn Fn(Request) -> Request + Send + Sync>>;
pub struct Api<'a> {
before_send_hook: BeforeSend,
pub(crate) client: &'a Client,
}
impl<'a> Api<'a> {
pub fn new(client: &'a Client) -> Self {
Self {
before_send_hook: None,
client,
}
}
pub fn before_send<F>(&mut self, hook: F)
where
F: Fn(Request) -> Request + Send + Sync + 'static,
{
self.before_send_hook = Some(Box::new(hook));
}
async fn send(&self, mut req: Request) -> AnyResult<reqwest::Response> {
if let Some(hook) = self.before_send_hook.as_ref() {
req = hook(req);
}
self.client.execute(req).await
}
fn build_request_api(&self, api_path: ApiPath) -> String {
self.client.config.base_url.clone() + api_path.as_str()
}
fn create_chat_messages_request(&self, req: ChatMessagesRequest) -> AnyResult<Request> {
let url = self.build_request_api(ApiPath::ChatMessages);
self.client.create_request(url, Method::POST, req)
}
pub async fn chat_messages(
&self,
mut req_data: ChatMessagesRequest,
) -> AnyResult<ChatMessagesResponse> {
req_data.response_mode = ResponseMode::Blocking;
let req = self.create_chat_messages_request(req_data)?;
let resp = self.send(req).await?;
let text = resp.text().await?;
parse_response::<ChatMessagesResponse>(&text)
}
pub async fn chat_messages_stream(
&self,
mut req_data: ChatMessagesRequest,
) -> AnyResult<SseMessageEventStream<impl Stream<Item = Result<Bytes, reqwest::Error>>>> {
req_data.response_mode = ResponseMode::Streaming;
let req = self.create_chat_messages_request(req_data)?;
let resp = self.send(req).await?;
let stream = resp.bytes_stream().eventsource();
let s = SseMessageEventStream::new(stream);
Ok(s)
}
pub async fn files_upload(
&self,
req_data: FilesUploadRequest,
) -> AnyResult<FilesUploadResponse> {
if !infer::is_image(&req_data.file) {
bail!("FilesUploadRequest.File Illegal");
}
let kind = infer::get(&req_data.file).expect("Failed to get file type");
let file_part = multipart::Part::stream(req_data.file)
.file_name(format!("image_file.{}", kind.extension()))
.mime_str(kind.mime_type())?;
let form = multipart::Form::new()
.text("user", req_data.user)
.part("file", file_part);
let url = self.build_request_api(ApiPath::FilesUpload);
let req = self.client.create_multipart_request(url, form)?;
let resp = self.send(req).await?;
let text = resp.text().await?;
parse_response::<FilesUploadResponse>(&text)
}
async fn stream_task_stop(
&self,
mut req_data: StreamTaskStopRequest,
api_path: ApiPath,
) -> AnyResult<ResultResponse> {
if req_data.task_id.is_empty() {
bail!("StreamTaskStopRequest.TaskId Illegal");
}
let url = self.build_request_api(api_path);
let url = url.replace("{task_id}", &req_data.task_id);
req_data.task_id = String::new();
let req = self.client.create_request(url, Method::POST, req_data)?;
let resp = self.send(req).await?;
let text = resp.text().await?;
parse_response::<ResultResponse>(&text)
}
pub async fn chat_messages_stop(
&self,
req_data: StreamTaskStopRequest,
) -> AnyResult<ResultResponse> {
self.stream_task_stop(req_data, ApiPath::ChatMessagesStop)
.await
}
pub async fn messages_suggested(
&self,
mut req_data: MessagesSuggestedRequest,
) -> AnyResult<MessagesSuggestedResponse> {
if req_data.message_id.is_empty() {
bail!("MessagesSuggestedRequest.MessageID Illegal");
}
let url = self.build_request_api(ApiPath::MessagesSuggested);
let url = url.replace("{message_id}", &req_data.message_id);
req_data.message_id = String::new();
let req = self.client.create_request(url, Method::GET, req_data)?;
let resp = self.send(req).await?;
let text = resp.text().await?;
parse_response::<MessagesSuggestedResponse>(&text)
}
pub async fn messages_feedbacks(
&self,
mut req_data: MessagesFeedbacksRequest,
) -> AnyResult<ResultResponse> {
if req_data.message_id.is_empty() {
bail!("MessagesFeedbacksRequest.MessageID Illegal");
}
let url = self.build_request_api(ApiPath::MessagesFeedbacks);
let url = url.replace("{message_id}", &req_data.message_id);
req_data.message_id = String::new();
let req = self.client.create_request(url, Method::POST, req_data)?;
let resp = self.send(req).await?;
let text = resp.text().await?;
parse_response::<ResultResponse>(&text)
}
pub async fn conversations(
&self,
req_data: ConversationsRequest,
) -> AnyResult<ConversationsResponse> {
if req_data.user.is_empty() {
bail!("ConversationsRequest.User Illegal");
}
let url = self.build_request_api(ApiPath::Conversations);
let req = self.client.create_request(url, Method::GET, req_data)?;
let resp = self.send(req).await?;
let text = resp.text().await?;
parse_response::<ConversationsResponse>(&text)
}
pub async fn messages(&self, req_data: MessagesRequest) -> AnyResult<MessagesResponse> {
if req_data.conversation_id.is_empty() {
bail!("MessagesRequest.ConversationID Illegal");
}
let url = self.build_request_api(ApiPath::Messages);
let req = self.client.create_request(url, Method::GET, req_data)?;
let resp = self.send(req).await?;
let text = resp.text().await?;
parse_response::<MessagesResponse>(&text)
}
pub async fn conversations_renaming(
&self,
mut req_data: ConversationsRenameRequest,
) -> AnyResult<ResultResponse> {
if req_data.conversation_id.is_empty() {
bail!("ConversationsRenameRequest.ConversationID Illegal");
}
if req_data.auto_generate && req_data.name.is_none() {
bail!("ConversationsRenameRequest.Name Illegal");
}
let url = self.build_request_api(ApiPath::ConversationsRename);
let url = url.replace("{conversation_id}", &req_data.conversation_id);
req_data.conversation_id = String::new();
let req = self.client.create_request(url, Method::POST, req_data)?;
let resp = self.send(req).await?;
let text = resp.text().await?;
parse_response::<ResultResponse>(&text)
}
pub async fn conversations_delete(
&self,
mut req_data: ConversationsDeleteRequest,
) -> AnyResult<()> {
if req_data.conversation_id.is_empty() {
bail!("ConversationsDeleteRequest.ConversationID Illegal");
}
let url = self.build_request_api(ApiPath::ConversationsDelete);
let url = url.replace("{conversation_id}", &req_data.conversation_id);
req_data.conversation_id = String::new();
let req = self.client.create_request(url, Method::DELETE, req_data)?;
let resp = self.send(req).await?;
if resp.status().as_u16() == 204 {
Ok(())
} else {
let text = resp.text().await?;
parse_error_response(&text)
}
}
pub async fn text_to_audio(&self, req_data: TextToAudioRequest) -> AnyResult<Bytes> {
if req_data.text.is_empty() {
bail!("TextToAudioRequest.Text Illegal");
}
let url = self.build_request_api(ApiPath::TextToAudio);
let req = self.client.create_request(url, Method::POST, req_data)?;
let resp = self.send(req).await?;
let content_type = resp.headers().get(reqwest::header::CONTENT_TYPE);
let content_type = content_type
.ok_or(anyhow::anyhow!("Content-Type is missing"))?
.to_str()?;
if content_type.starts_with("audio/") {
let bytes = resp.bytes().await?;
return Ok(bytes);
}
let text = resp.text().await?;
parse_error_response(&text)
}
pub async fn audio_to_text(
&self,
req_data: AudioToTextRequest,
) -> AnyResult<AudioToTextResponse> {
if !infer::is_audio(&req_data.file) {
bail!("AudioToTextRequest.File Illegal");
}
let kind = infer::get(&req_data.file).expect("Failed to get file type");
let file_part = multipart::Part::stream(req_data.file)
.file_name(format!("audio_file.{}", kind.extension()))
.mime_str(kind.mime_type())?;
let form = multipart::Form::new()
.text("user", req_data.user)
.part("file", file_part);
let url = self.build_request_api(ApiPath::AudioToText);
let req = self.client.create_multipart_request(url, form)?;
let resp = self.send(req).await?;
let text = resp.text().await?;
parse_response::<AudioToTextResponse>(&text)
}
pub async fn parameters(&self, req_data: ParametersRequest) -> AnyResult<ParametersResponse> {
if req_data.user.is_empty() {
bail!("ParametersRequest.User Illegal");
}
let url = self.build_request_api(ApiPath::Parameters);
let req = self.client.create_request(url, Method::GET, req_data)?;
let resp = self.send(req).await?;
let text = resp.text().await?;
parse_response::<ParametersResponse>(&text)
}
pub async fn meta(&self, req_data: MetaRequest) -> AnyResult<MetaResponse> {
if req_data.user.is_empty() {
bail!("MetaRequest.User Illegal");
}
let url = self.build_request_api(ApiPath::Meta);
let req = self.client.create_request(url, Method::GET, req_data)?;
let resp = self.send(req).await?;
let text = resp.text().await?;
parse_response::<MetaResponse>(&text)
}
fn create_workflows_run_request(&self, req: WorkflowsRunRequest) -> AnyResult<Request> {
let url = self.build_request_api(ApiPath::WorkflowsRun);
self.client.create_request(url, Method::POST, req)
}
pub async fn workflows_run(
&self,
mut req_data: WorkflowsRunRequest,
) -> AnyResult<WorkflowsRunResponse> {
req_data.response_mode = ResponseMode::Blocking;
let req = self.create_workflows_run_request(req_data)?;
let resp = self.send(req).await?;
let text = resp.text().await?;
parse_response::<WorkflowsRunResponse>(&text)
}
pub async fn workflows_run_stream(
&self,
mut req_data: WorkflowsRunRequest,
) -> AnyResult<SseMessageEventStream<impl Stream<Item = Result<Bytes, reqwest::Error>>>> {
req_data.response_mode = ResponseMode::Streaming;
let req = self.create_workflows_run_request(req_data)?;
let resp = self.send(req).await?;
let stream = resp.bytes_stream().eventsource();
let s = SseMessageEventStream::new(stream);
Ok(s)
}
pub async fn workflows_stop(
&self,
req_data: StreamTaskStopRequest,
) -> AnyResult<ResultResponse> {
self.stream_task_stop(req_data, ApiPath::WorkflowsStop)
.await
}
fn create_completion_messages_request(
&self,
req: CompletionMessagesRequest,
) -> AnyResult<Request> {
let url = self.build_request_api(ApiPath::CompletionMessages);
self.client.create_request(url, Method::POST, req)
}
pub async fn completion_messages(
&self,
mut req_data: CompletionMessagesRequest,
) -> AnyResult<CompletionMessagesResponse> {
req_data.response_mode = ResponseMode::Blocking;
let req = self.create_completion_messages_request(req_data)?;
let resp = self.send(req).await?;
let text = resp.text().await?;
parse_response::<CompletionMessagesResponse>(&text)
}
pub async fn completion_messages_stream(
&self,
mut req_data: CompletionMessagesRequest,
) -> AnyResult<SseMessageEventStream<impl Stream<Item = Result<Bytes, reqwest::Error>>>> {
req_data.response_mode = ResponseMode::Streaming;
let req = self.create_completion_messages_request(req_data)?;
let resp = self.send(req).await?;
let stream = resp.bytes_stream().eventsource();
let s = SseMessageEventStream::new(stream);
Ok(s)
}
pub async fn completion_messages_stop(
&self,
req_data: StreamTaskStopRequest,
) -> AnyResult<ResultResponse> {
self.stream_task_stop(req_data, ApiPath::CompletionMessagesStop)
.await
}
}