1use crate::schema_adapter::{GenericSchemaAdapter, SchemaAdapter};
2use crate::{Result, types::Content};
3use async_trait::async_trait;
4use futures::stream::Stream;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::pin::Pin;
8
9pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmResponse>> + Send>>;
10
11#[async_trait]
12pub trait Llm: Send + Sync {
13 fn name(&self) -> &str;
14 async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream>;
15
16 fn schema_adapter(&self) -> &dyn SchemaAdapter {
26 &GenericSchemaAdapter
27 }
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct LlmRequest {
32 pub model: String,
33 pub contents: Vec<Content>,
34 pub config: Option<GenerateContentConfig>,
35 #[serde(skip)]
36 pub tools: HashMap<String, serde_json::Value>,
37}
38
39#[derive(Debug, Clone, Default, Serialize, Deserialize)]
40pub struct GenerateContentConfig {
41 pub temperature: Option<f32>,
42 pub top_p: Option<f32>,
43 pub top_k: Option<i32>,
44 #[serde(skip_serializing_if = "Option::is_none", default)]
45 pub frequency_penalty: Option<f32>,
46 #[serde(skip_serializing_if = "Option::is_none", default)]
47 pub presence_penalty: Option<f32>,
48 pub max_output_tokens: Option<i32>,
49 #[serde(skip_serializing_if = "Option::is_none", default)]
50 pub seed: Option<i64>,
51 #[serde(skip_serializing_if = "Option::is_none", default)]
52 pub top_logprobs: Option<u8>,
53 #[serde(default, skip_serializing_if = "Vec::is_empty")]
54 pub stop_sequences: Vec<String>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub response_schema: Option<serde_json::Value>,
57
58 #[serde(skip_serializing_if = "Option::is_none", default)]
61 pub cached_content: Option<String>,
62
63 #[serde(default, skip_serializing_if = "serde_json::Map::is_empty")]
65 pub extensions: serde_json::Map<String, serde_json::Value>,
66}
67
68#[derive(Debug, Clone, Default, Serialize, Deserialize)]
69pub struct LlmResponse {
70 pub content: Option<Content>,
71 pub usage_metadata: Option<UsageMetadata>,
72 pub finish_reason: Option<FinishReason>,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pub citation_metadata: Option<CitationMetadata>,
75 pub partial: bool,
76 pub turn_complete: bool,
77 pub interrupted: bool,
78 pub error_code: Option<String>,
79 pub error_message: Option<String>,
80 #[serde(skip_serializing_if = "Option::is_none", default)]
81 pub provider_metadata: Option<serde_json::Value>,
82}
83
84#[async_trait]
104pub trait CacheCapable: Send + Sync {
105 async fn create_cache(
112 &self,
113 system_instruction: &str,
114 tools: &HashMap<String, serde_json::Value>,
115 ttl_seconds: u32,
116 ) -> Result<String>;
117
118 async fn delete_cache(&self, name: &str) -> Result<()>;
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct ContextCacheConfig {
141 pub min_tokens: u32,
144
145 pub ttl_seconds: u32,
148
149 pub cache_intervals: u32,
153}
154
155impl Default for ContextCacheConfig {
156 fn default() -> Self {
157 Self { min_tokens: 4096, ttl_seconds: 600, cache_intervals: 3 }
158 }
159}
160
161#[derive(Debug, Clone, Default, Serialize, Deserialize)]
162pub struct UsageMetadata {
163 pub prompt_token_count: i32,
164 pub candidates_token_count: i32,
165 pub total_token_count: i32,
166
167 #[serde(skip_serializing_if = "Option::is_none", default)]
168 pub cache_read_input_token_count: Option<i32>,
169
170 #[serde(skip_serializing_if = "Option::is_none", default)]
171 pub cache_creation_input_token_count: Option<i32>,
172
173 #[serde(skip_serializing_if = "Option::is_none", default)]
174 pub thinking_token_count: Option<i32>,
175
176 #[serde(skip_serializing_if = "Option::is_none", default)]
177 pub audio_input_token_count: Option<i32>,
178
179 #[serde(skip_serializing_if = "Option::is_none", default)]
180 pub audio_output_token_count: Option<i32>,
181
182 #[serde(skip_serializing_if = "Option::is_none", default)]
183 pub cost: Option<f64>,
184
185 #[serde(skip_serializing_if = "Option::is_none", default)]
186 pub is_byok: Option<bool>,
187
188 #[serde(skip_serializing_if = "Option::is_none", default)]
189 pub provider_usage: Option<serde_json::Value>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
194#[serde(rename_all = "camelCase")]
195pub struct CitationMetadata {
196 #[serde(default)]
197 pub citation_sources: Vec<CitationSource>,
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
202#[serde(rename_all = "camelCase")]
203pub struct CitationSource {
204 pub uri: Option<String>,
205 pub title: Option<String>,
206 pub start_index: Option<i32>,
207 pub end_index: Option<i32>,
208 pub license: Option<String>,
209 pub publication_date: Option<String>,
210}
211
212#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
213pub enum FinishReason {
214 Stop,
215 MaxTokens,
216 Safety,
217 Recitation,
218 Other,
219}
220
221impl LlmRequest {
222 pub fn new(model: impl Into<String>, contents: Vec<Content>) -> Self {
223 Self { model: model.into(), contents, config: None, tools: HashMap::new() }
224 }
225
226 pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
228 let config = self.config.get_or_insert(GenerateContentConfig::default());
229 config.response_schema = Some(schema);
230 self
231 }
232
233 pub fn with_config(mut self, config: GenerateContentConfig) -> Self {
235 self.config = Some(config);
236 self
237 }
238}
239
240impl LlmResponse {
241 pub fn new(content: Content) -> Self {
242 Self {
243 content: Some(content),
244 usage_metadata: None,
245 finish_reason: Some(FinishReason::Stop),
246 citation_metadata: None,
247 partial: false,
248 turn_complete: true,
249 interrupted: false,
250 error_code: None,
251 error_message: None,
252 provider_metadata: None,
253 }
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 #[test]
262 fn test_llm_request_creation() {
263 let req = LlmRequest::new("test-model", vec![]);
264 assert_eq!(req.model, "test-model");
265 assert!(req.contents.is_empty());
266 }
267
268 #[test]
269 fn test_llm_request_with_response_schema() {
270 let schema = serde_json::json!({
271 "type": "object",
272 "properties": {
273 "name": { "type": "string" }
274 }
275 });
276 let req = LlmRequest::new("test-model", vec![]).with_response_schema(schema.clone());
277
278 assert!(req.config.is_some());
279 let config = req.config.unwrap();
280 assert!(config.response_schema.is_some());
281 assert_eq!(config.response_schema.unwrap(), schema);
282 }
283
284 #[test]
285 fn test_llm_request_with_config() {
286 let config = GenerateContentConfig {
287 temperature: Some(0.7),
288 top_p: Some(0.9),
289 top_k: Some(40),
290 frequency_penalty: Some(0.2),
291 presence_penalty: Some(-0.3),
292 max_output_tokens: Some(1024),
293 seed: Some(42),
294 top_logprobs: Some(5),
295 stop_sequences: vec!["END".to_string()],
296 ..Default::default()
297 };
298 let req = LlmRequest::new("test-model", vec![]).with_config(config);
299
300 assert!(req.config.is_some());
301 let config = req.config.unwrap();
302 assert_eq!(config.temperature, Some(0.7));
303 assert_eq!(config.max_output_tokens, Some(1024));
304 assert_eq!(config.frequency_penalty, Some(0.2));
305 assert_eq!(config.presence_penalty, Some(-0.3));
306 assert_eq!(config.seed, Some(42));
307 assert_eq!(config.top_logprobs, Some(5));
308 assert_eq!(config.stop_sequences, vec!["END"]);
309 }
310
311 #[test]
312 fn test_llm_response_creation() {
313 let content = Content::new("assistant");
314 let resp = LlmResponse::new(content);
315 assert!(resp.content.is_some());
316 assert!(resp.turn_complete);
317 assert!(!resp.partial);
318 assert_eq!(resp.finish_reason, Some(FinishReason::Stop));
319 assert!(resp.citation_metadata.is_none());
320 assert!(resp.provider_metadata.is_none());
321 }
322
323 #[test]
324 fn test_llm_response_deserialize_without_citations() {
325 let json = serde_json::json!({
326 "content": {
327 "role": "model",
328 "parts": [{"text": "hello"}]
329 },
330 "partial": false,
331 "turn_complete": true,
332 "interrupted": false
333 });
334
335 let response: LlmResponse = serde_json::from_value(json).expect("should deserialize");
336 assert!(response.citation_metadata.is_none());
337 }
338
339 #[test]
340 fn test_llm_response_roundtrip_with_citations() {
341 let response = LlmResponse {
342 content: Some(Content::new("model").with_text("hello")),
343 usage_metadata: None,
344 finish_reason: Some(FinishReason::Stop),
345 citation_metadata: Some(CitationMetadata {
346 citation_sources: vec![CitationSource {
347 uri: Some("https://example.com".to_string()),
348 title: Some("Example".to_string()),
349 start_index: Some(0),
350 end_index: Some(5),
351 license: None,
352 publication_date: Some("2026-01-01T00:00:00Z".to_string()),
353 }],
354 }),
355 partial: false,
356 turn_complete: true,
357 interrupted: false,
358 error_code: None,
359 error_message: None,
360 provider_metadata: None,
361 };
362
363 let encoded = serde_json::to_string(&response).expect("serialize");
364 let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
365 assert_eq!(decoded.citation_metadata, response.citation_metadata);
366 }
367
368 #[test]
369 fn test_generate_content_config_roundtrip_with_extensions() {
370 let mut extensions = serde_json::Map::new();
371 extensions.insert(
372 "openrouter".to_string(),
373 serde_json::json!({
374 "provider": {
375 "zdr": true,
376 "order": ["openai", "anthropic"]
377 },
378 "plugins": [
379 { "id": "web", "enabled": true }
380 ]
381 }),
382 );
383
384 let config = GenerateContentConfig {
385 temperature: Some(0.4),
386 top_p: Some(0.8),
387 top_k: Some(12),
388 frequency_penalty: Some(0.1),
389 presence_penalty: Some(0.2),
390 max_output_tokens: Some(512),
391 seed: Some(7),
392 top_logprobs: Some(3),
393 stop_sequences: vec!["STOP".to_string(), "DONE".to_string()],
394 response_schema: Some(serde_json::json!({
395 "type": "object",
396 "properties": { "answer": { "type": "string" } },
397 "required": ["answer"]
398 })),
399 cached_content: Some("cachedContents/abc123".to_string()),
400 extensions,
401 };
402
403 let encoded = serde_json::to_string(&config).expect("serialize");
404 let decoded: GenerateContentConfig = serde_json::from_str(&encoded).expect("deserialize");
405
406 assert_eq!(decoded.temperature, config.temperature);
407 assert_eq!(decoded.top_p, config.top_p);
408 assert_eq!(decoded.top_k, config.top_k);
409 assert_eq!(decoded.frequency_penalty, config.frequency_penalty);
410 assert_eq!(decoded.presence_penalty, config.presence_penalty);
411 assert_eq!(decoded.max_output_tokens, config.max_output_tokens);
412 assert_eq!(decoded.seed, config.seed);
413 assert_eq!(decoded.top_logprobs, config.top_logprobs);
414 assert_eq!(decoded.stop_sequences, config.stop_sequences);
415 assert_eq!(decoded.response_schema, config.response_schema);
416 assert_eq!(decoded.cached_content, config.cached_content);
417 assert_eq!(decoded.extensions, config.extensions);
418 }
419
420 #[test]
421 fn test_llm_response_and_usage_roundtrip_with_provider_metadata() {
422 let response = LlmResponse {
423 content: Some(Content::new("model").with_text("hello")),
424 usage_metadata: Some(UsageMetadata {
425 prompt_token_count: 10,
426 candidates_token_count: 20,
427 total_token_count: 30,
428 cache_read_input_token_count: Some(5),
429 cache_creation_input_token_count: Some(2),
430 thinking_token_count: Some(3),
431 audio_input_token_count: Some(4),
432 audio_output_token_count: Some(6),
433 cost: Some(0.0125),
434 is_byok: Some(true),
435 provider_usage: Some(serde_json::json!({
436 "server_tool_use": {
437 "web_search_requests": 1
438 },
439 "prompt_tokens_details": {
440 "video_tokens": 8
441 }
442 })),
443 }),
444 finish_reason: Some(FinishReason::Stop),
445 citation_metadata: None,
446 partial: false,
447 turn_complete: true,
448 interrupted: false,
449 error_code: None,
450 error_message: None,
451 provider_metadata: Some(serde_json::json!({
452 "openrouter": {
453 "responseId": "resp_123",
454 "outputItems": 2
455 }
456 })),
457 };
458
459 let encoded = serde_json::to_string(&response).expect("serialize");
460 let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
461
462 assert_eq!(decoded.provider_metadata, response.provider_metadata);
463 assert_eq!(
464 decoded.usage_metadata.as_ref().and_then(|u| u.cost),
465 response.usage_metadata.as_ref().and_then(|u| u.cost),
466 );
467 assert_eq!(
468 decoded.usage_metadata.as_ref().and_then(|u| u.is_byok),
469 response.usage_metadata.as_ref().and_then(|u| u.is_byok),
470 );
471 assert_eq!(
472 decoded.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
473 response.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
474 );
475 }
476
477 #[test]
478 fn test_finish_reason() {
479 assert_eq!(FinishReason::Stop, FinishReason::Stop);
480 assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
481 }
482}