worker/ai.rs
1use crate::streams::ByteStream;
2use crate::{env::EnvBinding, send::SendFuture};
3use crate::{Error, Result};
4use serde::de::DeserializeOwned;
5use serde::Serialize;
6use wasm_bindgen::{JsCast, JsValue};
7use wasm_bindgen_futures::JsFuture;
8use web_sys::ReadableStream;
9use worker_sys::Ai as AiSys;
10
11/// Enables access to Workers AI functionality.
12#[derive(Debug)]
13pub struct Ai(AiSys);
14
15impl Ai {
16 /// Execute a Workers AI operation using the specified model.
17 /// Various forms of the input are documented in the Workers
18 /// AI documentation.
19 pub async fn run<T: Serialize, U: DeserializeOwned>(
20 &self,
21 model: impl AsRef<str>,
22 input: T,
23 ) -> Result<U> {
24 let fut = SendFuture::new(JsFuture::from(
25 self.0
26 .run(model.as_ref(), serde_wasm_bindgen::to_value(&input)?),
27 ));
28 match fut.await {
29 Ok(output) => Ok(serde_wasm_bindgen::from_value(output)?),
30 Err(err) => Err(Error::from(err)),
31 }
32 }
33
34 /// Execute a Workers AI operation that returns binary data as a [`ByteStream`].
35 ///
36 /// This method is designed for AI models that return raw bytes, such as:
37 /// - Image generation models (e.g., Stable Diffusion)
38 /// - Text-to-speech models
39 /// - Any other model that returns binary output
40 ///
41 /// The returned [`ByteStream`] implements [`Stream`](futures_util::Stream) and can be:
42 /// - Streamed directly to a [`Response`] using [`Response::from_stream`]
43 /// - Collected into a `Vec<u8>` by iterating over the chunks
44 ///
45 /// # Examples
46 ///
47 /// ## Streaming directly to a response (recommended)
48 ///
49 /// This approach is more memory-efficient as it doesn't buffer the entire
50 /// response in memory:
51 ///
52 /// ```ignore
53 /// use worker::*;
54 /// use serde::Serialize;
55 ///
56 /// #[derive(Serialize)]
57 /// struct ImageGenRequest {
58 /// prompt: String,
59 /// }
60 ///
61 /// async fn generate_image(env: &Env) -> Result<Response> {
62 /// let ai = env.ai("AI")?;
63 /// let request = ImageGenRequest {
64 /// prompt: "a beautiful sunset".to_string(),
65 /// };
66 /// let stream = ai.run_bytes(
67 /// "@cf/stabilityai/stable-diffusion-xl-base-1.0",
68 /// &request
69 /// ).await?;
70 ///
71 /// // Stream directly to the response
72 /// let mut response = Response::from_stream(stream)?;
73 /// response.headers_mut().set("Content-Type", "image/png")?;
74 /// Ok(response)
75 /// }
76 /// ```
77 ///
78 /// ## Collecting into bytes
79 ///
80 /// Use this approach if you need to inspect or modify the bytes before sending:
81 ///
82 /// ```ignore
83 /// use worker::*;
84 /// use serde::Serialize;
85 /// use futures_util::StreamExt;
86 ///
87 /// #[derive(Serialize)]
88 /// struct ImageGenRequest {
89 /// prompt: String,
90 /// }
91 ///
92 /// async fn generate_image(env: &Env) -> Result<Response> {
93 /// let ai = env.ai("AI")?;
94 /// let request = ImageGenRequest {
95 /// prompt: "a beautiful sunset".to_string(),
96 /// };
97 /// let mut stream = ai.run_bytes(
98 /// "@cf/stabilityai/stable-diffusion-xl-base-1.0",
99 /// &request
100 /// ).await?;
101 ///
102 /// // Collect all chunks into a Vec<u8>
103 /// let mut bytes = Vec::new();
104 /// while let Some(chunk) = stream.next().await {
105 /// bytes.extend_from_slice(&chunk?);
106 /// }
107 ///
108 /// let mut response = Response::from_bytes(bytes)?;
109 /// response.headers_mut().set("Content-Type", "image/png")?;
110 /// Ok(response)
111 /// }
112 /// ```
113 pub async fn run_bytes<T: Serialize>(
114 &self,
115 model: impl AsRef<str>,
116 input: T,
117 ) -> Result<ByteStream> {
118 let fut = SendFuture::new(JsFuture::from(
119 self.0
120 .run(model.as_ref(), serde_wasm_bindgen::to_value(&input)?),
121 ));
122 match fut.await {
123 Ok(output) => {
124 if output.is_instance_of::<ReadableStream>() {
125 let stream = ReadableStream::unchecked_from_js(output);
126 Ok(ByteStream::from(stream))
127 } else {
128 Err(Error::RustError(
129 "AI model did not return binary data. Use run() for non-binary responses."
130 .into(),
131 ))
132 }
133 }
134 Err(err) => Err(Error::from(err)),
135 }
136 }
137}
138
139unsafe impl Sync for Ai {}
140unsafe impl Send for Ai {}
141
142impl From<AiSys> for Ai {
143 fn from(inner: AiSys) -> Self {
144 Self(inner)
145 }
146}
147
148impl AsRef<JsValue> for Ai {
149 fn as_ref(&self) -> &JsValue {
150 &self.0
151 }
152}
153
154impl From<Ai> for JsValue {
155 fn from(database: Ai) -> Self {
156 JsValue::from(database.0)
157 }
158}
159
160impl JsCast for Ai {
161 fn instanceof(val: &JsValue) -> bool {
162 val.is_instance_of::<AiSys>()
163 }
164
165 fn unchecked_from_js(val: JsValue) -> Self {
166 Self(val.into())
167 }
168
169 fn unchecked_from_js_ref(val: &JsValue) -> &Self {
170 unsafe { &*(val as *const JsValue as *const Self) }
171 }
172}
173
174impl EnvBinding for Ai {
175 const TYPE_NAME: &'static str = "Ai";
176
177 fn get(val: JsValue) -> Result<Self> {
178 let obj = js_sys::Object::from(val);
179 if obj.constructor().name() == Self::TYPE_NAME {
180 Ok(obj.unchecked_into())
181 } else {
182 Err(format!(
183 "Binding cannot be cast to the type {} from {}",
184 Self::TYPE_NAME,
185 obj.constructor().name()
186 )
187 .into())
188 }
189 }
190}