Skip to main content

adk_core/
model.rs

1use crate::{Result, types::Content};
2use async_trait::async_trait;
3use futures::stream::Stream;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::pin::Pin;
7
8pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmResponse>> + Send>>;
9
10#[async_trait]
11pub trait Llm: Send + Sync {
12    fn name(&self) -> &str;
13    async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream>;
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct LlmRequest {
18    pub model: String,
19    pub contents: Vec<Content>,
20    pub config: Option<GenerateContentConfig>,
21    #[serde(skip)]
22    pub tools: HashMap<String, serde_json::Value>,
23}
24
25#[derive(Debug, Clone, Default, Serialize, Deserialize)]
26pub struct GenerateContentConfig {
27    pub temperature: Option<f32>,
28    pub top_p: Option<f32>,
29    pub top_k: Option<i32>,
30    pub max_output_tokens: Option<i32>,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub response_schema: Option<serde_json::Value>,
33
34    /// Optional cached content name for Gemini provider.
35    /// When set, the Gemini provider attaches this to the generation request.
36    #[serde(skip_serializing_if = "Option::is_none", default)]
37    pub cached_content: Option<String>,
38}
39
40#[derive(Debug, Clone, Default, Serialize, Deserialize)]
41pub struct LlmResponse {
42    pub content: Option<Content>,
43    pub usage_metadata: Option<UsageMetadata>,
44    pub finish_reason: Option<FinishReason>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub citation_metadata: Option<CitationMetadata>,
47    pub partial: bool,
48    pub turn_complete: bool,
49    pub interrupted: bool,
50    pub error_code: Option<String>,
51    pub error_message: Option<String>,
52}
53
54/// Trait for LLM providers that support prompt caching.
55///
56/// Providers implementing this trait can create and delete cached content
57/// resources, enabling automatic prompt caching lifecycle management by the
58/// runner. The runner stores an `Option<Arc<dyn CacheCapable>>` alongside the
59/// primary `Arc<dyn Llm>` and calls these methods when [`ContextCacheConfig`]
60/// is active.
61///
62/// # Example
63///
64/// ```rust,ignore
65/// use adk_core::CacheCapable;
66///
67/// let cache_name = model
68///     .create_cache("You are a helpful assistant.", &tools, 600)
69///     .await?;
70/// // ... use cache_name in generation requests ...
71/// model.delete_cache(&cache_name).await?;
72/// ```
73#[async_trait]
74pub trait CacheCapable: Send + Sync {
75    /// Create a cached content resource from the given system instruction,
76    /// tool definitions, and TTL.
77    ///
78    /// Returns the provider-specific cache name (e.g. `"cachedContents/abc123"`
79    /// for Gemini) that can be attached to subsequent generation requests via
80    /// [`GenerateContentConfig::cached_content`].
81    async fn create_cache(
82        &self,
83        system_instruction: &str,
84        tools: &HashMap<String, serde_json::Value>,
85        ttl_seconds: u32,
86    ) -> Result<String>;
87
88    /// Delete a previously created cached content resource by name.
89    async fn delete_cache(&self, name: &str) -> Result<()>;
90}
91
92/// Configuration for automatic prompt caching lifecycle management.
93///
94/// When set on runner configuration, the runner will automatically create and manage
95/// cached content resources for supported providers (currently Gemini).
96///
97/// # Example
98///
99/// ```rust
100/// use adk_core::ContextCacheConfig;
101///
102/// let config = ContextCacheConfig {
103///     min_tokens: 4096,
104///     ttl_seconds: 600,
105///     cache_intervals: 3,
106/// };
107/// assert_eq!(config.min_tokens, 4096);
108/// ```
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ContextCacheConfig {
111    /// Minimum system instruction + tool token count to trigger caching.
112    /// Set to 0 to disable caching.
113    pub min_tokens: u32,
114
115    /// Cache time-to-live in seconds.
116    /// Set to 0 to disable caching.
117    pub ttl_seconds: u32,
118
119    /// Maximum number of LLM invocations before cache refresh.
120    /// After this many invocations, the runner creates a new cache
121    /// and deletes the old one.
122    pub cache_intervals: u32,
123}
124
125impl Default for ContextCacheConfig {
126    fn default() -> Self {
127        Self { min_tokens: 4096, ttl_seconds: 600, cache_intervals: 3 }
128    }
129}
130
131#[derive(Debug, Clone, Default, Serialize, Deserialize)]
132pub struct UsageMetadata {
133    pub prompt_token_count: i32,
134    pub candidates_token_count: i32,
135    pub total_token_count: i32,
136
137    #[serde(skip_serializing_if = "Option::is_none", default)]
138    pub cache_read_input_token_count: Option<i32>,
139
140    #[serde(skip_serializing_if = "Option::is_none", default)]
141    pub cache_creation_input_token_count: Option<i32>,
142
143    #[serde(skip_serializing_if = "Option::is_none", default)]
144    pub thinking_token_count: Option<i32>,
145
146    #[serde(skip_serializing_if = "Option::is_none", default)]
147    pub audio_input_token_count: Option<i32>,
148
149    #[serde(skip_serializing_if = "Option::is_none", default)]
150    pub audio_output_token_count: Option<i32>,
151}
152
153/// Citation metadata emitted by model providers for source attribution.
154#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
155#[serde(rename_all = "camelCase")]
156pub struct CitationMetadata {
157    #[serde(default)]
158    pub citation_sources: Vec<CitationSource>,
159}
160
161/// One citation source with optional offsets.
162#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
163#[serde(rename_all = "camelCase")]
164pub struct CitationSource {
165    pub uri: Option<String>,
166    pub title: Option<String>,
167    pub start_index: Option<i32>,
168    pub end_index: Option<i32>,
169    pub license: Option<String>,
170    pub publication_date: Option<String>,
171}
172
173#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
174pub enum FinishReason {
175    Stop,
176    MaxTokens,
177    Safety,
178    Recitation,
179    Other,
180}
181
182impl LlmRequest {
183    pub fn new(model: impl Into<String>, contents: Vec<Content>) -> Self {
184        Self { model: model.into(), contents, config: None, tools: HashMap::new() }
185    }
186
187    /// Set the response schema for structured output.
188    pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
189        let config = self.config.get_or_insert(GenerateContentConfig::default());
190        config.response_schema = Some(schema);
191        self
192    }
193
194    /// Set the generation config.
195    pub fn with_config(mut self, config: GenerateContentConfig) -> Self {
196        self.config = Some(config);
197        self
198    }
199}
200
201impl LlmResponse {
202    pub fn new(content: Content) -> Self {
203        Self {
204            content: Some(content),
205            usage_metadata: None,
206            finish_reason: Some(FinishReason::Stop),
207            citation_metadata: None,
208            partial: false,
209            turn_complete: true,
210            interrupted: false,
211            error_code: None,
212            error_message: None,
213        }
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn test_llm_request_creation() {
223        let req = LlmRequest::new("test-model", vec![]);
224        assert_eq!(req.model, "test-model");
225        assert!(req.contents.is_empty());
226    }
227
228    #[test]
229    fn test_llm_request_with_response_schema() {
230        let schema = serde_json::json!({
231            "type": "object",
232            "properties": {
233                "name": { "type": "string" }
234            }
235        });
236        let req = LlmRequest::new("test-model", vec![]).with_response_schema(schema.clone());
237
238        assert!(req.config.is_some());
239        let config = req.config.unwrap();
240        assert!(config.response_schema.is_some());
241        assert_eq!(config.response_schema.unwrap(), schema);
242    }
243
244    #[test]
245    fn test_llm_request_with_config() {
246        let config = GenerateContentConfig {
247            temperature: Some(0.7),
248            top_p: Some(0.9),
249            top_k: Some(40),
250            max_output_tokens: Some(1024),
251            ..Default::default()
252        };
253        let req = LlmRequest::new("test-model", vec![]).with_config(config);
254
255        assert!(req.config.is_some());
256        let config = req.config.unwrap();
257        assert_eq!(config.temperature, Some(0.7));
258        assert_eq!(config.max_output_tokens, Some(1024));
259    }
260
261    #[test]
262    fn test_llm_response_creation() {
263        let content = Content::new("assistant");
264        let resp = LlmResponse::new(content);
265        assert!(resp.content.is_some());
266        assert!(resp.turn_complete);
267        assert!(!resp.partial);
268        assert_eq!(resp.finish_reason, Some(FinishReason::Stop));
269        assert!(resp.citation_metadata.is_none());
270    }
271
272    #[test]
273    fn test_llm_response_deserialize_without_citations() {
274        let json = serde_json::json!({
275            "content": {
276                "role": "model",
277                "parts": [{"text": "hello"}]
278            },
279            "partial": false,
280            "turn_complete": true,
281            "interrupted": false
282        });
283
284        let response: LlmResponse = serde_json::from_value(json).expect("should deserialize");
285        assert!(response.citation_metadata.is_none());
286    }
287
288    #[test]
289    fn test_llm_response_roundtrip_with_citations() {
290        let response = LlmResponse {
291            content: Some(Content::new("model").with_text("hello")),
292            usage_metadata: None,
293            finish_reason: Some(FinishReason::Stop),
294            citation_metadata: Some(CitationMetadata {
295                citation_sources: vec![CitationSource {
296                    uri: Some("https://example.com".to_string()),
297                    title: Some("Example".to_string()),
298                    start_index: Some(0),
299                    end_index: Some(5),
300                    license: None,
301                    publication_date: Some("2026-01-01T00:00:00Z".to_string()),
302                }],
303            }),
304            partial: false,
305            turn_complete: true,
306            interrupted: false,
307            error_code: None,
308            error_message: None,
309        };
310
311        let encoded = serde_json::to_string(&response).expect("serialize");
312        let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
313        assert_eq!(decoded.citation_metadata, response.citation_metadata);
314    }
315
316    #[test]
317    fn test_finish_reason() {
318        assert_eq!(FinishReason::Stop, FinishReason::Stop);
319        assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
320    }
321}