1use crate::schema_adapter::{GenericSchemaAdapter, SchemaAdapter};
2use crate::{Result, types::Content};
3use async_trait::async_trait;
4use futures::stream::Stream;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::pin::Pin;
8
9pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmResponse>> + Send>>;
11
12#[async_trait]
17pub trait Llm: Send + Sync {
18 fn name(&self) -> &str;
20 async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream>;
22
23 fn schema_adapter(&self) -> &dyn SchemaAdapter {
33 &GenericSchemaAdapter
34 }
35
36 fn uses_interactions_api(&self) -> bool {
46 false
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct LlmRequest {
53 pub model: String,
55 pub contents: Vec<Content>,
57 pub config: Option<GenerateContentConfig>,
59 #[serde(skip)]
61 pub tools: HashMap<String, serde_json::Value>,
62 #[serde(skip_serializing_if = "Option::is_none", default)]
67 pub previous_response_id: Option<String>,
68}
69
70#[derive(Debug, Clone, Default, Serialize, Deserialize)]
72pub struct GenerateContentConfig {
73 pub temperature: Option<f32>,
75 pub top_p: Option<f32>,
77 pub top_k: Option<i32>,
79 #[serde(skip_serializing_if = "Option::is_none", default)]
81 pub frequency_penalty: Option<f32>,
82 #[serde(skip_serializing_if = "Option::is_none", default)]
84 pub presence_penalty: Option<f32>,
85 pub max_output_tokens: Option<i32>,
87 #[serde(skip_serializing_if = "Option::is_none", default)]
89 pub seed: Option<i64>,
90 #[serde(skip_serializing_if = "Option::is_none", default)]
92 pub top_logprobs: Option<u8>,
93 #[serde(default, skip_serializing_if = "Vec::is_empty")]
95 pub stop_sequences: Vec<String>,
96 #[serde(skip_serializing_if = "Option::is_none")]
98 pub response_schema: Option<serde_json::Value>,
99
100 #[serde(skip_serializing_if = "Option::is_none", default)]
103 pub cached_content: Option<String>,
104
105 #[serde(default, skip_serializing_if = "serde_json::Map::is_empty")]
107 pub extensions: serde_json::Map<String, serde_json::Value>,
108}
109
110#[derive(Debug, Clone, Default, Serialize, Deserialize)]
112pub struct LlmResponse {
113 pub content: Option<Content>,
115 pub usage_metadata: Option<UsageMetadata>,
117 pub finish_reason: Option<FinishReason>,
119 #[serde(skip_serializing_if = "Option::is_none")]
121 pub citation_metadata: Option<CitationMetadata>,
122 pub partial: bool,
124 pub turn_complete: bool,
126 pub interrupted: bool,
128 pub error_code: Option<String>,
130 pub error_message: Option<String>,
132 #[serde(skip_serializing_if = "Option::is_none", default)]
134 pub provider_metadata: Option<serde_json::Value>,
135 #[serde(skip_serializing_if = "Option::is_none", default)]
139 pub interaction_id: Option<String>,
140}
141
142#[async_trait]
162pub trait CacheCapable: Send + Sync {
163 async fn create_cache(
170 &self,
171 system_instruction: &str,
172 tools: &HashMap<String, serde_json::Value>,
173 ttl_seconds: u32,
174 ) -> Result<String>;
175
176 async fn delete_cache(&self, name: &str) -> Result<()>;
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct ContextCacheConfig {
199 pub min_tokens: u32,
202
203 pub ttl_seconds: u32,
206
207 pub cache_intervals: u32,
211}
212
213impl Default for ContextCacheConfig {
214 fn default() -> Self {
215 Self { min_tokens: 4096, ttl_seconds: 600, cache_intervals: 3 }
216 }
217}
218
219#[derive(Debug, Clone, Default, Serialize, Deserialize)]
221pub struct UsageMetadata {
222 pub prompt_token_count: i32,
224 pub candidates_token_count: i32,
226 pub total_token_count: i32,
228
229 #[serde(skip_serializing_if = "Option::is_none", default)]
231 pub cache_read_input_token_count: Option<i32>,
232
233 #[serde(skip_serializing_if = "Option::is_none", default)]
235 pub cache_creation_input_token_count: Option<i32>,
236
237 #[serde(skip_serializing_if = "Option::is_none", default)]
239 pub thinking_token_count: Option<i32>,
240
241 #[serde(skip_serializing_if = "Option::is_none", default)]
243 pub audio_input_token_count: Option<i32>,
244
245 #[serde(skip_serializing_if = "Option::is_none", default)]
247 pub audio_output_token_count: Option<i32>,
248
249 #[serde(skip_serializing_if = "Option::is_none", default)]
251 pub cost: Option<f64>,
252
253 #[serde(skip_serializing_if = "Option::is_none", default)]
255 pub is_byok: Option<bool>,
256
257 #[serde(skip_serializing_if = "Option::is_none", default)]
259 pub provider_usage: Option<serde_json::Value>,
260}
261
262#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
264#[serde(rename_all = "camelCase")]
265pub struct CitationMetadata {
266 #[serde(default)]
268 pub citation_sources: Vec<CitationSource>,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
273#[serde(rename_all = "camelCase")]
274pub struct CitationSource {
275 pub uri: Option<String>,
277 pub title: Option<String>,
279 pub start_index: Option<i32>,
281 pub end_index: Option<i32>,
283 pub license: Option<String>,
285 pub publication_date: Option<String>,
287}
288
289#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
291pub enum FinishReason {
292 Stop,
294 MaxTokens,
296 Safety,
298 Recitation,
300 Other,
302}
303
304impl LlmRequest {
305 pub fn new(model: impl Into<String>, contents: Vec<Content>) -> Self {
307 Self {
308 model: model.into(),
309 contents,
310 config: None,
311 tools: HashMap::new(),
312 previous_response_id: None,
313 }
314 }
315
316 pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
318 let config = self.config.get_or_insert(GenerateContentConfig::default());
319 config.response_schema = Some(schema);
320 self
321 }
322
323 pub fn with_config(mut self, config: GenerateContentConfig) -> Self {
325 self.config = Some(config);
326 self
327 }
328}
329
330impl LlmResponse {
331 pub fn new(content: Content) -> Self {
333 Self {
334 content: Some(content),
335 usage_metadata: None,
336 finish_reason: Some(FinishReason::Stop),
337 citation_metadata: None,
338 partial: false,
339 turn_complete: true,
340 interrupted: false,
341 error_code: None,
342 error_message: None,
343 provider_metadata: None,
344 interaction_id: None,
345 }
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_llm_request_creation() {
355 let req = LlmRequest::new("test-model", vec![]);
356 assert_eq!(req.model, "test-model");
357 assert!(req.contents.is_empty());
358 }
359
360 #[test]
361 fn test_llm_request_with_response_schema() {
362 let schema = serde_json::json!({
363 "type": "object",
364 "properties": {
365 "name": { "type": "string" }
366 }
367 });
368 let req = LlmRequest::new("test-model", vec![]).with_response_schema(schema.clone());
369
370 assert!(req.config.is_some());
371 let config = req.config.unwrap();
372 assert!(config.response_schema.is_some());
373 assert_eq!(config.response_schema.unwrap(), schema);
374 }
375
376 #[test]
377 fn test_llm_request_with_config() {
378 let config = GenerateContentConfig {
379 temperature: Some(0.7),
380 top_p: Some(0.9),
381 top_k: Some(40),
382 frequency_penalty: Some(0.2),
383 presence_penalty: Some(-0.3),
384 max_output_tokens: Some(1024),
385 seed: Some(42),
386 top_logprobs: Some(5),
387 stop_sequences: vec!["END".to_string()],
388 ..Default::default()
389 };
390 let req = LlmRequest::new("test-model", vec![]).with_config(config);
391
392 assert!(req.config.is_some());
393 let config = req.config.unwrap();
394 assert_eq!(config.temperature, Some(0.7));
395 assert_eq!(config.max_output_tokens, Some(1024));
396 assert_eq!(config.frequency_penalty, Some(0.2));
397 assert_eq!(config.presence_penalty, Some(-0.3));
398 assert_eq!(config.seed, Some(42));
399 assert_eq!(config.top_logprobs, Some(5));
400 assert_eq!(config.stop_sequences, vec!["END"]);
401 }
402
403 #[test]
404 fn test_llm_response_creation() {
405 let content = Content::new("assistant");
406 let resp = LlmResponse::new(content);
407 assert!(resp.content.is_some());
408 assert!(resp.turn_complete);
409 assert!(!resp.partial);
410 assert_eq!(resp.finish_reason, Some(FinishReason::Stop));
411 assert!(resp.citation_metadata.is_none());
412 assert!(resp.provider_metadata.is_none());
413 }
414
415 #[test]
416 fn test_llm_response_deserialize_without_citations() {
417 let json = serde_json::json!({
418 "content": {
419 "role": "model",
420 "parts": [{"text": "hello"}]
421 },
422 "partial": false,
423 "turn_complete": true,
424 "interrupted": false
425 });
426
427 let response: LlmResponse = serde_json::from_value(json).expect("should deserialize");
428 assert!(response.citation_metadata.is_none());
429 }
430
431 #[test]
432 fn test_llm_response_roundtrip_with_citations() {
433 let response = LlmResponse {
434 content: Some(Content::new("model").with_text("hello")),
435 usage_metadata: None,
436 finish_reason: Some(FinishReason::Stop),
437 citation_metadata: Some(CitationMetadata {
438 citation_sources: vec![CitationSource {
439 uri: Some("https://example.com".to_string()),
440 title: Some("Example".to_string()),
441 start_index: Some(0),
442 end_index: Some(5),
443 license: None,
444 publication_date: Some("2026-01-01T00:00:00Z".to_string()),
445 }],
446 }),
447 partial: false,
448 turn_complete: true,
449 interrupted: false,
450 error_code: None,
451 error_message: None,
452 provider_metadata: None,
453 interaction_id: None,
454 };
455
456 let encoded = serde_json::to_string(&response).expect("serialize");
457 let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
458 assert_eq!(decoded.citation_metadata, response.citation_metadata);
459 }
460
461 #[test]
462 fn test_generate_content_config_roundtrip_with_extensions() {
463 let mut extensions = serde_json::Map::new();
464 extensions.insert(
465 "openrouter".to_string(),
466 serde_json::json!({
467 "provider": {
468 "zdr": true,
469 "order": ["openai", "anthropic"]
470 },
471 "plugins": [
472 { "id": "web", "enabled": true }
473 ]
474 }),
475 );
476
477 let config = GenerateContentConfig {
478 temperature: Some(0.4),
479 top_p: Some(0.8),
480 top_k: Some(12),
481 frequency_penalty: Some(0.1),
482 presence_penalty: Some(0.2),
483 max_output_tokens: Some(512),
484 seed: Some(7),
485 top_logprobs: Some(3),
486 stop_sequences: vec!["STOP".to_string(), "DONE".to_string()],
487 response_schema: Some(serde_json::json!({
488 "type": "object",
489 "properties": { "answer": { "type": "string" } },
490 "required": ["answer"]
491 })),
492 cached_content: Some("cachedContents/abc123".to_string()),
493 extensions,
494 };
495
496 let encoded = serde_json::to_string(&config).expect("serialize");
497 let decoded: GenerateContentConfig = serde_json::from_str(&encoded).expect("deserialize");
498
499 assert_eq!(decoded.temperature, config.temperature);
500 assert_eq!(decoded.top_p, config.top_p);
501 assert_eq!(decoded.top_k, config.top_k);
502 assert_eq!(decoded.frequency_penalty, config.frequency_penalty);
503 assert_eq!(decoded.presence_penalty, config.presence_penalty);
504 assert_eq!(decoded.max_output_tokens, config.max_output_tokens);
505 assert_eq!(decoded.seed, config.seed);
506 assert_eq!(decoded.top_logprobs, config.top_logprobs);
507 assert_eq!(decoded.stop_sequences, config.stop_sequences);
508 assert_eq!(decoded.response_schema, config.response_schema);
509 assert_eq!(decoded.cached_content, config.cached_content);
510 assert_eq!(decoded.extensions, config.extensions);
511 }
512
513 #[test]
514 fn test_llm_response_and_usage_roundtrip_with_provider_metadata() {
515 let response = LlmResponse {
516 content: Some(Content::new("model").with_text("hello")),
517 usage_metadata: Some(UsageMetadata {
518 prompt_token_count: 10,
519 candidates_token_count: 20,
520 total_token_count: 30,
521 cache_read_input_token_count: Some(5),
522 cache_creation_input_token_count: Some(2),
523 thinking_token_count: Some(3),
524 audio_input_token_count: Some(4),
525 audio_output_token_count: Some(6),
526 cost: Some(0.0125),
527 is_byok: Some(true),
528 provider_usage: Some(serde_json::json!({
529 "server_tool_use": {
530 "web_search_requests": 1
531 },
532 "prompt_tokens_details": {
533 "video_tokens": 8
534 }
535 })),
536 }),
537 finish_reason: Some(FinishReason::Stop),
538 citation_metadata: None,
539 partial: false,
540 turn_complete: true,
541 interrupted: false,
542 error_code: None,
543 error_message: None,
544 provider_metadata: Some(serde_json::json!({
545 "openrouter": {
546 "responseId": "resp_123",
547 "outputItems": 2
548 }
549 })),
550 interaction_id: None,
551 };
552
553 let encoded = serde_json::to_string(&response).expect("serialize");
554 let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
555
556 assert_eq!(decoded.provider_metadata, response.provider_metadata);
557 assert_eq!(
558 decoded.usage_metadata.as_ref().and_then(|u| u.cost),
559 response.usage_metadata.as_ref().and_then(|u| u.cost),
560 );
561 assert_eq!(
562 decoded.usage_metadata.as_ref().and_then(|u| u.is_byok),
563 response.usage_metadata.as_ref().and_then(|u| u.is_byok),
564 );
565 assert_eq!(
566 decoded.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
567 response.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
568 );
569 }
570
571 #[test]
572 fn test_finish_reason() {
573 assert_eq!(FinishReason::Stop, FinishReason::Stop);
574 assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
575 }
576}