Skip to main content

adk_core/
model.rs

1use crate::{Result, types::Content};
2use async_trait::async_trait;
3use futures::stream::Stream;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::pin::Pin;
7
8pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmResponse>> + Send>>;
9
10#[async_trait]
11pub trait Llm: Send + Sync {
12    fn name(&self) -> &str;
13    async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream>;
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct LlmRequest {
18    pub model: String,
19    pub contents: Vec<Content>,
20    pub config: Option<GenerateContentConfig>,
21    #[serde(skip)]
22    pub tools: HashMap<String, serde_json::Value>,
23}
24
25#[derive(Debug, Clone, Default, Serialize, Deserialize)]
26pub struct GenerateContentConfig {
27    pub temperature: Option<f32>,
28    pub top_p: Option<f32>,
29    pub top_k: Option<i32>,
30    #[serde(skip_serializing_if = "Option::is_none", default)]
31    pub frequency_penalty: Option<f32>,
32    #[serde(skip_serializing_if = "Option::is_none", default)]
33    pub presence_penalty: Option<f32>,
34    pub max_output_tokens: Option<i32>,
35    #[serde(skip_serializing_if = "Option::is_none", default)]
36    pub seed: Option<i64>,
37    #[serde(skip_serializing_if = "Option::is_none", default)]
38    pub top_logprobs: Option<u8>,
39    #[serde(default, skip_serializing_if = "Vec::is_empty")]
40    pub stop_sequences: Vec<String>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub response_schema: Option<serde_json::Value>,
43
44    /// Optional cached content name for Gemini provider.
45    /// When set, the Gemini provider attaches this to the generation request.
46    #[serde(skip_serializing_if = "Option::is_none", default)]
47    pub cached_content: Option<String>,
48
49    /// Provider-specific request options keyed by provider namespace.
50    #[serde(default, skip_serializing_if = "serde_json::Map::is_empty")]
51    pub extensions: serde_json::Map<String, serde_json::Value>,
52}
53
54#[derive(Debug, Clone, Default, Serialize, Deserialize)]
55pub struct LlmResponse {
56    pub content: Option<Content>,
57    pub usage_metadata: Option<UsageMetadata>,
58    pub finish_reason: Option<FinishReason>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub citation_metadata: Option<CitationMetadata>,
61    pub partial: bool,
62    pub turn_complete: bool,
63    pub interrupted: bool,
64    pub error_code: Option<String>,
65    pub error_message: Option<String>,
66    #[serde(skip_serializing_if = "Option::is_none", default)]
67    pub provider_metadata: Option<serde_json::Value>,
68}
69
70/// Trait for LLM providers that support prompt caching.
71///
72/// Providers implementing this trait can create and delete cached content
73/// resources, enabling automatic prompt caching lifecycle management by the
74/// runner. The runner stores an `Option<Arc<dyn CacheCapable>>` alongside the
75/// primary `Arc<dyn Llm>` and calls these methods when [`ContextCacheConfig`]
76/// is active.
77///
78/// # Example
79///
80/// ```rust,ignore
81/// use adk_core::CacheCapable;
82///
83/// let cache_name = model
84///     .create_cache("You are a helpful assistant.", &tools, 600)
85///     .await?;
86/// // ... use cache_name in generation requests ...
87/// model.delete_cache(&cache_name).await?;
88/// ```
89#[async_trait]
90pub trait CacheCapable: Send + Sync {
91    /// Create a cached content resource from the given system instruction,
92    /// tool definitions, and TTL.
93    ///
94    /// Returns the provider-specific cache name (e.g. `"cachedContents/abc123"`
95    /// for Gemini) that can be attached to subsequent generation requests via
96    /// [`GenerateContentConfig::cached_content`].
97    async fn create_cache(
98        &self,
99        system_instruction: &str,
100        tools: &HashMap<String, serde_json::Value>,
101        ttl_seconds: u32,
102    ) -> Result<String>;
103
104    /// Delete a previously created cached content resource by name.
105    async fn delete_cache(&self, name: &str) -> Result<()>;
106}
107
108/// Configuration for automatic prompt caching lifecycle management.
109///
110/// When set on runner configuration, the runner will automatically create and manage
111/// cached content resources for supported providers (currently Gemini).
112///
113/// # Example
114///
115/// ```rust
116/// use adk_core::ContextCacheConfig;
117///
118/// let config = ContextCacheConfig {
119///     min_tokens: 4096,
120///     ttl_seconds: 600,
121///     cache_intervals: 3,
122/// };
123/// assert_eq!(config.min_tokens, 4096);
124/// ```
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct ContextCacheConfig {
127    /// Minimum system instruction + tool token count to trigger caching.
128    /// Set to 0 to disable caching.
129    pub min_tokens: u32,
130
131    /// Cache time-to-live in seconds.
132    /// Set to 0 to disable caching.
133    pub ttl_seconds: u32,
134
135    /// Maximum number of LLM invocations before cache refresh.
136    /// After this many invocations, the runner creates a new cache
137    /// and deletes the old one.
138    pub cache_intervals: u32,
139}
140
141impl Default for ContextCacheConfig {
142    fn default() -> Self {
143        Self { min_tokens: 4096, ttl_seconds: 600, cache_intervals: 3 }
144    }
145}
146
147#[derive(Debug, Clone, Default, Serialize, Deserialize)]
148pub struct UsageMetadata {
149    pub prompt_token_count: i32,
150    pub candidates_token_count: i32,
151    pub total_token_count: i32,
152
153    #[serde(skip_serializing_if = "Option::is_none", default)]
154    pub cache_read_input_token_count: Option<i32>,
155
156    #[serde(skip_serializing_if = "Option::is_none", default)]
157    pub cache_creation_input_token_count: Option<i32>,
158
159    #[serde(skip_serializing_if = "Option::is_none", default)]
160    pub thinking_token_count: Option<i32>,
161
162    #[serde(skip_serializing_if = "Option::is_none", default)]
163    pub audio_input_token_count: Option<i32>,
164
165    #[serde(skip_serializing_if = "Option::is_none", default)]
166    pub audio_output_token_count: Option<i32>,
167
168    #[serde(skip_serializing_if = "Option::is_none", default)]
169    pub cost: Option<f64>,
170
171    #[serde(skip_serializing_if = "Option::is_none", default)]
172    pub is_byok: Option<bool>,
173
174    #[serde(skip_serializing_if = "Option::is_none", default)]
175    pub provider_usage: Option<serde_json::Value>,
176}
177
178/// Citation metadata emitted by model providers for source attribution.
179#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
180#[serde(rename_all = "camelCase")]
181pub struct CitationMetadata {
182    #[serde(default)]
183    pub citation_sources: Vec<CitationSource>,
184}
185
186/// One citation source with optional offsets.
187#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
188#[serde(rename_all = "camelCase")]
189pub struct CitationSource {
190    pub uri: Option<String>,
191    pub title: Option<String>,
192    pub start_index: Option<i32>,
193    pub end_index: Option<i32>,
194    pub license: Option<String>,
195    pub publication_date: Option<String>,
196}
197
198#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
199pub enum FinishReason {
200    Stop,
201    MaxTokens,
202    Safety,
203    Recitation,
204    Other,
205}
206
207impl LlmRequest {
208    pub fn new(model: impl Into<String>, contents: Vec<Content>) -> Self {
209        Self { model: model.into(), contents, config: None, tools: HashMap::new() }
210    }
211
212    /// Set the response schema for structured output.
213    pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
214        let config = self.config.get_or_insert(GenerateContentConfig::default());
215        config.response_schema = Some(schema);
216        self
217    }
218
219    /// Set the generation config.
220    pub fn with_config(mut self, config: GenerateContentConfig) -> Self {
221        self.config = Some(config);
222        self
223    }
224}
225
226impl LlmResponse {
227    pub fn new(content: Content) -> Self {
228        Self {
229            content: Some(content),
230            usage_metadata: None,
231            finish_reason: Some(FinishReason::Stop),
232            citation_metadata: None,
233            partial: false,
234            turn_complete: true,
235            interrupted: false,
236            error_code: None,
237            error_message: None,
238            provider_metadata: None,
239        }
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_llm_request_creation() {
249        let req = LlmRequest::new("test-model", vec![]);
250        assert_eq!(req.model, "test-model");
251        assert!(req.contents.is_empty());
252    }
253
254    #[test]
255    fn test_llm_request_with_response_schema() {
256        let schema = serde_json::json!({
257            "type": "object",
258            "properties": {
259                "name": { "type": "string" }
260            }
261        });
262        let req = LlmRequest::new("test-model", vec![]).with_response_schema(schema.clone());
263
264        assert!(req.config.is_some());
265        let config = req.config.unwrap();
266        assert!(config.response_schema.is_some());
267        assert_eq!(config.response_schema.unwrap(), schema);
268    }
269
270    #[test]
271    fn test_llm_request_with_config() {
272        let config = GenerateContentConfig {
273            temperature: Some(0.7),
274            top_p: Some(0.9),
275            top_k: Some(40),
276            frequency_penalty: Some(0.2),
277            presence_penalty: Some(-0.3),
278            max_output_tokens: Some(1024),
279            seed: Some(42),
280            top_logprobs: Some(5),
281            stop_sequences: vec!["END".to_string()],
282            ..Default::default()
283        };
284        let req = LlmRequest::new("test-model", vec![]).with_config(config);
285
286        assert!(req.config.is_some());
287        let config = req.config.unwrap();
288        assert_eq!(config.temperature, Some(0.7));
289        assert_eq!(config.max_output_tokens, Some(1024));
290        assert_eq!(config.frequency_penalty, Some(0.2));
291        assert_eq!(config.presence_penalty, Some(-0.3));
292        assert_eq!(config.seed, Some(42));
293        assert_eq!(config.top_logprobs, Some(5));
294        assert_eq!(config.stop_sequences, vec!["END"]);
295    }
296
297    #[test]
298    fn test_llm_response_creation() {
299        let content = Content::new("assistant");
300        let resp = LlmResponse::new(content);
301        assert!(resp.content.is_some());
302        assert!(resp.turn_complete);
303        assert!(!resp.partial);
304        assert_eq!(resp.finish_reason, Some(FinishReason::Stop));
305        assert!(resp.citation_metadata.is_none());
306        assert!(resp.provider_metadata.is_none());
307    }
308
309    #[test]
310    fn test_llm_response_deserialize_without_citations() {
311        let json = serde_json::json!({
312            "content": {
313                "role": "model",
314                "parts": [{"text": "hello"}]
315            },
316            "partial": false,
317            "turn_complete": true,
318            "interrupted": false
319        });
320
321        let response: LlmResponse = serde_json::from_value(json).expect("should deserialize");
322        assert!(response.citation_metadata.is_none());
323    }
324
325    #[test]
326    fn test_llm_response_roundtrip_with_citations() {
327        let response = LlmResponse {
328            content: Some(Content::new("model").with_text("hello")),
329            usage_metadata: None,
330            finish_reason: Some(FinishReason::Stop),
331            citation_metadata: Some(CitationMetadata {
332                citation_sources: vec![CitationSource {
333                    uri: Some("https://example.com".to_string()),
334                    title: Some("Example".to_string()),
335                    start_index: Some(0),
336                    end_index: Some(5),
337                    license: None,
338                    publication_date: Some("2026-01-01T00:00:00Z".to_string()),
339                }],
340            }),
341            partial: false,
342            turn_complete: true,
343            interrupted: false,
344            error_code: None,
345            error_message: None,
346            provider_metadata: None,
347        };
348
349        let encoded = serde_json::to_string(&response).expect("serialize");
350        let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
351        assert_eq!(decoded.citation_metadata, response.citation_metadata);
352    }
353
354    #[test]
355    fn test_generate_content_config_roundtrip_with_extensions() {
356        let mut extensions = serde_json::Map::new();
357        extensions.insert(
358            "openrouter".to_string(),
359            serde_json::json!({
360                "provider": {
361                    "zdr": true,
362                    "order": ["openai", "anthropic"]
363                },
364                "plugins": [
365                    { "id": "web", "enabled": true }
366                ]
367            }),
368        );
369
370        let config = GenerateContentConfig {
371            temperature: Some(0.4),
372            top_p: Some(0.8),
373            top_k: Some(12),
374            frequency_penalty: Some(0.1),
375            presence_penalty: Some(0.2),
376            max_output_tokens: Some(512),
377            seed: Some(7),
378            top_logprobs: Some(3),
379            stop_sequences: vec!["STOP".to_string(), "DONE".to_string()],
380            response_schema: Some(serde_json::json!({
381                "type": "object",
382                "properties": { "answer": { "type": "string" } },
383                "required": ["answer"]
384            })),
385            cached_content: Some("cachedContents/abc123".to_string()),
386            extensions,
387        };
388
389        let encoded = serde_json::to_string(&config).expect("serialize");
390        let decoded: GenerateContentConfig = serde_json::from_str(&encoded).expect("deserialize");
391
392        assert_eq!(decoded.temperature, config.temperature);
393        assert_eq!(decoded.top_p, config.top_p);
394        assert_eq!(decoded.top_k, config.top_k);
395        assert_eq!(decoded.frequency_penalty, config.frequency_penalty);
396        assert_eq!(decoded.presence_penalty, config.presence_penalty);
397        assert_eq!(decoded.max_output_tokens, config.max_output_tokens);
398        assert_eq!(decoded.seed, config.seed);
399        assert_eq!(decoded.top_logprobs, config.top_logprobs);
400        assert_eq!(decoded.stop_sequences, config.stop_sequences);
401        assert_eq!(decoded.response_schema, config.response_schema);
402        assert_eq!(decoded.cached_content, config.cached_content);
403        assert_eq!(decoded.extensions, config.extensions);
404    }
405
406    #[test]
407    fn test_llm_response_and_usage_roundtrip_with_provider_metadata() {
408        let response = LlmResponse {
409            content: Some(Content::new("model").with_text("hello")),
410            usage_metadata: Some(UsageMetadata {
411                prompt_token_count: 10,
412                candidates_token_count: 20,
413                total_token_count: 30,
414                cache_read_input_token_count: Some(5),
415                cache_creation_input_token_count: Some(2),
416                thinking_token_count: Some(3),
417                audio_input_token_count: Some(4),
418                audio_output_token_count: Some(6),
419                cost: Some(0.0125),
420                is_byok: Some(true),
421                provider_usage: Some(serde_json::json!({
422                    "server_tool_use": {
423                        "web_search_requests": 1
424                    },
425                    "prompt_tokens_details": {
426                        "video_tokens": 8
427                    }
428                })),
429            }),
430            finish_reason: Some(FinishReason::Stop),
431            citation_metadata: None,
432            partial: false,
433            turn_complete: true,
434            interrupted: false,
435            error_code: None,
436            error_message: None,
437            provider_metadata: Some(serde_json::json!({
438                "openrouter": {
439                    "responseId": "resp_123",
440                    "outputItems": 2
441                }
442            })),
443        };
444
445        let encoded = serde_json::to_string(&response).expect("serialize");
446        let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
447
448        assert_eq!(decoded.provider_metadata, response.provider_metadata);
449        assert_eq!(
450            decoded.usage_metadata.as_ref().and_then(|u| u.cost),
451            response.usage_metadata.as_ref().and_then(|u| u.cost),
452        );
453        assert_eq!(
454            decoded.usage_metadata.as_ref().and_then(|u| u.is_byok),
455            response.usage_metadata.as_ref().and_then(|u| u.is_byok),
456        );
457        assert_eq!(
458            decoded.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
459            response.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
460        );
461    }
462
463    #[test]
464    fn test_finish_reason() {
465        assert_eq!(FinishReason::Stop, FinishReason::Stop);
466        assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
467    }
468}