cloudflare/endpoints/ai/
execute_model.rs

1use serde::{Deserialize, Serialize};
2
3use crate::framework::endpoint::RequestBody;
4use crate::framework::response::ApiSuccess;
5use crate::framework::{
6    endpoint::{EndpointSpec, Method},
7    response::ApiResult,
8};
9
10/// Get an inference from a model.
11#[derive(Clone, Debug, Serialize, Deserialize)]
12pub struct ExecuteModel<'a> {
13    pub account_identifier: &'a str,
14    pub model_name: &'a str,
15    pub params: ExecuteModelParams,
16}
17
18impl EndpointSpec for ExecuteModel<'_> {
19    type JsonResponse = ExecuteModelResult;
20    type ResponseType = ApiSuccess<Self::JsonResponse>;
21
22    fn method(&self) -> Method {
23        Method::POST
24    }
25
26    fn path(&self) -> String {
27        format!(
28            "accounts/{}/ai/run/{}",
29            self.account_identifier, self.model_name
30        )
31    }
32
33    #[inline]
34    fn body(&self) -> Option<RequestBody> {
35        let body = serde_json::to_string(&self.params).unwrap();
36        Some(RequestBody::Json(body))
37    }
38}
39
40/// Represents various inference tasks supported by Workers AI.
41#[derive(Clone, Debug, Serialize, Deserialize)]
42#[serde(untagged)]
43pub enum ExecuteModelParams {
44    /// Text Classification task.
45    ///
46    /// Classifies the input text into predefined categories.
47    TextClassification {
48        /// The text that you want to classify.
49        /// Must be at least 1 character long.
50        text: String,
51    },
52
53    /// Text-to-Image generation task.
54    ///
55    /// Generates an image based on the provided text description.
56    TextToImage(TextToImageParams),
57
58    /// Text-to-Speech generation task.
59    ///
60    /// Converts text into speech.
61    TextToSpeech(TextToSpeechParams),
62
63    /// Text Embedding generation task.
64    ///
65    /// Converts text into numerical embeddings.
66    TextEmbeddings {
67        /// The array of texts to embed.
68        text: Vec<String>,
69    },
70
71    /// Automatic Speech Recognition task.
72    ///
73    /// Converts audio into text, with optional translation.
74    AutomaticSpeechRecognition(AutomaticSpeechRecognitionParams),
75
76    /// Image Classification task.
77    ///
78    /// Classifies an image into predefined categories.
79    ImageClassification {
80        /// An array of integers representing the image data (8-bit unsigned integer values).
81        image: Vec<u8>,
82    },
83
84    /// Object Detection task.
85    ///
86    /// Detects objects in the input image.
87    ObjectDetection {
88        /// An array of integers representing the image data (8-bit unsigned integer values).
89        image: Vec<u8>,
90    },
91
92    /// General Prompt task.
93    ///
94    /// Generates a response based on the provided input text.
95    Prompt(PromptParams),
96
97    /// Messages task.
98    ///
99    /// Handles conversation-based input and output.
100    Messages(MessagesParams),
101
102    /// Translation task.
103    /// Translates text into the specified language.
104    Translation(TranslationParams),
105
106    /// Summarization task.
107    /// Summarizes the provided input text.
108    Summarization(SummarizationParams),
109
110    /// Image-to-Text task.
111    /// Converts an image into text-based descriptions.
112    ImageToText(ImageToTextParams),
113}
114
115/// Parameters for the `TextToImage` task.
116#[derive(Clone, Debug, Default, Serialize, Deserialize)]
117pub struct TextToImageParams {
118    /// A text description of the image to generate.
119    /// Must be at least 1 character long.
120    pub prompt: String,
121
122    /// Controls how closely the generated image should adhere to the prompt.
123    pub guidance: Option<f64>,
124
125    /// The height of the generated image in pixels. Must be between 256 and 2048.
126    pub height: Option<u32>,
127
128    /// An array of integers representing the image data for img2img tasks.
129    pub image: Option<Vec<u8>>,
130
131    /// A base64-encoded string of the input image for img2img tasks.
132    pub image_b64: Option<String>,
133
134    /// An array of integers representing mask image data for inpainting.
135    pub mask: Option<Vec<u8>>,
136
137    /// Text describing elements to avoid in the generated image.
138    pub negative_prompt: Option<String>,
139
140    /// The number of diffusion steps (max 20).
141    pub num_steps: Option<u32>,
142
143    /// Random seed for reproducibility.
144    pub seed: Option<u64>,
145
146    /// Strength of transformation for img2img tasks (0.0 to 1.0).
147    pub strength: Option<f64>,
148
149    /// The width of the generated image in pixels. Must be between 256 and 2048.
150    pub width: Option<u32>,
151}
152
153/// Parameters for the `TextToSpeech` task.
154#[derive(Clone, Debug, Default, Serialize, Deserialize)]
155pub struct TextToSpeechParams {
156    /// The text to generate speech from.
157    /// Must be at least 1 character long.
158    pub prompt: String,
159
160    /// The language for the generated speech. Defaults to "en".
161    pub lang: Option<String>,
162}
163
164/// Parameters for the `AutomaticSpeechRecognition` task.
165#[derive(Clone, Debug, Default, Serialize, Deserialize)]
166pub struct AutomaticSpeechRecognitionParams {
167    /// An array of integers representing the audio data (8-bit unsigned integer values).
168    pub audio: Vec<u8>,
169
170    /// The language of the recorded audio.
171    pub source_lang: Option<String>,
172
173    /// The target language for translation (currently only English is supported).
174    pub target_lang: Option<String>,
175}
176
177/// Parameters for the `Prompt` task.
178#[derive(Clone, Debug, Default, Serialize, Deserialize)]
179pub struct PromptParams {
180    /// The input text prompt for the model.
181    /// Must be between `1` and `131072` characters long.
182    pub prompt: String,
183
184    /// Decreases the likelihood of repeating the same lines verbatim (0 to 2).
185    #[serde(skip_serializing_if = "Option::is_none")]
186    pub frequency_penalty: Option<f64>,
187
188    /// Name of the LoRA (Low-Rank Adaptation) model to fine-tune the base model.
189    #[serde(skip_serializing_if = "Option::is_none")]
190    pub lora: Option<String>,
191
192    /// The maximum number of tokens to generate in the response.
193    #[serde(skip_serializing_if = "Option::is_none")]
194    pub max_tokens: Option<u32>,
195
196    /// Increases the likelihood of introducing new topics (0 to 2).
197    #[serde(skip_serializing_if = "Option::is_none")]
198    pub presence_penalty: Option<f64>,
199
200    /// If `true`, bypasses chat templates and uses the model's raw format.
201    #[serde(skip_serializing_if = "Option::is_none")]
202    pub raw: Option<bool>,
203
204    /// Penalty for repeated tokens (`0` to `2`).
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub repetition_penalty: Option<f64>,
207
208    /// Random seed for reproducibility (`1` to `9999999999`).
209    #[serde(skip_serializing_if = "Option::is_none")]
210    pub seed: Option<u64>,
211
212    /// If `true`, streams the response incrementally using SSE.
213    #[serde(skip_serializing_if = "Option::is_none")]
214    pub stream: Option<bool>,
215
216    /// Controls the randomness of the output (`0` to `5`).
217    #[serde(skip_serializing_if = "Option::is_none")]
218    pub temperature: Option<f64>,
219
220    /// Limits the AI to top 'k' most probable words (`1` to `50`).
221    #[serde(skip_serializing_if = "Option::is_none")]
222    pub top_k: Option<u32>,
223
224    /// Adjusts creativity of responses (`0` to `2`).
225    #[serde(skip_serializing_if = "Option::is_none")]
226    pub top_p: Option<f64>,
227}
228
229/// Parameters for the `Messages` task.
230#[derive(Clone, Debug, Default, Serialize, Deserialize)]
231pub struct MessagesParams {
232    /// The conversation history as an array of message objects.
233    pub messages: Vec<Message>,
234
235    /// Decreases the likelihood of repeating the same lines verbatim (`0` to `2`).
236    #[serde(skip_serializing_if = "Option::is_none")]
237    pub frequency_penalty: Option<f64>,
238
239    /// An array of functions or tools available for the assistant.
240    #[serde(skip_serializing_if = "Option::is_none")]
241    pub functions: Option<Vec<AssistantFunction>>,
242
243    /// The maximum number of tokens to generate in the response.
244    #[serde(skip_serializing_if = "Option::is_none")]
245    pub max_tokens: Option<u32>,
246
247    /// Increases the likelihood of introducing new topics (`0` to `2`).
248    #[serde(skip_serializing_if = "Option::is_none")]
249    pub presence_penalty: Option<f64>,
250
251    /// Penalty for repeated tokens (`0` to `2`).
252    #[serde(skip_serializing_if = "Option::is_none")]
253    pub repetition_penalty: Option<f64>,
254
255    /// Random seed for reproducibility (`1` to `9999999999`).
256    #[serde(skip_serializing_if = "Option::is_none")]
257    pub seed: Option<u64>,
258
259    /// If `true`, streams the response incrementally using SSE.
260    #[serde(skip_serializing_if = "Option::is_none")]
261    pub stream: Option<bool>,
262
263    /// Controls the randomness of the output (`0` to `5`).
264    #[serde(skip_serializing_if = "Option::is_none")]
265    pub temperature: Option<f64>,
266
267    /// A list of tools available for the assistant.
268    #[serde(skip_serializing_if = "Option::is_none")]
269    pub tools: Option<Vec<AssistantTool>>,
270
271    /// Limits the AI to top `k` most probable words (`1` to `50`).
272    #[serde(skip_serializing_if = "Option::is_none")]
273    pub top_k: Option<u32>,
274
275    /// Adjusts creativity of responses (`0` to `2`).
276    #[serde(skip_serializing_if = "Option::is_none")]
277    pub top_p: Option<f64>,
278}
279
280/// Represents a single message in a conversation.
281#[derive(Clone, Debug, Serialize, Deserialize)]
282pub struct Message {
283    /// The content of the message.
284    pub content: String,
285
286    /// The role of the message sender (e.g., "user" or "assistant").
287    pub role: MessageRole,
288}
289
290impl Message {
291    pub fn system(content: String) -> Self {
292        Message {
293            content,
294            role: MessageRole::System,
295        }
296    }
297
298    pub fn user(content: String) -> Self {
299        Message {
300            content,
301            role: MessageRole::User,
302        }
303    }
304
305    pub fn assistant(content: String) -> Self {
306        Message {
307            content,
308            role: MessageRole::Assistant,
309        }
310    }
311}
312
313#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
314pub enum MessageRole {
315    #[serde(rename = "system")]
316    System,
317    #[serde(rename = "user")]
318    User,
319    #[serde(rename = "assistant")]
320    Assistant,
321}
322
323impl ToString for MessageRole {
324    fn to_string(&self) -> String {
325        match self {
326            MessageRole::System => "System".to_string(),
327            MessageRole::User => "User".to_string(),
328            MessageRole::Assistant => "Assistant".to_string(),
329        }
330    }
331}
332
333/// Represents a function or tool available for use by the assistant.
334#[derive(Clone, Debug, Serialize, Deserialize)]
335pub struct AssistantFunction {
336    /// The function code.
337    #[serde(skip_serializing_if = "Option::is_none")]
338    code: Option<String>,
339
340    /// The function name.
341    name: String,
342
343    /// The function parameters (if applicable).
344    #[serde(skip_serializing_if = "Option::is_none")]
345    parameters: Option<String>,
346}
347
348/// Represents a tool with additional details.
349#[derive(Clone, Debug, Serialize, Deserialize)]
350pub struct AssistantTool {
351    /// A description of the tool.
352    description: String,
353
354    /// The name of the tool.
355    name: String,
356
357    /// The parameters associated with the tool.
358    #[serde(skip_serializing_if = "Option::is_none")]
359    parameters: Option<String>,
360}
361
362/// Parameters for the `Translation` task.
363#[derive(Clone, Debug, Default, Serialize, Deserialize)]
364pub struct TranslationParams {
365    /// The target language code (e.g., `"es"` for Spanish).
366    pub target_lang: String,
367
368    /// The text to translate. Must be at least 1 character long.
369    pub text: String,
370
371    /// The source language code. Defaults to `"en"`.
372    #[serde(skip_serializing_if = "Option::is_none")]
373    pub source_lang: Option<String>,
374}
375
376/// Parameters for the `Summarization` task.
377#[derive(Clone, Debug, Default, Serialize, Deserialize)]
378pub struct SummarizationParams {
379    /// The text to summarize. Must be at least 1 character long.
380    pub input_text: String,
381
382    /// The maximum length of the generated summary in tokens.
383    #[serde(skip_serializing_if = "Option::is_none")]
384    pub max_length: Option<u32>,
385}
386
387/// Parameters for the `ImageToText` task.
388#[derive(Clone, Debug, Default, Serialize, Deserialize)]
389pub struct ImageToTextParams {
390    /// An array of integers representing the image data.
391    pub image: Vec<u8>,
392
393    /// The maximum number of tokens to generate in the response.
394    #[serde(skip_serializing_if = "Option::is_none")]
395    pub max_tokens: Option<u32>,
396
397    /// The input text prompt for the model.
398    #[serde(skip_serializing_if = "Option::is_none")]
399    pub prompt: Option<String>,
400
401    /// If `true`, bypasses chat templates and uses the model's raw format.
402    #[serde(skip_serializing_if = "Option::is_none")]
403    pub raw: Option<bool>,
404
405    /// Controls the randomness of the output; higher values produce more random results.
406    #[serde(skip_serializing_if = "Option::is_none")]
407    pub temperature: Option<f64>,
408}
409
410/// Enum representing various AI processing results, including text classification,
411/// text-to-image generation, audio generation, and more.
412#[derive(Clone, Debug, Deserialize, Serialize)]
413#[serde(untagged)]
414pub enum ExecuteModelResult {
415    /// Results of text classification, containing an array of classification results.
416    TextClassification(Vec<TextClassificationResult>),
417
418    /// The generated image in PNG format.
419    TextToImage(String),
420
421    /// The generated audio in MP3 format, base64-encoded.
422    Audio(AudioResult),
423
424    /// Text embeddings, containing a nested array of embedding values and their shape.
425    TextEmbeddings(TextEmbeddingsResult),
426
427    /// Results of automatic speech recognition.
428    AutomaticSpeechRecognition(AutomaticSpeechRecognitionResult),
429
430    /// Results of image classification, containing predicted categories and confidence scores.
431    ImageClassification(Vec<ImageClassificationResult>),
432
433    /// Results of object detection within an input image.
434    ObjectDetection(Vec<ObjectDetectionResult>),
435
436    /// Generated text response and tool calls from the model.
437    ResponseAndToolCallsResult(ResponseAndToolCallsResult),
438
439    /// Results of text translation into a target language.
440    Translation(TranslationResult),
441
442    /// Results of text summarization.
443    Summarization(SummarizationResult),
444
445    /// Generated description for an input image.
446    ImageToText(ImageToTextResult),
447}
448
449impl ApiResult for ExecuteModelResult {}
450
451/// Represents a single text classification result.
452#[derive(Clone, Debug, Deserialize, Serialize)]
453pub struct TextClassificationResult {
454    /// The classification label assigned to the text (e.g., `'POSITIVE'` or `'NEGATIVE'`).
455    pub label: String,
456
457    /// Confidence score indicating the likelihood of the label.
458    pub score: f64,
459}
460
461/// Represents the generated audio.
462#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
463pub struct AudioResult {
464    /// The generated audio in MP3 format, base64-encoded.
465    pub audio: String,
466}
467
468/// Represents text embeddings.
469///
470/// When the `ndarray` feature is enabled, the embeddings are automatically deserialized into an
471/// `ndarray::ArrayD<f64>`.
472#[derive(Clone, Debug, Deserialize, Serialize)]
473pub struct TextEmbeddingsResult {
474    #[cfg(feature = "ndarray")]
475    /// Embeddings of the requested text values.
476    pub data: ndarray::ArrayD<f64>,
477
478    #[cfg(not(feature = "ndarray"))]
479    /// Embeddings of the requested text values.
480    pub data: Vec<serde_json::Value>,
481
482    /// The shape of the embedding array.
483    pub shape: Vec<usize>,
484}
485
486/// Represents automatic speech recognition results.
487#[derive(Clone, Debug, Deserialize, Serialize)]
488pub struct AutomaticSpeechRecognitionResult {
489    /// The transcription of the audio.
490    pub text: String,
491
492    /// The transcription in VTT format.
493    #[serde(skip_serializing_if = "Option::is_none")]
494    pub vtt: Option<String>,
495
496    /// The word count of the transcription.
497    #[serde(skip_serializing_if = "Option::is_none")]
498    pub word_count: Option<usize>,
499
500    /// Array of words with timing information.
501    #[serde(default, skip_serializing_if = "Vec::is_empty")]
502    pub words: Vec<WordTiming>,
503}
504
505/// Represents timing information for words in an automatic speech recognition result.
506#[derive(Clone, Debug, Deserialize, Serialize)]
507pub struct WordTiming {
508    /// The start time of the word.
509    pub start: f64,
510
511    /// The end time of the word.
512    pub end: f64,
513
514    /// The word itself.
515    pub word: String,
516}
517
518/// Represents a single image classification result.
519#[derive(Clone, Debug, Deserialize, Serialize)]
520pub struct ImageClassificationResult {
521    /// The predicted category or class for the input image.
522    pub label: String,
523
524    /// Confidence score for the classification.
525    pub score: f64,
526}
527
528/// Represents a single object detection result.
529#[derive(Clone, Debug, Deserialize, Serialize)]
530pub struct ObjectDetectionResult {
531    /// The bounding box around the detected object.
532    #[serde(rename = "box")]
533    pub bounding_box: BoundingBox,
534
535    /// The class label or name of the detected object.
536    #[serde(skip_serializing_if = "Option::is_none")]
537    pub label: Option<String>,
538
539    /// Confidence score for the object detection.
540    pub score: f64,
541}
542
543/// Represents the bounding box coordinates for an object.
544#[derive(Clone, Debug, Deserialize, Serialize)]
545pub struct BoundingBox {
546    /// The minimum x-coordinate.
547    pub xmin: f64,
548
549    /// The maximum x-coordinate.
550    pub xmax: f64,
551
552    /// The minimum y-coordinate.
553    pub ymin: f64,
554
555    /// The maximum y-coordinate.
556    pub ymax: f64,
557}
558
559/// Represents a generated text response and tool calls from the model.
560#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
561pub struct ResponseAndToolCallsResult {
562    /// The generated text response.
563    pub response: String,
564
565    /// Array of tool call requests made during the response generation.
566    #[serde(default, skip_serializing_if = "Vec::is_empty")]
567    pub tool_calls: Vec<ToolCall>,
568    // TODO: Missing `usage` field
569}
570
571/// Represents a single tool call request during response generation.
572#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
573pub struct ToolCall {
574    /// The name of the tool.
575    pub name: String,
576
577    /// The arguments passed to the tool.
578    pub arguments: String,
579}
580
581/// Represents translation results.
582#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
583pub struct TranslationResult {
584    /// The translated text in the target language.
585    pub translated_text: String,
586}
587
588/// Represents summarization results.
589#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
590pub struct SummarizationResult {
591    /// The summarized text.
592    pub summary: String,
593}
594
595/// Represents a generated description for an input image.
596#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
597pub struct ImageToTextResult {
598    /// Generated description for an input image.
599    pub description: String,
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605
606    /// This tests the use-case showcased on the website's Workers AI beta.
607    #[test]
608    fn test_deserialize_response_and_tool_calls_result() {
609        let json = r#"
610        {"response":"\"A short story\""}
611        "#;
612
613        let response: ExecuteModelResult = serde_json::from_str(json).unwrap();
614        assert!(matches!(
615            response,
616            ExecuteModelResult::ResponseAndToolCallsResult(_)
617        ));
618    }
619}