use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
use super::{
CapabilitySet, GenerationConfig, Message, MessageRole, ToolResultPart, VendorExtensions,
};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct LlmRequest {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub input: Vec<RequestItem>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub messages: Vec<Message>,
#[serde(default, skip_serializing_if = "CapabilitySet::is_empty")]
pub capabilities: CapabilitySet,
#[serde(default, skip_serializing_if = "GenerationConfig::is_default")]
pub generation: GenerationConfig,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub metadata: VendorExtensions,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub vendor_extensions: VendorExtensions,
}
impl LlmRequest {
pub fn normalized_input(&self) -> Vec<RequestItem> {
if !self.input.is_empty() {
return self.input.clone();
}
self.messages
.iter()
.cloned()
.map(RequestItem::from)
.collect()
}
pub fn normalized_messages(&self) -> Vec<Message> {
if !self.messages.is_empty() {
return self.messages.clone();
}
self.input
.iter()
.filter_map(RequestItem::as_message)
.cloned()
.collect()
}
pub fn normalized_instructions(&self) -> Option<String> {
if self.instructions.is_some() {
return self.instructions.clone();
}
let folded = self
.normalized_messages()
.into_iter()
.filter(|message| matches!(message.role, MessageRole::System | MessageRole::Developer))
.map(|message| message.plain_text())
.filter(|text| !text.is_empty())
.collect::<Vec<_>>()
.join("\n\n");
if folded.is_empty() {
None
} else {
Some(folded)
}
}
pub fn estimated_prompt_tokens(&self) -> u32 {
let mut chars = self
.normalized_input()
.iter()
.map(RequestItem::estimated_chars)
.sum::<usize>();
if let Some(instructions) = self.normalized_instructions() {
chars += instructions.len();
}
(chars / 4).max(1) as u32
}
pub fn estimated_tokens(&self) -> u32 {
self.estimated_prompt_tokens() + self.generation.max_output_tokens.unwrap_or(1024)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum RequestItem {
Message { message: Message },
ToolResult {
#[serde(flatten)]
result: ToolResultPart,
},
}
impl RequestItem {
fn estimated_chars(&self) -> usize {
match self {
Self::Message { message } => message.estimated_chars(),
Self::ToolResult { result } => result.output.to_string().len(),
}
}
pub(crate) fn as_message(&self) -> Option<&Message> {
match self {
Self::Message { message } => Some(message),
Self::ToolResult { .. } => None,
}
}
}
impl From<Message> for RequestItem {
fn from(message: Message) -> Self {
Self::Message { message }
}
}