cmfy 0.3.0

A CLI companion app for Comfy UI
Documentation
use crate::Result;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::collections::BTreeMap;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Prompt {
    pub index: u64,
    pub uuid: String,
    pub nodes: PromptNodes,
    pub png_info: serde_json::Value,
    pub output_nodes: Vec<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(transparent)]
pub struct PromptNodes(pub BTreeMap<String, Node<serde_json::Value>>);

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Node<I> {
    pub class_type: String,
    pub inputs: I,
}

#[derive(Deserialize, Debug, Clone)]
pub struct SubmitResponse {
    pub number: u32,
    pub prompt_id: String,
    pub node_errors: serde_json::Value,
}

pub trait ClassType: Serialize {
    const CLASS_TYPE: &str;
}

impl PromptNodes {
    pub fn put<N: ClassType + Serialize>(&mut self, id: String, node: N) -> Result<()> {
        let class_type = N::CLASS_TYPE.into();
        let inputs = serde_json::to_value(&node)?;
        self.0.insert(id, Node { class_type, inputs });
        Ok(())
    }

    pub fn take<N: ClassType + DeserializeOwned>(&mut self, id: String) -> Result<N> {
        let node = self.0.remove(&id).ok_or(format!("node id '{}' not found", id))?;
        Ok(serde_json::from_value(node.inputs)?)
    }

    pub fn all_by_class<N>(&self) -> Result<BTreeMap<String, N>>
    where
        N: DeserializeOwned + ClassType,
    {
        Ok(self
            .0
            .iter()
            .filter(|(_, node)| node.class_type == N::CLASS_TYPE)
            .map(|(id, node)| {
                let serialized = serde_json::to_string(&node)?;
                let parsed: Node<N> = serde_json::from_str(&serialized)?;
                Ok::<_, serde_json::Error>((id.clone(), parsed.inputs))
            })
            .collect::<serde_json::Result<_>>()?)
    }

    pub fn first_by_class<N>(&self) -> Result<(String, N)>
    where
        N: DeserializeOwned + ClassType,
    {
        Ok(self
            .all_by_class::<N>()?
            .into_iter()
            .next()
            .ok_or(format!("node with class '{}' not found", N::CLASS_TYPE))?)
    }

    pub fn change_first_by_class<N, C>(&mut self, change: C) -> Result<()>
    where
        N: DeserializeOwned + Serialize + ClassType + std::fmt::Debug,
        C: Fn(&mut N),
    {
        let (id, mut node) = self.first_by_class()?;
        change(&mut node);
        self.put(id, node)
    }
}