use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{collections::HashMap, fmt::Debug};
#[derive(Serialize, Deserialize, Debug)]
pub struct PromptInfo {
pub exec_info: ExecInfo,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecInfo {
pub queue_remaining: usize,
}
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
pub struct FileInfo {
#[serde(alias = "name")]
pub filename: String,
pub subfolder: String,
pub r#type: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Prompt {
pub prompt_id: String,
pub number: usize,
pub node_errors: HashMap<String, Value>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct History {
pub outputs: HashMap<String, Images>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Images {
pub images: Option<Vec<FileInfo>>,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type", content = "data")]
#[serde(rename_all = "snake_case")]
pub enum Event {
Status(StatusEvent),
Progress(ProgressEvent),
Executed(ExecutedEvent),
Executing(ExecutingEvent),
ExecutionStart(ExecutionStartEvent),
ExecutionError(ExecutionErrorEvent),
ExecutionCached(ExecutionCachedEvent),
ExecutionInterrupted(ExecutionInterruptedEvent),
#[serde(skip)]
Unknown(Value),
}
#[derive(Serialize, Deserialize, Debug)]
pub struct StatusEvent {
pub status: ExecInfo,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ProgressEvent {
pub value: usize,
pub max: usize,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Output {
pub images: Vec<FileInfo>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecutedEvent {
pub node: String,
pub prompt_id: String,
pub output: Output,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecutingEvent {
pub node: String,
pub prompt_id: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecutionStartEvent {
pub prompt_id: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecutionErrorEvent {
pub prompt_id: String,
pub node_id: String,
pub node_type: String,
pub executed: Vec<String>,
pub exception_message: String,
pub exception_type: String,
pub traceback: Vec<String>,
pub current_inputs: HashMap<String, Value>,
pub current_outputs: HashMap<String, Value>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecutionCachedEvent {
pub nodes: Vec<String>,
pub prompt_id: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecutionInterruptedEvent {
pub prompt_id: String,
pub node_id: String,
pub node_type: String,
pub executed: Vec<String>,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_serialize_event() {
let ev = Event::Status(StatusEvent {
status: ExecInfo { queue_remaining: 0 },
});
let value = serde_json::to_value(&ev).unwrap();
assert_eq!(
value,
json!({
"type": "status",
"data": {
"status": {
"queue_remaining": 0,
}
}
})
);
let ev = Event::ExecutionStart(ExecutionStartEvent {
prompt_id: "xxxxxx".to_string(),
});
let value = serde_json::to_value(&ev).unwrap();
assert_eq!(
value,
json!({
"type": "execution_start",
"data": {
"prompt_id": "xxxxxx",
}
})
);
}
}