use crate::{Result, types::Content};
use async_trait::async_trait;
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::pin::Pin;
pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmResponse>> + Send>>;
#[async_trait]
pub trait Llm: Send + Sync {
fn name(&self) -> &str;
async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmRequest {
pub model: String,
pub contents: Vec<Content>,
pub config: Option<GenerateContentConfig>,
#[serde(skip)]
pub tools: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GenerateContentConfig {
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub presence_penalty: Option<f32>,
pub max_output_tokens: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub seed: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub top_logprobs: Option<u8>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stop_sequences: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_schema: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub cached_content: Option<String>,
#[serde(default, skip_serializing_if = "serde_json::Map::is_empty")]
pub extensions: serde_json::Map<String, serde_json::Value>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct LlmResponse {
pub content: Option<Content>,
pub usage_metadata: Option<UsageMetadata>,
pub finish_reason: Option<FinishReason>,
#[serde(skip_serializing_if = "Option::is_none")]
pub citation_metadata: Option<CitationMetadata>,
pub partial: bool,
pub turn_complete: bool,
pub interrupted: bool,
pub error_code: Option<String>,
pub error_message: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub provider_metadata: Option<serde_json::Value>,
}
#[async_trait]
pub trait CacheCapable: Send + Sync {
async fn create_cache(
&self,
system_instruction: &str,
tools: &HashMap<String, serde_json::Value>,
ttl_seconds: u32,
) -> Result<String>;
async fn delete_cache(&self, name: &str) -> Result<()>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextCacheConfig {
pub min_tokens: u32,
pub ttl_seconds: u32,
pub cache_intervals: u32,
}
impl Default for ContextCacheConfig {
fn default() -> Self {
Self { min_tokens: 4096, ttl_seconds: 600, cache_intervals: 3 }
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct UsageMetadata {
pub prompt_token_count: i32,
pub candidates_token_count: i32,
pub total_token_count: i32,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub cache_read_input_token_count: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub cache_creation_input_token_count: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub thinking_token_count: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub audio_input_token_count: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub audio_output_token_count: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub cost: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub is_byok: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub provider_usage: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct CitationMetadata {
#[serde(default)]
pub citation_sources: Vec<CitationSource>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct CitationSource {
pub uri: Option<String>,
pub title: Option<String>,
pub start_index: Option<i32>,
pub end_index: Option<i32>,
pub license: Option<String>,
pub publication_date: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FinishReason {
Stop,
MaxTokens,
Safety,
Recitation,
Other,
}
impl LlmRequest {
pub fn new(model: impl Into<String>, contents: Vec<Content>) -> Self {
Self { model: model.into(), contents, config: None, tools: HashMap::new() }
}
pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
let config = self.config.get_or_insert(GenerateContentConfig::default());
config.response_schema = Some(schema);
self
}
pub fn with_config(mut self, config: GenerateContentConfig) -> Self {
self.config = Some(config);
self
}
}
impl LlmResponse {
pub fn new(content: Content) -> Self {
Self {
content: Some(content),
usage_metadata: None,
finish_reason: Some(FinishReason::Stop),
citation_metadata: None,
partial: false,
turn_complete: true,
interrupted: false,
error_code: None,
error_message: None,
provider_metadata: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llm_request_creation() {
let req = LlmRequest::new("test-model", vec![]);
assert_eq!(req.model, "test-model");
assert!(req.contents.is_empty());
}
#[test]
fn test_llm_request_with_response_schema() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"name": { "type": "string" }
}
});
let req = LlmRequest::new("test-model", vec![]).with_response_schema(schema.clone());
assert!(req.config.is_some());
let config = req.config.unwrap();
assert!(config.response_schema.is_some());
assert_eq!(config.response_schema.unwrap(), schema);
}
#[test]
fn test_llm_request_with_config() {
let config = GenerateContentConfig {
temperature: Some(0.7),
top_p: Some(0.9),
top_k: Some(40),
frequency_penalty: Some(0.2),
presence_penalty: Some(-0.3),
max_output_tokens: Some(1024),
seed: Some(42),
top_logprobs: Some(5),
stop_sequences: vec!["END".to_string()],
..Default::default()
};
let req = LlmRequest::new("test-model", vec![]).with_config(config);
assert!(req.config.is_some());
let config = req.config.unwrap();
assert_eq!(config.temperature, Some(0.7));
assert_eq!(config.max_output_tokens, Some(1024));
assert_eq!(config.frequency_penalty, Some(0.2));
assert_eq!(config.presence_penalty, Some(-0.3));
assert_eq!(config.seed, Some(42));
assert_eq!(config.top_logprobs, Some(5));
assert_eq!(config.stop_sequences, vec!["END"]);
}
#[test]
fn test_llm_response_creation() {
let content = Content::new("assistant");
let resp = LlmResponse::new(content);
assert!(resp.content.is_some());
assert!(resp.turn_complete);
assert!(!resp.partial);
assert_eq!(resp.finish_reason, Some(FinishReason::Stop));
assert!(resp.citation_metadata.is_none());
assert!(resp.provider_metadata.is_none());
}
#[test]
fn test_llm_response_deserialize_without_citations() {
let json = serde_json::json!({
"content": {
"role": "model",
"parts": [{"text": "hello"}]
},
"partial": false,
"turn_complete": true,
"interrupted": false
});
let response: LlmResponse = serde_json::from_value(json).expect("should deserialize");
assert!(response.citation_metadata.is_none());
}
#[test]
fn test_llm_response_roundtrip_with_citations() {
let response = LlmResponse {
content: Some(Content::new("model").with_text("hello")),
usage_metadata: None,
finish_reason: Some(FinishReason::Stop),
citation_metadata: Some(CitationMetadata {
citation_sources: vec![CitationSource {
uri: Some("https://example.com".to_string()),
title: Some("Example".to_string()),
start_index: Some(0),
end_index: Some(5),
license: None,
publication_date: Some("2026-01-01T00:00:00Z".to_string()),
}],
}),
partial: false,
turn_complete: true,
interrupted: false,
error_code: None,
error_message: None,
provider_metadata: None,
};
let encoded = serde_json::to_string(&response).expect("serialize");
let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
assert_eq!(decoded.citation_metadata, response.citation_metadata);
}
#[test]
fn test_generate_content_config_roundtrip_with_extensions() {
let mut extensions = serde_json::Map::new();
extensions.insert(
"openrouter".to_string(),
serde_json::json!({
"provider": {
"zdr": true,
"order": ["openai", "anthropic"]
},
"plugins": [
{ "id": "web", "enabled": true }
]
}),
);
let config = GenerateContentConfig {
temperature: Some(0.4),
top_p: Some(0.8),
top_k: Some(12),
frequency_penalty: Some(0.1),
presence_penalty: Some(0.2),
max_output_tokens: Some(512),
seed: Some(7),
top_logprobs: Some(3),
stop_sequences: vec!["STOP".to_string(), "DONE".to_string()],
response_schema: Some(serde_json::json!({
"type": "object",
"properties": { "answer": { "type": "string" } },
"required": ["answer"]
})),
cached_content: Some("cachedContents/abc123".to_string()),
extensions,
};
let encoded = serde_json::to_string(&config).expect("serialize");
let decoded: GenerateContentConfig = serde_json::from_str(&encoded).expect("deserialize");
assert_eq!(decoded.temperature, config.temperature);
assert_eq!(decoded.top_p, config.top_p);
assert_eq!(decoded.top_k, config.top_k);
assert_eq!(decoded.frequency_penalty, config.frequency_penalty);
assert_eq!(decoded.presence_penalty, config.presence_penalty);
assert_eq!(decoded.max_output_tokens, config.max_output_tokens);
assert_eq!(decoded.seed, config.seed);
assert_eq!(decoded.top_logprobs, config.top_logprobs);
assert_eq!(decoded.stop_sequences, config.stop_sequences);
assert_eq!(decoded.response_schema, config.response_schema);
assert_eq!(decoded.cached_content, config.cached_content);
assert_eq!(decoded.extensions, config.extensions);
}
#[test]
fn test_llm_response_and_usage_roundtrip_with_provider_metadata() {
let response = LlmResponse {
content: Some(Content::new("model").with_text("hello")),
usage_metadata: Some(UsageMetadata {
prompt_token_count: 10,
candidates_token_count: 20,
total_token_count: 30,
cache_read_input_token_count: Some(5),
cache_creation_input_token_count: Some(2),
thinking_token_count: Some(3),
audio_input_token_count: Some(4),
audio_output_token_count: Some(6),
cost: Some(0.0125),
is_byok: Some(true),
provider_usage: Some(serde_json::json!({
"server_tool_use": {
"web_search_requests": 1
},
"prompt_tokens_details": {
"video_tokens": 8
}
})),
}),
finish_reason: Some(FinishReason::Stop),
citation_metadata: None,
partial: false,
turn_complete: true,
interrupted: false,
error_code: None,
error_message: None,
provider_metadata: Some(serde_json::json!({
"openrouter": {
"responseId": "resp_123",
"outputItems": 2
}
})),
};
let encoded = serde_json::to_string(&response).expect("serialize");
let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
assert_eq!(decoded.provider_metadata, response.provider_metadata);
assert_eq!(
decoded.usage_metadata.as_ref().and_then(|u| u.cost),
response.usage_metadata.as_ref().and_then(|u| u.cost),
);
assert_eq!(
decoded.usage_metadata.as_ref().and_then(|u| u.is_byok),
response.usage_metadata.as_ref().and_then(|u| u.is_byok),
);
assert_eq!(
decoded.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
response.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
);
}
#[test]
fn test_finish_reason() {
assert_eq!(FinishReason::Stop, FinishReason::Stop);
assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
}
}