comfyui-client 0.1.0

Rust client for comfyui.
Documentation
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{collections::HashMap, fmt::Debug};

/// Contains information about a prompt, including its execution details.
#[derive(Serialize, Deserialize, Debug)]
pub struct PromptInfo {
    /// Execution information related to the prompt.
    pub exec_info: ExecInfo,
}

/// Contains execution details such as the remaining queue length.
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecInfo {
    /// The number of remaining tasks in the execution queue.
    pub queue_remaining: usize,
}

/// Represents file information including filename, subfolder, and file type.
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
pub struct FileInfo {
    /// The name of the file.
    #[serde(alias = "name")]
    pub filename: String,
    /// The subfolder where the file is located.
    pub subfolder: String,
    /// The type of the file.
    pub r#type: String,
}

/// Represents a prompt with an identifier, a number, and potential node errors.
#[derive(Serialize, Deserialize, Debug)]
pub struct Prompt {
    /// Unique identifier for the prompt.
    pub prompt_id: String,
    /// A numeric identifier for the prompt.
    pub number: usize,
    /// A mapping of node identifiers to error details in JSON format.
    pub node_errors: HashMap<String, Value>,
}

/// Represents the history of outputs for a prompt.
#[derive(Serialize, Deserialize, Debug)]
pub struct History {
    /// A mapping of output identifiers to their corresponding images.
    pub outputs: HashMap<String, Images>,
}

/// Contains an optional list of image file information.
#[derive(Serialize, Deserialize, Debug)]
pub struct Images {
    /// A vector of file information objects, if available.
    pub images: Option<Vec<FileInfo>>,
}

/// Represents events emitted by the system.
///
/// The enum variants correspond to different event types. The `Unknown` variant
/// holds raw JSON data for unrecognized events.
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type", content = "data")]
#[serde(rename_all = "snake_case")]
pub enum Event {
    /// A status event containing execution information.
    Status(StatusEvent),
    /// A progress event indicating current progress.
    Progress(ProgressEvent),
    /// An event indicating that a node has completed execution along with its
    /// output.
    Executed(ExecutedEvent),
    /// An event indicating that a node is currently executing.
    Executing(ExecutingEvent),
    /// An event signaling the start of execution for a prompt.
    ExecutionStart(ExecutionStartEvent),
    /// An event signaling that an error occurred during execution.
    ExecutionError(ExecutionErrorEvent),
    /// An event indicating that the execution results were retrieved from the
    /// cache.
    ExecutionCached(ExecutionCachedEvent),
    /// An event indicating that the execution was interrupted.
    ExecutionInterrupted(ExecutionInterruptedEvent),
    /// An unknown event type that encapsulates raw JSON data.
    #[serde(skip)]
    Unknown(Value),
}

/// Event payload for a status event, containing execution information.
#[derive(Serialize, Deserialize, Debug)]
pub struct StatusEvent {
    /// Execution information associated with the event.
    pub status: ExecInfo,
}

/// Event payload for a progress update, including current value and maximum
/// value.
#[derive(Serialize, Deserialize, Debug)]
pub struct ProgressEvent {
    /// The current progress value.
    pub value: usize,
    /// The maximum progress value.
    pub max: usize,
}

/// Represents the output of an executed node.
#[derive(Serialize, Deserialize, Debug)]
pub struct Output {
    /// A list of image file information objects.
    pub images: Vec<FileInfo>,
}

/// Event payload for a completed execution, including the node identifier,
/// prompt ID, and output data.
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecutedEvent {
    /// Identifier of the node that completed execution.
    pub node: String,
    /// The prompt ID associated with the execution.
    pub prompt_id: String,
    /// The output generated by the executed node.
    pub output: Output,
}

/// Event payload for an execution in progress, including the node identifier
/// and prompt ID.
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecutingEvent {
    /// Identifier of the node currently executing.
    pub node: String,
    /// The prompt ID associated with the execution.
    pub prompt_id: String,
}

/// Event payload indicating that the execution has started.
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecutionStartEvent {
    /// The prompt ID for which the execution has started.
    pub prompt_id: String,
}

/// Event payload for an execution error, containing details about the error and
/// its context.
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecutionErrorEvent {
    /// The prompt ID associated with the error.
    pub prompt_id: String,
    /// The identifier of the node where the error occurred.
    pub node_id: String,
    /// The type of the node where the error occurred.
    pub node_type: String,
    /// A list of node identifiers that were executed before the error.
    pub executed: Vec<String>,
    /// The error message from the exception.
    pub exception_message: String,
    /// The type of the exception.
    pub exception_type: String,
    /// A traceback of the error as a list of strings.
    pub traceback: Vec<String>,
    /// The current input values at the time of the error.
    pub current_inputs: HashMap<String, Value>,
    /// The current output values at the time of the error.
    pub current_outputs: HashMap<String, Value>,
}

/// Event payload indicating that the execution result was obtained from the
/// cache.
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecutionCachedEvent {
    /// A list of node identifiers that were cached.
    pub nodes: Vec<String>,
    /// The prompt ID associated with the cached execution.
    pub prompt_id: String,
}

/// Event payload for an interrupted execution, containing details about the
/// interruption.
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecutionInterruptedEvent {
    /// The prompt ID associated with the interruption.
    pub prompt_id: String,
    /// The identifier of the node where the execution was interrupted.
    pub node_id: String,
    /// The type of the node that was interrupted.
    pub node_type: String,
    /// A list of node identifiers that were executed before the interruption.
    pub executed: Vec<String>,
}

#[cfg(test)]
mod tests {
    use super::*;
    use serde_json::json;

    /// Tests serialization of different event types.
    #[test]
    fn test_serialize_event() {
        let ev = Event::Status(StatusEvent {
            status: ExecInfo { queue_remaining: 0 },
        });
        let value = serde_json::to_value(&ev).unwrap();
        assert_eq!(
            value,
            json!({
                "type": "status",
                "data": {
                    "status": {
                        "queue_remaining": 0,
                    }
                }
            })
        );

        let ev = Event::ExecutionStart(ExecutionStartEvent {
            prompt_id: "xxxxxx".to_string(),
        });
        let value = serde_json::to_value(&ev).unwrap();
        assert_eq!(
            value,
            json!({
                "type": "execution_start",
                "data": {
                    "prompt_id": "xxxxxx",
                }
            })
        );
    }
}