comfyui_client/lib.rs
1#![warn(rust_2018_idioms, missing_docs)]
2#![warn(clippy::dbg_macro, clippy::print_stdout)]
3#![doc = include_str!("../README.md")]
4
5/// Module containing error definitions.
6mod errors;
7/// Module containing metadata such as prompt and file information.
8pub mod meta;
9
10pub use crate::errors::{ClientError, ClientResult};
11use crate::meta::{FileInfo, PromptInfo};
12use bytes::Bytes;
13use cfg_if::cfg_if;
14use errors::{ApiBody, ApiError};
15use futures_util::StreamExt;
16use meta::{Event, History, Prompt};
17use reqwest::{
18 Body, IntoUrl, Response,
19 multipart::{self},
20};
21use serde::Serialize;
22use serde_json::{Value, json};
23use std::{
24 collections::HashMap,
25 ops::{Deref, DerefMut},
26};
27use tokio::{sync::mpsc, task::JoinHandle};
28use tokio_stream::wrappers::ReceiverStream;
29use tokio_tungstenite::{
30 connect_async,
31 tungstenite::{self, Message},
32};
33use url::Url;
34use uuid::Uuid;
35
36/// A builder for creating a [`ComfyUIClient`] instance.
37///
38/// This builder helps initialize the client with the provided base URL and sets
39/// up a websocket connection to stream events.
40pub struct ClientBuilder {
41 base_url: Url,
42 channel_bound: usize,
43}
44
45impl ClientBuilder {
46 /// Creates a new [`ClientBuilder`] instance.
47 ///
48 /// # Parameters
49 ///
50 /// - `base_url`: The base URL of the ComfyUI service.
51 ///
52 /// # Returns
53 ///
54 /// A new instance of [`ClientBuilder`] wrapped in a `ClientResult`, or an
55 /// error if the URL is invalid.
56 pub fn new(base_url: impl IntoUrl) -> ClientResult<Self> {
57 Ok(Self {
58 base_url: base_url.into_url()?,
59 channel_bound: 100,
60 })
61 }
62
63 /// Builds the [`ComfyUIClient`] along with an associated [`EventStream`].
64 ///
65 /// This method establishes a websocket connection and spawns an
66 /// asynchronous task to process incoming messages.
67 ///
68 /// # Returns
69 ///
70 /// A tuple containing the [`ComfyUIClient`] and [`EventStream`] on success,
71 /// or an error.
72 pub async fn build(self) -> ClientResult<(ComfyUIClient, EventStream)> {
73 let base_url = self.base_url;
74 let http_client = reqwest::Client::new();
75 let client_id = Uuid::new_v4().to_string();
76
77 let (msg_tx, msg_rx) = mpsc::channel(self.channel_bound);
78
79 let ws_url = Self::generate_websocket_url(base_url.clone(), &client_id)?;
80 let (stream, _) = if ws_url.scheme() == "wss" {
81 cfg_if! {
82 if #[cfg(feature = "rustls")] {
83 let root_store = rustls::RootCertStore {
84 roots: webpki_roots::TLS_SERVER_ROOTS.into(),
85 };
86 let config = rustls::ClientConfig::builder()
87 .with_root_certificates(root_store)
88 .with_no_client_auth();
89
90 tokio_tungstenite::connect_async_tls_with_config(
91 ws_url,
92 None,
93 false,
94 Some(tokio_tungstenite::Connector::Rustls(std::sync::Arc::new(config))),
95 )
96 .await?
97 } else {
98 connect_async(ws_url).await?
99 }
100 }
101 } else {
102 connect_async(ws_url).await?
103 };
104
105 let stream_handle = tokio::spawn(async move {
106 let (_, mut read_stream) = stream.split();
107 while let Some(msg) = read_stream.next().await {
108 let msg = EventStream::handle_message(msg);
109 let Some(msg) = msg.transpose() else {
110 continue;
111 };
112 if msg_tx.send(msg).await.is_err() {
113 break;
114 }
115 }
116 });
117
118 let rx_stream = ReceiverStream::new(msg_rx);
119
120 let client = ComfyUIClient {
121 base_url,
122 http_client,
123 client_id,
124 };
125
126 let stream = EventStream {
127 stream_handle,
128 rx_stream,
129 };
130
131 Ok((client, stream))
132 }
133
134 /// Generates the websocket URL based on the base URL and client ID.
135 ///
136 /// This method changes the URL scheme to `wss` if the base URL uses HTTPS,
137 /// or `ws` otherwise, appends the `ws` path, and adds a query parameter
138 /// for `clientId`.
139 ///
140 /// # Parameters
141 ///
142 /// - `base_url`: The base URL of the ComfyUI service.
143 /// - `client_id`: The unique identifier for the client.
144 ///
145 /// # Returns
146 ///
147 /// The generated websocket URL on success, or an error if the URL cannot be
148 /// modified.
149 fn generate_websocket_url(base_url: Url, client_id: &str) -> ClientResult<Url> {
150 let mut ws_url = base_url;
151 let scheme = if ws_url.scheme() == "https" {
152 "wss"
153 } else {
154 "ws"
155 };
156 ws_url
157 .set_scheme(scheme)
158 .map_err(|_| ClientError::SetWsScheme)?;
159 ws_url = ws_url.join("ws")?;
160 ws_url.query_pairs_mut().append_pair("clientId", client_id);
161 Ok(ws_url)
162 }
163}
164
165/// A client for interacting with the ComfyUI service.
166///
167/// This client provides methods to fetch history, prompts, views, and to upload
168/// images.
169pub struct ComfyUIClient {
170 client_id: String,
171 base_url: Url,
172 http_client: reqwest::Client,
173}
174
175impl ComfyUIClient {
176 /// Retrieves the history for a specified prompt.
177 ///
178 /// Sends a GET request to the `history/{prompt_id}` endpoint and parses the
179 /// returned history data.
180 ///
181 /// # Parameters
182 ///
183 /// - `prompt_id`: The ID of the prompt whose history is being requested.
184 ///
185 /// # Returns
186 ///
187 /// An optional [`History`] object wrapped in a `ClientResult`. Returns
188 /// `None` if the history is not found.
189 pub async fn get_history(&self, prompt_id: &str) -> ClientResult<Option<History>> {
190 let resp = self
191 .http_client
192 .get(self.base_url.join(&format!("history/{prompt_id}"))?)
193 .send()
194 .await?;
195 let resp = Self::error_for_status(resp).await?;
196 let mut histories = resp.json::<HashMap<String, History>>().await?;
197 Ok(histories.remove(prompt_id))
198 }
199
200 /// Retrieves the current prompt information.
201 ///
202 /// Sends a GET request to the `prompt` endpoint and returns the parsed
203 /// [`PromptInfo`] data.
204 ///
205 /// # Returns
206 ///
207 /// A [`PromptInfo`] object on success, or an error.
208 pub async fn get_prompt(&self) -> ClientResult<PromptInfo> {
209 let resp = self
210 .http_client
211 .get(self.base_url.join("prompt")?)
212 .send()
213 .await?;
214 let resp = Self::error_for_status(resp).await?;
215 Ok(resp.json().await?)
216 }
217
218 /// Retrieves view data corresponding to the provided file information.
219 ///
220 /// Sends a GET request to the `view` endpoint, including the file
221 /// information as query parameters.
222 ///
223 /// # Parameters
224 ///
225 /// - `file_info`: A [`FileInfo`] object containing details about the file.
226 ///
227 /// # Returns
228 ///
229 /// The response as a [`Bytes`] object on success, or an error.
230 pub async fn get_view(&self, file_info: &FileInfo) -> ClientResult<Bytes> {
231 let resp = self
232 .http_client
233 .get(self.base_url.join("view")?)
234 .query(file_info)
235 .send()
236 .await?;
237 let resp = Self::error_for_status(resp).await?;
238 Ok(resp.bytes().await?)
239 }
240
241 /// Sends a prompt in string format.
242 ///
243 /// Parses the input string as JSON and calls [`Self::post_prompt_value`] to
244 /// send the prompt.
245 ///
246 /// # Parameters
247 ///
248 /// - `prompt`: A string slice representing the prompt in JSON format.
249 ///
250 /// # Returns
251 ///
252 /// A [`Prompt`] object on success, or an error.
253 pub async fn post_prompt_str(&self, prompt: &str) -> ClientResult<Prompt> {
254 let prompt = serde_json::from_str::<Value>(prompt)?;
255 self.post_prompt_value(&prompt).await
256 }
257
258 /// Sends a prompt from any serializable data.
259 ///
260 /// Converts the provided data into JSON and calls
261 /// [`Self::post_prompt_value`] to send the prompt.
262 ///
263 /// # Parameters
264 ///
265 /// - `prompt`: A reference to any data that implements [`Serialize`].
266 ///
267 /// # Returns
268 ///
269 /// A [`Prompt`] object on success, or an error.
270 pub async fn post_prompt<T: Serialize>(&self, prompt: &T) -> ClientResult<Prompt> {
271 let prompt = serde_json::to_value(prompt)?;
272 self.post_prompt_value(&prompt).await
273 }
274
275 /// Sends a prompt in JSON format.
276 ///
277 /// Constructs the request payload (including the client ID and prompt data)
278 /// and sends a POST request to the `prompt` endpoint.
279 ///
280 /// # Parameters
281 ///
282 /// - `prompt`: A JSON value representing the prompt data.
283 ///
284 /// # Returns
285 ///
286 /// A [`Prompt`] object on success, or an error.
287 pub async fn post_prompt_value(&self, prompt: &Value) -> ClientResult<Prompt> {
288 let data = json!({"client_id": &self.client_id, "prompt": prompt});
289 let resp = self
290 .http_client
291 .post(self.base_url.join("prompt")?)
292 .json(&data)
293 .send()
294 .await?;
295 let resp = Self::error_for_status(resp).await?;
296 Ok(resp.json().await?)
297 }
298
299 /// Uploads an image.
300 ///
301 /// Constructs a multipart form containing the image data and file
302 /// information, then sends a POST request to the `upload/image` endpoint.
303 ///
304 /// # Parameters
305 ///
306 /// - `body`: The image data, convertible into a [`Body`].
307 /// - `info`: A [`FileInfo`] object containing details about the image file.
308 /// - `overwrite`: A boolean indicating whether to overwrite an existing
309 /// file.
310 ///
311 /// # Returns
312 ///
313 /// An updated [`FileInfo`] object on success, or an error.
314 pub async fn upload_image(
315 &self, body: impl Into<Body>, info: &FileInfo, overwrite: bool,
316 ) -> ClientResult<FileInfo> {
317 let part = multipart::Part::stream(body).file_name(info.filename.to_string());
318 let mut form = multipart::Form::new()
319 .part("image", part)
320 .text("overwrite", overwrite.to_string())
321 .text("type", info.r#type.to_string());
322 if !info.subfolder.is_empty() {
323 form = form.text("subfolder", info.subfolder.to_string());
324 }
325
326 let resp = self
327 .http_client
328 .post(self.base_url.join("upload/image")?)
329 .multipart(form)
330 .send()
331 .await?;
332
333 let resp = Self::error_for_status(resp).await?;
334 Ok(resp.json().await?)
335 }
336
337 /// Checks the HTTP response status code and returns an error if it
338 /// indicates failure.
339 ///
340 /// If the response status is a client or server error, this method attempts
341 /// to parse the response body as JSON. If parsing fails, it returns the
342 /// body as text.
343 ///
344 /// # Parameters
345 ///
346 /// - `resp`: The HTTP response to evaluate.
347 ///
348 /// # Returns
349 ///
350 /// The original response if the status is successful, or an error if the
351 /// status indicates a failure.
352 async fn error_for_status(resp: Response) -> ClientResult<Response> {
353 let status = resp.status();
354 if status.is_client_error() || status.is_server_error() {
355 let body = resp.text().await?;
356 let body = match serde_json::from_str::<Value>(&body) {
357 Ok(value) => ApiBody::Json(value),
358 Err(_) => ApiBody::Text(body),
359 };
360 Err(ApiError { status, body }.into())
361 } else {
362 Ok(resp)
363 }
364 }
365}
366
367/// A structure representing the event stream received via a websocket
368/// connection.
369///
370/// This stream continuously processes events from the ComfyUI service.
371pub struct EventStream {
372 stream_handle: JoinHandle<()>,
373 rx_stream: ReceiverStream<ClientResult<Event>>,
374}
375
376impl EventStream {
377 /// Handles a single websocket message and attempts to parse it as an
378 /// [`Event`].
379 ///
380 /// For text messages, it tries to deserialize the message into an
381 /// [`Event`]. If the deserialization fails, it wraps the message as
382 /// [`Event::Unknown`]. Other message types are ignored.
383 ///
384 /// # Parameters
385 ///
386 /// - `msg`: A result containing a [`Message`] from the websocket.
387 ///
388 /// # Returns
389 ///
390 /// An `Option<Event>` wrapped in a `ClientResult`. Returns `None` for
391 /// unsupported message types.
392 fn handle_message(msg: tungstenite::Result<Message>) -> ClientResult<Option<Event>> {
393 let msg = msg?;
394 match msg {
395 Message::Text(b) => {
396 let value = serde_json::from_slice::<Value>(b.as_bytes())?;
397 match serde_json::from_slice::<Event>(b.as_bytes()) {
398 Ok(ev) => Ok(Some(ev)),
399 Err(_) => Ok(Some(Event::Unknown(value))),
400 }
401 }
402 _ => Ok(None),
403 }
404 }
405}
406
407impl Drop for EventStream {
408 /// When the [`EventStream`] is dropped, abort the associated websocket
409 /// handling task.
410 fn drop(&mut self) {
411 self.stream_handle.abort();
412 }
413}
414
415impl Deref for EventStream {
416 type Target = ReceiverStream<ClientResult<Event>>;
417
418 /// Allows access to the inner [`ReceiverStream`] containing the events.
419 fn deref(&self) -> &Self::Target {
420 &self.rx_stream
421 }
422}
423
424impl DerefMut for EventStream {
425 /// Allows mutable access to the inner [`ReceiverStream`].
426 fn deref_mut(&mut self) -> &mut Self::Target {
427 &mut self.rx_stream
428 }
429}