Skip to main content

worker/
ai.rs

1use crate::streams::ByteStream;
2use crate::{env::EnvBinding, send::SendFuture};
3use crate::{Error, Result};
4use serde::de::DeserializeOwned;
5use serde::Serialize;
6use wasm_bindgen::{JsCast, JsValue};
7use wasm_bindgen_futures::JsFuture;
8use web_sys::ReadableStream;
9use worker_sys::Ai as AiSys;
10
11/// Enables access to Workers AI functionality.
12#[derive(Debug)]
13pub struct Ai(AiSys);
14
15impl Ai {
16    /// Execute a Workers AI operation using the specified model.
17    /// Various forms of the input are documented in the Workers
18    /// AI documentation.
19    pub async fn run<T: Serialize, U: DeserializeOwned>(
20        &self,
21        model: impl AsRef<str>,
22        input: T,
23    ) -> Result<U> {
24        let fut = SendFuture::new(JsFuture::from(
25            self.0
26                .run(model.as_ref(), serde_wasm_bindgen::to_value(&input)?),
27        ));
28        match fut.await {
29            Ok(output) => Ok(serde_wasm_bindgen::from_value(output)?),
30            Err(err) => Err(Error::from(err)),
31        }
32    }
33
34    /// Execute a Workers AI operation that returns binary data as a [`ByteStream`].
35    ///
36    /// This method is designed for AI models that return raw bytes, such as:
37    /// - Image generation models (e.g., Stable Diffusion)
38    /// - Text-to-speech models
39    /// - Any other model that returns binary output
40    ///
41    /// The returned [`ByteStream`] implements [`Stream`](futures_util::Stream) and can be:
42    /// - Streamed directly to a [`Response`] using [`Response::from_stream`]
43    /// - Collected into a `Vec<u8>` by iterating over the chunks
44    ///
45    /// # Examples
46    ///
47    /// ## Streaming directly to a response (recommended)
48    ///
49    /// This approach is more memory-efficient as it doesn't buffer the entire
50    /// response in memory:
51    ///
52    /// ```ignore
53    /// use worker::*;
54    /// use serde::Serialize;
55    ///
56    /// #[derive(Serialize)]
57    /// struct ImageGenRequest {
58    ///     prompt: String,
59    /// }
60    ///
61    /// async fn generate_image(env: &Env) -> Result<Response> {
62    ///     let ai = env.ai("AI")?;
63    ///     let request = ImageGenRequest {
64    ///         prompt: "a beautiful sunset".to_string(),
65    ///     };
66    ///     let stream = ai.run_bytes(
67    ///         "@cf/stabilityai/stable-diffusion-xl-base-1.0",
68    ///         &request
69    ///     ).await?;
70    ///
71    ///     // Stream directly to the response
72    ///     let mut response = Response::from_stream(stream)?;
73    ///     response.headers_mut().set("Content-Type", "image/png")?;
74    ///     Ok(response)
75    /// }
76    /// ```
77    ///
78    /// ## Collecting into bytes
79    ///
80    /// Use this approach if you need to inspect or modify the bytes before sending:
81    ///
82    /// ```ignore
83    /// use worker::*;
84    /// use serde::Serialize;
85    /// use futures_util::StreamExt;
86    ///
87    /// #[derive(Serialize)]
88    /// struct ImageGenRequest {
89    ///     prompt: String,
90    /// }
91    ///
92    /// async fn generate_image(env: &Env) -> Result<Response> {
93    ///     let ai = env.ai("AI")?;
94    ///     let request = ImageGenRequest {
95    ///         prompt: "a beautiful sunset".to_string(),
96    ///     };
97    ///     let mut stream = ai.run_bytes(
98    ///         "@cf/stabilityai/stable-diffusion-xl-base-1.0",
99    ///         &request
100    ///     ).await?;
101    ///
102    ///     // Collect all chunks into a Vec<u8>
103    ///     let mut bytes = Vec::new();
104    ///     while let Some(chunk) = stream.next().await {
105    ///         bytes.extend_from_slice(&chunk?);
106    ///     }
107    ///
108    ///     let mut response = Response::from_bytes(bytes)?;
109    ///     response.headers_mut().set("Content-Type", "image/png")?;
110    ///     Ok(response)
111    /// }
112    /// ```
113    pub async fn run_bytes<T: Serialize>(
114        &self,
115        model: impl AsRef<str>,
116        input: T,
117    ) -> Result<ByteStream> {
118        let fut = SendFuture::new(JsFuture::from(
119            self.0
120                .run(model.as_ref(), serde_wasm_bindgen::to_value(&input)?),
121        ));
122        match fut.await {
123            Ok(output) => {
124                if output.is_instance_of::<ReadableStream>() {
125                    let stream = ReadableStream::unchecked_from_js(output);
126                    Ok(ByteStream::from(stream))
127                } else {
128                    Err(Error::RustError(
129                        "AI model did not return binary data. Use run() for non-binary responses."
130                            .into(),
131                    ))
132                }
133            }
134            Err(err) => Err(Error::from(err)),
135        }
136    }
137}
138
139unsafe impl Sync for Ai {}
140unsafe impl Send for Ai {}
141
142impl From<AiSys> for Ai {
143    fn from(inner: AiSys) -> Self {
144        Self(inner)
145    }
146}
147
148impl AsRef<JsValue> for Ai {
149    fn as_ref(&self) -> &JsValue {
150        &self.0
151    }
152}
153
154impl From<Ai> for JsValue {
155    fn from(database: Ai) -> Self {
156        JsValue::from(database.0)
157    }
158}
159
160impl JsCast for Ai {
161    fn instanceof(val: &JsValue) -> bool {
162        val.is_instance_of::<AiSys>()
163    }
164
165    fn unchecked_from_js(val: JsValue) -> Self {
166        Self(val.into())
167    }
168
169    fn unchecked_from_js_ref(val: &JsValue) -> &Self {
170        unsafe { &*(val as *const JsValue as *const Self) }
171    }
172}
173
174impl EnvBinding for Ai {
175    const TYPE_NAME: &'static str = "Ai";
176
177    fn get(val: JsValue) -> Result<Self> {
178        let obj = js_sys::Object::from(val);
179        if obj.constructor().name() == Self::TYPE_NAME {
180            Ok(obj.unchecked_into())
181        } else {
182            Err(format!(
183                "Binding cannot be cast to the type {} from {}",
184                Self::TYPE_NAME,
185                obj.constructor().name()
186            )
187            .into())
188        }
189    }
190}