comfyui_client/
lib.rs

1#![warn(rust_2018_idioms, missing_docs)]
2#![warn(clippy::dbg_macro, clippy::print_stdout)]
3#![doc = include_str!("../README.md")]
4
5/// Module containing error definitions.
6pub mod errors;
7/// Module containing metadata such as prompt and file information.
8pub mod meta;
9
10pub use crate::errors::{ClientError, ClientResult};
11use crate::meta::{FileInfo, PromptInfo};
12use bytes::Bytes;
13use cfg_if::cfg_if;
14use errors::{ApiBody, ApiError};
15use futures_util::StreamExt;
16use meta::{Event, History, Prompt};
17use reqwest::{
18    Body, IntoUrl, Response,
19    multipart::{self},
20};
21use serde::Serialize;
22use serde_json::{Value, json};
23use std::{
24    collections::HashMap,
25    ops::{Deref, DerefMut},
26};
27use tokio::{sync::mpsc, task::JoinHandle};
28use tokio_stream::wrappers::ReceiverStream;
29use tokio_tungstenite::{
30    connect_async,
31    tungstenite::{self, Message},
32};
33use url::Url;
34use uuid::Uuid;
35
36/// A builder for creating a [`ComfyUIClient`] instance.
37///
38/// This builder helps initialize the client with the provided base URL and sets
39/// up a websocket connection to stream events.
40pub struct ClientBuilder {
41    base_url: Url,
42    channel_bound: usize,
43}
44
45impl ClientBuilder {
46    /// Creates a new [`ClientBuilder`] instance.
47    ///
48    /// # Parameters
49    ///
50    /// - `base_url`: The base URL of the ComfyUI service.
51    ///
52    /// # Returns
53    ///
54    /// A new instance of [`ClientBuilder`] wrapped in a `ClientResult`, or an
55    /// error if the URL is invalid.
56    pub fn new(base_url: impl IntoUrl) -> ClientResult<Self> {
57        Ok(Self {
58            base_url: base_url.into_url()?,
59            channel_bound: 100,
60        })
61    }
62
63    /// Builds the [`ComfyUIClient`] along with an associated [`EventStream`].
64    ///
65    /// This method establishes a websocket connection and spawns an
66    /// asynchronous task to process incoming messages.
67    ///
68    /// # Returns
69    ///
70    /// A tuple containing the [`ComfyUIClient`] and [`EventStream`] on success,
71    /// or an error.
72    pub async fn build(self) -> ClientResult<(ComfyUIClient, EventStream)> {
73        let base_url = self.base_url;
74        let http_client = reqwest::Client::new();
75        let client_id = Uuid::new_v4().to_string();
76
77        let (ev_tx, ev_rx) = mpsc::channel(self.channel_bound);
78
79        let ws_url = Self::generate_websocket_url(base_url.clone(), &client_id)?;
80        let (stream, _) = if ws_url.scheme() == "wss" {
81            cfg_if! {
82                if #[cfg(feature = "rustls")] {
83                    let root_store = rustls::RootCertStore {
84                        roots: webpki_roots::TLS_SERVER_ROOTS.into(),
85                    };
86                    let config = rustls::ClientConfig::builder()
87                        .with_root_certificates(root_store)
88                        .with_no_client_auth();
89
90                    tokio_tungstenite::connect_async_tls_with_config(
91                        ws_url,
92                        None,
93                        false,
94                        Some(tokio_tungstenite::Connector::Rustls(std::sync::Arc::new(config))),
95                    )
96                    .await?
97                } else {
98                    connect_async(ws_url).await?
99                }
100            }
101        } else {
102            connect_async(ws_url).await?
103        };
104
105        let stream_handle = tokio::spawn(async move {
106            let (_, mut read_stream) = stream.split();
107            while let Some(msg) = read_stream.next().await {
108                let ev = EventStream::handle_message(msg);
109                let Some(ev) = ev.transpose() else {
110                    continue;
111                };
112                if ev_tx.send(ev).await.is_err() {
113                    break;
114                }
115            }
116        });
117
118        let rx_stream = ReceiverStream::new(ev_rx);
119
120        let client = ComfyUIClient {
121            base_url,
122            http_client,
123            client_id,
124        };
125
126        let stream = EventStream {
127            stream_handle,
128            rx_stream,
129        };
130
131        Ok((client, stream))
132    }
133
134    /// Builds a [`ComfyUIClient`] instance configured for HTTP-only
135    /// communication.
136    ///
137    /// This method initializes the client without establishing a websocket
138    /// connection, enabling you to interact with the ComfyUI service using
139    /// only HTTP (REST) requests.
140    ///
141    /// # Returns
142    ///
143    /// A [`ComfyUIClient`] instance on success, or an error.
144    pub async fn build_only_http(self) -> ClientResult<ComfyUIClient> {
145        let base_url = self.base_url;
146        let http_client = reqwest::Client::new();
147        let client_id = Uuid::new_v4().to_string();
148
149        Ok(ComfyUIClient {
150            base_url,
151            http_client,
152            client_id,
153        })
154    }
155
156    /// Generates the websocket URL based on the base URL and client ID.
157    ///
158    /// This method changes the URL scheme to `wss` if the base URL uses HTTPS,
159    /// or `ws` otherwise, appends the `ws` path, and adds a query parameter
160    /// for `clientId`.
161    ///
162    /// # Parameters
163    ///
164    /// - `base_url`: The base URL of the ComfyUI service.
165    /// - `client_id`: The unique identifier for the client.
166    ///
167    /// # Returns
168    ///
169    /// The generated websocket URL on success, or an error if the URL cannot be
170    /// modified.
171    fn generate_websocket_url(base_url: Url, client_id: &str) -> ClientResult<Url> {
172        let mut ws_url = base_url;
173        let scheme = if ws_url.scheme() == "https" {
174            "wss"
175        } else {
176            "ws"
177        };
178        ws_url
179            .set_scheme(scheme)
180            .map_err(|_| ClientError::SetWsScheme)?;
181        ws_url = ws_url.join("ws")?;
182        ws_url.query_pairs_mut().append_pair("clientId", client_id);
183        Ok(ws_url)
184    }
185}
186
187/// A client for interacting with the ComfyUI service.
188///
189/// This client provides methods to fetch history, prompts, views, and to upload
190/// images.
191pub struct ComfyUIClient {
192    client_id: String,
193    base_url: Url,
194    http_client: reqwest::Client,
195}
196
197impl ComfyUIClient {
198    /// Retrieves the history for a specified prompt.
199    ///
200    /// Sends a GET request to the `history/{prompt_id}` endpoint and parses the
201    /// returned history data.
202    ///
203    /// # Parameters
204    ///
205    /// - `prompt_id`: The ID of the prompt whose history is being requested.
206    ///
207    /// # Returns
208    ///
209    /// An optional [`History`] object wrapped in a `ClientResult`. Returns
210    /// `None` if the history is not found.
211    pub async fn get_history(&self, prompt_id: &str) -> ClientResult<Option<History>> {
212        let resp = self
213            .http_client
214            .get(self.base_url.join(&format!("history/{prompt_id}"))?)
215            .send()
216            .await?;
217        let resp = Self::error_for_status(resp).await?;
218        let mut histories = resp.json::<HashMap<String, History>>().await?;
219        Ok(histories.remove(prompt_id))
220    }
221
222    /// Retrieves the current prompt information.
223    ///
224    /// Sends a GET request to the `prompt` endpoint and returns the parsed
225    /// [`PromptInfo`] data.
226    ///
227    /// # Returns
228    ///
229    /// A [`PromptInfo`] object on success, or an error.
230    pub async fn get_prompt(&self) -> ClientResult<PromptInfo> {
231        let resp = self
232            .http_client
233            .get(self.base_url.join("prompt")?)
234            .send()
235            .await?;
236        let resp = Self::error_for_status(resp).await?;
237        Ok(resp.json().await?)
238    }
239
240    /// Retrieves view data corresponding to the provided file information.
241    ///
242    /// Sends a GET request to the `view` endpoint, including the file
243    /// information as query parameters.
244    ///
245    /// # Parameters
246    ///
247    /// - `file_info`: A [`FileInfo`] object containing details about the file.
248    ///
249    /// # Returns
250    ///
251    /// The response as a [`Bytes`] object on success, or an error.
252    pub async fn get_view(&self, file_info: &FileInfo) -> ClientResult<Bytes> {
253        let resp = self
254            .http_client
255            .get(self.base_url.join("view")?)
256            .query(file_info)
257            .send()
258            .await?;
259        let resp = Self::error_for_status(resp).await?;
260        Ok(resp.bytes().await?)
261    }
262
263    /// Sends a prompt in string format.
264    ///
265    /// Parses the input string as JSON and calls [`Self::post_prompt_value`] to
266    /// send the prompt.
267    ///
268    /// # Parameters
269    ///
270    /// - `prompt`: A string slice representing the prompt in JSON format.
271    ///
272    /// # Returns
273    ///
274    /// A [`Prompt`] object on success, or an error.
275    pub async fn post_prompt_str(&self, prompt: &str) -> ClientResult<Prompt> {
276        let prompt = serde_json::from_str::<Value>(prompt)?;
277        self.post_prompt_value(&prompt).await
278    }
279
280    /// Sends a prompt from any serializable data.
281    ///
282    /// Converts the provided data into JSON and calls
283    /// [`Self::post_prompt_value`] to send the prompt.
284    ///
285    /// # Parameters
286    ///
287    /// - `prompt`: A reference to any data that implements [`Serialize`].
288    ///
289    /// # Returns
290    ///
291    /// A [`Prompt`] object on success, or an error.
292    pub async fn post_prompt<T: Serialize>(&self, prompt: &T) -> ClientResult<Prompt> {
293        let prompt = serde_json::to_value(prompt)?;
294        self.post_prompt_value(&prompt).await
295    }
296
297    /// Sends a prompt in JSON format.
298    ///
299    /// Constructs the request payload (including the client ID and prompt data)
300    /// and sends a POST request to the `prompt` endpoint.
301    ///
302    /// # Parameters
303    ///
304    /// - `prompt`: A JSON value representing the prompt data.
305    ///
306    /// # Returns
307    ///
308    /// A [`Prompt`] object on success, or an error.
309    pub async fn post_prompt_value(&self, prompt: &Value) -> ClientResult<Prompt> {
310        let data = json!({"client_id": &self.client_id, "prompt": prompt});
311        let resp = self
312            .http_client
313            .post(self.base_url.join("prompt")?)
314            .json(&data)
315            .send()
316            .await?;
317        let resp = Self::error_for_status(resp).await?;
318        Ok(resp.json().await?)
319    }
320
321    /// Uploads an image.
322    ///
323    /// Constructs a multipart form containing the image data and file
324    /// information, then sends a POST request to the `upload/image` endpoint.
325    ///
326    /// # Parameters
327    ///
328    /// - `body`: The image data, convertible into a [`Body`].
329    /// - `info`: A [`FileInfo`] object containing details about the image file.
330    /// - `overwrite`: A boolean indicating whether to overwrite an existing
331    ///   file.
332    ///
333    /// # Returns
334    ///
335    /// An updated [`FileInfo`] object on success, or an error.
336    pub async fn upload_image(
337        &self, body: impl Into<Body>, info: &FileInfo, overwrite: bool,
338    ) -> ClientResult<FileInfo> {
339        let part = multipart::Part::stream(body).file_name(info.filename.to_string());
340        let mut form = multipart::Form::new()
341            .part("image", part)
342            .text("overwrite", overwrite.to_string())
343            .text("type", info.r#type.to_string());
344        if !info.subfolder.is_empty() {
345            form = form.text("subfolder", info.subfolder.to_string());
346        }
347
348        let resp = self
349            .http_client
350            .post(self.base_url.join("upload/image")?)
351            .multipart(form)
352            .send()
353            .await?;
354
355        let resp = Self::error_for_status(resp).await?;
356        Ok(resp.json().await?)
357    }
358
359    /// Checks the HTTP response status code and returns an error if it
360    /// indicates failure.
361    ///
362    /// If the response status is a client or server error, this method attempts
363    /// to parse the response body as JSON. If parsing fails, it returns the
364    /// body as text.
365    ///
366    /// # Parameters
367    ///
368    /// - `resp`: The HTTP response to evaluate.
369    ///
370    /// # Returns
371    ///
372    /// The original response if the status is successful, or an error if the
373    /// status indicates a failure.
374    async fn error_for_status(resp: Response) -> ClientResult<Response> {
375        let status = resp.status();
376        if status.is_client_error() || status.is_server_error() {
377            let body = resp.text().await?;
378            let body = match serde_json::from_str::<Value>(&body) {
379                Ok(value) => ApiBody::Json(value),
380                Err(_) => ApiBody::Text(body),
381            };
382            Err(ApiError { status, body }.into())
383        } else {
384            Ok(resp)
385        }
386    }
387}
388
389/// A structure representing the event stream received via a websocket
390/// connection.
391///
392/// This stream continuously processes events from the ComfyUI service.
393pub struct EventStream {
394    stream_handle: JoinHandle<()>,
395    rx_stream: ReceiverStream<ClientResult<Event>>,
396}
397
398impl EventStream {
399    /// Handles a single websocket message and attempts to parse it as an
400    /// [`Event`].
401    ///
402    /// For text messages, it tries to deserialize the message into an
403    /// [`Event`]. If the deserialization fails, it wraps the message as
404    /// [`Event::Unknown`]. Other message types are ignored.
405    ///
406    /// # Parameters
407    ///
408    /// - `msg`: A result containing a [`Message`] from the websocket.
409    ///
410    /// # Returns
411    ///
412    /// An `Option<Event>` wrapped in a `ClientResult`. Returns `None` for
413    /// unsupported message types.
414    fn handle_message(msg: tungstenite::Result<Message>) -> ClientResult<Option<Event>> {
415        let msg = msg?;
416        match msg {
417            Message::Text(b) => {
418                let value = serde_json::from_slice::<Value>(b.as_bytes())?;
419                match serde_json::from_value::<Event>(value.clone()) {
420                    Ok(ev) => Ok(Some(ev)),
421                    Err(_) => Ok(Some(Event::Unknown(value))),
422                }
423            }
424            _ => Ok(None),
425        }
426    }
427}
428
429impl Drop for EventStream {
430    /// When the [`EventStream`] is dropped, abort the associated websocket
431    /// handling task.
432    fn drop(&mut self) {
433        self.stream_handle.abort();
434    }
435}
436
437impl Deref for EventStream {
438    type Target = ReceiverStream<ClientResult<Event>>;
439
440    /// Allows access to the inner [`ReceiverStream`] containing the events.
441    fn deref(&self) -> &Self::Target {
442        &self.rx_stream
443    }
444}
445
446impl DerefMut for EventStream {
447    /// Allows mutable access to the inner [`ReceiverStream`].
448    fn deref_mut(&mut self) -> &mut Self::Target {
449        &mut self.rx_stream
450    }
451}