Skip to main content

adk_core/
model.rs

1use crate::schema_adapter::{GenericSchemaAdapter, SchemaAdapter};
2use crate::{Result, types::Content};
3use async_trait::async_trait;
4use futures::stream::Stream;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::pin::Pin;
8
9pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmResponse>> + Send>>;
10
11#[async_trait]
12pub trait Llm: Send + Sync {
13    fn name(&self) -> &str;
14    async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream>;
15
16    /// Returns the schema adapter for this provider.
17    ///
18    /// The schema adapter normalizes raw JSON Schema from MCP tools into the
19    /// format accepted by this provider's function-calling API.
20    ///
21    /// Default implementation returns [`GenericSchemaAdapter`], which applies
22    /// safe transforms suitable for most providers. Override this method to
23    /// return a provider-specific adapter (e.g., `GeminiSchemaAdapter`,
24    /// `OpenAiStrictSchemaAdapter`).
25    fn schema_adapter(&self) -> &dyn SchemaAdapter {
26        &GenericSchemaAdapter
27    }
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct LlmRequest {
32    pub model: String,
33    pub contents: Vec<Content>,
34    pub config: Option<GenerateContentConfig>,
35    #[serde(skip)]
36    pub tools: HashMap<String, serde_json::Value>,
37}
38
39#[derive(Debug, Clone, Default, Serialize, Deserialize)]
40pub struct GenerateContentConfig {
41    pub temperature: Option<f32>,
42    pub top_p: Option<f32>,
43    pub top_k: Option<i32>,
44    #[serde(skip_serializing_if = "Option::is_none", default)]
45    pub frequency_penalty: Option<f32>,
46    #[serde(skip_serializing_if = "Option::is_none", default)]
47    pub presence_penalty: Option<f32>,
48    pub max_output_tokens: Option<i32>,
49    #[serde(skip_serializing_if = "Option::is_none", default)]
50    pub seed: Option<i64>,
51    #[serde(skip_serializing_if = "Option::is_none", default)]
52    pub top_logprobs: Option<u8>,
53    #[serde(default, skip_serializing_if = "Vec::is_empty")]
54    pub stop_sequences: Vec<String>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub response_schema: Option<serde_json::Value>,
57
58    /// Optional cached content name for Gemini provider.
59    /// When set, the Gemini provider attaches this to the generation request.
60    #[serde(skip_serializing_if = "Option::is_none", default)]
61    pub cached_content: Option<String>,
62
63    /// Provider-specific request options keyed by provider namespace.
64    #[serde(default, skip_serializing_if = "serde_json::Map::is_empty")]
65    pub extensions: serde_json::Map<String, serde_json::Value>,
66}
67
68#[derive(Debug, Clone, Default, Serialize, Deserialize)]
69pub struct LlmResponse {
70    pub content: Option<Content>,
71    pub usage_metadata: Option<UsageMetadata>,
72    pub finish_reason: Option<FinishReason>,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub citation_metadata: Option<CitationMetadata>,
75    pub partial: bool,
76    pub turn_complete: bool,
77    pub interrupted: bool,
78    pub error_code: Option<String>,
79    pub error_message: Option<String>,
80    #[serde(skip_serializing_if = "Option::is_none", default)]
81    pub provider_metadata: Option<serde_json::Value>,
82}
83
84/// Trait for LLM providers that support prompt caching.
85///
86/// Providers implementing this trait can create and delete cached content
87/// resources, enabling automatic prompt caching lifecycle management by the
88/// runner. The runner stores an `Option<Arc<dyn CacheCapable>>` alongside the
89/// primary `Arc<dyn Llm>` and calls these methods when [`ContextCacheConfig`]
90/// is active.
91///
92/// # Example
93///
94/// ```rust,ignore
95/// use adk_core::CacheCapable;
96///
97/// let cache_name = model
98///     .create_cache("You are a helpful assistant.", &tools, 600)
99///     .await?;
100/// // ... use cache_name in generation requests ...
101/// model.delete_cache(&cache_name).await?;
102/// ```
103#[async_trait]
104pub trait CacheCapable: Send + Sync {
105    /// Create a cached content resource from the given system instruction,
106    /// tool definitions, and TTL.
107    ///
108    /// Returns the provider-specific cache name (e.g. `"cachedContents/abc123"`
109    /// for Gemini) that can be attached to subsequent generation requests via
110    /// [`GenerateContentConfig::cached_content`].
111    async fn create_cache(
112        &self,
113        system_instruction: &str,
114        tools: &HashMap<String, serde_json::Value>,
115        ttl_seconds: u32,
116    ) -> Result<String>;
117
118    /// Delete a previously created cached content resource by name.
119    async fn delete_cache(&self, name: &str) -> Result<()>;
120}
121
122/// Configuration for automatic prompt caching lifecycle management.
123///
124/// When set on runner configuration, the runner will automatically create and manage
125/// cached content resources for supported providers (currently Gemini).
126///
127/// # Example
128///
129/// ```rust
130/// use adk_core::ContextCacheConfig;
131///
132/// let config = ContextCacheConfig {
133///     min_tokens: 4096,
134///     ttl_seconds: 600,
135///     cache_intervals: 3,
136/// };
137/// assert_eq!(config.min_tokens, 4096);
138/// ```
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct ContextCacheConfig {
141    /// Minimum system instruction + tool token count to trigger caching.
142    /// Set to 0 to disable caching.
143    pub min_tokens: u32,
144
145    /// Cache time-to-live in seconds.
146    /// Set to 0 to disable caching.
147    pub ttl_seconds: u32,
148
149    /// Maximum number of LLM invocations before cache refresh.
150    /// After this many invocations, the runner creates a new cache
151    /// and deletes the old one.
152    pub cache_intervals: u32,
153}
154
155impl Default for ContextCacheConfig {
156    fn default() -> Self {
157        Self { min_tokens: 4096, ttl_seconds: 600, cache_intervals: 3 }
158    }
159}
160
161#[derive(Debug, Clone, Default, Serialize, Deserialize)]
162pub struct UsageMetadata {
163    pub prompt_token_count: i32,
164    pub candidates_token_count: i32,
165    pub total_token_count: i32,
166
167    #[serde(skip_serializing_if = "Option::is_none", default)]
168    pub cache_read_input_token_count: Option<i32>,
169
170    #[serde(skip_serializing_if = "Option::is_none", default)]
171    pub cache_creation_input_token_count: Option<i32>,
172
173    #[serde(skip_serializing_if = "Option::is_none", default)]
174    pub thinking_token_count: Option<i32>,
175
176    #[serde(skip_serializing_if = "Option::is_none", default)]
177    pub audio_input_token_count: Option<i32>,
178
179    #[serde(skip_serializing_if = "Option::is_none", default)]
180    pub audio_output_token_count: Option<i32>,
181
182    #[serde(skip_serializing_if = "Option::is_none", default)]
183    pub cost: Option<f64>,
184
185    #[serde(skip_serializing_if = "Option::is_none", default)]
186    pub is_byok: Option<bool>,
187
188    #[serde(skip_serializing_if = "Option::is_none", default)]
189    pub provider_usage: Option<serde_json::Value>,
190}
191
192/// Citation metadata emitted by model providers for source attribution.
193#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
194#[serde(rename_all = "camelCase")]
195pub struct CitationMetadata {
196    #[serde(default)]
197    pub citation_sources: Vec<CitationSource>,
198}
199
200/// One citation source with optional offsets.
201#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
202#[serde(rename_all = "camelCase")]
203pub struct CitationSource {
204    pub uri: Option<String>,
205    pub title: Option<String>,
206    pub start_index: Option<i32>,
207    pub end_index: Option<i32>,
208    pub license: Option<String>,
209    pub publication_date: Option<String>,
210}
211
212#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
213pub enum FinishReason {
214    Stop,
215    MaxTokens,
216    Safety,
217    Recitation,
218    Other,
219}
220
221impl LlmRequest {
222    pub fn new(model: impl Into<String>, contents: Vec<Content>) -> Self {
223        Self { model: model.into(), contents, config: None, tools: HashMap::new() }
224    }
225
226    /// Set the response schema for structured output.
227    pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
228        let config = self.config.get_or_insert(GenerateContentConfig::default());
229        config.response_schema = Some(schema);
230        self
231    }
232
233    /// Set the generation config.
234    pub fn with_config(mut self, config: GenerateContentConfig) -> Self {
235        self.config = Some(config);
236        self
237    }
238}
239
240impl LlmResponse {
241    pub fn new(content: Content) -> Self {
242        Self {
243            content: Some(content),
244            usage_metadata: None,
245            finish_reason: Some(FinishReason::Stop),
246            citation_metadata: None,
247            partial: false,
248            turn_complete: true,
249            interrupted: false,
250            error_code: None,
251            error_message: None,
252            provider_metadata: None,
253        }
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    #[test]
262    fn test_llm_request_creation() {
263        let req = LlmRequest::new("test-model", vec![]);
264        assert_eq!(req.model, "test-model");
265        assert!(req.contents.is_empty());
266    }
267
268    #[test]
269    fn test_llm_request_with_response_schema() {
270        let schema = serde_json::json!({
271            "type": "object",
272            "properties": {
273                "name": { "type": "string" }
274            }
275        });
276        let req = LlmRequest::new("test-model", vec![]).with_response_schema(schema.clone());
277
278        assert!(req.config.is_some());
279        let config = req.config.unwrap();
280        assert!(config.response_schema.is_some());
281        assert_eq!(config.response_schema.unwrap(), schema);
282    }
283
284    #[test]
285    fn test_llm_request_with_config() {
286        let config = GenerateContentConfig {
287            temperature: Some(0.7),
288            top_p: Some(0.9),
289            top_k: Some(40),
290            frequency_penalty: Some(0.2),
291            presence_penalty: Some(-0.3),
292            max_output_tokens: Some(1024),
293            seed: Some(42),
294            top_logprobs: Some(5),
295            stop_sequences: vec!["END".to_string()],
296            ..Default::default()
297        };
298        let req = LlmRequest::new("test-model", vec![]).with_config(config);
299
300        assert!(req.config.is_some());
301        let config = req.config.unwrap();
302        assert_eq!(config.temperature, Some(0.7));
303        assert_eq!(config.max_output_tokens, Some(1024));
304        assert_eq!(config.frequency_penalty, Some(0.2));
305        assert_eq!(config.presence_penalty, Some(-0.3));
306        assert_eq!(config.seed, Some(42));
307        assert_eq!(config.top_logprobs, Some(5));
308        assert_eq!(config.stop_sequences, vec!["END"]);
309    }
310
311    #[test]
312    fn test_llm_response_creation() {
313        let content = Content::new("assistant");
314        let resp = LlmResponse::new(content);
315        assert!(resp.content.is_some());
316        assert!(resp.turn_complete);
317        assert!(!resp.partial);
318        assert_eq!(resp.finish_reason, Some(FinishReason::Stop));
319        assert!(resp.citation_metadata.is_none());
320        assert!(resp.provider_metadata.is_none());
321    }
322
323    #[test]
324    fn test_llm_response_deserialize_without_citations() {
325        let json = serde_json::json!({
326            "content": {
327                "role": "model",
328                "parts": [{"text": "hello"}]
329            },
330            "partial": false,
331            "turn_complete": true,
332            "interrupted": false
333        });
334
335        let response: LlmResponse = serde_json::from_value(json).expect("should deserialize");
336        assert!(response.citation_metadata.is_none());
337    }
338
339    #[test]
340    fn test_llm_response_roundtrip_with_citations() {
341        let response = LlmResponse {
342            content: Some(Content::new("model").with_text("hello")),
343            usage_metadata: None,
344            finish_reason: Some(FinishReason::Stop),
345            citation_metadata: Some(CitationMetadata {
346                citation_sources: vec![CitationSource {
347                    uri: Some("https://example.com".to_string()),
348                    title: Some("Example".to_string()),
349                    start_index: Some(0),
350                    end_index: Some(5),
351                    license: None,
352                    publication_date: Some("2026-01-01T00:00:00Z".to_string()),
353                }],
354            }),
355            partial: false,
356            turn_complete: true,
357            interrupted: false,
358            error_code: None,
359            error_message: None,
360            provider_metadata: None,
361        };
362
363        let encoded = serde_json::to_string(&response).expect("serialize");
364        let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
365        assert_eq!(decoded.citation_metadata, response.citation_metadata);
366    }
367
368    #[test]
369    fn test_generate_content_config_roundtrip_with_extensions() {
370        let mut extensions = serde_json::Map::new();
371        extensions.insert(
372            "openrouter".to_string(),
373            serde_json::json!({
374                "provider": {
375                    "zdr": true,
376                    "order": ["openai", "anthropic"]
377                },
378                "plugins": [
379                    { "id": "web", "enabled": true }
380                ]
381            }),
382        );
383
384        let config = GenerateContentConfig {
385            temperature: Some(0.4),
386            top_p: Some(0.8),
387            top_k: Some(12),
388            frequency_penalty: Some(0.1),
389            presence_penalty: Some(0.2),
390            max_output_tokens: Some(512),
391            seed: Some(7),
392            top_logprobs: Some(3),
393            stop_sequences: vec!["STOP".to_string(), "DONE".to_string()],
394            response_schema: Some(serde_json::json!({
395                "type": "object",
396                "properties": { "answer": { "type": "string" } },
397                "required": ["answer"]
398            })),
399            cached_content: Some("cachedContents/abc123".to_string()),
400            extensions,
401        };
402
403        let encoded = serde_json::to_string(&config).expect("serialize");
404        let decoded: GenerateContentConfig = serde_json::from_str(&encoded).expect("deserialize");
405
406        assert_eq!(decoded.temperature, config.temperature);
407        assert_eq!(decoded.top_p, config.top_p);
408        assert_eq!(decoded.top_k, config.top_k);
409        assert_eq!(decoded.frequency_penalty, config.frequency_penalty);
410        assert_eq!(decoded.presence_penalty, config.presence_penalty);
411        assert_eq!(decoded.max_output_tokens, config.max_output_tokens);
412        assert_eq!(decoded.seed, config.seed);
413        assert_eq!(decoded.top_logprobs, config.top_logprobs);
414        assert_eq!(decoded.stop_sequences, config.stop_sequences);
415        assert_eq!(decoded.response_schema, config.response_schema);
416        assert_eq!(decoded.cached_content, config.cached_content);
417        assert_eq!(decoded.extensions, config.extensions);
418    }
419
420    #[test]
421    fn test_llm_response_and_usage_roundtrip_with_provider_metadata() {
422        let response = LlmResponse {
423            content: Some(Content::new("model").with_text("hello")),
424            usage_metadata: Some(UsageMetadata {
425                prompt_token_count: 10,
426                candidates_token_count: 20,
427                total_token_count: 30,
428                cache_read_input_token_count: Some(5),
429                cache_creation_input_token_count: Some(2),
430                thinking_token_count: Some(3),
431                audio_input_token_count: Some(4),
432                audio_output_token_count: Some(6),
433                cost: Some(0.0125),
434                is_byok: Some(true),
435                provider_usage: Some(serde_json::json!({
436                    "server_tool_use": {
437                        "web_search_requests": 1
438                    },
439                    "prompt_tokens_details": {
440                        "video_tokens": 8
441                    }
442                })),
443            }),
444            finish_reason: Some(FinishReason::Stop),
445            citation_metadata: None,
446            partial: false,
447            turn_complete: true,
448            interrupted: false,
449            error_code: None,
450            error_message: None,
451            provider_metadata: Some(serde_json::json!({
452                "openrouter": {
453                    "responseId": "resp_123",
454                    "outputItems": 2
455                }
456            })),
457        };
458
459        let encoded = serde_json::to_string(&response).expect("serialize");
460        let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
461
462        assert_eq!(decoded.provider_metadata, response.provider_metadata);
463        assert_eq!(
464            decoded.usage_metadata.as_ref().and_then(|u| u.cost),
465            response.usage_metadata.as_ref().and_then(|u| u.cost),
466        );
467        assert_eq!(
468            decoded.usage_metadata.as_ref().and_then(|u| u.is_byok),
469            response.usage_metadata.as_ref().and_then(|u| u.is_byok),
470        );
471        assert_eq!(
472            decoded.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
473            response.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
474        );
475    }
476
477    #[test]
478    fn test_finish_reason() {
479        assert_eq!(FinishReason::Stop, FinishReason::Stop);
480        assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
481    }
482}