use serde::{Deserialize, Serialize};
use katu_core::{FinishReason, Message, ToolChoice, ToolDefinition, Usage};
use crate::cache::CachePolicy;
use katu_core::GenerationOptions;
use crate::http::HttpOptions;
use crate::model::ModelRef;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmRequest {
pub model: ModelRef,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<ToolDefinition>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub generation: Option<GenerationOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache: Option<CachePolicy>,
#[serde(skip_serializing_if = "Option::is_none")]
pub http: Option<HttpOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub provider_options: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
impl LlmRequest {
pub fn new(model: ModelRef) -> Self {
Self {
model,
system: None,
messages: Vec::new(),
tools: Vec::new(),
tool_choice: None,
generation: None,
cache: None,
http: None,
provider_options: None,
metadata: None,
}
}
pub fn with_system(mut self, system: impl Into<String>) -> Self {
self.system = Some(system.into());
self
}
pub fn with_message(mut self, message: Message) -> Self {
self.messages.push(message);
self
}
pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
self.messages = messages;
self
}
pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
self.tools = tools;
self
}
pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
self.tool_choice = Some(choice);
self
}
pub fn with_generation(mut self, generation: GenerationOptions) -> Self {
self.generation = Some(generation);
self
}
pub fn with_cache(mut self, cache: CachePolicy) -> Self {
self.cache = Some(cache);
self
}
pub fn with_http(mut self, http: HttpOptions) -> Self {
self.http = Some(http);
self
}
pub fn with_provider_options(mut self, options: serde_json::Value) -> Self {
self.provider_options = Some(options);
self
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = Some(metadata);
self
}
pub fn resolved_generation(&self) -> GenerationOptions {
match (&self.model.generation, &self.generation) {
(Some(model_gen), Some(req_gen)) => model_gen.merge(req_gen),
(Some(model_gen), None) => model_gen.clone(),
(None, Some(req_gen)) => req_gen.clone(),
(None, None) => GenerationOptions::default(),
}
}
pub fn resolved_http(&self) -> Option<HttpOptions> {
match (&self.model.http, &self.http) {
(Some(model_http), Some(req_http)) => Some(model_http.merge(req_http)),
(Some(h), None) | (None, Some(h)) => Some(h.clone()),
(None, None) => None,
}
}
pub fn resolved_cache(&self) -> CachePolicy {
self.cache
.clone()
.or_else(|| self.model.cache_policy.clone())
.unwrap_or_default()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmResponse {
pub message: Message,
pub finish_reason: FinishReason,
pub usage: Usage,
}
#[cfg(test)]
mod tests {
use super::*;
use katu_core::{ModelId, ProviderId, RouteId};
use crate::model::ModelLimits;
fn sample_model() -> ModelRef {
ModelRef::new(
ModelId::new("gpt-4o"),
ProviderId::new("openai"),
RouteId::new("openai-chat"),
"https://api.openai.com/v1",
ModelLimits {
context_window: 128_000,
max_output_tokens: 4096,
},
)
}
#[test]
fn test_new_is_minimal() {
let req = LlmRequest::new(sample_model());
assert!(req.system.is_none());
assert!(req.messages.is_empty());
assert!(req.tools.is_empty());
assert!(req.tool_choice.is_none());
assert!(req.generation.is_none());
}
#[test]
fn test_builder_chain() {
let req = LlmRequest::new(sample_model())
.with_system("You are helpful.")
.with_message(Message::user("Hi"))
.with_generation(GenerationOptions::new().with_max_tokens(1024));
assert_eq!(req.system.as_deref(), Some("You are helpful."));
assert_eq!(req.messages.len(), 1);
assert_eq!(req.generation.as_ref().unwrap().max_tokens, Some(1024));
}
#[test]
fn test_resolved_generation_request_overrides_model() {
let model = sample_model()
.with_generation(GenerationOptions::new().with_max_tokens(2048).with_temperature(0.5));
let req = LlmRequest::new(model)
.with_generation(GenerationOptions::new().with_max_tokens(4096));
let resolved = req.resolved_generation();
assert_eq!(resolved.max_tokens, Some(4096)); assert_eq!(resolved.temperature, Some(0.5)); }
#[test]
fn test_resolved_generation_model_only() {
let model =
sample_model().with_generation(GenerationOptions::new().with_temperature(0.7));
let req = LlmRequest::new(model);
let resolved = req.resolved_generation();
assert_eq!(resolved.temperature, Some(0.7));
assert_eq!(resolved.max_tokens, None);
}
#[test]
fn test_resolved_generation_neither() {
let req = LlmRequest::new(sample_model());
let resolved = req.resolved_generation();
assert_eq!(resolved, GenerationOptions::default());
}
#[test]
fn test_resolved_cache_defaults_to_auto() {
let req = LlmRequest::new(sample_model());
assert_eq!(req.resolved_cache(), CachePolicy::Auto);
}
#[test]
fn test_resolved_cache_request_overrides_model() {
let model = sample_model().with_cache_policy(CachePolicy::Auto);
let req = LlmRequest::new(model).with_cache(CachePolicy::None);
assert_eq!(req.resolved_cache(), CachePolicy::None);
}
#[test]
fn test_resolved_http_merge() {
let model = sample_model().with_http(
HttpOptions::new()
.with_header("x-base", "1")
.with_query_param("v", "1"),
);
let req = LlmRequest::new(model)
.with_http(HttpOptions::new().with_header("x-req", "2"));
let resolved = req.resolved_http().unwrap();
let headers = resolved.headers.unwrap();
assert_eq!(headers.get("x-base").unwrap(), "1");
assert_eq!(headers.get("x-req").unwrap(), "2");
}
#[test]
fn test_llm_response() {
let response = LlmResponse {
message: Message::assistant("Hello!"),
finish_reason: FinishReason::Stop,
usage: Usage::default(),
};
assert_eq!(response.finish_reason, FinishReason::Stop);
}
}