use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: MessageContent,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[non_exhaustive]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
#[non_exhaustive]
pub enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[non_exhaustive]
pub enum ContentPart {
Text {
text: String,
},
ImageBase64 {
mime: String,
data: String,
},
ImageUrl {
url: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct SamplingParams {
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub max_tokens: Option<u32>,
pub stop: Vec<String>,
pub presence_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
pub seed: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecuteBatch {
pub request_id: String,
pub model: String,
pub messages: Vec<Message>,
pub sampling: SamplingParams,
pub stream: bool,
pub estimated_tokens: u32,
}
impl ExecuteBatch {
pub fn estimated_tokens(&self) -> u32 {
self.estimated_tokens.max(1)
}
}