use crate::ClientError;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{collections::HashMap, fmt::Debug};
use tokio_tungstenite::tungstenite;
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct PromptInfo {
pub exec_info: ExecInfo,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct ExecInfo {
pub queue_remaining: usize,
}
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
pub struct FileInfo {
#[serde(alias = "name")]
pub filename: String,
pub subfolder: String,
pub r#type: String,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct PromptStatus {
pub prompt_id: String,
pub number: usize,
pub node_errors: HashMap<String, Value>,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct History {
pub outputs: HashMap<String, Images>,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct Images {
pub images: Option<Vec<FileInfo>>,
pub gifs: Option<Vec<FileInfo>>,
}
#[non_exhaustive]
pub enum Event {
Comfy(ComfyEvent),
Connection(ConnectionEvent),
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ComfyEvent {
Status {
data: StatusEventData,
sid: Option<String>,
},
Progress {
data: ProgressEventData,
},
Executed {
data: ExecutedEventData,
},
Executing {
data: ExecutingEventData,
},
ExecutionStart {
data: ExecutionStartEventData,
},
ExecutionError {
data: ExecutionErrorEventData,
},
ExecutionCached {
data: ExecutionCachedEventData,
},
ExecutionInterrupted {
data: ExecutionInterruptedEventData,
},
ExecutionSuccess {
data: ExecutionSuccessEventData,
},
#[serde(skip)]
Unknown(Value),
}
#[derive(Debug)]
#[non_exhaustive]
pub enum ConnectionEvent {
WSReconnectSuccess,
WSReconnectError(ClientError),
WSReceiveError(tungstenite::Error),
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct StatusEventData {
pub status: StatusEventStatus,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct StatusEventStatus {
pub exec_info: ExecInfo,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct ProgressEventData {
pub value: usize,
pub max: usize,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct ExecutedOutput {
pub images: Option<Vec<FileInfo>>,
#[serde(flatten)]
pub others: HashMap<String, Value>,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct ExecutedEventData {
pub node: String,
pub prompt_id: String,
pub output: Option<ExecutedOutput>,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct ExecutingEventData {
pub node: Option<String>,
pub display_node: Option<String>,
pub prompt_id: String,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct ExecutionStartEventData {
pub prompt_id: String,
pub timestamp: u64,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct ExecutionErrorEventData {
pub prompt_id: String,
pub node_id: String,
pub node_type: String,
pub executed: Vec<String>,
pub exception_message: String,
pub exception_type: String,
pub traceback: Vec<String>,
pub current_inputs: HashMap<String, Value>,
pub current_outputs: HashMap<String, Value>,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct ExecutionCachedEventData {
pub nodes: Vec<String>,
pub prompt_id: String,
pub timestamp: u64,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct ExecutionInterruptedEventData {
pub prompt_id: String,
pub node_id: String,
pub node_type: String,
pub executed: Vec<String>,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct ExecutionSuccessEventData {
pub prompt_id: String,
}
pub enum Prompt<'a> {
Str(&'a str),
Value(&'a Value),
}
impl<'a> From<&'a str> for Prompt<'a> {
fn from(value: &'a str) -> Self {
Prompt::Str(value)
}
}
impl<'a> From<&'a String> for Prompt<'a> {
fn from(value: &'a String) -> Self {
Prompt::Str(value)
}
}
impl<'a> From<&'a Value> for Prompt<'a> {
fn from(value: &'a Value) -> Self {
Prompt::Value(value)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_serialize_event() {
let ev = ComfyEvent::Status {
data: StatusEventData {
status: StatusEventStatus {
exec_info: ExecInfo { queue_remaining: 0 },
},
},
sid: None,
};
let value = serde_json::to_value(&ev).unwrap();
assert_eq!(
value,
json!({
"type": "status",
"data": {
"status": {
"exec_info": {
"queue_remaining": 0,
}
}
},
"sid": null
})
);
let ev = ComfyEvent::ExecutionStart {
data: ExecutionStartEventData {
prompt_id: "xxxxxx".to_string(),
timestamp: 123456789,
},
};
let value = serde_json::to_value(&ev).unwrap();
assert_eq!(
value,
json!({
"type": "execution_start",
"data": {
"prompt_id": "xxxxxx",
"timestamp": 123456789
}
})
);
}
}