1use serde::{Deserialize, Serialize};
2
3use crate::framework::endpoint::RequestBody;
4use crate::framework::response::ApiSuccess;
5use crate::framework::{
6 endpoint::{EndpointSpec, Method},
7 response::ApiResult,
8};
9
10#[derive(Clone, Debug, Serialize, Deserialize)]
12pub struct ExecuteModel<'a> {
13 pub account_identifier: &'a str,
14 pub model_name: &'a str,
15 pub params: ExecuteModelParams,
16}
17
18impl EndpointSpec for ExecuteModel<'_> {
19 type JsonResponse = ExecuteModelResult;
20 type ResponseType = ApiSuccess<Self::JsonResponse>;
21
22 fn method(&self) -> Method {
23 Method::POST
24 }
25
26 fn path(&self) -> String {
27 format!(
28 "accounts/{}/ai/run/{}",
29 self.account_identifier, self.model_name
30 )
31 }
32
33 #[inline]
34 fn body(&self) -> Option<RequestBody> {
35 let body = serde_json::to_string(&self.params).unwrap();
36 Some(RequestBody::Json(body))
37 }
38}
39
40#[derive(Clone, Debug, Serialize, Deserialize)]
42#[serde(untagged)]
43pub enum ExecuteModelParams {
44 TextClassification {
48 text: String,
51 },
52
53 TextToImage(TextToImageParams),
57
58 TextToSpeech(TextToSpeechParams),
62
63 TextEmbeddings {
67 text: Vec<String>,
69 },
70
71 AutomaticSpeechRecognition(AutomaticSpeechRecognitionParams),
75
76 ImageClassification {
80 image: Vec<u8>,
82 },
83
84 ObjectDetection {
88 image: Vec<u8>,
90 },
91
92 Prompt(PromptParams),
96
97 Messages(MessagesParams),
101
102 Translation(TranslationParams),
105
106 Summarization(SummarizationParams),
109
110 ImageToText(ImageToTextParams),
113}
114
115#[derive(Clone, Debug, Default, Serialize, Deserialize)]
117pub struct TextToImageParams {
118 pub prompt: String,
121
122 pub guidance: Option<f64>,
124
125 pub height: Option<u32>,
127
128 pub image: Option<Vec<u8>>,
130
131 pub image_b64: Option<String>,
133
134 pub mask: Option<Vec<u8>>,
136
137 pub negative_prompt: Option<String>,
139
140 pub num_steps: Option<u32>,
142
143 pub seed: Option<u64>,
145
146 pub strength: Option<f64>,
148
149 pub width: Option<u32>,
151}
152
153#[derive(Clone, Debug, Default, Serialize, Deserialize)]
155pub struct TextToSpeechParams {
156 pub prompt: String,
159
160 pub lang: Option<String>,
162}
163
164#[derive(Clone, Debug, Default, Serialize, Deserialize)]
166pub struct AutomaticSpeechRecognitionParams {
167 pub audio: Vec<u8>,
169
170 pub source_lang: Option<String>,
172
173 pub target_lang: Option<String>,
175}
176
177#[derive(Clone, Debug, Default, Serialize, Deserialize)]
179pub struct PromptParams {
180 pub prompt: String,
183
184 #[serde(skip_serializing_if = "Option::is_none")]
186 pub frequency_penalty: Option<f64>,
187
188 #[serde(skip_serializing_if = "Option::is_none")]
190 pub lora: Option<String>,
191
192 #[serde(skip_serializing_if = "Option::is_none")]
194 pub max_tokens: Option<u32>,
195
196 #[serde(skip_serializing_if = "Option::is_none")]
198 pub presence_penalty: Option<f64>,
199
200 #[serde(skip_serializing_if = "Option::is_none")]
202 pub raw: Option<bool>,
203
204 #[serde(skip_serializing_if = "Option::is_none")]
206 pub repetition_penalty: Option<f64>,
207
208 #[serde(skip_serializing_if = "Option::is_none")]
210 pub seed: Option<u64>,
211
212 #[serde(skip_serializing_if = "Option::is_none")]
214 pub stream: Option<bool>,
215
216 #[serde(skip_serializing_if = "Option::is_none")]
218 pub temperature: Option<f64>,
219
220 #[serde(skip_serializing_if = "Option::is_none")]
222 pub top_k: Option<u32>,
223
224 #[serde(skip_serializing_if = "Option::is_none")]
226 pub top_p: Option<f64>,
227}
228
229#[derive(Clone, Debug, Default, Serialize, Deserialize)]
231pub struct MessagesParams {
232 pub messages: Vec<Message>,
234
235 #[serde(skip_serializing_if = "Option::is_none")]
237 pub frequency_penalty: Option<f64>,
238
239 #[serde(skip_serializing_if = "Option::is_none")]
241 pub functions: Option<Vec<AssistantFunction>>,
242
243 #[serde(skip_serializing_if = "Option::is_none")]
245 pub max_tokens: Option<u32>,
246
247 #[serde(skip_serializing_if = "Option::is_none")]
249 pub presence_penalty: Option<f64>,
250
251 #[serde(skip_serializing_if = "Option::is_none")]
253 pub repetition_penalty: Option<f64>,
254
255 #[serde(skip_serializing_if = "Option::is_none")]
257 pub seed: Option<u64>,
258
259 #[serde(skip_serializing_if = "Option::is_none")]
261 pub stream: Option<bool>,
262
263 #[serde(skip_serializing_if = "Option::is_none")]
265 pub temperature: Option<f64>,
266
267 #[serde(skip_serializing_if = "Option::is_none")]
269 pub tools: Option<Vec<AssistantTool>>,
270
271 #[serde(skip_serializing_if = "Option::is_none")]
273 pub top_k: Option<u32>,
274
275 #[serde(skip_serializing_if = "Option::is_none")]
277 pub top_p: Option<f64>,
278}
279
280#[derive(Clone, Debug, Serialize, Deserialize)]
282pub struct Message {
283 pub content: String,
285
286 pub role: MessageRole,
288}
289
290impl Message {
291 pub fn system(content: String) -> Self {
292 Message {
293 content,
294 role: MessageRole::System,
295 }
296 }
297
298 pub fn user(content: String) -> Self {
299 Message {
300 content,
301 role: MessageRole::User,
302 }
303 }
304
305 pub fn assistant(content: String) -> Self {
306 Message {
307 content,
308 role: MessageRole::Assistant,
309 }
310 }
311}
312
313#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
314pub enum MessageRole {
315 #[serde(rename = "system")]
316 System,
317 #[serde(rename = "user")]
318 User,
319 #[serde(rename = "assistant")]
320 Assistant,
321}
322
323impl ToString for MessageRole {
324 fn to_string(&self) -> String {
325 match self {
326 MessageRole::System => "System".to_string(),
327 MessageRole::User => "User".to_string(),
328 MessageRole::Assistant => "Assistant".to_string(),
329 }
330 }
331}
332
333#[derive(Clone, Debug, Serialize, Deserialize)]
335pub struct AssistantFunction {
336 #[serde(skip_serializing_if = "Option::is_none")]
338 code: Option<String>,
339
340 name: String,
342
343 #[serde(skip_serializing_if = "Option::is_none")]
345 parameters: Option<String>,
346}
347
348#[derive(Clone, Debug, Serialize, Deserialize)]
350pub struct AssistantTool {
351 description: String,
353
354 name: String,
356
357 #[serde(skip_serializing_if = "Option::is_none")]
359 parameters: Option<String>,
360}
361
362#[derive(Clone, Debug, Default, Serialize, Deserialize)]
364pub struct TranslationParams {
365 pub target_lang: String,
367
368 pub text: String,
370
371 #[serde(skip_serializing_if = "Option::is_none")]
373 pub source_lang: Option<String>,
374}
375
376#[derive(Clone, Debug, Default, Serialize, Deserialize)]
378pub struct SummarizationParams {
379 pub input_text: String,
381
382 #[serde(skip_serializing_if = "Option::is_none")]
384 pub max_length: Option<u32>,
385}
386
387#[derive(Clone, Debug, Default, Serialize, Deserialize)]
389pub struct ImageToTextParams {
390 pub image: Vec<u8>,
392
393 #[serde(skip_serializing_if = "Option::is_none")]
395 pub max_tokens: Option<u32>,
396
397 #[serde(skip_serializing_if = "Option::is_none")]
399 pub prompt: Option<String>,
400
401 #[serde(skip_serializing_if = "Option::is_none")]
403 pub raw: Option<bool>,
404
405 #[serde(skip_serializing_if = "Option::is_none")]
407 pub temperature: Option<f64>,
408}
409
410#[derive(Clone, Debug, Deserialize, Serialize)]
413#[serde(untagged)]
414pub enum ExecuteModelResult {
415 TextClassification(Vec<TextClassificationResult>),
417
418 TextToImage(String),
420
421 Audio(AudioResult),
423
424 TextEmbeddings(TextEmbeddingsResult),
426
427 AutomaticSpeechRecognition(AutomaticSpeechRecognitionResult),
429
430 ImageClassification(Vec<ImageClassificationResult>),
432
433 ObjectDetection(Vec<ObjectDetectionResult>),
435
436 ResponseAndToolCallsResult(ResponseAndToolCallsResult),
438
439 Translation(TranslationResult),
441
442 Summarization(SummarizationResult),
444
445 ImageToText(ImageToTextResult),
447}
448
449impl ApiResult for ExecuteModelResult {}
450
451#[derive(Clone, Debug, Deserialize, Serialize)]
453pub struct TextClassificationResult {
454 pub label: String,
456
457 pub score: f64,
459}
460
461#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
463pub struct AudioResult {
464 pub audio: String,
466}
467
468#[derive(Clone, Debug, Deserialize, Serialize)]
473pub struct TextEmbeddingsResult {
474 #[cfg(feature = "ndarray")]
475 pub data: ndarray::ArrayD<f64>,
477
478 #[cfg(not(feature = "ndarray"))]
479 pub data: Vec<serde_json::Value>,
481
482 pub shape: Vec<usize>,
484}
485
486#[derive(Clone, Debug, Deserialize, Serialize)]
488pub struct AutomaticSpeechRecognitionResult {
489 pub text: String,
491
492 #[serde(skip_serializing_if = "Option::is_none")]
494 pub vtt: Option<String>,
495
496 #[serde(skip_serializing_if = "Option::is_none")]
498 pub word_count: Option<usize>,
499
500 #[serde(default, skip_serializing_if = "Vec::is_empty")]
502 pub words: Vec<WordTiming>,
503}
504
505#[derive(Clone, Debug, Deserialize, Serialize)]
507pub struct WordTiming {
508 pub start: f64,
510
511 pub end: f64,
513
514 pub word: String,
516}
517
518#[derive(Clone, Debug, Deserialize, Serialize)]
520pub struct ImageClassificationResult {
521 pub label: String,
523
524 pub score: f64,
526}
527
528#[derive(Clone, Debug, Deserialize, Serialize)]
530pub struct ObjectDetectionResult {
531 #[serde(rename = "box")]
533 pub bounding_box: BoundingBox,
534
535 #[serde(skip_serializing_if = "Option::is_none")]
537 pub label: Option<String>,
538
539 pub score: f64,
541}
542
543#[derive(Clone, Debug, Deserialize, Serialize)]
545pub struct BoundingBox {
546 pub xmin: f64,
548
549 pub xmax: f64,
551
552 pub ymin: f64,
554
555 pub ymax: f64,
557}
558
559#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
561pub struct ResponseAndToolCallsResult {
562 pub response: String,
564
565 #[serde(default, skip_serializing_if = "Vec::is_empty")]
567 pub tool_calls: Vec<ToolCall>,
568 }
570
571#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
573pub struct ToolCall {
574 pub name: String,
576
577 pub arguments: String,
579}
580
581#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
583pub struct TranslationResult {
584 pub translated_text: String,
586}
587
588#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
590pub struct SummarizationResult {
591 pub summary: String,
593}
594
595#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
597pub struct ImageToTextResult {
598 pub description: String,
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 #[test]
608 fn test_deserialize_response_and_tool_calls_result() {
609 let json = r#"
610 {"response":"\"A short story\""}
611 "#;
612
613 let response: ExecuteModelResult = serde_json::from_str(json).unwrap();
614 assert!(matches!(
615 response,
616 ExecuteModelResult::ResponseAndToolCallsResult(_)
617 ));
618 }
619}