use serde::{Deserialize, Serialize};
use crate::framework::endpoint::RequestBody;
use crate::framework::response::ApiSuccess;
use crate::framework::{
endpoint::{EndpointSpec, Method},
response::ApiResult,
};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExecuteModel<'a> {
pub account_identifier: &'a str,
pub model_name: &'a str,
pub params: ExecuteModelParams,
}
impl EndpointSpec for ExecuteModel<'_> {
type JsonResponse = ExecuteModelResult;
type ResponseType = ApiSuccess<Self::JsonResponse>;
fn method(&self) -> Method {
Method::POST
}
fn path(&self) -> String {
format!(
"accounts/{}/ai/run/{}",
self.account_identifier, self.model_name
)
}
#[inline]
fn body(&self) -> Option<RequestBody> {
let body = serde_json::to_string(&self.params).unwrap();
Some(RequestBody::Json(body))
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ExecuteModelParams {
TextClassification {
text: String,
},
TextToImage(TextToImageParams),
TextToSpeech(TextToSpeechParams),
TextEmbeddings {
text: Vec<String>,
},
AutomaticSpeechRecognition(AutomaticSpeechRecognitionParams),
ImageClassification {
image: Vec<u8>,
},
ObjectDetection {
image: Vec<u8>,
},
Prompt(PromptParams),
Messages(MessagesParams),
Translation(TranslationParams),
Summarization(SummarizationParams),
ImageToText(ImageToTextParams),
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct TextToImageParams {
pub prompt: String,
pub guidance: Option<f64>,
pub height: Option<u32>,
pub image: Option<Vec<u8>>,
pub image_b64: Option<String>,
pub mask: Option<Vec<u8>>,
pub negative_prompt: Option<String>,
pub num_steps: Option<u32>,
pub seed: Option<u64>,
pub strength: Option<f64>,
pub width: Option<u32>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct TextToSpeechParams {
pub prompt: String,
pub lang: Option<String>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct AutomaticSpeechRecognitionParams {
pub audio: Vec<u8>,
pub source_lang: Option<String>,
pub target_lang: Option<String>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct PromptParams {
pub prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lora: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub raw: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct MessagesParams {
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub functions: Option<Vec<AssistantFunction>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<AssistantTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Message {
pub content: String,
pub role: MessageRole,
}
impl Message {
pub fn system(content: String) -> Self {
Message {
content,
role: MessageRole::System,
}
}
pub fn user(content: String) -> Self {
Message {
content,
role: MessageRole::User,
}
}
pub fn assistant(content: String) -> Self {
Message {
content,
role: MessageRole::Assistant,
}
}
}
#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
pub enum MessageRole {
#[serde(rename = "system")]
System,
#[serde(rename = "user")]
User,
#[serde(rename = "assistant")]
Assistant,
}
impl ToString for MessageRole {
fn to_string(&self) -> String {
match self {
MessageRole::System => "System".to_string(),
MessageRole::User => "User".to_string(),
MessageRole::Assistant => "Assistant".to_string(),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AssistantFunction {
#[serde(skip_serializing_if = "Option::is_none")]
code: Option<String>,
name: String,
#[serde(skip_serializing_if = "Option::is_none")]
parameters: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AssistantTool {
description: String,
name: String,
#[serde(skip_serializing_if = "Option::is_none")]
parameters: Option<String>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct TranslationParams {
pub target_lang: String,
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub source_lang: Option<String>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct SummarizationParams {
pub input_text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_length: Option<u32>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct ImageToTextParams {
pub image: Vec<u8>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub raw: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ExecuteModelResult {
TextClassification(Vec<TextClassificationResult>),
TextToImage(String),
Audio(AudioResult),
TextEmbeddings(TextEmbeddingsResult),
AutomaticSpeechRecognition(AutomaticSpeechRecognitionResult),
ImageClassification(Vec<ImageClassificationResult>),
ObjectDetection(Vec<ObjectDetectionResult>),
ResponseAndToolCallsResult(ResponseAndToolCallsResult),
Translation(TranslationResult),
Summarization(SummarizationResult),
ImageToText(ImageToTextResult),
}
impl ApiResult for ExecuteModelResult {}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct TextClassificationResult {
pub label: String,
pub score: f64,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
pub struct AudioResult {
pub audio: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct TextEmbeddingsResult {
#[cfg(feature = "ndarray")]
pub data: ndarray::ArrayD<f64>,
#[cfg(not(feature = "ndarray"))]
pub data: Vec<serde_json::Value>,
pub shape: Vec<usize>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct AutomaticSpeechRecognitionResult {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub vtt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub word_count: Option<usize>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub words: Vec<WordTiming>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct WordTiming {
pub start: f64,
pub end: f64,
pub word: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ImageClassificationResult {
pub label: String,
pub score: f64,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ObjectDetectionResult {
#[serde(rename = "box")]
pub bounding_box: BoundingBox,
#[serde(skip_serializing_if = "Option::is_none")]
pub label: Option<String>,
pub score: f64,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct BoundingBox {
pub xmin: f64,
pub xmax: f64,
pub ymin: f64,
pub ymax: f64,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
pub struct ResponseAndToolCallsResult {
pub response: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<ToolCall>,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
pub struct ToolCall {
pub name: String,
pub arguments: String,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
pub struct TranslationResult {
pub translated_text: String,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
pub struct SummarizationResult {
pub summary: String,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
pub struct ImageToTextResult {
pub description: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deserialize_response_and_tool_calls_result() {
let json = r#"
{"response":"\"A short story\""}
"#;
let response: ExecuteModelResult = serde_json::from_str(json).unwrap();
assert!(matches!(
response,
ExecuteModelResult::ResponseAndToolCallsResult(_)
));
}
}