comfyui_client/
lib.rs

1#![warn(rust_2018_idioms, missing_docs)]
2#![warn(clippy::dbg_macro, clippy::print_stdout)]
3#![doc = include_str!("../README.md")]
4
5/// Module containing error definitions.
6pub mod errors;
7/// Module containing metadata such as prompt and file information.
8pub mod meta;
9
10pub use crate::errors::{ClientError, ClientResult};
11use crate::meta::{FileInfo, PromptInfo};
12use bytes::Bytes;
13use errors::{ApiBody, ApiError};
14use futures_util::stream::{Stream, StreamExt};
15use log::trace;
16use meta::{ComfyEvent, ConnectionEvent, Event, History, Prompt, PromptStatus};
17use pin_project_lite::pin_project;
18use reqwest::{
19    Body, IntoUrl, Response,
20    multipart::{self},
21};
22use serde_json::{Value, json};
23use std::{
24    collections::HashMap,
25    pin::Pin,
26    task::{Context, Poll},
27};
28use tokio::{
29    sync::mpsc,
30    time::{Duration, sleep},
31};
32use tokio_stream::wrappers::ReceiverStream;
33use tokio_tungstenite::{connect_async, tungstenite::Message};
34use url::Url;
35use uuid::Uuid;
36
37/// A builder for creating a [`ComfyUIClient`] instance.
38///
39/// This builder helps initialize the client with the provided base URL and sets
40/// up a websocket connection to stream events.
41pub struct ClientBuilder<U> {
42    base_url: U,
43    channel_bound: usize,
44    reconnect_web_socket: bool,
45}
46
47impl<U: IntoUrl> ClientBuilder<U> {
48    /// Creates a new [`ClientBuilder`] instance.
49    ///
50    /// # Parameters
51    ///
52    /// - `base_url`: The base URL of the ComfyUI service.
53    ///
54    /// # Returns
55    ///
56    /// A new instance of [`ClientBuilder`] wrapped in a `ClientResult`, or an
57    /// error if the URL is invalid.
58    pub fn new(base_url: U) -> Self {
59        Self {
60            base_url,
61            channel_bound: 100,
62            reconnect_web_socket: true,
63        }
64    }
65
66    /// Sets the capacity of the internal channel used for event streaming.
67    ///
68    /// This controls how many events can be buffered before backpressure is
69    /// applied. The default value is 100.
70    ///
71    /// # Parameters
72    ///
73    /// - `channel_bound`: The maximum number of events the channel can hold.
74    ///
75    /// # Returns
76    ///
77    /// The updated [`ClientBuilder`] instance.
78    pub fn channel_bound(mut self, channel_bound: usize) -> Self {
79        self.channel_bound = channel_bound;
80        self
81    }
82
83    /// Sets whether the websocket should attempt to reconnect automatically
84    /// when disconnected.
85    ///
86    /// By default, reconnection is enabled (`true`).
87    ///
88    /// # Parameters
89    ///
90    /// - `reconnect`: Whether to attempt reconnection when the WebSocket
91    ///   connection drops unexpectedly.
92    ///
93    /// # Returns
94    ///
95    /// The updated [`ClientBuilder`] instance.
96    pub fn reconnect_web_socket(mut self, reconnect: bool) -> Self {
97        self.reconnect_web_socket = reconnect;
98        self
99    }
100
101    /// Builds the [`ComfyUIClient`] along with an associated [`EventStream`]
102    /// and a background task handle.
103    ///
104    /// This method establishes a websocket connection and spawns an
105    /// asynchronous task to process incoming messages. If reconnection is
106    /// enabled, the task will automatically attempt to reconnect when the
107    /// WebSocket connection drops unexpectedly.
108    ///
109    /// # Returns
110    ///
111    /// A tuple containing:
112    /// - The [`ComfyUIClient`] for HTTP API interactions
113    /// - An [`EventStream`] for receiving real-time events
114    ///
115    /// The WebSocket connection will be automatically closed when the
116    /// [`EventStream`] is dropped, as the background task managing the
117    /// connection terminates when the stream is no longer being consumed.
118    ///
119    /// Returns an error if the initial connection cannot be established.
120    pub async fn build(self) -> ClientResult<(ComfyUIClient, EventStream)> {
121        let base_url = self.base_url.into_url()?;
122        let http_client = reqwest::Client::new();
123        let client_id = Uuid::new_v4().to_string();
124        let reconnect_web_socket = self.reconnect_web_socket;
125
126        let (ev_tx, ev_rx) = mpsc::channel(self.channel_bound);
127
128        let ws_url = Self::generate_websocket_url(base_url.clone(), &client_id)?;
129
130        // Initial connection
131        let (ws_stream, _) = connect_async(&ws_url).await?;
132
133        // Spawn the stream handling task with reconnection support
134        tokio::spawn(async move {
135            let (_, mut read_stream) = ws_stream.split();
136
137            loop {
138                // Process messages until the connection drops or channel is closed
139                loop {
140                    tokio::select! {
141                        // Check for new WebSocket messages
142                        msg = read_stream.next() => {
143                            match msg {
144                                Some(Ok(message)) => {
145                                    let ev = EventStream::handle_message(message);
146                                    let Some(ev) = ev.transpose() else {
147                                        continue;
148                                    };
149                                    if ev_tx.send(ev).await.is_err() {
150                                        return;
151                                    }
152                                }
153                                Some(Err(err)) => {
154                                    // If reconnect is enabled, wrap error in OtherEvent, otherwise pass
155                                    // through as ClientError
156                                    if reconnect_web_socket {
157                                        // Send receive error as an Event::Other
158                                        if ev_tx
159                                            .send(Ok(Event::Connection(ConnectionEvent::WSReceiveError(err))))
160                                            .await.is_err() {
161                                                return;
162                                            }
163                                    } else {
164                                        // Without reconnect, send as ClientError
165                                        if ev_tx.send(Err(ClientError::from(err))).await.is_err() {
166                                            return;
167                                        }
168                                    }
169
170                                    break;
171                                }
172                                None => {
173                                    break;
174                                }
175                            }
176                        }
177
178                        // Check if the channel is closed
179                        _ = ev_tx.closed() => {
180                            // Channel is closed, exit immediately
181                            return;
182                        }
183                    }
184                }
185
186                // If reconnect is disabled, exit the loop
187                if !reconnect_web_socket {
188                    return;
189                }
190
191                // Attempt to reconnect with a small delay until successful or channel closed
192                loop {
193                    tokio::select! {
194                        _ = sleep(Duration::from_secs(1)) => {
195                        }
196
197                        // Check if the channel is closed
198                        _ = ev_tx.closed() => {
199                            // Channel is closed, exit immediately
200                            return;
201                        }
202                    }
203
204                    // Try to establish a new connection
205                    tokio::select! {
206                        conn_result = connect_async(&ws_url) => {
207                            match conn_result {
208                                Ok(new_stream) => {
209                                    // Successfully reconnected
210                                    (_, read_stream) = new_stream.0.split();
211                                    // Send reconnection success event
212                                    if ev_tx
213                                        .send(Ok(Event::Connection(ConnectionEvent::WSReconnectSuccess)))
214                                        .await.is_err() {
215                                            // Channel is closed, exit immediately
216                                            return;
217                                        }
218                                    // Exit the reconnection loop to start using the new read_stream
219                                    break;
220                                }
221                                Err(err) => {
222                                    // Failed to reconnect, send error as Event::Other
223                                    let err = ClientError::Tungstenite(err);
224                                    if ev_tx
225                                        .send(Ok(Event::Connection(ConnectionEvent::WSReconnectError(err))))
226                                        .await
227                                        .is_err()
228                                    {
229                                        // Channel is closed, exit immediately
230                                        return;
231                                    }
232                                }
233                            }
234                        }
235
236                        // Check if the channel is closed during connection attempt
237                        _ = ev_tx.closed() => {
238                            // Channel is closed, exit immediately
239                            return;
240                        }
241                    }
242                }
243            }
244        });
245
246        let rx_stream = ReceiverStream::new(ev_rx);
247
248        let client = ComfyUIClient {
249            base_url,
250            http_client,
251            client_id,
252        };
253
254        let stream = EventStream { rx_stream };
255
256        Ok((client, stream))
257    }
258
259    /// Builds a [`ComfyUIClient`] instance configured for HTTP-only
260    /// communication.
261    ///
262    /// This method initializes the client without establishing a websocket
263    /// connection, enabling you to interact with the ComfyUI service using
264    /// only HTTP (REST) requests.
265    ///
266    /// # Returns
267    ///
268    /// A [`ComfyUIClient`] instance on success, or an error.
269    pub async fn build_only_http(self) -> ClientResult<ComfyUIClient> {
270        let base_url = self.base_url.into_url()?;
271        let http_client = reqwest::Client::new();
272        let client_id = Uuid::new_v4().to_string();
273
274        Ok(ComfyUIClient {
275            base_url,
276            http_client,
277            client_id,
278        })
279    }
280
281    /// Generates the websocket URL based on the base URL and client ID.
282    ///
283    /// This method changes the URL scheme to `wss` if the base URL uses HTTPS,
284    /// or `ws` otherwise, appends the `ws` path, and adds a query parameter
285    /// for `clientId`.
286    ///
287    /// # Parameters
288    ///
289    /// - `base_url`: The base URL of the ComfyUI service.
290    /// - `client_id`: The unique identifier for the client.
291    ///
292    /// # Returns
293    ///
294    /// The generated websocket URL on success, or an error if the URL cannot be
295    /// modified.
296    fn generate_websocket_url(base_url: Url, client_id: &str) -> ClientResult<Url> {
297        let mut ws_url = base_url;
298        let scheme = if ws_url.scheme() == "https" {
299            "wss"
300        } else {
301            "ws"
302        };
303        ws_url
304            .set_scheme(scheme)
305            .map_err(|_| ClientError::SetWsScheme)?;
306        ws_url = ws_url.join("ws")?;
307        ws_url.query_pairs_mut().append_pair("clientId", client_id);
308        Ok(ws_url)
309    }
310}
311
312/// A client for interacting with the ComfyUI service.
313///
314/// This client provides methods to fetch history, prompts, views, and to upload
315/// images.
316pub struct ComfyUIClient {
317    client_id: String,
318    base_url: Url,
319    http_client: reqwest::Client,
320}
321
322impl ComfyUIClient {
323    /// Retrieves the history for a specified prompt.
324    ///
325    /// Sends a GET request to the `history/{prompt_id}` endpoint and parses the
326    /// returned history data.
327    ///
328    /// # Parameters
329    ///
330    /// - `prompt_id`: The ID of the prompt whose history is being requested.
331    ///
332    /// # Returns
333    ///
334    /// An optional [`History`] object wrapped in a `ClientResult`. Returns
335    /// `None` if the history is not found.
336    pub async fn get_history(&self, prompt_id: &str) -> ClientResult<Option<History>> {
337        let resp = self
338            .http_client
339            .get(self.base_url.join(&format!("history/{prompt_id}"))?)
340            .send()
341            .await?;
342        let resp = Self::error_for_status(resp).await?;
343        let mut histories = resp.json::<HashMap<String, History>>().await?;
344        Ok(histories.remove(prompt_id))
345    }
346
347    /// Retrieves the current prompt information.
348    ///
349    /// Sends a GET request to the `prompt` endpoint and returns the parsed
350    /// [`PromptInfo`] data.
351    ///
352    /// # Returns
353    ///
354    /// A [`PromptInfo`] object on success, or an error.
355    pub async fn get_prompt(&self) -> ClientResult<PromptInfo> {
356        let resp = self
357            .http_client
358            .get(self.base_url.join("prompt")?)
359            .send()
360            .await?;
361        let resp = Self::error_for_status(resp).await?;
362        Ok(resp.json().await?)
363    }
364
365    /// Retrieves view data corresponding to the provided file information.
366    ///
367    /// Sends a GET request to the `view` endpoint, including the file
368    /// information as query parameters.
369    ///
370    /// # Parameters
371    ///
372    /// - `file_info`: A [`FileInfo`] object containing details about the file.
373    ///
374    /// # Returns
375    ///
376    /// The response as a [`Bytes`] object on success, or an error.
377    pub async fn get_view(&self, file_info: &FileInfo) -> ClientResult<Bytes> {
378        let resp = self
379            .http_client
380            .get(self.base_url.join("view")?)
381            .query(file_info)
382            .send()
383            .await?;
384        let resp = Self::error_for_status(resp).await?;
385        Ok(resp.bytes().await?)
386    }
387
388    /// Sends a prompt in JSON format.
389    ///
390    /// Constructs the request payload (including the client ID and prompt data)
391    /// and sends a POST request to the `prompt` endpoint.
392    ///
393    /// # Parameters
394    ///
395    /// - `prompt`: representing the prompt data.
396    ///
397    /// # Returns
398    ///
399    /// A [`PromptStatus`] object on success, or an error.
400    pub async fn post_prompt(&self, prompt: impl Into<Prompt<'_>>) -> ClientResult<PromptStatus> {
401        let prompt = match prompt.into() {
402            Prompt::Str(prompt) => &serde_json::from_str::<Value>(prompt)?,
403            Prompt::Value(prompt) => prompt,
404        };
405        let data = json!({"client_id": &self.client_id, "prompt": prompt});
406        let resp = self
407            .http_client
408            .post(self.base_url.join("prompt")?)
409            .json(&data)
410            .send()
411            .await?;
412        let resp = Self::error_for_status(resp).await?;
413        Ok(resp.json().await?)
414    }
415
416    /// Uploads an image.
417    ///
418    /// Constructs a multipart form containing the image data and file
419    /// information, then sends a POST request to the `upload/image` endpoint.
420    ///
421    /// # Parameters
422    ///
423    /// - `body`: The image data, convertible into a [`Body`].
424    /// - `info`: A [`FileInfo`] object containing details about the image file.
425    /// - `overwrite`: A boolean indicating whether to overwrite an existing
426    ///   file.
427    ///
428    /// # Returns
429    ///
430    /// An updated [`FileInfo`] object on success, or an error.
431    pub async fn upload_image(
432        &self, body: impl Into<Body>, info: &FileInfo, overwrite: bool,
433    ) -> ClientResult<FileInfo> {
434        let part = multipart::Part::stream(body).file_name(info.filename.to_string());
435        let mut form = multipart::Form::new()
436            .part("image", part)
437            .text("overwrite", overwrite.to_string())
438            .text("type", info.r#type.to_string());
439        if !info.subfolder.is_empty() {
440            form = form.text("subfolder", info.subfolder.to_string());
441        }
442
443        let resp = self
444            .http_client
445            .post(self.base_url.join("upload/image")?)
446            .multipart(form)
447            .send()
448            .await?;
449
450        let resp = Self::error_for_status(resp).await?;
451        Ok(resp.json().await?)
452    }
453
454    /// Checks the HTTP response status code and returns an error if it
455    /// indicates failure.
456    ///
457    /// If the response status is a client or server error, this method attempts
458    /// to parse the response body as JSON. If parsing fails, it returns the
459    /// body as text.
460    ///
461    /// # Parameters
462    ///
463    /// - `resp`: The HTTP response to evaluate.
464    ///
465    /// # Returns
466    ///
467    /// The original response if the status is successful, or an error if the
468    /// status indicates a failure.
469    async fn error_for_status(resp: Response) -> ClientResult<Response> {
470        let status = resp.status();
471        if status.is_client_error() || status.is_server_error() {
472            let body = resp.text().await?;
473            let body = match serde_json::from_str::<Value>(&body) {
474                Ok(value) => ApiBody::Json(value),
475                Err(_) => ApiBody::Text(body),
476            };
477            Err(ApiError { status, body }.into())
478        } else {
479            Ok(resp)
480        }
481    }
482}
483
484pin_project! {
485    /// A structure representing the event stream received via a websocket connection.
486    ///
487    /// This stream continuously processes events from the ComfyUI service.
488    /// It handles WebSocket connection management including automatic reconnection
489    /// when enabled through the [`ClientBuilder`].
490    ///
491    /// The stream emits various events including:
492    /// - ComfyUI service events (execution status, errors, etc.) via `Event::Comfy`
493    /// - WebSocket connection state changes via `Event::Connection`
494    ///
495    /// All WebSocket communication is managed by a background task, allowing the stream
496    /// to be consumed without worrying about connection details.
497    pub struct EventStream {
498        #[pin]
499        rx_stream: ReceiverStream<ClientResult<Event>>,
500    }
501}
502
503impl EventStream {
504    /// Handles a single websocket message and attempts to parse it as an
505    /// [`Event`].
506    ///
507    /// For text messages, it tries to deserialize the message into a
508    /// [`ComfyEvent`] and wraps it in `Event::Comfy`.
509    /// If deserialization fails, it wraps the raw value as
510    /// `Event::Comfy(ComfyEvent::Unknown)`.
511    /// Non-text message types are ignored and return `None`.
512    ///
513    /// # Parameters
514    ///
515    /// - `msg`: A [`Message`] from the websocket.
516    ///
517    /// # Returns
518    ///
519    /// An `Option<Event>` wrapped in a `ClientResult`. Returns `None` for
520    /// unsupported message types.
521    fn handle_message(msg: Message) -> ClientResult<Option<Event>> {
522        match msg {
523            Message::Text(b) => {
524                trace!(message:% = b.as_str(); "received websocket message");
525                let value = serde_json::from_slice::<Value>(b.as_bytes())?;
526                match serde_json::from_value::<ComfyEvent>(value.clone()) {
527                    Ok(ev) => Ok(Some(Event::Comfy(ev))),
528                    Err(_) => Ok(Some(Event::Comfy(ComfyEvent::Unknown(value)))),
529                }
530            }
531            _ => Ok(None),
532        }
533    }
534}
535
536impl Stream for EventStream {
537    type Item = ClientResult<Event>;
538
539    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
540        let this = self.project();
541        this.rx_stream.poll_next(cx)
542    }
543
544    fn size_hint(&self) -> (usize, Option<usize>) {
545        self.rx_stream.size_hint()
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552
553    #[test]
554    fn test_builder() {
555        let _ = ClientBuilder::new("http://example.org/");
556        let _ = ClientBuilder::new("http://example.org/".parse::<Url>().unwrap());
557    }
558}