cmfy 0.4.1

A CLI companion app for Comfy UI
Documentation
use crate::{
    dto::{self, PromptBatch},
    error::Result,
    websocket::MessageStream,
};
use reqwest::Url;
use serde::{de::DeserializeOwned, Serialize};

#[derive(Clone, Debug)]
pub struct Client {
    client: reqwest::Client,
    pub hostname: String,
    pub port: u32,
    pub id: String,
}

impl Client {
    pub fn new(hostname: impl AsRef<str>, port: u32, id: String) -> Self {
        let server = hostname.as_ref().to_string();
        let client = reqwest::Client::new();
        Self {
            client,
            hostname: server,
            port,
            id,
        }
    }

    pub fn base_url(&self) -> Result<Url> {
        let address = format!("http://{}:{}", self.hostname, self.port);
        let url = Url::parse(address.as_str())?;
        Ok(url)
    }

    pub fn url_for_image(&self, image: &dto::Image) -> Url {
        let params = serde_urlencoded::to_string(image).unwrap();
        let address = format!("http://{}:{}/api/view?{}", self.hostname, self.port, params);
        Url::parse(address.as_str()).unwrap()
    }

    pub async fn get<R: DeserializeOwned>(&self, route: impl AsRef<str>) -> Result<R> {
        let url = format!("http://{}:{}/{}", self.hostname, self.port, route.as_ref());
        let response = self.client.get(url).send().await?;
        let body = response.error_for_status()?.bytes().await?;
        let parsed = serde_json::from_slice(&body)?;
        Ok(parsed)
    }

    pub async fn post<R: DeserializeOwned>(&self, route: impl AsRef<str>, payload: &impl Serialize) -> Result<Option<R>> {
        let url = format!("http://{}:{}/{}", self.hostname, self.port, route.as_ref());
        let body = serde_json::to_string(payload)?;
        let response = self.client.post(url).body(body).send().await?;
        let body = response.error_for_status()?.bytes().await?;
        if body.is_empty() {
            Ok(None)
        } else {
            let parsed = serde_json::from_slice(&body)?;
            Ok(Some(parsed))
        }
    }

    pub async fn system_stats(&self) -> Result<dto::SystemStats> {
        self.get("system_stats").await
    }

    pub async fn history(&self) -> Result<dto::History> {
        self.get("history").await
    }

    pub async fn queue(&self) -> Result<dto::Queue> {
        self.get("queue").await
    }

    async fn clear(&self, route: impl AsRef<str>) -> Result<()> {
        let payload = serde_json::json!({"clear":true});
        let response: Option<()> = self.post(route, &payload).await?;
        assert!(response.is_none());
        Ok(())
    }

    pub async fn clear_queue(&self) -> Result<()> {
        self.clear("queue").await
    }

    pub async fn clear_history(&self) -> Result<()> {
        self.clear("history").await
    }

    pub async fn delete_from_history<'a>(&self, prompt_id: impl AsRef<str>) -> Result<()> {
        let delete = vec![prompt_id.as_ref()];
        let payload = serde_json::json!({"delete": delete});
        let response: Option<()> = self.post("history", &payload).await?;
        assert!(response.is_none());
        Ok(())
    }

    pub async fn cancel_running_prompt(&self) -> Result<()> {
        let payload = serde_json::Value::Null;
        let response: Option<()> = self.post("interrupt", &payload).await?;
        assert!(response.is_none());
        Ok(())
    }

    pub async fn submit(&self, nodes: &dto::PromptNodes) -> Result<dto::SubmitResponse> {
        let client_id = self.id.to_string();
        let payload = serde_json::json!({
            "client_id": client_id,
            "prompt": nodes
        });
        let response = self.post("prompt", &payload).await?;
        response.ok_or("invalid response".into())
    }

    pub async fn collect_prompt_batch(&self, history: bool, queue: bool) -> Result<PromptBatch> {
        let mut batch = vec![];
        if history {
            let history = self.history().await?;
            batch.extend(history.into_batch_entries())
        }
        if queue {
            let queue = self.queue().await?;
            batch.extend(queue.into_batch_entries())
        }
        batch.sort_by(|l, r| l.inner.index.cmp(&r.inner.index));
        Ok(batch)
    }

    pub async fn listen(&self) -> Result<MessageStream> {
        let address = format!("ws://{}:{}/ws?clientId={}", self.hostname, self.port, self.id);
        MessageStream::open(address).await
    }
}