use super::request::{Feedback, FileType};
use anyhow::{anyhow, bail, Result as AnyResult};
use eventsource_stream::EventStream;
use futures::Stream;
use pin_project_lite::pin_project;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use serde_with::{serde_as, EnumMap};
use std::{
collections::HashMap,
fmt::{Display, Formatter, Result as FmtResult},
pin::Pin,
task::{Context, Poll},
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorResponse {
pub code: String,
pub message: String,
pub status: u32,
}
impl Display for ErrorResponse {
fn fmt(&self, f: &mut Formatter) -> FmtResult {
write!(f, "{}", serde_json::to_string(&self).unwrap())
}
}
impl ErrorResponse {
pub fn unknown<T>(message: T) -> Self
where
T: ToString,
{
ErrorResponse {
code: "unknown_error".into(),
message: message.to_string(),
status: 503,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResultResponse {
pub result: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageBase {
pub message_id: String,
pub conversation_id: Option<String>,
pub created_at: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessagesResponse {
#[serde(flatten)]
pub base: MessageBase,
pub event: String,
pub mode: AppMode,
pub answer: String,
pub metadata: HashMap<String, JsonValue>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "event", rename_all = "snake_case")]
pub enum SseMessageEvent {
Message {
#[serde(flatten)]
base: Option<MessageBase>,
id: String,
task_id: String,
answer: String,
#[serde(flatten)]
extra: HashMap<String, JsonValue>,
},
MessageFile {
#[serde(flatten)]
base: Option<MessageBase>,
id: String,
#[serde(rename = "type")]
type_: FileType,
belongs_to: BelongsTo,
url: String,
#[serde(flatten)]
extra: HashMap<String, JsonValue>,
},
MessageEnd {
#[serde(flatten)]
base: Option<MessageBase>,
id: String,
task_id: String,
metadata: HashMap<String, JsonValue>,
#[serde(flatten)]
extra: HashMap<String, JsonValue>,
},
MessageReplace {
#[serde(flatten)]
base: Option<MessageBase>,
task_id: String,
answer: String,
#[serde(flatten)]
extra: HashMap<String, JsonValue>,
},
WorkflowStarted {
#[serde(flatten)]
base: Option<MessageBase>,
task_id: String,
workflow_run_id: String,
data: WorkflowStartedData,
#[serde(flatten)]
extra: HashMap<String, JsonValue>,
},
NodeStarted {
#[serde(flatten)]
base: Option<MessageBase>,
task_id: String,
workflow_run_id: String,
data: NodeStartedData,
#[serde(flatten)]
extra: HashMap<String, JsonValue>,
},
NodeFinished {
#[serde(flatten)]
base: Option<MessageBase>,
task_id: String,
workflow_run_id: String,
data: NodeFinishedData,
#[serde(flatten)]
extra: HashMap<String, JsonValue>,
},
WorkflowFinished {
#[serde(flatten)]
base: Option<MessageBase>,
task_id: String,
workflow_run_id: String,
data: WorkflowFinishedData,
#[serde(flatten)]
extra: HashMap<String, JsonValue>,
},
AgentMessage {
#[serde(flatten)]
base: Option<MessageBase>,
id: String,
task_id: String,
answer: String,
#[serde(flatten)]
extra: HashMap<String, JsonValue>,
},
AgentThought {
#[serde(flatten)]
base: Option<MessageBase>,
id: String,
task_id: String,
position: u32,
thought: String,
observation: String,
tool: String,
tool_labels: JsonValue,
tool_input: String,
message_files: Vec<String>,
},
Error {
#[serde(flatten)]
base: Option<MessageBase>,
status: u32,
code: String,
message: String,
#[serde(flatten)]
extra: HashMap<String, JsonValue>,
},
Ping,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowStartedData {
pub id: String,
pub workflow_id: String,
pub sequence_number: u32,
pub inputs: JsonValue,
pub created_at: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowFinishedData {
pub id: String,
pub workflow_id: String,
pub status: FinishedStatus,
pub outputs: Option<JsonValue>,
pub error: Option<String>,
pub elapsed_time: Option<f64>,
pub total_tokens: Option<u32>,
pub total_steps: u32,
pub created_at: u64,
pub finished_at: u64,
#[serde(flatten)]
pub extra: HashMap<String, JsonValue>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeStartedData {
pub id: String,
pub node_id: String,
pub node_type: String,
pub title: String,
pub index: u32,
pub predecessor_node_id: Option<String>,
pub inputs: Option<JsonValue>,
pub created_at: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeFinishedData {
pub id: String,
pub node_id: String,
pub index: u32,
pub predecessor_node_id: Option<String>,
pub inputs: Option<JsonValue>,
pub process_data: Option<JsonValue>,
pub outputs: Option<JsonValue>,
pub status: FinishedStatus,
pub error: Option<String>,
pub elapsed_time: Option<f64>,
pub execution_metadata: Option<ExecutionMetadata>,
pub created_at: u64,
#[serde(flatten)]
pub extra: HashMap<String, JsonValue>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FinishedStatus {
Running,
Succeeded,
Failed,
Stopped,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionMetadata {
pub total_tokens: Option<u32>,
pub total_price: Option<String>,
pub currency: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "kebab-case")]
pub enum AppMode {
Completion,
Workflow,
Chat,
AdvancedChat,
AgentChat,
Channel,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessagesSuggestedResponse {
pub result: String,
pub data: Vec<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessagesResponse {
pub limit: u32,
pub has_more: bool,
pub data: Vec<MessageData>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageData {
pub id: String,
pub conversation_id: String,
pub inputs: JsonValue,
pub query: String,
pub answer: String,
pub message_files: Vec<MessageFile>,
pub feedback: Option<MessageFeedback>,
pub retriever_resources: Vec<JsonValue>,
pub created_at: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageFile {
pub id: String,
#[serde(rename = "type")]
pub type_: FileType,
pub url: String,
pub belongs_to: BelongsTo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum BelongsTo {
User,
Assistant,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageFeedback {
pub rating: Feedback,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ConversationsResponse {
pub has_more: bool,
pub limit: u32,
pub data: Vec<ConversationData>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ConversationData {
pub id: String,
pub name: String,
pub inputs: HashMap<String, String>,
pub introduction: String,
pub created_at: u64,
}
#[serde_as]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ParametersResponse {
pub opening_statement: String,
pub suggested_questions: Vec<String>,
pub suggested_questions_after_answer: ParameterSuggestedQuestionsAfterAnswer,
pub speech_to_text: ParameterSpeechToText,
pub retriever_resource: ParameterRetrieverResource,
pub annotation_reply: ParameterAnnotationReply,
pub user_input_form: Vec<ParameterUserInputFormItem>,
#[serde_as(as = "EnumMap")]
pub file_upload: Vec<ParameterFileUploadItem>,
pub system_parameters: SystemParameters,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ParameterSuggestedQuestionsAfterAnswer {
pub enabled: bool,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ParameterSpeechToText {
pub enabled: bool,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ParameterRetrieverResource {
pub enabled: bool,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ParameterAnnotationReply {
pub enabled: bool,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ParameterUserInputFormItem {
#[serde(rename = "text-input")]
TextInput {
label: String,
variable: String,
required: bool,
},
Paragraph {
label: String,
variable: String,
required: bool,
},
Number {
label: String,
variable: String,
required: bool,
},
Select {
label: String,
variable: String,
required: bool,
options: Vec<String>,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ParameterFileUploadItem {
Image {
enabled: bool,
number_limits: u32,
transfer_methods: Vec<TransferMethod>,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum TransferMethod {
RemoteUrl,
LocalFile,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SystemParameters {
pub image_file_size_limit: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MetaResponse {
pub tool_icons: HashMap<String, ToolIcon>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ToolIcon {
Url(String),
Emoji { background: String, content: String },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AudioToTextResponse {
pub text: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FilesUploadResponse {
pub id: String,
pub name: String,
pub size: u64,
pub extension: String,
pub mime_type: String,
pub created_by: String,
pub created_at: u64,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct WorkflowsRunResponse {
pub workflow_run_id: String,
pub task_id: String,
pub data: WorkflowFinishedData,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionMessagesResponse {
#[serde(flatten)]
pub base: MessageBase,
pub task_id: String,
pub event: String,
pub mode: AppMode,
pub answer: String,
pub metadata: HashMap<String, JsonValue>,
}
pin_project! {
pub struct SseMessageEventStream<S>
{
#[pin]
stream: EventStream<S>,
terminated: bool,
}
}
impl<S> SseMessageEventStream<S> {
pub fn new(stream: EventStream<S>) -> Self {
Self {
stream,
terminated: false,
}
}
}
impl<S, B, E> Stream for SseMessageEventStream<S>
where
S: Stream<Item = Result<B, E>>,
B: AsRef<[u8]>,
E: Display,
{
type Item = AnyResult<SseMessageEvent>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut this = self.project();
if *this.terminated {
return Poll::Ready(None);
}
loop {
match this.stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(event))) => match event.event.as_str() {
"message" => match serde_json::from_str::<SseMessageEvent>(&event.data) {
Ok(msg_event) => return Poll::Ready(Some(Ok(msg_event))),
Err(e) => return Poll::Ready(Some(Err(e.into()))),
},
_ => {}
},
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(anyhow!(e.to_string())))),
Poll::Ready(None) => {
*this.terminated = true;
return Poll::Ready(None);
}
Poll::Pending => return Poll::Pending,
}
}
}
}
pub(crate) fn parse_response<T>(text: &str) -> AnyResult<T>
where
T: serde::de::DeserializeOwned,
{
if let Ok(data) = serde_json::from_str::<T>(text) {
Ok(data)
} else {
parse_error_response(text)
}
}
pub(crate) fn parse_error_response<T>(text: &str) -> AnyResult<T> {
if let Ok(err) = serde_json::from_str::<ErrorResponse>(text) {
bail!(err)
} else {
bail!(ErrorResponse::unknown(text))
}
}