1use crate::{Result, types::Content};
2use async_trait::async_trait;
3use futures::stream::Stream;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::pin::Pin;
7
8pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmResponse>> + Send>>;
9
10#[async_trait]
11pub trait Llm: Send + Sync {
12 fn name(&self) -> &str;
13 async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream>;
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct LlmRequest {
18 pub model: String,
19 pub contents: Vec<Content>,
20 pub config: Option<GenerateContentConfig>,
21 #[serde(skip)]
22 pub tools: HashMap<String, serde_json::Value>,
23}
24
25#[derive(Debug, Clone, Default, Serialize, Deserialize)]
26pub struct GenerateContentConfig {
27 pub temperature: Option<f32>,
28 pub top_p: Option<f32>,
29 pub top_k: Option<i32>,
30 #[serde(skip_serializing_if = "Option::is_none", default)]
31 pub frequency_penalty: Option<f32>,
32 #[serde(skip_serializing_if = "Option::is_none", default)]
33 pub presence_penalty: Option<f32>,
34 pub max_output_tokens: Option<i32>,
35 #[serde(skip_serializing_if = "Option::is_none", default)]
36 pub seed: Option<i64>,
37 #[serde(skip_serializing_if = "Option::is_none", default)]
38 pub top_logprobs: Option<u8>,
39 #[serde(default, skip_serializing_if = "Vec::is_empty")]
40 pub stop_sequences: Vec<String>,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub response_schema: Option<serde_json::Value>,
43
44 #[serde(skip_serializing_if = "Option::is_none", default)]
47 pub cached_content: Option<String>,
48
49 #[serde(default, skip_serializing_if = "serde_json::Map::is_empty")]
51 pub extensions: serde_json::Map<String, serde_json::Value>,
52}
53
54#[derive(Debug, Clone, Default, Serialize, Deserialize)]
55pub struct LlmResponse {
56 pub content: Option<Content>,
57 pub usage_metadata: Option<UsageMetadata>,
58 pub finish_reason: Option<FinishReason>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub citation_metadata: Option<CitationMetadata>,
61 pub partial: bool,
62 pub turn_complete: bool,
63 pub interrupted: bool,
64 pub error_code: Option<String>,
65 pub error_message: Option<String>,
66 #[serde(skip_serializing_if = "Option::is_none", default)]
67 pub provider_metadata: Option<serde_json::Value>,
68}
69
70#[async_trait]
90pub trait CacheCapable: Send + Sync {
91 async fn create_cache(
98 &self,
99 system_instruction: &str,
100 tools: &HashMap<String, serde_json::Value>,
101 ttl_seconds: u32,
102 ) -> Result<String>;
103
104 async fn delete_cache(&self, name: &str) -> Result<()>;
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct ContextCacheConfig {
127 pub min_tokens: u32,
130
131 pub ttl_seconds: u32,
134
135 pub cache_intervals: u32,
139}
140
141impl Default for ContextCacheConfig {
142 fn default() -> Self {
143 Self { min_tokens: 4096, ttl_seconds: 600, cache_intervals: 3 }
144 }
145}
146
147#[derive(Debug, Clone, Default, Serialize, Deserialize)]
148pub struct UsageMetadata {
149 pub prompt_token_count: i32,
150 pub candidates_token_count: i32,
151 pub total_token_count: i32,
152
153 #[serde(skip_serializing_if = "Option::is_none", default)]
154 pub cache_read_input_token_count: Option<i32>,
155
156 #[serde(skip_serializing_if = "Option::is_none", default)]
157 pub cache_creation_input_token_count: Option<i32>,
158
159 #[serde(skip_serializing_if = "Option::is_none", default)]
160 pub thinking_token_count: Option<i32>,
161
162 #[serde(skip_serializing_if = "Option::is_none", default)]
163 pub audio_input_token_count: Option<i32>,
164
165 #[serde(skip_serializing_if = "Option::is_none", default)]
166 pub audio_output_token_count: Option<i32>,
167
168 #[serde(skip_serializing_if = "Option::is_none", default)]
169 pub cost: Option<f64>,
170
171 #[serde(skip_serializing_if = "Option::is_none", default)]
172 pub is_byok: Option<bool>,
173
174 #[serde(skip_serializing_if = "Option::is_none", default)]
175 pub provider_usage: Option<serde_json::Value>,
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
180#[serde(rename_all = "camelCase")]
181pub struct CitationMetadata {
182 #[serde(default)]
183 pub citation_sources: Vec<CitationSource>,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
188#[serde(rename_all = "camelCase")]
189pub struct CitationSource {
190 pub uri: Option<String>,
191 pub title: Option<String>,
192 pub start_index: Option<i32>,
193 pub end_index: Option<i32>,
194 pub license: Option<String>,
195 pub publication_date: Option<String>,
196}
197
198#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
199pub enum FinishReason {
200 Stop,
201 MaxTokens,
202 Safety,
203 Recitation,
204 Other,
205}
206
207impl LlmRequest {
208 pub fn new(model: impl Into<String>, contents: Vec<Content>) -> Self {
209 Self { model: model.into(), contents, config: None, tools: HashMap::new() }
210 }
211
212 pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
214 let config = self.config.get_or_insert(GenerateContentConfig::default());
215 config.response_schema = Some(schema);
216 self
217 }
218
219 pub fn with_config(mut self, config: GenerateContentConfig) -> Self {
221 self.config = Some(config);
222 self
223 }
224}
225
226impl LlmResponse {
227 pub fn new(content: Content) -> Self {
228 Self {
229 content: Some(content),
230 usage_metadata: None,
231 finish_reason: Some(FinishReason::Stop),
232 citation_metadata: None,
233 partial: false,
234 turn_complete: true,
235 interrupted: false,
236 error_code: None,
237 error_message: None,
238 provider_metadata: None,
239 }
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_llm_request_creation() {
249 let req = LlmRequest::new("test-model", vec![]);
250 assert_eq!(req.model, "test-model");
251 assert!(req.contents.is_empty());
252 }
253
254 #[test]
255 fn test_llm_request_with_response_schema() {
256 let schema = serde_json::json!({
257 "type": "object",
258 "properties": {
259 "name": { "type": "string" }
260 }
261 });
262 let req = LlmRequest::new("test-model", vec![]).with_response_schema(schema.clone());
263
264 assert!(req.config.is_some());
265 let config = req.config.unwrap();
266 assert!(config.response_schema.is_some());
267 assert_eq!(config.response_schema.unwrap(), schema);
268 }
269
270 #[test]
271 fn test_llm_request_with_config() {
272 let config = GenerateContentConfig {
273 temperature: Some(0.7),
274 top_p: Some(0.9),
275 top_k: Some(40),
276 frequency_penalty: Some(0.2),
277 presence_penalty: Some(-0.3),
278 max_output_tokens: Some(1024),
279 seed: Some(42),
280 top_logprobs: Some(5),
281 stop_sequences: vec!["END".to_string()],
282 ..Default::default()
283 };
284 let req = LlmRequest::new("test-model", vec![]).with_config(config);
285
286 assert!(req.config.is_some());
287 let config = req.config.unwrap();
288 assert_eq!(config.temperature, Some(0.7));
289 assert_eq!(config.max_output_tokens, Some(1024));
290 assert_eq!(config.frequency_penalty, Some(0.2));
291 assert_eq!(config.presence_penalty, Some(-0.3));
292 assert_eq!(config.seed, Some(42));
293 assert_eq!(config.top_logprobs, Some(5));
294 assert_eq!(config.stop_sequences, vec!["END"]);
295 }
296
297 #[test]
298 fn test_llm_response_creation() {
299 let content = Content::new("assistant");
300 let resp = LlmResponse::new(content);
301 assert!(resp.content.is_some());
302 assert!(resp.turn_complete);
303 assert!(!resp.partial);
304 assert_eq!(resp.finish_reason, Some(FinishReason::Stop));
305 assert!(resp.citation_metadata.is_none());
306 assert!(resp.provider_metadata.is_none());
307 }
308
309 #[test]
310 fn test_llm_response_deserialize_without_citations() {
311 let json = serde_json::json!({
312 "content": {
313 "role": "model",
314 "parts": [{"text": "hello"}]
315 },
316 "partial": false,
317 "turn_complete": true,
318 "interrupted": false
319 });
320
321 let response: LlmResponse = serde_json::from_value(json).expect("should deserialize");
322 assert!(response.citation_metadata.is_none());
323 }
324
325 #[test]
326 fn test_llm_response_roundtrip_with_citations() {
327 let response = LlmResponse {
328 content: Some(Content::new("model").with_text("hello")),
329 usage_metadata: None,
330 finish_reason: Some(FinishReason::Stop),
331 citation_metadata: Some(CitationMetadata {
332 citation_sources: vec![CitationSource {
333 uri: Some("https://example.com".to_string()),
334 title: Some("Example".to_string()),
335 start_index: Some(0),
336 end_index: Some(5),
337 license: None,
338 publication_date: Some("2026-01-01T00:00:00Z".to_string()),
339 }],
340 }),
341 partial: false,
342 turn_complete: true,
343 interrupted: false,
344 error_code: None,
345 error_message: None,
346 provider_metadata: None,
347 };
348
349 let encoded = serde_json::to_string(&response).expect("serialize");
350 let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
351 assert_eq!(decoded.citation_metadata, response.citation_metadata);
352 }
353
354 #[test]
355 fn test_generate_content_config_roundtrip_with_extensions() {
356 let mut extensions = serde_json::Map::new();
357 extensions.insert(
358 "openrouter".to_string(),
359 serde_json::json!({
360 "provider": {
361 "zdr": true,
362 "order": ["openai", "anthropic"]
363 },
364 "plugins": [
365 { "id": "web", "enabled": true }
366 ]
367 }),
368 );
369
370 let config = GenerateContentConfig {
371 temperature: Some(0.4),
372 top_p: Some(0.8),
373 top_k: Some(12),
374 frequency_penalty: Some(0.1),
375 presence_penalty: Some(0.2),
376 max_output_tokens: Some(512),
377 seed: Some(7),
378 top_logprobs: Some(3),
379 stop_sequences: vec!["STOP".to_string(), "DONE".to_string()],
380 response_schema: Some(serde_json::json!({
381 "type": "object",
382 "properties": { "answer": { "type": "string" } },
383 "required": ["answer"]
384 })),
385 cached_content: Some("cachedContents/abc123".to_string()),
386 extensions,
387 };
388
389 let encoded = serde_json::to_string(&config).expect("serialize");
390 let decoded: GenerateContentConfig = serde_json::from_str(&encoded).expect("deserialize");
391
392 assert_eq!(decoded.temperature, config.temperature);
393 assert_eq!(decoded.top_p, config.top_p);
394 assert_eq!(decoded.top_k, config.top_k);
395 assert_eq!(decoded.frequency_penalty, config.frequency_penalty);
396 assert_eq!(decoded.presence_penalty, config.presence_penalty);
397 assert_eq!(decoded.max_output_tokens, config.max_output_tokens);
398 assert_eq!(decoded.seed, config.seed);
399 assert_eq!(decoded.top_logprobs, config.top_logprobs);
400 assert_eq!(decoded.stop_sequences, config.stop_sequences);
401 assert_eq!(decoded.response_schema, config.response_schema);
402 assert_eq!(decoded.cached_content, config.cached_content);
403 assert_eq!(decoded.extensions, config.extensions);
404 }
405
406 #[test]
407 fn test_llm_response_and_usage_roundtrip_with_provider_metadata() {
408 let response = LlmResponse {
409 content: Some(Content::new("model").with_text("hello")),
410 usage_metadata: Some(UsageMetadata {
411 prompt_token_count: 10,
412 candidates_token_count: 20,
413 total_token_count: 30,
414 cache_read_input_token_count: Some(5),
415 cache_creation_input_token_count: Some(2),
416 thinking_token_count: Some(3),
417 audio_input_token_count: Some(4),
418 audio_output_token_count: Some(6),
419 cost: Some(0.0125),
420 is_byok: Some(true),
421 provider_usage: Some(serde_json::json!({
422 "server_tool_use": {
423 "web_search_requests": 1
424 },
425 "prompt_tokens_details": {
426 "video_tokens": 8
427 }
428 })),
429 }),
430 finish_reason: Some(FinishReason::Stop),
431 citation_metadata: None,
432 partial: false,
433 turn_complete: true,
434 interrupted: false,
435 error_code: None,
436 error_message: None,
437 provider_metadata: Some(serde_json::json!({
438 "openrouter": {
439 "responseId": "resp_123",
440 "outputItems": 2
441 }
442 })),
443 };
444
445 let encoded = serde_json::to_string(&response).expect("serialize");
446 let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
447
448 assert_eq!(decoded.provider_metadata, response.provider_metadata);
449 assert_eq!(
450 decoded.usage_metadata.as_ref().and_then(|u| u.cost),
451 response.usage_metadata.as_ref().and_then(|u| u.cost),
452 );
453 assert_eq!(
454 decoded.usage_metadata.as_ref().and_then(|u| u.is_byok),
455 response.usage_metadata.as_ref().and_then(|u| u.is_byok),
456 );
457 assert_eq!(
458 decoded.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
459 response.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
460 );
461 }
462
463 #[test]
464 fn test_finish_reason() {
465 assert_eq!(FinishReason::Stop, FinishReason::Stop);
466 assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
467 }
468}