comfyui_client/lib.rs
1#![warn(rust_2018_idioms, missing_docs)]
2#![warn(clippy::dbg_macro, clippy::print_stdout)]
3#![doc = include_str!("../README.md")]
4
5/// Module containing error definitions.
6pub mod errors;
7/// Module containing metadata such as prompt and file information.
8pub mod meta;
9
10pub use crate::errors::{ClientError, ClientResult};
11use crate::meta::{FileInfo, PromptInfo};
12use bytes::Bytes;
13use errors::{ApiBody, ApiError};
14use futures_util::stream::{Stream, StreamExt};
15use log::trace;
16use meta::{ComfyEvent, ConnectionEvent, Event, History, Prompt, PromptStatus};
17use pin_project_lite::pin_project;
18use reqwest::{
19 Body, IntoUrl, Response,
20 multipart::{self},
21};
22use serde_json::{Value, json};
23use std::{
24 collections::HashMap,
25 pin::Pin,
26 task::{Context, Poll},
27};
28use tokio::{
29 sync::mpsc,
30 time::{Duration, sleep},
31};
32use tokio_stream::wrappers::ReceiverStream;
33use tokio_tungstenite::{connect_async, tungstenite::Message};
34use url::Url;
35use uuid::Uuid;
36
37/// A builder for creating a [`ComfyUIClient`] instance.
38///
39/// This builder helps initialize the client with the provided base URL and sets
40/// up a websocket connection to stream events.
41pub struct ClientBuilder<U> {
42 base_url: U,
43 channel_bound: usize,
44 reconnect_web_socket: bool,
45}
46
47impl<U: IntoUrl> ClientBuilder<U> {
48 /// Creates a new [`ClientBuilder`] instance.
49 ///
50 /// # Parameters
51 ///
52 /// - `base_url`: The base URL of the ComfyUI service.
53 ///
54 /// # Returns
55 ///
56 /// A new instance of [`ClientBuilder`] wrapped in a `ClientResult`, or an
57 /// error if the URL is invalid.
58 pub fn new(base_url: U) -> Self {
59 Self {
60 base_url,
61 channel_bound: 100,
62 reconnect_web_socket: true,
63 }
64 }
65
66 /// Sets the capacity of the internal channel used for event streaming.
67 ///
68 /// This controls how many events can be buffered before backpressure is
69 /// applied. The default value is 100.
70 ///
71 /// # Parameters
72 ///
73 /// - `channel_bound`: The maximum number of events the channel can hold.
74 ///
75 /// # Returns
76 ///
77 /// The updated [`ClientBuilder`] instance.
78 pub fn channel_bound(mut self, channel_bound: usize) -> Self {
79 self.channel_bound = channel_bound;
80 self
81 }
82
83 /// Sets whether the websocket should attempt to reconnect automatically
84 /// when disconnected.
85 ///
86 /// By default, reconnection is enabled (`true`).
87 ///
88 /// # Parameters
89 ///
90 /// - `reconnect`: Whether to attempt reconnection when the WebSocket
91 /// connection drops unexpectedly.
92 ///
93 /// # Returns
94 ///
95 /// The updated [`ClientBuilder`] instance.
96 pub fn reconnect_web_socket(mut self, reconnect: bool) -> Self {
97 self.reconnect_web_socket = reconnect;
98 self
99 }
100
101 /// Builds the [`ComfyUIClient`] along with an associated [`EventStream`]
102 /// and a background task handle.
103 ///
104 /// This method establishes a websocket connection and spawns an
105 /// asynchronous task to process incoming messages. If reconnection is
106 /// enabled, the task will automatically attempt to reconnect when the
107 /// WebSocket connection drops unexpectedly.
108 ///
109 /// # Returns
110 ///
111 /// A tuple containing:
112 /// - The [`ComfyUIClient`] for HTTP API interactions
113 /// - An [`EventStream`] for receiving real-time events
114 ///
115 /// The WebSocket connection will be automatically closed when the
116 /// [`EventStream`] is dropped, as the background task managing the
117 /// connection terminates when the stream is no longer being consumed.
118 ///
119 /// Returns an error if the initial connection cannot be established.
120 pub async fn build(self) -> ClientResult<(ComfyUIClient, EventStream)> {
121 let base_url = self.base_url.into_url()?;
122 let http_client = reqwest::Client::new();
123 let client_id = Uuid::new_v4().to_string();
124 let reconnect_web_socket = self.reconnect_web_socket;
125
126 let (ev_tx, ev_rx) = mpsc::channel(self.channel_bound);
127
128 let ws_url = Self::generate_websocket_url(base_url.clone(), &client_id)?;
129
130 // Initial connection
131 let (ws_stream, _) = connect_async(&ws_url).await?;
132
133 // Spawn the stream handling task with reconnection support
134 tokio::spawn(async move {
135 let (_, mut read_stream) = ws_stream.split();
136
137 loop {
138 // Process messages until the connection drops or channel is closed
139 loop {
140 tokio::select! {
141 // Check for new WebSocket messages
142 msg = read_stream.next() => {
143 match msg {
144 Some(Ok(message)) => {
145 let ev = EventStream::handle_message(message);
146 let Some(ev) = ev.transpose() else {
147 continue;
148 };
149 if ev_tx.send(ev).await.is_err() {
150 return;
151 }
152 }
153 Some(Err(err)) => {
154 // If reconnect is enabled, wrap error in OtherEvent, otherwise pass
155 // through as ClientError
156 if reconnect_web_socket {
157 // Send receive error as an Event::Other
158 if ev_tx
159 .send(Ok(Event::Connection(ConnectionEvent::WSReceiveError(err))))
160 .await.is_err() {
161 return;
162 }
163 } else {
164 // Without reconnect, send as ClientError
165 if ev_tx.send(Err(ClientError::from(err))).await.is_err() {
166 return;
167 }
168 }
169
170 break;
171 }
172 None => {
173 break;
174 }
175 }
176 }
177
178 // Check if the channel is closed
179 _ = ev_tx.closed() => {
180 // Channel is closed, exit immediately
181 return;
182 }
183 }
184 }
185
186 // If reconnect is disabled, exit the loop
187 if !reconnect_web_socket {
188 return;
189 }
190
191 // Attempt to reconnect with a small delay until successful or channel closed
192 loop {
193 tokio::select! {
194 _ = sleep(Duration::from_secs(1)) => {
195 }
196
197 // Check if the channel is closed
198 _ = ev_tx.closed() => {
199 // Channel is closed, exit immediately
200 return;
201 }
202 }
203
204 // Try to establish a new connection
205 tokio::select! {
206 conn_result = connect_async(&ws_url) => {
207 match conn_result {
208 Ok(new_stream) => {
209 // Successfully reconnected
210 (_, read_stream) = new_stream.0.split();
211 // Send reconnection success event
212 if ev_tx
213 .send(Ok(Event::Connection(ConnectionEvent::WSReconnectSuccess)))
214 .await.is_err() {
215 // Channel is closed, exit immediately
216 return;
217 }
218 // Exit the reconnection loop to start using the new read_stream
219 break;
220 }
221 Err(err) => {
222 // Failed to reconnect, send error as Event::Other
223 let err = ClientError::Tungstenite(err);
224 if ev_tx
225 .send(Ok(Event::Connection(ConnectionEvent::WSReconnectError(err))))
226 .await
227 .is_err()
228 {
229 // Channel is closed, exit immediately
230 return;
231 }
232 }
233 }
234 }
235
236 // Check if the channel is closed during connection attempt
237 _ = ev_tx.closed() => {
238 // Channel is closed, exit immediately
239 return;
240 }
241 }
242 }
243 }
244 });
245
246 let rx_stream = ReceiverStream::new(ev_rx);
247
248 let client = ComfyUIClient {
249 base_url,
250 http_client,
251 client_id,
252 };
253
254 let stream = EventStream { rx_stream };
255
256 Ok((client, stream))
257 }
258
259 /// Builds a [`ComfyUIClient`] instance configured for HTTP-only
260 /// communication.
261 ///
262 /// This method initializes the client without establishing a websocket
263 /// connection, enabling you to interact with the ComfyUI service using
264 /// only HTTP (REST) requests.
265 ///
266 /// # Returns
267 ///
268 /// A [`ComfyUIClient`] instance on success, or an error.
269 pub async fn build_only_http(self) -> ClientResult<ComfyUIClient> {
270 let base_url = self.base_url.into_url()?;
271 let http_client = reqwest::Client::new();
272 let client_id = Uuid::new_v4().to_string();
273
274 Ok(ComfyUIClient {
275 base_url,
276 http_client,
277 client_id,
278 })
279 }
280
281 /// Generates the websocket URL based on the base URL and client ID.
282 ///
283 /// This method changes the URL scheme to `wss` if the base URL uses HTTPS,
284 /// or `ws` otherwise, appends the `ws` path, and adds a query parameter
285 /// for `clientId`.
286 ///
287 /// # Parameters
288 ///
289 /// - `base_url`: The base URL of the ComfyUI service.
290 /// - `client_id`: The unique identifier for the client.
291 ///
292 /// # Returns
293 ///
294 /// The generated websocket URL on success, or an error if the URL cannot be
295 /// modified.
296 fn generate_websocket_url(base_url: Url, client_id: &str) -> ClientResult<Url> {
297 let mut ws_url = base_url;
298 let scheme = if ws_url.scheme() == "https" {
299 "wss"
300 } else {
301 "ws"
302 };
303 ws_url
304 .set_scheme(scheme)
305 .map_err(|_| ClientError::SetWsScheme)?;
306 ws_url = ws_url.join("ws")?;
307 ws_url.query_pairs_mut().append_pair("clientId", client_id);
308 Ok(ws_url)
309 }
310}
311
312/// A client for interacting with the ComfyUI service.
313///
314/// This client provides methods to fetch history, prompts, views, and to upload
315/// images.
316pub struct ComfyUIClient {
317 client_id: String,
318 base_url: Url,
319 http_client: reqwest::Client,
320}
321
322impl ComfyUIClient {
323 /// Retrieves the history for a specified prompt.
324 ///
325 /// Sends a GET request to the `history/{prompt_id}` endpoint and parses the
326 /// returned history data.
327 ///
328 /// # Parameters
329 ///
330 /// - `prompt_id`: The ID of the prompt whose history is being requested.
331 ///
332 /// # Returns
333 ///
334 /// An optional [`History`] object wrapped in a `ClientResult`. Returns
335 /// `None` if the history is not found.
336 pub async fn get_history(&self, prompt_id: &str) -> ClientResult<Option<History>> {
337 let resp = self
338 .http_client
339 .get(self.base_url.join(&format!("history/{prompt_id}"))?)
340 .send()
341 .await?;
342 let resp = Self::error_for_status(resp).await?;
343 let mut histories = resp.json::<HashMap<String, History>>().await?;
344 Ok(histories.remove(prompt_id))
345 }
346
347 /// Retrieves the current prompt information.
348 ///
349 /// Sends a GET request to the `prompt` endpoint and returns the parsed
350 /// [`PromptInfo`] data.
351 ///
352 /// # Returns
353 ///
354 /// A [`PromptInfo`] object on success, or an error.
355 pub async fn get_prompt(&self) -> ClientResult<PromptInfo> {
356 let resp = self
357 .http_client
358 .get(self.base_url.join("prompt")?)
359 .send()
360 .await?;
361 let resp = Self::error_for_status(resp).await?;
362 Ok(resp.json().await?)
363 }
364
365 /// Retrieves view data corresponding to the provided file information.
366 ///
367 /// Sends a GET request to the `view` endpoint, including the file
368 /// information as query parameters.
369 ///
370 /// # Parameters
371 ///
372 /// - `file_info`: A [`FileInfo`] object containing details about the file.
373 ///
374 /// # Returns
375 ///
376 /// The response as a [`Bytes`] object on success, or an error.
377 pub async fn get_view(&self, file_info: &FileInfo) -> ClientResult<Bytes> {
378 let resp = self
379 .http_client
380 .get(self.base_url.join("view")?)
381 .query(file_info)
382 .send()
383 .await?;
384 let resp = Self::error_for_status(resp).await?;
385 Ok(resp.bytes().await?)
386 }
387
388 /// Sends a prompt in JSON format.
389 ///
390 /// Constructs the request payload (including the client ID and prompt data)
391 /// and sends a POST request to the `prompt` endpoint.
392 ///
393 /// # Parameters
394 ///
395 /// - `prompt`: representing the prompt data.
396 ///
397 /// # Returns
398 ///
399 /// A [`PromptStatus`] object on success, or an error.
400 pub async fn post_prompt(&self, prompt: impl Into<Prompt<'_>>) -> ClientResult<PromptStatus> {
401 let prompt = match prompt.into() {
402 Prompt::Str(prompt) => &serde_json::from_str::<Value>(prompt)?,
403 Prompt::Value(prompt) => prompt,
404 };
405 let data = json!({"client_id": &self.client_id, "prompt": prompt});
406 let resp = self
407 .http_client
408 .post(self.base_url.join("prompt")?)
409 .json(&data)
410 .send()
411 .await?;
412 let resp = Self::error_for_status(resp).await?;
413 Ok(resp.json().await?)
414 }
415
416 /// Uploads an image.
417 ///
418 /// Constructs a multipart form containing the image data and file
419 /// information, then sends a POST request to the `upload/image` endpoint.
420 ///
421 /// # Parameters
422 ///
423 /// - `body`: The image data, convertible into a [`Body`].
424 /// - `info`: A [`FileInfo`] object containing details about the image file.
425 /// - `overwrite`: A boolean indicating whether to overwrite an existing
426 /// file.
427 ///
428 /// # Returns
429 ///
430 /// An updated [`FileInfo`] object on success, or an error.
431 pub async fn upload_image(
432 &self, body: impl Into<Body>, info: &FileInfo, overwrite: bool,
433 ) -> ClientResult<FileInfo> {
434 let part = multipart::Part::stream(body).file_name(info.filename.to_string());
435 let mut form = multipart::Form::new()
436 .part("image", part)
437 .text("overwrite", overwrite.to_string())
438 .text("type", info.r#type.to_string());
439 if !info.subfolder.is_empty() {
440 form = form.text("subfolder", info.subfolder.to_string());
441 }
442
443 let resp = self
444 .http_client
445 .post(self.base_url.join("upload/image")?)
446 .multipart(form)
447 .send()
448 .await?;
449
450 let resp = Self::error_for_status(resp).await?;
451 Ok(resp.json().await?)
452 }
453
454 /// Checks the HTTP response status code and returns an error if it
455 /// indicates failure.
456 ///
457 /// If the response status is a client or server error, this method attempts
458 /// to parse the response body as JSON. If parsing fails, it returns the
459 /// body as text.
460 ///
461 /// # Parameters
462 ///
463 /// - `resp`: The HTTP response to evaluate.
464 ///
465 /// # Returns
466 ///
467 /// The original response if the status is successful, or an error if the
468 /// status indicates a failure.
469 async fn error_for_status(resp: Response) -> ClientResult<Response> {
470 let status = resp.status();
471 if status.is_client_error() || status.is_server_error() {
472 let body = resp.text().await?;
473 let body = match serde_json::from_str::<Value>(&body) {
474 Ok(value) => ApiBody::Json(value),
475 Err(_) => ApiBody::Text(body),
476 };
477 Err(ApiError { status, body }.into())
478 } else {
479 Ok(resp)
480 }
481 }
482}
483
484pin_project! {
485 /// A structure representing the event stream received via a websocket connection.
486 ///
487 /// This stream continuously processes events from the ComfyUI service.
488 /// It handles WebSocket connection management including automatic reconnection
489 /// when enabled through the [`ClientBuilder`].
490 ///
491 /// The stream emits various events including:
492 /// - ComfyUI service events (execution status, errors, etc.) via `Event::Comfy`
493 /// - WebSocket connection state changes via `Event::Connection`
494 ///
495 /// All WebSocket communication is managed by a background task, allowing the stream
496 /// to be consumed without worrying about connection details.
497 pub struct EventStream {
498 #[pin]
499 rx_stream: ReceiverStream<ClientResult<Event>>,
500 }
501}
502
503impl EventStream {
504 /// Handles a single websocket message and attempts to parse it as an
505 /// [`Event`].
506 ///
507 /// For text messages, it tries to deserialize the message into a
508 /// [`ComfyEvent`] and wraps it in `Event::Comfy`.
509 /// If deserialization fails, it wraps the raw value as
510 /// `Event::Comfy(ComfyEvent::Unknown)`.
511 /// Non-text message types are ignored and return `None`.
512 ///
513 /// # Parameters
514 ///
515 /// - `msg`: A [`Message`] from the websocket.
516 ///
517 /// # Returns
518 ///
519 /// An `Option<Event>` wrapped in a `ClientResult`. Returns `None` for
520 /// unsupported message types.
521 fn handle_message(msg: Message) -> ClientResult<Option<Event>> {
522 match msg {
523 Message::Text(b) => {
524 trace!(message:% = b.as_str(); "received websocket message");
525 let value = serde_json::from_slice::<Value>(b.as_bytes())?;
526 match serde_json::from_value::<ComfyEvent>(value.clone()) {
527 Ok(ev) => Ok(Some(Event::Comfy(ev))),
528 Err(_) => Ok(Some(Event::Comfy(ComfyEvent::Unknown(value)))),
529 }
530 }
531 _ => Ok(None),
532 }
533 }
534}
535
536impl Stream for EventStream {
537 type Item = ClientResult<Event>;
538
539 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
540 let this = self.project();
541 this.rx_stream.poll_next(cx)
542 }
543
544 fn size_hint(&self) -> (usize, Option<usize>) {
545 self.rx_stream.size_hint()
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552
553 #[test]
554 fn test_builder() {
555 let _ = ClientBuilder::new("http://example.org/");
556 let _ = ClientBuilder::new("http://example.org/".parse::<Url>().unwrap());
557 }
558}