comfyui_client/
lib.rs

1#![warn(rust_2018_idioms, missing_docs)]
2#![warn(clippy::dbg_macro, clippy::print_stdout)]
3#![doc = include_str!("../README.md")]
4
5/// Module containing error definitions.
6mod errors;
7/// Module containing metadata such as prompt and file information.
8pub mod meta;
9
10pub use crate::errors::{ClientError, ClientResult};
11use crate::meta::{FileInfo, PromptInfo};
12use bytes::Bytes;
13use cfg_if::cfg_if;
14use errors::{ApiBody, ApiError};
15use futures_util::StreamExt;
16use meta::{Event, History, Prompt};
17use reqwest::{
18    Body, IntoUrl, Response,
19    multipart::{self},
20};
21use serde::Serialize;
22use serde_json::{Value, json};
23use std::{
24    collections::HashMap,
25    ops::{Deref, DerefMut},
26};
27use tokio::{sync::mpsc, task::JoinHandle};
28use tokio_stream::wrappers::ReceiverStream;
29use tokio_tungstenite::{
30    connect_async,
31    tungstenite::{self, Message},
32};
33use url::Url;
34use uuid::Uuid;
35
36/// A builder for creating a [`ComfyUIClient`] instance.
37///
38/// This builder helps initialize the client with the provided base URL and sets
39/// up a websocket connection to stream events.
40pub struct ClientBuilder {
41    base_url: Url,
42    channel_bound: usize,
43}
44
45impl ClientBuilder {
46    /// Creates a new [`ClientBuilder`] instance.
47    ///
48    /// # Parameters
49    ///
50    /// - `base_url`: The base URL of the ComfyUI service.
51    ///
52    /// # Returns
53    ///
54    /// A new instance of [`ClientBuilder`] wrapped in a `ClientResult`, or an
55    /// error if the URL is invalid.
56    pub fn new(base_url: impl IntoUrl) -> ClientResult<Self> {
57        Ok(Self {
58            base_url: base_url.into_url()?,
59            channel_bound: 100,
60        })
61    }
62
63    /// Builds the [`ComfyUIClient`] along with an associated [`EventStream`].
64    ///
65    /// This method establishes a websocket connection and spawns an
66    /// asynchronous task to process incoming messages.
67    ///
68    /// # Returns
69    ///
70    /// A tuple containing the [`ComfyUIClient`] and [`EventStream`] on success,
71    /// or an error.
72    pub async fn build(self) -> ClientResult<(ComfyUIClient, EventStream)> {
73        let base_url = self.base_url;
74        let http_client = reqwest::Client::new();
75        let client_id = Uuid::new_v4().to_string();
76
77        let (msg_tx, msg_rx) = mpsc::channel(self.channel_bound);
78
79        let ws_url = Self::generate_websocket_url(base_url.clone(), &client_id)?;
80        let (stream, _) = if ws_url.scheme() == "wss" {
81            cfg_if! {
82                if #[cfg(feature = "rustls")] {
83                    let root_store = rustls::RootCertStore {
84                        roots: webpki_roots::TLS_SERVER_ROOTS.into(),
85                    };
86                    let config = rustls::ClientConfig::builder()
87                        .with_root_certificates(root_store)
88                        .with_no_client_auth();
89
90                    tokio_tungstenite::connect_async_tls_with_config(
91                        ws_url,
92                        None,
93                        false,
94                        Some(tokio_tungstenite::Connector::Rustls(std::sync::Arc::new(config))),
95                    )
96                    .await?
97                } else {
98                    connect_async(ws_url).await?
99                }
100            }
101        } else {
102            connect_async(ws_url).await?
103        };
104
105        let stream_handle = tokio::spawn(async move {
106            let (_, mut read_stream) = stream.split();
107            while let Some(msg) = read_stream.next().await {
108                let msg = EventStream::handle_message(msg);
109                let Some(msg) = msg.transpose() else {
110                    continue;
111                };
112                if msg_tx.send(msg).await.is_err() {
113                    break;
114                }
115            }
116        });
117
118        let rx_stream = ReceiverStream::new(msg_rx);
119
120        let client = ComfyUIClient {
121            base_url,
122            http_client,
123            client_id,
124        };
125
126        let stream = EventStream {
127            stream_handle,
128            rx_stream,
129        };
130
131        Ok((client, stream))
132    }
133
134    /// Generates the websocket URL based on the base URL and client ID.
135    ///
136    /// This method changes the URL scheme to `wss` if the base URL uses HTTPS,
137    /// or `ws` otherwise, appends the `ws` path, and adds a query parameter
138    /// for `clientId`.
139    ///
140    /// # Parameters
141    ///
142    /// - `base_url`: The base URL of the ComfyUI service.
143    /// - `client_id`: The unique identifier for the client.
144    ///
145    /// # Returns
146    ///
147    /// The generated websocket URL on success, or an error if the URL cannot be
148    /// modified.
149    fn generate_websocket_url(base_url: Url, client_id: &str) -> ClientResult<Url> {
150        let mut ws_url = base_url;
151        let scheme = if ws_url.scheme() == "https" {
152            "wss"
153        } else {
154            "ws"
155        };
156        ws_url
157            .set_scheme(scheme)
158            .map_err(|_| ClientError::SetWsScheme)?;
159        ws_url = ws_url.join("ws")?;
160        ws_url.query_pairs_mut().append_pair("clientId", client_id);
161        Ok(ws_url)
162    }
163}
164
165/// A client for interacting with the ComfyUI service.
166///
167/// This client provides methods to fetch history, prompts, views, and to upload
168/// images.
169pub struct ComfyUIClient {
170    client_id: String,
171    base_url: Url,
172    http_client: reqwest::Client,
173}
174
175impl ComfyUIClient {
176    /// Retrieves the history for a specified prompt.
177    ///
178    /// Sends a GET request to the `history/{prompt_id}` endpoint and parses the
179    /// returned history data.
180    ///
181    /// # Parameters
182    ///
183    /// - `prompt_id`: The ID of the prompt whose history is being requested.
184    ///
185    /// # Returns
186    ///
187    /// An optional [`History`] object wrapped in a `ClientResult`. Returns
188    /// `None` if the history is not found.
189    pub async fn get_history(&self, prompt_id: &str) -> ClientResult<Option<History>> {
190        let resp = self
191            .http_client
192            .get(self.base_url.join(&format!("history/{prompt_id}"))?)
193            .send()
194            .await?;
195        let resp = Self::error_for_status(resp).await?;
196        let mut histories = resp.json::<HashMap<String, History>>().await?;
197        Ok(histories.remove(prompt_id))
198    }
199
200    /// Retrieves the current prompt information.
201    ///
202    /// Sends a GET request to the `prompt` endpoint and returns the parsed
203    /// [`PromptInfo`] data.
204    ///
205    /// # Returns
206    ///
207    /// A [`PromptInfo`] object on success, or an error.
208    pub async fn get_prompt(&self) -> ClientResult<PromptInfo> {
209        let resp = self
210            .http_client
211            .get(self.base_url.join("prompt")?)
212            .send()
213            .await?;
214        let resp = Self::error_for_status(resp).await?;
215        Ok(resp.json().await?)
216    }
217
218    /// Retrieves view data corresponding to the provided file information.
219    ///
220    /// Sends a GET request to the `view` endpoint, including the file
221    /// information as query parameters.
222    ///
223    /// # Parameters
224    ///
225    /// - `file_info`: A [`FileInfo`] object containing details about the file.
226    ///
227    /// # Returns
228    ///
229    /// The response as a [`Bytes`] object on success, or an error.
230    pub async fn get_view(&self, file_info: &FileInfo) -> ClientResult<Bytes> {
231        let resp = self
232            .http_client
233            .get(self.base_url.join("view")?)
234            .query(file_info)
235            .send()
236            .await?;
237        let resp = Self::error_for_status(resp).await?;
238        Ok(resp.bytes().await?)
239    }
240
241    /// Sends a prompt in string format.
242    ///
243    /// Parses the input string as JSON and calls [`Self::post_prompt_value`] to
244    /// send the prompt.
245    ///
246    /// # Parameters
247    ///
248    /// - `prompt`: A string slice representing the prompt in JSON format.
249    ///
250    /// # Returns
251    ///
252    /// A [`Prompt`] object on success, or an error.
253    pub async fn post_prompt_str(&self, prompt: &str) -> ClientResult<Prompt> {
254        let prompt = serde_json::from_str::<Value>(prompt)?;
255        self.post_prompt_value(&prompt).await
256    }
257
258    /// Sends a prompt from any serializable data.
259    ///
260    /// Converts the provided data into JSON and calls
261    /// [`Self::post_prompt_value`] to send the prompt.
262    ///
263    /// # Parameters
264    ///
265    /// - `prompt`: A reference to any data that implements [`Serialize`].
266    ///
267    /// # Returns
268    ///
269    /// A [`Prompt`] object on success, or an error.
270    pub async fn post_prompt<T: Serialize>(&self, prompt: &T) -> ClientResult<Prompt> {
271        let prompt = serde_json::to_value(prompt)?;
272        self.post_prompt_value(&prompt).await
273    }
274
275    /// Sends a prompt in JSON format.
276    ///
277    /// Constructs the request payload (including the client ID and prompt data)
278    /// and sends a POST request to the `prompt` endpoint.
279    ///
280    /// # Parameters
281    ///
282    /// - `prompt`: A JSON value representing the prompt data.
283    ///
284    /// # Returns
285    ///
286    /// A [`Prompt`] object on success, or an error.
287    pub async fn post_prompt_value(&self, prompt: &Value) -> ClientResult<Prompt> {
288        let data = json!({"client_id": &self.client_id, "prompt": prompt});
289        let resp = self
290            .http_client
291            .post(self.base_url.join("prompt")?)
292            .json(&data)
293            .send()
294            .await?;
295        let resp = Self::error_for_status(resp).await?;
296        Ok(resp.json().await?)
297    }
298
299    /// Uploads an image.
300    ///
301    /// Constructs a multipart form containing the image data and file
302    /// information, then sends a POST request to the `upload/image` endpoint.
303    ///
304    /// # Parameters
305    ///
306    /// - `body`: The image data, convertible into a [`Body`].
307    /// - `info`: A [`FileInfo`] object containing details about the image file.
308    /// - `overwrite`: A boolean indicating whether to overwrite an existing
309    ///   file.
310    ///
311    /// # Returns
312    ///
313    /// An updated [`FileInfo`] object on success, or an error.
314    pub async fn upload_image(
315        &self, body: impl Into<Body>, info: &FileInfo, overwrite: bool,
316    ) -> ClientResult<FileInfo> {
317        let part = multipart::Part::stream(body).file_name(info.filename.to_string());
318        let mut form = multipart::Form::new()
319            .part("image", part)
320            .text("overwrite", overwrite.to_string())
321            .text("type", info.r#type.to_string());
322        if !info.subfolder.is_empty() {
323            form = form.text("subfolder", info.subfolder.to_string());
324        }
325
326        let resp = self
327            .http_client
328            .post(self.base_url.join("upload/image")?)
329            .multipart(form)
330            .send()
331            .await?;
332
333        let resp = Self::error_for_status(resp).await?;
334        Ok(resp.json().await?)
335    }
336
337    /// Checks the HTTP response status code and returns an error if it
338    /// indicates failure.
339    ///
340    /// If the response status is a client or server error, this method attempts
341    /// to parse the response body as JSON. If parsing fails, it returns the
342    /// body as text.
343    ///
344    /// # Parameters
345    ///
346    /// - `resp`: The HTTP response to evaluate.
347    ///
348    /// # Returns
349    ///
350    /// The original response if the status is successful, or an error if the
351    /// status indicates a failure.
352    async fn error_for_status(resp: Response) -> ClientResult<Response> {
353        let status = resp.status();
354        if status.is_client_error() || status.is_server_error() {
355            let body = resp.text().await?;
356            let body = match serde_json::from_str::<Value>(&body) {
357                Ok(value) => ApiBody::Json(value),
358                Err(_) => ApiBody::Text(body),
359            };
360            Err(ApiError { status, body }.into())
361        } else {
362            Ok(resp)
363        }
364    }
365}
366
367/// A structure representing the event stream received via a websocket
368/// connection.
369///
370/// This stream continuously processes events from the ComfyUI service.
371pub struct EventStream {
372    stream_handle: JoinHandle<()>,
373    rx_stream: ReceiverStream<ClientResult<Event>>,
374}
375
376impl EventStream {
377    /// Handles a single websocket message and attempts to parse it as an
378    /// [`Event`].
379    ///
380    /// For text messages, it tries to deserialize the message into an
381    /// [`Event`]. If the deserialization fails, it wraps the message as
382    /// [`Event::Unknown`]. Other message types are ignored.
383    ///
384    /// # Parameters
385    ///
386    /// - `msg`: A result containing a [`Message`] from the websocket.
387    ///
388    /// # Returns
389    ///
390    /// An `Option<Event>` wrapped in a `ClientResult`. Returns `None` for
391    /// unsupported message types.
392    fn handle_message(msg: tungstenite::Result<Message>) -> ClientResult<Option<Event>> {
393        let msg = msg?;
394        match msg {
395            Message::Text(b) => {
396                let value = serde_json::from_slice::<Value>(b.as_bytes())?;
397                match serde_json::from_slice::<Event>(b.as_bytes()) {
398                    Ok(ev) => Ok(Some(ev)),
399                    Err(_) => Ok(Some(Event::Unknown(value))),
400                }
401            }
402            _ => Ok(None),
403        }
404    }
405}
406
407impl Drop for EventStream {
408    /// When the [`EventStream`] is dropped, abort the associated websocket
409    /// handling task.
410    fn drop(&mut self) {
411        self.stream_handle.abort();
412    }
413}
414
415impl Deref for EventStream {
416    type Target = ReceiverStream<ClientResult<Event>>;
417
418    /// Allows access to the inner [`ReceiverStream`] containing the events.
419    fn deref(&self) -> &Self::Target {
420        &self.rx_stream
421    }
422}
423
424impl DerefMut for EventStream {
425    /// Allows mutable access to the inner [`ReceiverStream`].
426    fn deref_mut(&mut self) -> &mut Self::Target {
427        &mut self.rx_stream
428    }
429}