1use super::{
47 client::Client,
48 http::{multipart, Method, Request},
49 request::{
50 AudioToTextRequest, Bytes, ChatMessagesRequest, CompletionMessagesRequest,
51 ConversationsDeleteRequest, ConversationsRenameRequest, ConversationsRequest,
52 FilesUploadRequest, MessagesFeedbacksRequest, MessagesRequest, MessagesSuggestedRequest,
53 MetaRequest, ParametersRequest, ResponseMode, StreamTaskStopRequest, TextToAudioRequest,
54 WorkflowsRunRequest,
55 },
56 response::{
57 parse_error_response, parse_response, AudioToTextResponse, ChatMessagesResponse,
58 CompletionMessagesResponse, ConversationsResponse, FilesUploadResponse, MessagesResponse,
59 MessagesSuggestedResponse, MetaResponse, ParametersResponse, ResultResponse,
60 SseMessageEventStream, WorkflowsRunResponse,
61 },
62};
63use anyhow::{bail, Result as AnyResult};
64use eventsource_stream::Eventsource;
65use futures::stream::Stream;
66use std::fmt::{Display, Formatter, Result as FmtResult};
67
68#[derive(Debug)]
70pub enum ApiPath {
71 ChatMessages,
73 FilesUpload,
78 ChatMessagesStop,
80 MessagesFeedbacks,
82 MessagesSuggested,
84 Messages,
86 Conversations,
88 ConversationsDelete,
90 ConversationsRename,
92 AudioToText,
94 TextToAudio,
96 Parameters,
98 Meta,
100
101 WorkflowsRun,
104 WorkflowsStop,
106
107 CompletionMessages,
110 CompletionMessagesStop,
112}
113
114impl ApiPath {
116 pub fn as_str(&self) -> &'static str {
124 match self {
125 ApiPath::ChatMessages => "/v1/chat-messages",
126 ApiPath::FilesUpload => "/v1/files/upload",
127 ApiPath::ChatMessagesStop => "/v1/chat-messages/{task_id}/stop",
128 ApiPath::MessagesFeedbacks => "/v1/messages/{message_id}/feedbacks",
129 ApiPath::MessagesSuggested => "/v1/messages/{message_id}/suggested",
130 ApiPath::Messages => "/v1/messages",
131 ApiPath::Conversations => "/v1/conversations",
132 ApiPath::ConversationsDelete => "/v1/conversations/{conversation_id}",
133 ApiPath::ConversationsRename => "/v1/conversations/{conversation_id}/name",
134 ApiPath::AudioToText => "/v1/audio-to-text",
135 ApiPath::TextToAudio => "/v1/text-to-audio",
136 ApiPath::Parameters => "/v1/parameters",
137 ApiPath::Meta => "/v1/meta",
138 ApiPath::WorkflowsRun => "/v1/workflows/run",
139 ApiPath::WorkflowsStop => "/v1/workflows/{task_id}/stop",
140 ApiPath::CompletionMessages => "/v1/completion-messages",
141 ApiPath::CompletionMessagesStop => "/v1/completion-messages/{task_id}/stop",
142 }
143 }
144}
145
146impl Display for ApiPath {
147 fn fmt(&self, f: &mut Formatter) -> FmtResult {
148 write!(f, "{}", self.as_str())
149 }
150}
151
152type BeforeSend = Option<Box<dyn Fn(Request) -> Request + Send + Sync>>;
154
155pub struct Api<'a> {
157 before_send_hook: BeforeSend,
158 pub(crate) client: &'a Client,
159}
160
161impl<'a> Api<'a> {
163 pub fn new(client: &'a Client) -> Self {
168 Self {
169 before_send_hook: None,
170 client,
171 }
172 }
173
174 pub fn before_send<F>(&mut self, hook: F)
183 where
184 F: Fn(Request) -> Request + Send + Sync + 'static,
185 {
186 self.before_send_hook = Some(Box::new(hook));
187 }
188
189 async fn send(&self, mut req: Request) -> AnyResult<reqwest::Response> {
197 if let Some(hook) = self.before_send_hook.as_ref() {
198 req = hook(req);
199 }
200 self.client.execute(req).await
201 }
202
203 fn build_request_api(&self, api_path: ApiPath) -> String {
211 self.client.config.base_url.clone() + api_path.as_str()
212 }
213
214 fn create_chat_messages_request(&self, req: ChatMessagesRequest) -> AnyResult<Request> {
225 let url = self.build_request_api(ApiPath::ChatMessages);
226 self.client.create_request(url, Method::POST, req)
227 }
228
229 pub async fn chat_messages(
237 &self,
238 mut req_data: ChatMessagesRequest,
239 ) -> AnyResult<ChatMessagesResponse> {
240 req_data.response_mode = ResponseMode::Blocking;
241
242 let req = self.create_chat_messages_request(req_data)?;
243 let resp = self.send(req).await?;
244 let text = resp.text().await?;
245 parse_response::<ChatMessagesResponse>(&text)
246 }
247
248 pub async fn chat_messages_stream(
273 &self,
274 mut req_data: ChatMessagesRequest,
275 ) -> AnyResult<SseMessageEventStream<impl Stream<Item = Result<Bytes, reqwest::Error>>>> {
276 req_data.response_mode = ResponseMode::Streaming;
277
278 let req = self.create_chat_messages_request(req_data)?;
279 let resp = self.send(req).await?;
280 let stream = resp.bytes_stream().eventsource();
281 let s = SseMessageEventStream::new(stream);
282
283 Ok(s)
284 }
285
286 pub async fn files_upload(
297 &self,
298 req_data: FilesUploadRequest,
299 ) -> AnyResult<FilesUploadResponse> {
300 if !infer::is_image(&req_data.file) {
301 bail!("FilesUploadRequest.File Illegal");
302 }
303 let kind = infer::get(&req_data.file).expect("Failed to get file type");
304 let file_part = multipart::Part::stream(req_data.file)
305 .file_name(format!("image_file.{}", kind.extension()))
306 .mime_str(kind.mime_type())?;
307 let form = multipart::Form::new()
308 .text("user", req_data.user)
309 .part("file", file_part);
310
311 let url = self.build_request_api(ApiPath::FilesUpload);
312 let req = self.client.create_multipart_request(url, form)?;
313 let resp = self.send(req).await?;
314 let text = resp.text().await?;
315 parse_response::<FilesUploadResponse>(&text)
316 }
317
318 async fn stream_task_stop(
328 &self,
329 mut req_data: StreamTaskStopRequest,
330 api_path: ApiPath,
331 ) -> AnyResult<ResultResponse> {
332 if req_data.task_id.is_empty() {
333 bail!("StreamTaskStopRequest.TaskId Illegal");
334 }
335
336 let url = self.build_request_api(api_path);
337 let url = url.replace("{task_id}", &req_data.task_id);
338
339 req_data.task_id = String::new();
340 let req = self.client.create_request(url, Method::POST, req_data)?;
341 let resp = self.send(req).await?;
342 let text = resp.text().await?;
343 parse_response::<ResultResponse>(&text)
344 }
345
346 pub async fn chat_messages_stop(
354 &self,
355 req_data: StreamTaskStopRequest,
356 ) -> AnyResult<ResultResponse> {
357 self.stream_task_stop(req_data, ApiPath::ChatMessagesStop)
358 .await
359 }
360
361 pub async fn messages_suggested(
369 &self,
370 mut req_data: MessagesSuggestedRequest,
371 ) -> AnyResult<MessagesSuggestedResponse> {
372 if req_data.message_id.is_empty() {
373 bail!("MessagesSuggestedRequest.MessageID Illegal");
374 }
375
376 let url = self.build_request_api(ApiPath::MessagesSuggested);
377 let url = url.replace("{message_id}", &req_data.message_id);
378
379 req_data.message_id = String::new();
380 let req = self.client.create_request(url, Method::GET, req_data)?;
381 let resp = self.send(req).await?;
382 let text = resp.text().await?;
383 parse_response::<MessagesSuggestedResponse>(&text)
384 }
385
386 pub async fn messages_feedbacks(
394 &self,
395 mut req_data: MessagesFeedbacksRequest,
396 ) -> AnyResult<ResultResponse> {
397 if req_data.message_id.is_empty() {
398 bail!("MessagesFeedbacksRequest.MessageID Illegal");
399 }
400
401 let url = self.build_request_api(ApiPath::MessagesFeedbacks);
402 let url = url.replace("{message_id}", &req_data.message_id);
403
404 req_data.message_id = String::new();
405 let req = self.client.create_request(url, Method::POST, req_data)?;
406 let resp = self.send(req).await?;
407 let text = resp.text().await?;
408 parse_response::<ResultResponse>(&text)
409 }
410
411 pub async fn conversations(
419 &self,
420 req_data: ConversationsRequest,
421 ) -> AnyResult<ConversationsResponse> {
422 if req_data.user.is_empty() {
423 bail!("ConversationsRequest.User Illegal");
424 }
425
426 let url = self.build_request_api(ApiPath::Conversations);
427 let req = self.client.create_request(url, Method::GET, req_data)?;
428 let resp = self.send(req).await?;
429 let text = resp.text().await?;
430 parse_response::<ConversationsResponse>(&text)
431 }
432
433 pub async fn messages(&self, req_data: MessagesRequest) -> AnyResult<MessagesResponse> {
441 if req_data.conversation_id.is_empty() {
442 bail!("MessagesRequest.ConversationID Illegal");
443 }
444
445 let url = self.build_request_api(ApiPath::Messages);
446 let req = self.client.create_request(url, Method::GET, req_data)?;
447 let resp = self.send(req).await?;
448 let text = resp.text().await?;
449 parse_response::<MessagesResponse>(&text)
450 }
451
452 pub async fn conversations_renaming(
460 &self,
461 mut req_data: ConversationsRenameRequest,
462 ) -> AnyResult<ResultResponse> {
463 if req_data.conversation_id.is_empty() {
464 bail!("ConversationsRenameRequest.ConversationID Illegal");
465 }
466 if req_data.auto_generate && req_data.name.is_none() {
467 bail!("ConversationsRenameRequest.Name Illegal");
468 }
469
470 let url = self.build_request_api(ApiPath::ConversationsRename);
471 let url = url.replace("{conversation_id}", &req_data.conversation_id);
472
473 req_data.conversation_id = String::new();
474 let req = self.client.create_request(url, Method::POST, req_data)?;
475 let resp = self.send(req).await?;
476 let text = resp.text().await?;
477 parse_response::<ResultResponse>(&text)
478 }
479
480 pub async fn conversations_delete(
488 &self,
489 mut req_data: ConversationsDeleteRequest,
490 ) -> AnyResult<()> {
491 if req_data.conversation_id.is_empty() {
492 bail!("ConversationsDeleteRequest.ConversationID Illegal");
493 }
494
495 let url = self.build_request_api(ApiPath::ConversationsDelete);
496 let url = url.replace("{conversation_id}", &req_data.conversation_id);
497
498 req_data.conversation_id = String::new();
499 let req = self.client.create_request(url, Method::DELETE, req_data)?;
500 let resp = self.send(req).await?;
501 if resp.status().as_u16() == 204 {
503 Ok(())
504 } else {
505 let text = resp.text().await?;
507 parse_error_response(&text)
508 }
509 }
510
511 pub async fn text_to_audio(&self, req_data: TextToAudioRequest) -> AnyResult<Bytes> {
519 if req_data.text.is_empty() {
520 bail!("TextToAudioRequest.Text Illegal");
521 }
522
523 let url = self.build_request_api(ApiPath::TextToAudio);
524 let req = self.client.create_request(url, Method::POST, req_data)?;
525 let resp = self.send(req).await?;
526 let content_type = resp.headers().get(reqwest::header::CONTENT_TYPE);
528 let content_type = content_type
529 .ok_or(anyhow::anyhow!("Content-Type is missing"))?
530 .to_str()?;
531 if content_type.starts_with("audio/") {
533 let bytes = resp.bytes().await?;
534 return Ok(bytes);
535 }
536 let text = resp.text().await?;
537 parse_error_response(&text)
538 }
539
540 pub async fn audio_to_text(
548 &self,
549 req_data: AudioToTextRequest,
550 ) -> AnyResult<AudioToTextResponse> {
551 if !infer::is_audio(&req_data.file) {
552 bail!("AudioToTextRequest.File Illegal");
553 }
554 let kind = infer::get(&req_data.file).expect("Failed to get file type");
555 let file_part = multipart::Part::stream(req_data.file)
556 .file_name(format!("audio_file.{}", kind.extension()))
557 .mime_str(kind.mime_type())?;
558 let form = multipart::Form::new()
559 .text("user", req_data.user)
560 .part("file", file_part);
561
562 let url = self.build_request_api(ApiPath::AudioToText);
563 let req = self.client.create_multipart_request(url, form)?;
564 let resp = self.send(req).await?;
565 let text = resp.text().await?;
566 parse_response::<AudioToTextResponse>(&text)
567 }
568
569 pub async fn parameters(&self, req_data: ParametersRequest) -> AnyResult<ParametersResponse> {
577 if req_data.user.is_empty() {
578 bail!("ParametersRequest.User Illegal");
579 }
580
581 let url = self.build_request_api(ApiPath::Parameters);
582 let req = self.client.create_request(url, Method::GET, req_data)?;
583 let resp = self.send(req).await?;
584 let text = resp.text().await?;
585 parse_response::<ParametersResponse>(&text)
586 }
587
588 pub async fn meta(&self, req_data: MetaRequest) -> AnyResult<MetaResponse> {
596 if req_data.user.is_empty() {
597 bail!("MetaRequest.User Illegal");
598 }
599
600 let url = self.build_request_api(ApiPath::Meta);
601 let req = self.client.create_request(url, Method::GET, req_data)?;
602 let resp = self.send(req).await?;
603 let text = resp.text().await?;
604 parse_response::<MetaResponse>(&text)
605 }
606
607 fn create_workflows_run_request(&self, req: WorkflowsRunRequest) -> AnyResult<Request> {
615 let url = self.build_request_api(ApiPath::WorkflowsRun);
616 self.client.create_request(url, Method::POST, req)
617 }
618
619 pub async fn workflows_run(
627 &self,
628 mut req_data: WorkflowsRunRequest,
629 ) -> AnyResult<WorkflowsRunResponse> {
630 req_data.response_mode = ResponseMode::Blocking;
631
632 let req = self.create_workflows_run_request(req_data)?;
633 let resp = self.send(req).await?;
634 let text = resp.text().await?;
635 parse_response::<WorkflowsRunResponse>(&text)
636 }
637
638 pub async fn workflows_run_stream(
649 &self,
650 mut req_data: WorkflowsRunRequest,
651 ) -> AnyResult<SseMessageEventStream<impl Stream<Item = Result<Bytes, reqwest::Error>>>> {
652 req_data.response_mode = ResponseMode::Streaming;
653
654 let req = self.create_workflows_run_request(req_data)?;
655 let resp = self.send(req).await?;
656 let stream = resp.bytes_stream().eventsource();
657 let s = SseMessageEventStream::new(stream);
658 Ok(s)
659 }
660
661 pub async fn workflows_stop(
669 &self,
670 req_data: StreamTaskStopRequest,
671 ) -> AnyResult<ResultResponse> {
672 self.stream_task_stop(req_data, ApiPath::WorkflowsStop)
673 .await
674 }
675
676 fn create_completion_messages_request(
684 &self,
685 req: CompletionMessagesRequest,
686 ) -> AnyResult<Request> {
687 let url = self.build_request_api(ApiPath::CompletionMessages);
688 self.client.create_request(url, Method::POST, req)
689 }
690
691 pub async fn completion_messages(
700 &self,
701 mut req_data: CompletionMessagesRequest,
702 ) -> AnyResult<CompletionMessagesResponse> {
703 req_data.response_mode = ResponseMode::Blocking;
704
705 let req = self.create_completion_messages_request(req_data)?;
706 let resp = self.send(req).await?;
707 let text = resp.text().await?;
708 parse_response::<CompletionMessagesResponse>(&text)
709 }
710
711 pub async fn completion_messages_stream(
722 &self,
723 mut req_data: CompletionMessagesRequest,
724 ) -> AnyResult<SseMessageEventStream<impl Stream<Item = Result<Bytes, reqwest::Error>>>> {
725 req_data.response_mode = ResponseMode::Streaming;
726
727 let req = self.create_completion_messages_request(req_data)?;
728 let resp = self.send(req).await?;
729 let stream = resp.bytes_stream().eventsource();
730 let s = SseMessageEventStream::new(stream);
731 Ok(s)
732 }
733
734 pub async fn completion_messages_stop(
743 &self,
744 req_data: StreamTaskStopRequest,
745 ) -> AnyResult<ResultResponse> {
746 self.stream_task_stop(req_data, ApiPath::CompletionMessagesStop)
747 .await
748 }
749}