1use super::request::{Feedback, FileType};
68use anyhow::{anyhow, bail, Result as AnyResult};
69use eventsource_stream::EventStream;
70use futures::Stream;
71use pin_project_lite::pin_project;
72use serde::{Deserialize, Serialize};
73use serde_json::Value as JsonValue;
74use serde_with::{serde_as, EnumMap};
75use std::{
76 collections::HashMap,
77 fmt::{Display, Formatter, Result as FmtResult},
78 pin::Pin,
79 task::{Context, Poll},
80};
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ErrorResponse {
85 pub code: String,
86 pub message: String,
87 pub status: u32,
88}
89
90impl Display for ErrorResponse {
91 fn fmt(&self, f: &mut Formatter) -> FmtResult {
92 write!(f, "{}", serde_json::to_string(&self).unwrap())
93 }
94}
95
96impl ErrorResponse {
97 pub fn unknown<T>(message: T) -> Self
98 where
99 T: ToString,
100 {
101 ErrorResponse {
102 code: "unknown_error".into(),
103 message: message.to_string(),
104 status: 503,
105 }
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ResultResponse {
112 pub result: String,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct MessageBase {
118 pub message_id: String,
120 pub conversation_id: Option<String>,
122 pub created_at: u64,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ChatMessagesResponse {
129 #[serde(flatten)]
131 pub base: MessageBase,
132 pub event: String,
134 pub mode: AppMode,
136 pub answer: String,
138 pub metadata: HashMap<String, JsonValue>,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144#[serde(tag = "event", rename_all = "snake_case")]
145pub enum SseMessageEvent {
146 Message {
148 #[serde(flatten)]
150 base: Option<MessageBase>,
151 id: String,
153 task_id: String,
155 answer: String,
157 #[serde(flatten)]
158 extra: HashMap<String, JsonValue>,
159 },
160 MessageFile {
162 #[serde(flatten)]
164 base: Option<MessageBase>,
165 id: String,
167 #[serde(rename = "type")]
169 type_: FileType,
170 belongs_to: BelongsTo,
172 url: String,
174 #[serde(flatten)]
175 extra: HashMap<String, JsonValue>,
176 },
177 MessageEnd {
179 #[serde(flatten)]
181 base: Option<MessageBase>,
182 id: String,
184 task_id: String,
186 metadata: HashMap<String, JsonValue>,
188 #[serde(flatten)]
189 extra: HashMap<String, JsonValue>,
190 },
191 MessageReplace {
194 #[serde(flatten)]
196 base: Option<MessageBase>,
197 task_id: String,
199 answer: String,
201 #[serde(flatten)]
202 extra: HashMap<String, JsonValue>,
203 },
204 WorkflowStarted {
206 #[serde(flatten)]
208 base: Option<MessageBase>,
209 task_id: String,
211 workflow_run_id: String,
213 data: WorkflowStartedData,
215 #[serde(flatten)]
216 extra: HashMap<String, JsonValue>,
217 },
218
219 NodeStarted {
221 #[serde(flatten)]
223 base: Option<MessageBase>,
224 task_id: String,
226 workflow_run_id: String,
228 data: NodeStartedData,
230 #[serde(flatten)]
231 extra: HashMap<String, JsonValue>,
232 },
233 NodeFinished {
235 #[serde(flatten)]
237 base: Option<MessageBase>,
238 task_id: String,
240 workflow_run_id: String,
242 data: NodeFinishedData,
244 #[serde(flatten)]
245 extra: HashMap<String, JsonValue>,
246 },
247 WorkflowFinished {
249 #[serde(flatten)]
251 base: Option<MessageBase>,
252 task_id: String,
254 workflow_run_id: String,
256 data: WorkflowFinishedData,
258 #[serde(flatten)]
259 extra: HashMap<String, JsonValue>,
260 },
261 AgentMessage {
263 #[serde(flatten)]
265 base: Option<MessageBase>,
266 id: String,
268 task_id: String,
270 answer: String,
272 #[serde(flatten)]
273 extra: HashMap<String, JsonValue>,
274 },
275 AgentThought {
277 #[serde(flatten)]
279 base: Option<MessageBase>,
280 id: String,
282 task_id: String,
284 position: u32,
286 thought: String,
288 observation: String,
290 tool: String,
292 tool_labels: JsonValue,
294 tool_input: String,
296 message_files: Vec<String>,
298 },
299 Error {
301 #[serde(flatten)]
303 base: Option<MessageBase>,
304 status: u32,
306 code: String,
308 message: String,
310 #[serde(flatten)]
311 extra: HashMap<String, JsonValue>,
312 },
313 Ping,
315}
316
317#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct WorkflowStartedData {
320 pub id: String,
322 pub workflow_id: String,
324 pub sequence_number: u32,
326 pub inputs: JsonValue,
328 pub created_at: u64,
330}
331
332#[derive(Debug, Clone, Serialize, Deserialize)]
334pub struct WorkflowFinishedData {
335 pub id: String,
337 pub workflow_id: String,
339 pub status: FinishedStatus,
341 pub outputs: Option<JsonValue>,
343 pub error: Option<String>,
345 pub elapsed_time: Option<f64>,
347 pub total_tokens: Option<u32>,
349 pub total_steps: u32,
351 pub created_at: u64,
353 pub finished_at: u64,
355 #[serde(flatten)]
356 pub extra: HashMap<String, JsonValue>,
357}
358
359#[derive(Debug, Clone, Serialize, Deserialize)]
361pub struct NodeStartedData {
362 pub id: String,
364 pub node_id: String,
366 pub node_type: String,
368 pub title: String,
370 pub index: u32,
372 pub predecessor_node_id: Option<String>,
374 pub inputs: Option<JsonValue>,
376 pub created_at: u64,
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize)]
382pub struct NodeFinishedData {
383 pub id: String,
385 pub node_id: String,
387 pub index: u32,
389 pub predecessor_node_id: Option<String>,
391 pub inputs: Option<JsonValue>,
393 pub process_data: Option<JsonValue>,
395 pub outputs: Option<JsonValue>,
397 pub status: FinishedStatus,
399 pub error: Option<String>,
401 pub elapsed_time: Option<f64>,
403 pub execution_metadata: Option<ExecutionMetadata>,
405 pub created_at: u64,
407 #[serde(flatten)]
408 pub extra: HashMap<String, JsonValue>,
409}
410
411#[derive(Debug, Clone, Serialize, Deserialize)]
413#[serde(rename_all = "snake_case")]
414pub enum FinishedStatus {
415 Running,
416 Succeeded,
417 Failed,
418 Stopped,
419}
420
421#[derive(Debug, Clone, Serialize, Deserialize)]
423pub struct ExecutionMetadata {
424 pub total_tokens: Option<u32>,
426 pub total_price: Option<String>,
428 pub currency: Option<String>,
430}
431
432#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
434#[serde(rename_all = "kebab-case")]
435pub enum AppMode {
436 Completion,
437 Workflow,
438 Chat,
439 AdvancedChat,
440 AgentChat,
441 Channel,
442}
443
444#[derive(Debug, Clone, Deserialize, Serialize)]
446pub struct MessagesSuggestedResponse {
447 pub result: String,
448 pub data: Vec<String>,
450}
451
452#[derive(Debug, Clone, Deserialize, Serialize)]
454pub struct MessagesResponse {
455 pub limit: u32,
457 pub has_more: bool,
459 pub data: Vec<MessageData>,
461}
462
463#[derive(Debug, Clone, Deserialize, Serialize)]
465pub struct MessageData {
466 pub id: String,
468 pub conversation_id: String,
470 pub inputs: JsonValue,
472 pub query: String,
474 pub answer: String,
476 pub message_files: Vec<MessageFile>,
478 pub feedback: Option<MessageFeedback>,
480 pub retriever_resources: Vec<JsonValue>,
482 pub created_at: u64,
484}
485
486#[derive(Debug, Clone, Serialize, Deserialize)]
488pub struct MessageFile {
489 pub id: String,
491 #[serde(rename = "type")]
493 pub type_: FileType,
494 pub url: String,
496 pub belongs_to: BelongsTo,
498}
499
500#[derive(Debug, Clone, Serialize, Deserialize)]
502#[serde(rename_all = "snake_case")]
503pub enum BelongsTo {
504 User,
505 Assistant,
506}
507
508#[derive(Debug, Clone, Deserialize, Serialize)]
510pub struct MessageFeedback {
511 pub rating: Feedback,
513}
514
515#[derive(Debug, Clone, Deserialize, Serialize)]
517pub struct ConversationsResponse {
518 pub has_more: bool,
520 pub limit: u32,
522 pub data: Vec<ConversationData>,
524}
525
526#[derive(Debug, Clone, Deserialize, Serialize)]
528pub struct ConversationData {
529 pub id: String,
531 pub name: String,
533 pub inputs: HashMap<String, String>,
535 pub introduction: String,
537 pub created_at: u64,
539}
540
541#[serde_as]
542#[derive(Debug, Clone, Deserialize, Serialize)]
543pub struct ParametersResponse {
545 pub opening_statement: String,
547 pub suggested_questions: Vec<String>,
549 pub suggested_questions_after_answer: ParameterSuggestedQuestionsAfterAnswer,
551 pub speech_to_text: ParameterSpeechToText,
553 pub retriever_resource: ParameterRetrieverResource,
555 pub annotation_reply: ParameterAnnotationReply,
557 pub user_input_form: Vec<ParameterUserInputFormItem>,
559 #[serde_as(as = "EnumMap")]
561 pub file_upload: Vec<ParameterFileUploadItem>,
562 pub system_parameters: SystemParameters,
563}
564
565#[derive(Debug, Clone, Deserialize, Serialize)]
566pub struct ParameterSuggestedQuestionsAfterAnswer {
568 pub enabled: bool,
570}
571
572#[derive(Debug, Clone, Deserialize, Serialize)]
573pub struct ParameterSpeechToText {
575 pub enabled: bool,
577}
578
579#[derive(Debug, Clone, Deserialize, Serialize)]
580pub struct ParameterRetrieverResource {
582 pub enabled: bool,
584}
585
586#[derive(Debug, Clone, Deserialize, Serialize)]
587pub struct ParameterAnnotationReply {
589 pub enabled: bool,
591}
592
593#[derive(Debug, Clone, Deserialize, Serialize)]
594#[serde(rename_all = "snake_case")]
595pub enum ParameterUserInputFormItem {
597 #[serde(rename = "text-input")]
599 TextInput {
600 label: String,
602 variable: String,
604 required: bool,
606 },
607 Paragraph {
609 label: String,
611 variable: String,
613 required: bool,
615 },
616 Number {
618 label: String,
620 variable: String,
622 required: bool,
624 },
625 Select {
626 label: String,
628 variable: String,
630 required: bool,
632 options: Vec<String>,
634 },
635}
636
637#[derive(Debug, Clone, Deserialize, Serialize)]
638#[serde(rename_all = "snake_case")]
639pub enum ParameterFileUploadItem {
641 Image {
643 enabled: bool,
645 number_limits: u32,
647 transfer_methods: Vec<TransferMethod>,
649 },
650}
651
652#[derive(Debug, Clone, Deserialize, Serialize)]
654#[serde(rename_all = "snake_case")]
655pub enum TransferMethod {
656 RemoteUrl,
657 LocalFile,
658}
659
660#[derive(Debug, Clone, Deserialize, Serialize)]
662pub struct SystemParameters {
663 pub image_file_size_limit: String,
665}
666
667#[derive(Debug, Clone, Deserialize, Serialize)]
668pub struct MetaResponse {
670 pub tool_icons: HashMap<String, ToolIcon>,
671}
672
673#[derive(Debug, Clone, Deserialize, Serialize)]
675#[serde(untagged)]
676pub enum ToolIcon {
677 Url(String),
678 Emoji { background: String, content: String },
679}
680
681#[derive(Debug, Clone, Deserialize, Serialize)]
683pub struct AudioToTextResponse {
684 pub text: String,
686}
687
688#[derive(Debug, Clone, Deserialize, Serialize)]
690pub struct FilesUploadResponse {
691 pub id: String,
693 pub name: String,
695 pub size: u64,
697 pub extension: String,
699 pub mime_type: String,
701 pub created_by: String,
703 pub created_at: u64,
705}
706
707#[derive(Debug, Clone, Deserialize, Serialize)]
709pub struct WorkflowsRunResponse {
710 pub workflow_run_id: String,
712 pub task_id: String,
714 pub data: WorkflowFinishedData,
716}
717
718#[derive(Debug, Clone, Serialize, Deserialize)]
720pub struct CompletionMessagesResponse {
721 #[serde(flatten)]
723 pub base: MessageBase,
724 pub task_id: String,
726 pub event: String,
728 pub mode: AppMode,
730 pub answer: String,
732 pub metadata: HashMap<String, JsonValue>,
734}
735
736pin_project! {
737 pub struct SseMessageEventStream<S>
739 {
740 #[pin]
741 stream: EventStream<S>,
742 terminated: bool,
743 }
744}
745
746impl<S> SseMessageEventStream<S> {
747 pub fn new(stream: EventStream<S>) -> Self {
749 Self {
750 stream,
751 terminated: false,
752 }
753 }
754}
755
756impl<S, B, E> Stream for SseMessageEventStream<S>
757where
758 S: Stream<Item = Result<B, E>>,
759 B: AsRef<[u8]>,
760 E: Display,
761{
762 type Item = AnyResult<SseMessageEvent>;
763
764 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
765 let mut this = self.project();
766 if *this.terminated {
767 return Poll::Ready(None);
768 }
769
770 loop {
771 match this.stream.as_mut().poll_next(cx) {
772 Poll::Ready(Some(Ok(event))) => match event.event.as_str() {
773 "message" => match serde_json::from_str::<SseMessageEvent>(&event.data) {
774 Ok(msg_event) => return Poll::Ready(Some(Ok(msg_event))),
775 Err(e) => return Poll::Ready(Some(Err(e.into()))),
776 },
777 _ => {}
778 },
779 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(anyhow!(e.to_string())))),
780 Poll::Ready(None) => {
781 *this.terminated = true;
782 return Poll::Ready(None);
783 }
784 Poll::Pending => return Poll::Pending,
785 }
786 }
787 }
788}
789
790pub(crate) fn parse_response<T>(text: &str) -> AnyResult<T>
792where
793 T: serde::de::DeserializeOwned,
794{
795 if let Ok(data) = serde_json::from_str::<T>(text) {
796 Ok(data)
797 } else {
798 parse_error_response(text)
799 }
800}
801
802pub(crate) fn parse_error_response<T>(text: &str) -> AnyResult<T> {
804 if let Ok(err) = serde_json::from_str::<ErrorResponse>(text) {
805 bail!(err)
806 } else {
807 bail!(ErrorResponse::unknown(text))
808 }
809}