comfyui_client/lib.rs
1#![warn(rust_2018_idioms, missing_docs)]
2#![warn(clippy::dbg_macro, clippy::print_stdout)]
3#![doc = include_str!("../README.md")]
4
5/// Module containing error definitions.
6pub mod errors;
7/// Module containing metadata such as prompt and file information.
8pub mod meta;
9
10pub use crate::errors::{ClientError, ClientResult};
11use crate::meta::{FileInfo, PromptInfo};
12use bytes::Bytes;
13use cfg_if::cfg_if;
14use errors::{ApiBody, ApiError};
15use futures_util::StreamExt;
16use meta::{Event, History, Prompt};
17use reqwest::{
18 Body, IntoUrl, Response,
19 multipart::{self},
20};
21use serde::Serialize;
22use serde_json::{Value, json};
23use std::{
24 collections::HashMap,
25 ops::{Deref, DerefMut},
26};
27use tokio::{sync::mpsc, task::JoinHandle};
28use tokio_stream::wrappers::ReceiverStream;
29use tokio_tungstenite::{
30 connect_async,
31 tungstenite::{self, Message},
32};
33use url::Url;
34use uuid::Uuid;
35
36/// A builder for creating a [`ComfyUIClient`] instance.
37///
38/// This builder helps initialize the client with the provided base URL and sets
39/// up a websocket connection to stream events.
40pub struct ClientBuilder {
41 base_url: Url,
42 channel_bound: usize,
43}
44
45impl ClientBuilder {
46 /// Creates a new [`ClientBuilder`] instance.
47 ///
48 /// # Parameters
49 ///
50 /// - `base_url`: The base URL of the ComfyUI service.
51 ///
52 /// # Returns
53 ///
54 /// A new instance of [`ClientBuilder`] wrapped in a `ClientResult`, or an
55 /// error if the URL is invalid.
56 pub fn new(base_url: impl IntoUrl) -> ClientResult<Self> {
57 Ok(Self {
58 base_url: base_url.into_url()?,
59 channel_bound: 100,
60 })
61 }
62
63 /// Builds the [`ComfyUIClient`] along with an associated [`EventStream`].
64 ///
65 /// This method establishes a websocket connection and spawns an
66 /// asynchronous task to process incoming messages.
67 ///
68 /// # Returns
69 ///
70 /// A tuple containing the [`ComfyUIClient`] and [`EventStream`] on success,
71 /// or an error.
72 pub async fn build(self) -> ClientResult<(ComfyUIClient, EventStream)> {
73 let base_url = self.base_url;
74 let http_client = reqwest::Client::new();
75 let client_id = Uuid::new_v4().to_string();
76
77 let (ev_tx, ev_rx) = mpsc::channel(self.channel_bound);
78
79 let ws_url = Self::generate_websocket_url(base_url.clone(), &client_id)?;
80 let (stream, _) = if ws_url.scheme() == "wss" {
81 cfg_if! {
82 if #[cfg(feature = "rustls")] {
83 let root_store = rustls::RootCertStore {
84 roots: webpki_roots::TLS_SERVER_ROOTS.into(),
85 };
86 let config = rustls::ClientConfig::builder()
87 .with_root_certificates(root_store)
88 .with_no_client_auth();
89
90 tokio_tungstenite::connect_async_tls_with_config(
91 ws_url,
92 None,
93 false,
94 Some(tokio_tungstenite::Connector::Rustls(std::sync::Arc::new(config))),
95 )
96 .await?
97 } else {
98 connect_async(ws_url).await?
99 }
100 }
101 } else {
102 connect_async(ws_url).await?
103 };
104
105 let stream_handle = tokio::spawn(async move {
106 let (_, mut read_stream) = stream.split();
107 while let Some(msg) = read_stream.next().await {
108 let ev = EventStream::handle_message(msg);
109 let Some(ev) = ev.transpose() else {
110 continue;
111 };
112 if ev_tx.send(ev).await.is_err() {
113 break;
114 }
115 }
116 });
117
118 let rx_stream = ReceiverStream::new(ev_rx);
119
120 let client = ComfyUIClient {
121 base_url,
122 http_client,
123 client_id,
124 };
125
126 let stream = EventStream {
127 stream_handle,
128 rx_stream,
129 };
130
131 Ok((client, stream))
132 }
133
134 /// Builds a [`ComfyUIClient`] instance configured for HTTP-only
135 /// communication.
136 ///
137 /// This method initializes the client without establishing a websocket
138 /// connection, enabling you to interact with the ComfyUI service using
139 /// only HTTP (REST) requests.
140 ///
141 /// # Returns
142 ///
143 /// A [`ComfyUIClient`] instance on success, or an error.
144 pub async fn build_only_http(self) -> ClientResult<ComfyUIClient> {
145 let base_url = self.base_url;
146 let http_client = reqwest::Client::new();
147 let client_id = Uuid::new_v4().to_string();
148
149 Ok(ComfyUIClient {
150 base_url,
151 http_client,
152 client_id,
153 })
154 }
155
156 /// Generates the websocket URL based on the base URL and client ID.
157 ///
158 /// This method changes the URL scheme to `wss` if the base URL uses HTTPS,
159 /// or `ws` otherwise, appends the `ws` path, and adds a query parameter
160 /// for `clientId`.
161 ///
162 /// # Parameters
163 ///
164 /// - `base_url`: The base URL of the ComfyUI service.
165 /// - `client_id`: The unique identifier for the client.
166 ///
167 /// # Returns
168 ///
169 /// The generated websocket URL on success, or an error if the URL cannot be
170 /// modified.
171 fn generate_websocket_url(base_url: Url, client_id: &str) -> ClientResult<Url> {
172 let mut ws_url = base_url;
173 let scheme = if ws_url.scheme() == "https" {
174 "wss"
175 } else {
176 "ws"
177 };
178 ws_url
179 .set_scheme(scheme)
180 .map_err(|_| ClientError::SetWsScheme)?;
181 ws_url = ws_url.join("ws")?;
182 ws_url.query_pairs_mut().append_pair("clientId", client_id);
183 Ok(ws_url)
184 }
185}
186
187/// A client for interacting with the ComfyUI service.
188///
189/// This client provides methods to fetch history, prompts, views, and to upload
190/// images.
191pub struct ComfyUIClient {
192 client_id: String,
193 base_url: Url,
194 http_client: reqwest::Client,
195}
196
197impl ComfyUIClient {
198 /// Retrieves the history for a specified prompt.
199 ///
200 /// Sends a GET request to the `history/{prompt_id}` endpoint and parses the
201 /// returned history data.
202 ///
203 /// # Parameters
204 ///
205 /// - `prompt_id`: The ID of the prompt whose history is being requested.
206 ///
207 /// # Returns
208 ///
209 /// An optional [`History`] object wrapped in a `ClientResult`. Returns
210 /// `None` if the history is not found.
211 pub async fn get_history(&self, prompt_id: &str) -> ClientResult<Option<History>> {
212 let resp = self
213 .http_client
214 .get(self.base_url.join(&format!("history/{prompt_id}"))?)
215 .send()
216 .await?;
217 let resp = Self::error_for_status(resp).await?;
218 let mut histories = resp.json::<HashMap<String, History>>().await?;
219 Ok(histories.remove(prompt_id))
220 }
221
222 /// Retrieves the current prompt information.
223 ///
224 /// Sends a GET request to the `prompt` endpoint and returns the parsed
225 /// [`PromptInfo`] data.
226 ///
227 /// # Returns
228 ///
229 /// A [`PromptInfo`] object on success, or an error.
230 pub async fn get_prompt(&self) -> ClientResult<PromptInfo> {
231 let resp = self
232 .http_client
233 .get(self.base_url.join("prompt")?)
234 .send()
235 .await?;
236 let resp = Self::error_for_status(resp).await?;
237 Ok(resp.json().await?)
238 }
239
240 /// Retrieves view data corresponding to the provided file information.
241 ///
242 /// Sends a GET request to the `view` endpoint, including the file
243 /// information as query parameters.
244 ///
245 /// # Parameters
246 ///
247 /// - `file_info`: A [`FileInfo`] object containing details about the file.
248 ///
249 /// # Returns
250 ///
251 /// The response as a [`Bytes`] object on success, or an error.
252 pub async fn get_view(&self, file_info: &FileInfo) -> ClientResult<Bytes> {
253 let resp = self
254 .http_client
255 .get(self.base_url.join("view")?)
256 .query(file_info)
257 .send()
258 .await?;
259 let resp = Self::error_for_status(resp).await?;
260 Ok(resp.bytes().await?)
261 }
262
263 /// Sends a prompt in string format.
264 ///
265 /// Parses the input string as JSON and calls [`Self::post_prompt_value`] to
266 /// send the prompt.
267 ///
268 /// # Parameters
269 ///
270 /// - `prompt`: A string slice representing the prompt in JSON format.
271 ///
272 /// # Returns
273 ///
274 /// A [`Prompt`] object on success, or an error.
275 pub async fn post_prompt_str(&self, prompt: &str) -> ClientResult<Prompt> {
276 let prompt = serde_json::from_str::<Value>(prompt)?;
277 self.post_prompt_value(&prompt).await
278 }
279
280 /// Sends a prompt from any serializable data.
281 ///
282 /// Converts the provided data into JSON and calls
283 /// [`Self::post_prompt_value`] to send the prompt.
284 ///
285 /// # Parameters
286 ///
287 /// - `prompt`: A reference to any data that implements [`Serialize`].
288 ///
289 /// # Returns
290 ///
291 /// A [`Prompt`] object on success, or an error.
292 pub async fn post_prompt<T: Serialize>(&self, prompt: &T) -> ClientResult<Prompt> {
293 let prompt = serde_json::to_value(prompt)?;
294 self.post_prompt_value(&prompt).await
295 }
296
297 /// Sends a prompt in JSON format.
298 ///
299 /// Constructs the request payload (including the client ID and prompt data)
300 /// and sends a POST request to the `prompt` endpoint.
301 ///
302 /// # Parameters
303 ///
304 /// - `prompt`: A JSON value representing the prompt data.
305 ///
306 /// # Returns
307 ///
308 /// A [`Prompt`] object on success, or an error.
309 pub async fn post_prompt_value(&self, prompt: &Value) -> ClientResult<Prompt> {
310 let data = json!({"client_id": &self.client_id, "prompt": prompt});
311 let resp = self
312 .http_client
313 .post(self.base_url.join("prompt")?)
314 .json(&data)
315 .send()
316 .await?;
317 let resp = Self::error_for_status(resp).await?;
318 Ok(resp.json().await?)
319 }
320
321 /// Uploads an image.
322 ///
323 /// Constructs a multipart form containing the image data and file
324 /// information, then sends a POST request to the `upload/image` endpoint.
325 ///
326 /// # Parameters
327 ///
328 /// - `body`: The image data, convertible into a [`Body`].
329 /// - `info`: A [`FileInfo`] object containing details about the image file.
330 /// - `overwrite`: A boolean indicating whether to overwrite an existing
331 /// file.
332 ///
333 /// # Returns
334 ///
335 /// An updated [`FileInfo`] object on success, or an error.
336 pub async fn upload_image(
337 &self, body: impl Into<Body>, info: &FileInfo, overwrite: bool,
338 ) -> ClientResult<FileInfo> {
339 let part = multipart::Part::stream(body).file_name(info.filename.to_string());
340 let mut form = multipart::Form::new()
341 .part("image", part)
342 .text("overwrite", overwrite.to_string())
343 .text("type", info.r#type.to_string());
344 if !info.subfolder.is_empty() {
345 form = form.text("subfolder", info.subfolder.to_string());
346 }
347
348 let resp = self
349 .http_client
350 .post(self.base_url.join("upload/image")?)
351 .multipart(form)
352 .send()
353 .await?;
354
355 let resp = Self::error_for_status(resp).await?;
356 Ok(resp.json().await?)
357 }
358
359 /// Checks the HTTP response status code and returns an error if it
360 /// indicates failure.
361 ///
362 /// If the response status is a client or server error, this method attempts
363 /// to parse the response body as JSON. If parsing fails, it returns the
364 /// body as text.
365 ///
366 /// # Parameters
367 ///
368 /// - `resp`: The HTTP response to evaluate.
369 ///
370 /// # Returns
371 ///
372 /// The original response if the status is successful, or an error if the
373 /// status indicates a failure.
374 async fn error_for_status(resp: Response) -> ClientResult<Response> {
375 let status = resp.status();
376 if status.is_client_error() || status.is_server_error() {
377 let body = resp.text().await?;
378 let body = match serde_json::from_str::<Value>(&body) {
379 Ok(value) => ApiBody::Json(value),
380 Err(_) => ApiBody::Text(body),
381 };
382 Err(ApiError { status, body }.into())
383 } else {
384 Ok(resp)
385 }
386 }
387}
388
389/// A structure representing the event stream received via a websocket
390/// connection.
391///
392/// This stream continuously processes events from the ComfyUI service.
393pub struct EventStream {
394 stream_handle: JoinHandle<()>,
395 rx_stream: ReceiverStream<ClientResult<Event>>,
396}
397
398impl EventStream {
399 /// Handles a single websocket message and attempts to parse it as an
400 /// [`Event`].
401 ///
402 /// For text messages, it tries to deserialize the message into an
403 /// [`Event`]. If the deserialization fails, it wraps the message as
404 /// [`Event::Unknown`]. Other message types are ignored.
405 ///
406 /// # Parameters
407 ///
408 /// - `msg`: A result containing a [`Message`] from the websocket.
409 ///
410 /// # Returns
411 ///
412 /// An `Option<Event>` wrapped in a `ClientResult`. Returns `None` for
413 /// unsupported message types.
414 fn handle_message(msg: tungstenite::Result<Message>) -> ClientResult<Option<Event>> {
415 let msg = msg?;
416 match msg {
417 Message::Text(b) => {
418 let value = serde_json::from_slice::<Value>(b.as_bytes())?;
419 match serde_json::from_value::<Event>(value.clone()) {
420 Ok(ev) => Ok(Some(ev)),
421 Err(_) => Ok(Some(Event::Unknown(value))),
422 }
423 }
424 _ => Ok(None),
425 }
426 }
427}
428
429impl Drop for EventStream {
430 /// When the [`EventStream`] is dropped, abort the associated websocket
431 /// handling task.
432 fn drop(&mut self) {
433 self.stream_handle.abort();
434 }
435}
436
437impl Deref for EventStream {
438 type Target = ReceiverStream<ClientResult<Event>>;
439
440 /// Allows access to the inner [`ReceiverStream`] containing the events.
441 fn deref(&self) -> &Self::Target {
442 &self.rx_stream
443 }
444}
445
446impl DerefMut for EventStream {
447 /// Allows mutable access to the inner [`ReceiverStream`].
448 fn deref_mut(&mut self) -> &mut Self::Target {
449 &mut self.rx_stream
450 }
451}