1use crate::schema_adapter::{GenericSchemaAdapter, SchemaAdapter};
2use crate::{Result, types::Content};
3use async_trait::async_trait;
4use futures::stream::Stream;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::pin::Pin;
8
9pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmResponse>> + Send>>;
11
12#[async_trait]
17pub trait Llm: Send + Sync {
18 fn name(&self) -> &str;
20 async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream>;
22
23 fn schema_adapter(&self) -> &dyn SchemaAdapter {
33 &GenericSchemaAdapter
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct LlmRequest {
40 pub model: String,
42 pub contents: Vec<Content>,
44 pub config: Option<GenerateContentConfig>,
46 #[serde(skip)]
48 pub tools: HashMap<String, serde_json::Value>,
49}
50
51#[derive(Debug, Clone, Default, Serialize, Deserialize)]
53pub struct GenerateContentConfig {
54 pub temperature: Option<f32>,
56 pub top_p: Option<f32>,
58 pub top_k: Option<i32>,
60 #[serde(skip_serializing_if = "Option::is_none", default)]
62 pub frequency_penalty: Option<f32>,
63 #[serde(skip_serializing_if = "Option::is_none", default)]
65 pub presence_penalty: Option<f32>,
66 pub max_output_tokens: Option<i32>,
68 #[serde(skip_serializing_if = "Option::is_none", default)]
70 pub seed: Option<i64>,
71 #[serde(skip_serializing_if = "Option::is_none", default)]
73 pub top_logprobs: Option<u8>,
74 #[serde(default, skip_serializing_if = "Vec::is_empty")]
76 pub stop_sequences: Vec<String>,
77 #[serde(skip_serializing_if = "Option::is_none")]
79 pub response_schema: Option<serde_json::Value>,
80
81 #[serde(skip_serializing_if = "Option::is_none", default)]
84 pub cached_content: Option<String>,
85
86 #[serde(default, skip_serializing_if = "serde_json::Map::is_empty")]
88 pub extensions: serde_json::Map<String, serde_json::Value>,
89}
90
91#[derive(Debug, Clone, Default, Serialize, Deserialize)]
93pub struct LlmResponse {
94 pub content: Option<Content>,
96 pub usage_metadata: Option<UsageMetadata>,
98 pub finish_reason: Option<FinishReason>,
100 #[serde(skip_serializing_if = "Option::is_none")]
102 pub citation_metadata: Option<CitationMetadata>,
103 pub partial: bool,
105 pub turn_complete: bool,
107 pub interrupted: bool,
109 pub error_code: Option<String>,
111 pub error_message: Option<String>,
113 #[serde(skip_serializing_if = "Option::is_none", default)]
115 pub provider_metadata: Option<serde_json::Value>,
116}
117
118#[async_trait]
138pub trait CacheCapable: Send + Sync {
139 async fn create_cache(
146 &self,
147 system_instruction: &str,
148 tools: &HashMap<String, serde_json::Value>,
149 ttl_seconds: u32,
150 ) -> Result<String>;
151
152 async fn delete_cache(&self, name: &str) -> Result<()>;
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct ContextCacheConfig {
175 pub min_tokens: u32,
178
179 pub ttl_seconds: u32,
182
183 pub cache_intervals: u32,
187}
188
189impl Default for ContextCacheConfig {
190 fn default() -> Self {
191 Self { min_tokens: 4096, ttl_seconds: 600, cache_intervals: 3 }
192 }
193}
194
195#[derive(Debug, Clone, Default, Serialize, Deserialize)]
197pub struct UsageMetadata {
198 pub prompt_token_count: i32,
200 pub candidates_token_count: i32,
202 pub total_token_count: i32,
204
205 #[serde(skip_serializing_if = "Option::is_none", default)]
207 pub cache_read_input_token_count: Option<i32>,
208
209 #[serde(skip_serializing_if = "Option::is_none", default)]
211 pub cache_creation_input_token_count: Option<i32>,
212
213 #[serde(skip_serializing_if = "Option::is_none", default)]
215 pub thinking_token_count: Option<i32>,
216
217 #[serde(skip_serializing_if = "Option::is_none", default)]
219 pub audio_input_token_count: Option<i32>,
220
221 #[serde(skip_serializing_if = "Option::is_none", default)]
223 pub audio_output_token_count: Option<i32>,
224
225 #[serde(skip_serializing_if = "Option::is_none", default)]
227 pub cost: Option<f64>,
228
229 #[serde(skip_serializing_if = "Option::is_none", default)]
231 pub is_byok: Option<bool>,
232
233 #[serde(skip_serializing_if = "Option::is_none", default)]
235 pub provider_usage: Option<serde_json::Value>,
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
240#[serde(rename_all = "camelCase")]
241pub struct CitationMetadata {
242 #[serde(default)]
244 pub citation_sources: Vec<CitationSource>,
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
249#[serde(rename_all = "camelCase")]
250pub struct CitationSource {
251 pub uri: Option<String>,
253 pub title: Option<String>,
255 pub start_index: Option<i32>,
257 pub end_index: Option<i32>,
259 pub license: Option<String>,
261 pub publication_date: Option<String>,
263}
264
265#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
267pub enum FinishReason {
268 Stop,
270 MaxTokens,
272 Safety,
274 Recitation,
276 Other,
278}
279
280impl LlmRequest {
281 pub fn new(model: impl Into<String>, contents: Vec<Content>) -> Self {
283 Self { model: model.into(), contents, config: None, tools: HashMap::new() }
284 }
285
286 pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
288 let config = self.config.get_or_insert(GenerateContentConfig::default());
289 config.response_schema = Some(schema);
290 self
291 }
292
293 pub fn with_config(mut self, config: GenerateContentConfig) -> Self {
295 self.config = Some(config);
296 self
297 }
298}
299
300impl LlmResponse {
301 pub fn new(content: Content) -> Self {
303 Self {
304 content: Some(content),
305 usage_metadata: None,
306 finish_reason: Some(FinishReason::Stop),
307 citation_metadata: None,
308 partial: false,
309 turn_complete: true,
310 interrupted: false,
311 error_code: None,
312 error_message: None,
313 provider_metadata: None,
314 }
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_llm_request_creation() {
324 let req = LlmRequest::new("test-model", vec![]);
325 assert_eq!(req.model, "test-model");
326 assert!(req.contents.is_empty());
327 }
328
329 #[test]
330 fn test_llm_request_with_response_schema() {
331 let schema = serde_json::json!({
332 "type": "object",
333 "properties": {
334 "name": { "type": "string" }
335 }
336 });
337 let req = LlmRequest::new("test-model", vec![]).with_response_schema(schema.clone());
338
339 assert!(req.config.is_some());
340 let config = req.config.unwrap();
341 assert!(config.response_schema.is_some());
342 assert_eq!(config.response_schema.unwrap(), schema);
343 }
344
345 #[test]
346 fn test_llm_request_with_config() {
347 let config = GenerateContentConfig {
348 temperature: Some(0.7),
349 top_p: Some(0.9),
350 top_k: Some(40),
351 frequency_penalty: Some(0.2),
352 presence_penalty: Some(-0.3),
353 max_output_tokens: Some(1024),
354 seed: Some(42),
355 top_logprobs: Some(5),
356 stop_sequences: vec!["END".to_string()],
357 ..Default::default()
358 };
359 let req = LlmRequest::new("test-model", vec![]).with_config(config);
360
361 assert!(req.config.is_some());
362 let config = req.config.unwrap();
363 assert_eq!(config.temperature, Some(0.7));
364 assert_eq!(config.max_output_tokens, Some(1024));
365 assert_eq!(config.frequency_penalty, Some(0.2));
366 assert_eq!(config.presence_penalty, Some(-0.3));
367 assert_eq!(config.seed, Some(42));
368 assert_eq!(config.top_logprobs, Some(5));
369 assert_eq!(config.stop_sequences, vec!["END"]);
370 }
371
372 #[test]
373 fn test_llm_response_creation() {
374 let content = Content::new("assistant");
375 let resp = LlmResponse::new(content);
376 assert!(resp.content.is_some());
377 assert!(resp.turn_complete);
378 assert!(!resp.partial);
379 assert_eq!(resp.finish_reason, Some(FinishReason::Stop));
380 assert!(resp.citation_metadata.is_none());
381 assert!(resp.provider_metadata.is_none());
382 }
383
384 #[test]
385 fn test_llm_response_deserialize_without_citations() {
386 let json = serde_json::json!({
387 "content": {
388 "role": "model",
389 "parts": [{"text": "hello"}]
390 },
391 "partial": false,
392 "turn_complete": true,
393 "interrupted": false
394 });
395
396 let response: LlmResponse = serde_json::from_value(json).expect("should deserialize");
397 assert!(response.citation_metadata.is_none());
398 }
399
400 #[test]
401 fn test_llm_response_roundtrip_with_citations() {
402 let response = LlmResponse {
403 content: Some(Content::new("model").with_text("hello")),
404 usage_metadata: None,
405 finish_reason: Some(FinishReason::Stop),
406 citation_metadata: Some(CitationMetadata {
407 citation_sources: vec![CitationSource {
408 uri: Some("https://example.com".to_string()),
409 title: Some("Example".to_string()),
410 start_index: Some(0),
411 end_index: Some(5),
412 license: None,
413 publication_date: Some("2026-01-01T00:00:00Z".to_string()),
414 }],
415 }),
416 partial: false,
417 turn_complete: true,
418 interrupted: false,
419 error_code: None,
420 error_message: None,
421 provider_metadata: None,
422 };
423
424 let encoded = serde_json::to_string(&response).expect("serialize");
425 let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
426 assert_eq!(decoded.citation_metadata, response.citation_metadata);
427 }
428
429 #[test]
430 fn test_generate_content_config_roundtrip_with_extensions() {
431 let mut extensions = serde_json::Map::new();
432 extensions.insert(
433 "openrouter".to_string(),
434 serde_json::json!({
435 "provider": {
436 "zdr": true,
437 "order": ["openai", "anthropic"]
438 },
439 "plugins": [
440 { "id": "web", "enabled": true }
441 ]
442 }),
443 );
444
445 let config = GenerateContentConfig {
446 temperature: Some(0.4),
447 top_p: Some(0.8),
448 top_k: Some(12),
449 frequency_penalty: Some(0.1),
450 presence_penalty: Some(0.2),
451 max_output_tokens: Some(512),
452 seed: Some(7),
453 top_logprobs: Some(3),
454 stop_sequences: vec!["STOP".to_string(), "DONE".to_string()],
455 response_schema: Some(serde_json::json!({
456 "type": "object",
457 "properties": { "answer": { "type": "string" } },
458 "required": ["answer"]
459 })),
460 cached_content: Some("cachedContents/abc123".to_string()),
461 extensions,
462 };
463
464 let encoded = serde_json::to_string(&config).expect("serialize");
465 let decoded: GenerateContentConfig = serde_json::from_str(&encoded).expect("deserialize");
466
467 assert_eq!(decoded.temperature, config.temperature);
468 assert_eq!(decoded.top_p, config.top_p);
469 assert_eq!(decoded.top_k, config.top_k);
470 assert_eq!(decoded.frequency_penalty, config.frequency_penalty);
471 assert_eq!(decoded.presence_penalty, config.presence_penalty);
472 assert_eq!(decoded.max_output_tokens, config.max_output_tokens);
473 assert_eq!(decoded.seed, config.seed);
474 assert_eq!(decoded.top_logprobs, config.top_logprobs);
475 assert_eq!(decoded.stop_sequences, config.stop_sequences);
476 assert_eq!(decoded.response_schema, config.response_schema);
477 assert_eq!(decoded.cached_content, config.cached_content);
478 assert_eq!(decoded.extensions, config.extensions);
479 }
480
481 #[test]
482 fn test_llm_response_and_usage_roundtrip_with_provider_metadata() {
483 let response = LlmResponse {
484 content: Some(Content::new("model").with_text("hello")),
485 usage_metadata: Some(UsageMetadata {
486 prompt_token_count: 10,
487 candidates_token_count: 20,
488 total_token_count: 30,
489 cache_read_input_token_count: Some(5),
490 cache_creation_input_token_count: Some(2),
491 thinking_token_count: Some(3),
492 audio_input_token_count: Some(4),
493 audio_output_token_count: Some(6),
494 cost: Some(0.0125),
495 is_byok: Some(true),
496 provider_usage: Some(serde_json::json!({
497 "server_tool_use": {
498 "web_search_requests": 1
499 },
500 "prompt_tokens_details": {
501 "video_tokens": 8
502 }
503 })),
504 }),
505 finish_reason: Some(FinishReason::Stop),
506 citation_metadata: None,
507 partial: false,
508 turn_complete: true,
509 interrupted: false,
510 error_code: None,
511 error_message: None,
512 provider_metadata: Some(serde_json::json!({
513 "openrouter": {
514 "responseId": "resp_123",
515 "outputItems": 2
516 }
517 })),
518 };
519
520 let encoded = serde_json::to_string(&response).expect("serialize");
521 let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
522
523 assert_eq!(decoded.provider_metadata, response.provider_metadata);
524 assert_eq!(
525 decoded.usage_metadata.as_ref().and_then(|u| u.cost),
526 response.usage_metadata.as_ref().and_then(|u| u.cost),
527 );
528 assert_eq!(
529 decoded.usage_metadata.as_ref().and_then(|u| u.is_byok),
530 response.usage_metadata.as_ref().and_then(|u| u.is_byok),
531 );
532 assert_eq!(
533 decoded.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
534 response.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
535 );
536 }
537
538 #[test]
539 fn test_finish_reason() {
540 assert_eq!(FinishReason::Stop, FinishReason::Stop);
541 assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
542 }
543}