#![warn(rust_2018_idioms, missing_docs)]
#![warn(clippy::dbg_macro, clippy::print_stdout)]
#![doc = include_str!("../README.md")]
pub mod errors;
pub mod meta;
pub use crate::errors::{ClientError, ClientResult};
use crate::meta::{FileInfo, PromptInfo};
use bytes::Bytes;
use errors::{ApiBody, ApiError};
use futures_util::StreamExt;
use log::trace;
use meta::{Event, History, OtherEvent, Prompt, PromptStatus};
use reqwest::{
Body, IntoUrl, Response,
multipart::{self},
};
use serde_json::{Value, json};
use std::{
collections::HashMap,
ops::{Deref, DerefMut},
};
use tokio::{
sync::mpsc,
task::JoinHandle,
time::{Duration, sleep},
};
use tokio_stream::wrappers::ReceiverStream;
use tokio_tungstenite::{connect_async, tungstenite::Message};
use url::Url;
use uuid::Uuid;
pub struct ClientBuilder {
base_url: Url,
channel_bound: usize,
reconnect_web_socket: bool,
}
impl ClientBuilder {
pub fn new(base_url: impl IntoUrl) -> ClientResult<Self> {
Ok(Self {
base_url: base_url.into_url()?,
channel_bound: 100,
reconnect_web_socket: true,
})
}
pub fn channel_bound(mut self, channel_bound: usize) -> Self {
self.channel_bound = channel_bound;
self
}
pub fn reconnect_web_socket(mut self, reconnect: bool) -> Self {
self.reconnect_web_socket = reconnect;
self
}
pub async fn build(self) -> ClientResult<(ComfyUIClient, EventStream)> {
let base_url = self.base_url;
let http_client = reqwest::Client::new();
let client_id = Uuid::new_v4().to_string();
let reconnect_web_socket = self.reconnect_web_socket;
let (ev_tx, ev_rx) = mpsc::channel(self.channel_bound);
let ws_url = Self::generate_websocket_url(base_url.clone(), &client_id)?;
let (ws_stream, _) = connect_async(&ws_url).await?;
let stream_handle = tokio::spawn(async move {
let (_, mut read_stream) = ws_stream.split();
loop {
let mut connection_alive = true;
while let Some(msg) = read_stream.next().await {
match msg {
Ok(message) => {
let ev = EventStream::handle_message(message);
let Some(ev) = ev.transpose() else {
continue;
};
if ev_tx.send(ev).await.is_err() {
connection_alive = false;
break;
}
}
Err(err) => {
connection_alive = false;
if reconnect_web_socket {
let _ = ev_tx
.send(Ok(Event::Other(OtherEvent::WSReceiveError(err))))
.await;
} else {
let _ = ev_tx.send(Err(ClientError::from(err))).await;
}
break;
}
}
}
if !reconnect_web_socket || ev_tx.is_closed() {
break;
}
if connection_alive {
break;
}
loop {
sleep(Duration::from_secs(1)).await;
if ev_tx.is_closed() {
break;
}
match connect_async(&ws_url).await.map(|x| x.0) {
Ok(new_stream) => {
(_, read_stream) = new_stream.split();
let _ = ev_tx
.send(Ok(Event::Other(OtherEvent::WSReconnectSuccess)))
.await;
break;
}
Err(err) => {
let err = ClientError::Tungstenite(err);
if ev_tx
.send(Ok(Event::Other(OtherEvent::WSReconnectError(err))))
.await
.is_err()
{
break;
}
}
}
}
if ev_tx.is_closed() {
break;
}
}
});
let rx_stream = ReceiverStream::new(ev_rx);
let client = ComfyUIClient {
base_url,
http_client,
client_id,
};
let stream = EventStream {
stream_handle,
rx_stream,
};
Ok((client, stream))
}
pub async fn build_only_http(self) -> ClientResult<ComfyUIClient> {
let base_url = self.base_url;
let http_client = reqwest::Client::new();
let client_id = Uuid::new_v4().to_string();
Ok(ComfyUIClient {
base_url,
http_client,
client_id,
})
}
fn generate_websocket_url(base_url: Url, client_id: &str) -> ClientResult<Url> {
let mut ws_url = base_url;
let scheme = if ws_url.scheme() == "https" {
"wss"
} else {
"ws"
};
ws_url
.set_scheme(scheme)
.map_err(|_| ClientError::SetWsScheme)?;
ws_url = ws_url.join("ws")?;
ws_url.query_pairs_mut().append_pair("clientId", client_id);
Ok(ws_url)
}
}
pub struct ComfyUIClient {
client_id: String,
base_url: Url,
http_client: reqwest::Client,
}
impl ComfyUIClient {
pub async fn get_history(&self, prompt_id: &str) -> ClientResult<Option<History>> {
let resp = self
.http_client
.get(self.base_url.join(&format!("history/{prompt_id}"))?)
.send()
.await?;
let resp = Self::error_for_status(resp).await?;
let mut histories = resp.json::<HashMap<String, History>>().await?;
Ok(histories.remove(prompt_id))
}
pub async fn get_prompt(&self) -> ClientResult<PromptInfo> {
let resp = self
.http_client
.get(self.base_url.join("prompt")?)
.send()
.await?;
let resp = Self::error_for_status(resp).await?;
Ok(resp.json().await?)
}
pub async fn get_view(&self, file_info: &FileInfo) -> ClientResult<Bytes> {
let resp = self
.http_client
.get(self.base_url.join("view")?)
.query(file_info)
.send()
.await?;
let resp = Self::error_for_status(resp).await?;
Ok(resp.bytes().await?)
}
pub async fn post_prompt(&self, prompt: impl Into<Prompt<'_>>) -> ClientResult<PromptStatus> {
let prompt = match prompt.into() {
Prompt::Str(prompt) => &serde_json::from_str::<Value>(prompt)?,
Prompt::Value(prompt) => prompt,
};
let data = json!({"client_id": &self.client_id, "prompt": prompt});
let resp = self
.http_client
.post(self.base_url.join("prompt")?)
.json(&data)
.send()
.await?;
let resp = Self::error_for_status(resp).await?;
Ok(resp.json().await?)
}
pub async fn upload_image(
&self, body: impl Into<Body>, info: &FileInfo, overwrite: bool,
) -> ClientResult<FileInfo> {
let part = multipart::Part::stream(body).file_name(info.filename.to_string());
let mut form = multipart::Form::new()
.part("image", part)
.text("overwrite", overwrite.to_string())
.text("type", info.r#type.to_string());
if !info.subfolder.is_empty() {
form = form.text("subfolder", info.subfolder.to_string());
}
let resp = self
.http_client
.post(self.base_url.join("upload/image")?)
.multipart(form)
.send()
.await?;
let resp = Self::error_for_status(resp).await?;
Ok(resp.json().await?)
}
async fn error_for_status(resp: Response) -> ClientResult<Response> {
let status = resp.status();
if status.is_client_error() || status.is_server_error() {
let body = resp.text().await?;
let body = match serde_json::from_str::<Value>(&body) {
Ok(value) => ApiBody::Json(value),
Err(_) => ApiBody::Text(body),
};
Err(ApiError { status, body }.into())
} else {
Ok(resp)
}
}
}
pub struct EventStream {
stream_handle: JoinHandle<()>,
rx_stream: ReceiverStream<ClientResult<Event>>,
}
impl EventStream {
fn handle_message(msg: Message) -> ClientResult<Option<Event>> {
match msg {
Message::Text(b) => {
trace!(message:% = b.as_str(); "received websocket message");
let value = serde_json::from_slice::<Value>(b.as_bytes())?;
match serde_json::from_value::<Event>(value.clone()) {
Ok(ev) => Ok(Some(ev)),
Err(_) => Ok(Some(Event::Unknown(value))),
}
}
_ => Ok(None),
}
}
}
impl Drop for EventStream {
fn drop(&mut self) {
self.stream_handle.abort();
}
}
impl Deref for EventStream {
type Target = ReceiverStream<ClientResult<Event>>;
fn deref(&self) -> &Self::Target {
&self.rx_stream
}
}
impl DerefMut for EventStream {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.rx_stream
}
}