1use crate::{Result, types::Content};
2use async_trait::async_trait;
3use futures::stream::Stream;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::pin::Pin;
7
8pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmResponse>> + Send>>;
9
10#[async_trait]
11pub trait Llm: Send + Sync {
12 fn name(&self) -> &str;
13 async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream>;
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct LlmRequest {
18 pub model: String,
19 pub contents: Vec<Content>,
20 pub config: Option<GenerateContentConfig>,
21 #[serde(skip)]
22 pub tools: HashMap<String, serde_json::Value>,
23}
24
25#[derive(Debug, Clone, Default, Serialize, Deserialize)]
26pub struct GenerateContentConfig {
27 pub temperature: Option<f32>,
28 pub top_p: Option<f32>,
29 pub top_k: Option<i32>,
30 pub max_output_tokens: Option<i32>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub response_schema: Option<serde_json::Value>,
33
34 #[serde(skip_serializing_if = "Option::is_none", default)]
37 pub cached_content: Option<String>,
38}
39
40#[derive(Debug, Clone, Default, Serialize, Deserialize)]
41pub struct LlmResponse {
42 pub content: Option<Content>,
43 pub usage_metadata: Option<UsageMetadata>,
44 pub finish_reason: Option<FinishReason>,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub citation_metadata: Option<CitationMetadata>,
47 pub partial: bool,
48 pub turn_complete: bool,
49 pub interrupted: bool,
50 pub error_code: Option<String>,
51 pub error_message: Option<String>,
52}
53
54#[async_trait]
74pub trait CacheCapable: Send + Sync {
75 async fn create_cache(
82 &self,
83 system_instruction: &str,
84 tools: &HashMap<String, serde_json::Value>,
85 ttl_seconds: u32,
86 ) -> Result<String>;
87
88 async fn delete_cache(&self, name: &str) -> Result<()>;
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ContextCacheConfig {
111 pub min_tokens: u32,
114
115 pub ttl_seconds: u32,
118
119 pub cache_intervals: u32,
123}
124
125impl Default for ContextCacheConfig {
126 fn default() -> Self {
127 Self { min_tokens: 4096, ttl_seconds: 600, cache_intervals: 3 }
128 }
129}
130
131#[derive(Debug, Clone, Default, Serialize, Deserialize)]
132pub struct UsageMetadata {
133 pub prompt_token_count: i32,
134 pub candidates_token_count: i32,
135 pub total_token_count: i32,
136
137 #[serde(skip_serializing_if = "Option::is_none", default)]
138 pub cache_read_input_token_count: Option<i32>,
139
140 #[serde(skip_serializing_if = "Option::is_none", default)]
141 pub cache_creation_input_token_count: Option<i32>,
142
143 #[serde(skip_serializing_if = "Option::is_none", default)]
144 pub thinking_token_count: Option<i32>,
145
146 #[serde(skip_serializing_if = "Option::is_none", default)]
147 pub audio_input_token_count: Option<i32>,
148
149 #[serde(skip_serializing_if = "Option::is_none", default)]
150 pub audio_output_token_count: Option<i32>,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
155#[serde(rename_all = "camelCase")]
156pub struct CitationMetadata {
157 #[serde(default)]
158 pub citation_sources: Vec<CitationSource>,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
163#[serde(rename_all = "camelCase")]
164pub struct CitationSource {
165 pub uri: Option<String>,
166 pub title: Option<String>,
167 pub start_index: Option<i32>,
168 pub end_index: Option<i32>,
169 pub license: Option<String>,
170 pub publication_date: Option<String>,
171}
172
173#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
174pub enum FinishReason {
175 Stop,
176 MaxTokens,
177 Safety,
178 Recitation,
179 Other,
180}
181
182impl LlmRequest {
183 pub fn new(model: impl Into<String>, contents: Vec<Content>) -> Self {
184 Self { model: model.into(), contents, config: None, tools: HashMap::new() }
185 }
186
187 pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
189 let config = self.config.get_or_insert(GenerateContentConfig::default());
190 config.response_schema = Some(schema);
191 self
192 }
193
194 pub fn with_config(mut self, config: GenerateContentConfig) -> Self {
196 self.config = Some(config);
197 self
198 }
199}
200
201impl LlmResponse {
202 pub fn new(content: Content) -> Self {
203 Self {
204 content: Some(content),
205 usage_metadata: None,
206 finish_reason: Some(FinishReason::Stop),
207 citation_metadata: None,
208 partial: false,
209 turn_complete: true,
210 interrupted: false,
211 error_code: None,
212 error_message: None,
213 }
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn test_llm_request_creation() {
223 let req = LlmRequest::new("test-model", vec![]);
224 assert_eq!(req.model, "test-model");
225 assert!(req.contents.is_empty());
226 }
227
228 #[test]
229 fn test_llm_request_with_response_schema() {
230 let schema = serde_json::json!({
231 "type": "object",
232 "properties": {
233 "name": { "type": "string" }
234 }
235 });
236 let req = LlmRequest::new("test-model", vec![]).with_response_schema(schema.clone());
237
238 assert!(req.config.is_some());
239 let config = req.config.unwrap();
240 assert!(config.response_schema.is_some());
241 assert_eq!(config.response_schema.unwrap(), schema);
242 }
243
244 #[test]
245 fn test_llm_request_with_config() {
246 let config = GenerateContentConfig {
247 temperature: Some(0.7),
248 top_p: Some(0.9),
249 top_k: Some(40),
250 max_output_tokens: Some(1024),
251 ..Default::default()
252 };
253 let req = LlmRequest::new("test-model", vec![]).with_config(config);
254
255 assert!(req.config.is_some());
256 let config = req.config.unwrap();
257 assert_eq!(config.temperature, Some(0.7));
258 assert_eq!(config.max_output_tokens, Some(1024));
259 }
260
261 #[test]
262 fn test_llm_response_creation() {
263 let content = Content::new("assistant");
264 let resp = LlmResponse::new(content);
265 assert!(resp.content.is_some());
266 assert!(resp.turn_complete);
267 assert!(!resp.partial);
268 assert_eq!(resp.finish_reason, Some(FinishReason::Stop));
269 assert!(resp.citation_metadata.is_none());
270 }
271
272 #[test]
273 fn test_llm_response_deserialize_without_citations() {
274 let json = serde_json::json!({
275 "content": {
276 "role": "model",
277 "parts": [{"text": "hello"}]
278 },
279 "partial": false,
280 "turn_complete": true,
281 "interrupted": false
282 });
283
284 let response: LlmResponse = serde_json::from_value(json).expect("should deserialize");
285 assert!(response.citation_metadata.is_none());
286 }
287
288 #[test]
289 fn test_llm_response_roundtrip_with_citations() {
290 let response = LlmResponse {
291 content: Some(Content::new("model").with_text("hello")),
292 usage_metadata: None,
293 finish_reason: Some(FinishReason::Stop),
294 citation_metadata: Some(CitationMetadata {
295 citation_sources: vec![CitationSource {
296 uri: Some("https://example.com".to_string()),
297 title: Some("Example".to_string()),
298 start_index: Some(0),
299 end_index: Some(5),
300 license: None,
301 publication_date: Some("2026-01-01T00:00:00Z".to_string()),
302 }],
303 }),
304 partial: false,
305 turn_complete: true,
306 interrupted: false,
307 error_code: None,
308 error_message: None,
309 };
310
311 let encoded = serde_json::to_string(&response).expect("serialize");
312 let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
313 assert_eq!(decoded.citation_metadata, response.citation_metadata);
314 }
315
316 #[test]
317 fn test_finish_reason() {
318 assert_eq!(FinishReason::Stop, FinishReason::Stop);
319 assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
320 }
321}