Skip to main content

adk_core/
model.rs

1use crate::schema_adapter::{GenericSchemaAdapter, SchemaAdapter};
2use crate::{Result, types::Content};
3use async_trait::async_trait;
4use futures::stream::Stream;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::pin::Pin;
8
9/// A pinned, boxed stream of [`LlmResponse`] results from a model.
10pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmResponse>> + Send>>;
11
12/// The core trait for all LLM providers.
13///
14/// Implementations wrap a specific model API (Gemini, OpenAI, Anthropic, etc.)
15/// and produce a stream of responses for a given request.
16#[async_trait]
17pub trait Llm: Send + Sync {
18    /// Returns the model identifier (e.g., "gemini-2.5-flash").
19    fn name(&self) -> &str;
20    /// Generates content from the given request, optionally streaming.
21    async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream>;
22
23    /// Returns the schema adapter for this provider.
24    ///
25    /// The schema adapter normalizes raw JSON Schema from MCP tools into the
26    /// format accepted by this provider's function-calling API.
27    ///
28    /// Default implementation returns [`GenericSchemaAdapter`], which applies
29    /// safe transforms suitable for most providers. Override this method to
30    /// return a provider-specific adapter (e.g., `GeminiSchemaAdapter`,
31    /// `OpenAiStrictSchemaAdapter`).
32    fn schema_adapter(&self) -> &dyn SchemaAdapter {
33        &GenericSchemaAdapter
34    }
35
36    /// Returns `true` if this model is configured to use a server-managed
37    /// environment (e.g., Gemini Interactions API) where the provider owns
38    /// the tool-calling loop and filesystem.
39    ///
40    /// This is used at agent build time to detect conflicts with client-side
41    /// sandbox tools that would produce competing filesystems.
42    ///
43    /// Default implementation returns `false`. Override in providers that
44    /// support server-managed environments.
45    fn uses_interactions_api(&self) -> bool {
46        false
47    }
48}
49
50/// A request to an LLM provider.
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct LlmRequest {
53    /// The model identifier to use for generation.
54    pub model: String,
55    /// The conversation contents (system, user, model messages).
56    pub contents: Vec<Content>,
57    /// Optional generation configuration (temperature, tokens, etc.).
58    pub config: Option<GenerateContentConfig>,
59    /// Tool declarations keyed by tool name.
60    #[serde(skip)]
61    pub tools: HashMap<String, serde_json::Value>,
62    /// Provider-neutral identifier of the previous response to continue from.
63    /// The agent populates this from the most recent event's `interaction_id`.
64    /// Maps to `previous_interaction_id` for the Gemini Interactions transport;
65    /// ignored by transports that do not support response chaining.
66    #[serde(skip_serializing_if = "Option::is_none", default)]
67    pub previous_response_id: Option<String>,
68}
69
70/// Configuration for LLM content generation.
71#[derive(Debug, Clone, Default, Serialize, Deserialize)]
72pub struct GenerateContentConfig {
73    /// Sampling temperature (0.0 = deterministic, higher = more random).
74    pub temperature: Option<f32>,
75    /// Nucleus sampling threshold.
76    pub top_p: Option<f32>,
77    /// Top-k sampling parameter.
78    pub top_k: Option<i32>,
79    /// Frequency penalty to reduce repetition.
80    #[serde(skip_serializing_if = "Option::is_none", default)]
81    pub frequency_penalty: Option<f32>,
82    /// Presence penalty to encourage topic diversity.
83    #[serde(skip_serializing_if = "Option::is_none", default)]
84    pub presence_penalty: Option<f32>,
85    /// Maximum number of output tokens to generate.
86    pub max_output_tokens: Option<i32>,
87    /// Random seed for reproducible generation.
88    #[serde(skip_serializing_if = "Option::is_none", default)]
89    pub seed: Option<i64>,
90    /// Number of top log probabilities to return per token.
91    #[serde(skip_serializing_if = "Option::is_none", default)]
92    pub top_logprobs: Option<u8>,
93    /// Sequences that stop generation when encountered.
94    #[serde(default, skip_serializing_if = "Vec::is_empty")]
95    pub stop_sequences: Vec<String>,
96    /// JSON Schema for structured output.
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub response_schema: Option<serde_json::Value>,
99
100    /// Optional cached content name for Gemini provider.
101    /// When set, the Gemini provider attaches this to the generation request.
102    #[serde(skip_serializing_if = "Option::is_none", default)]
103    pub cached_content: Option<String>,
104
105    /// Provider-specific request options keyed by provider namespace.
106    #[serde(default, skip_serializing_if = "serde_json::Map::is_empty")]
107    pub extensions: serde_json::Map<String, serde_json::Value>,
108}
109
110/// A response from an LLM provider.
111#[derive(Debug, Clone, Default, Serialize, Deserialize)]
112pub struct LlmResponse {
113    /// The generated content (text, function calls, etc.).
114    pub content: Option<Content>,
115    /// Token usage statistics.
116    pub usage_metadata: Option<UsageMetadata>,
117    /// Reason the model stopped generating.
118    pub finish_reason: Option<FinishReason>,
119    /// Citation sources referenced in the response.
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub citation_metadata: Option<CitationMetadata>,
122    /// Whether this is a partial streaming chunk.
123    pub partial: bool,
124    /// Whether the model has finished its turn.
125    pub turn_complete: bool,
126    /// Whether the response was interrupted.
127    pub interrupted: bool,
128    /// Error code from the provider, if any.
129    pub error_code: Option<String>,
130    /// Error message from the provider, if any.
131    pub error_message: Option<String>,
132    /// Provider-specific metadata (e.g., response IDs, routing info).
133    #[serde(skip_serializing_if = "Option::is_none", default)]
134    pub provider_metadata: Option<serde_json::Value>,
135    /// Interactions API: the server-assigned interaction id for this turn, used
136    /// to continue the conversation via `previous_interaction_id`. `None` for
137    /// the generateContent transport and non-Gemini providers.
138    #[serde(skip_serializing_if = "Option::is_none", default)]
139    pub interaction_id: Option<String>,
140}
141
142/// Trait for LLM providers that support prompt caching.
143///
144/// Providers implementing this trait can create and delete cached content
145/// resources, enabling automatic prompt caching lifecycle management by the
146/// runner. The runner stores an `Option<Arc<dyn CacheCapable>>` alongside the
147/// primary `Arc<dyn Llm>` and calls these methods when [`ContextCacheConfig`]
148/// is active.
149///
150/// # Example
151///
152/// ```rust,ignore
153/// use adk_core::CacheCapable;
154///
155/// let cache_name = model
156///     .create_cache("You are a helpful assistant.", &tools, 600)
157///     .await?;
158/// // ... use cache_name in generation requests ...
159/// model.delete_cache(&cache_name).await?;
160/// ```
161#[async_trait]
162pub trait CacheCapable: Send + Sync {
163    /// Create a cached content resource from the given system instruction,
164    /// tool definitions, and TTL.
165    ///
166    /// Returns the provider-specific cache name (e.g. `"cachedContents/abc123"`
167    /// for Gemini) that can be attached to subsequent generation requests via
168    /// [`GenerateContentConfig::cached_content`].
169    async fn create_cache(
170        &self,
171        system_instruction: &str,
172        tools: &HashMap<String, serde_json::Value>,
173        ttl_seconds: u32,
174    ) -> Result<String>;
175
176    /// Delete a previously created cached content resource by name.
177    async fn delete_cache(&self, name: &str) -> Result<()>;
178}
179
180/// Configuration for automatic prompt caching lifecycle management.
181///
182/// When set on runner configuration, the runner will automatically create and manage
183/// cached content resources for supported providers (currently Gemini).
184///
185/// # Example
186///
187/// ```rust
188/// use adk_core::ContextCacheConfig;
189///
190/// let config = ContextCacheConfig {
191///     min_tokens: 4096,
192///     ttl_seconds: 600,
193///     cache_intervals: 3,
194/// };
195/// assert_eq!(config.min_tokens, 4096);
196/// ```
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct ContextCacheConfig {
199    /// Minimum system instruction + tool token count to trigger caching.
200    /// Set to 0 to disable caching.
201    pub min_tokens: u32,
202
203    /// Cache time-to-live in seconds.
204    /// Set to 0 to disable caching.
205    pub ttl_seconds: u32,
206
207    /// Maximum number of LLM invocations before cache refresh.
208    /// After this many invocations, the runner creates a new cache
209    /// and deletes the old one.
210    pub cache_intervals: u32,
211}
212
213impl Default for ContextCacheConfig {
214    fn default() -> Self {
215        Self { min_tokens: 4096, ttl_seconds: 600, cache_intervals: 3 }
216    }
217}
218
219/// Token usage statistics from an LLM response.
220#[derive(Debug, Clone, Default, Serialize, Deserialize)]
221pub struct UsageMetadata {
222    /// Number of tokens in the prompt.
223    pub prompt_token_count: i32,
224    /// Number of tokens in the generated response.
225    pub candidates_token_count: i32,
226    /// Total tokens (prompt + response).
227    pub total_token_count: i32,
228
229    /// Tokens read from cache (Gemini/Anthropic).
230    #[serde(skip_serializing_if = "Option::is_none", default)]
231    pub cache_read_input_token_count: Option<i32>,
232
233    /// Tokens written to cache during this request.
234    #[serde(skip_serializing_if = "Option::is_none", default)]
235    pub cache_creation_input_token_count: Option<i32>,
236
237    /// Tokens used for thinking/reasoning (thinking models).
238    #[serde(skip_serializing_if = "Option::is_none", default)]
239    pub thinking_token_count: Option<i32>,
240
241    /// Audio input tokens (multimodal models).
242    #[serde(skip_serializing_if = "Option::is_none", default)]
243    pub audio_input_token_count: Option<i32>,
244
245    /// Audio output tokens (multimodal models).
246    #[serde(skip_serializing_if = "Option::is_none", default)]
247    pub audio_output_token_count: Option<i32>,
248
249    /// Estimated cost in USD for this request.
250    #[serde(skip_serializing_if = "Option::is_none", default)]
251    pub cost: Option<f64>,
252
253    /// Whether this request used a bring-your-own-key provider.
254    #[serde(skip_serializing_if = "Option::is_none", default)]
255    pub is_byok: Option<bool>,
256
257    /// Provider-specific usage details (e.g., server tool use, video tokens).
258    #[serde(skip_serializing_if = "Option::is_none", default)]
259    pub provider_usage: Option<serde_json::Value>,
260}
261
262/// Citation metadata emitted by model providers for source attribution.
263#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
264#[serde(rename_all = "camelCase")]
265pub struct CitationMetadata {
266    /// The list of citation sources in this response.
267    #[serde(default)]
268    pub citation_sources: Vec<CitationSource>,
269}
270
271/// One citation source with optional offsets.
272#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
273#[serde(rename_all = "camelCase")]
274pub struct CitationSource {
275    /// URI of the cited source.
276    pub uri: Option<String>,
277    /// Title of the cited source.
278    pub title: Option<String>,
279    /// Start character index in the response text.
280    pub start_index: Option<i32>,
281    /// End character index in the response text.
282    pub end_index: Option<i32>,
283    /// License of the cited source.
284    pub license: Option<String>,
285    /// Publication date of the cited source.
286    pub publication_date: Option<String>,
287}
288
289/// Reason the model stopped generating content.
290#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
291pub enum FinishReason {
292    /// Natural stop (end of response).
293    Stop,
294    /// Hit the maximum token limit.
295    MaxTokens,
296    /// Content filtered for safety.
297    Safety,
298    /// Content blocked due to recitation/copyright.
299    Recitation,
300    /// Other/unknown reason.
301    Other,
302}
303
304impl LlmRequest {
305    /// Creates a new request with the given model and contents.
306    pub fn new(model: impl Into<String>, contents: Vec<Content>) -> Self {
307        Self {
308            model: model.into(),
309            contents,
310            config: None,
311            tools: HashMap::new(),
312            previous_response_id: None,
313        }
314    }
315
316    /// Set the response schema for structured output.
317    pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
318        let config = self.config.get_or_insert(GenerateContentConfig::default());
319        config.response_schema = Some(schema);
320        self
321    }
322
323    /// Set the generation config.
324    pub fn with_config(mut self, config: GenerateContentConfig) -> Self {
325        self.config = Some(config);
326        self
327    }
328}
329
330impl LlmResponse {
331    /// Creates a complete (non-streaming) response with the given content.
332    pub fn new(content: Content) -> Self {
333        Self {
334            content: Some(content),
335            usage_metadata: None,
336            finish_reason: Some(FinishReason::Stop),
337            citation_metadata: None,
338            partial: false,
339            turn_complete: true,
340            interrupted: false,
341            error_code: None,
342            error_message: None,
343            provider_metadata: None,
344            interaction_id: None,
345        }
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_llm_request_creation() {
355        let req = LlmRequest::new("test-model", vec![]);
356        assert_eq!(req.model, "test-model");
357        assert!(req.contents.is_empty());
358    }
359
360    #[test]
361    fn test_llm_request_with_response_schema() {
362        let schema = serde_json::json!({
363            "type": "object",
364            "properties": {
365                "name": { "type": "string" }
366            }
367        });
368        let req = LlmRequest::new("test-model", vec![]).with_response_schema(schema.clone());
369
370        assert!(req.config.is_some());
371        let config = req.config.unwrap();
372        assert!(config.response_schema.is_some());
373        assert_eq!(config.response_schema.unwrap(), schema);
374    }
375
376    #[test]
377    fn test_llm_request_with_config() {
378        let config = GenerateContentConfig {
379            temperature: Some(0.7),
380            top_p: Some(0.9),
381            top_k: Some(40),
382            frequency_penalty: Some(0.2),
383            presence_penalty: Some(-0.3),
384            max_output_tokens: Some(1024),
385            seed: Some(42),
386            top_logprobs: Some(5),
387            stop_sequences: vec!["END".to_string()],
388            ..Default::default()
389        };
390        let req = LlmRequest::new("test-model", vec![]).with_config(config);
391
392        assert!(req.config.is_some());
393        let config = req.config.unwrap();
394        assert_eq!(config.temperature, Some(0.7));
395        assert_eq!(config.max_output_tokens, Some(1024));
396        assert_eq!(config.frequency_penalty, Some(0.2));
397        assert_eq!(config.presence_penalty, Some(-0.3));
398        assert_eq!(config.seed, Some(42));
399        assert_eq!(config.top_logprobs, Some(5));
400        assert_eq!(config.stop_sequences, vec!["END"]);
401    }
402
403    #[test]
404    fn test_llm_response_creation() {
405        let content = Content::new("assistant");
406        let resp = LlmResponse::new(content);
407        assert!(resp.content.is_some());
408        assert!(resp.turn_complete);
409        assert!(!resp.partial);
410        assert_eq!(resp.finish_reason, Some(FinishReason::Stop));
411        assert!(resp.citation_metadata.is_none());
412        assert!(resp.provider_metadata.is_none());
413    }
414
415    #[test]
416    fn test_llm_response_deserialize_without_citations() {
417        let json = serde_json::json!({
418            "content": {
419                "role": "model",
420                "parts": [{"text": "hello"}]
421            },
422            "partial": false,
423            "turn_complete": true,
424            "interrupted": false
425        });
426
427        let response: LlmResponse = serde_json::from_value(json).expect("should deserialize");
428        assert!(response.citation_metadata.is_none());
429    }
430
431    #[test]
432    fn test_llm_response_roundtrip_with_citations() {
433        let response = LlmResponse {
434            content: Some(Content::new("model").with_text("hello")),
435            usage_metadata: None,
436            finish_reason: Some(FinishReason::Stop),
437            citation_metadata: Some(CitationMetadata {
438                citation_sources: vec![CitationSource {
439                    uri: Some("https://example.com".to_string()),
440                    title: Some("Example".to_string()),
441                    start_index: Some(0),
442                    end_index: Some(5),
443                    license: None,
444                    publication_date: Some("2026-01-01T00:00:00Z".to_string()),
445                }],
446            }),
447            partial: false,
448            turn_complete: true,
449            interrupted: false,
450            error_code: None,
451            error_message: None,
452            provider_metadata: None,
453            interaction_id: None,
454        };
455
456        let encoded = serde_json::to_string(&response).expect("serialize");
457        let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
458        assert_eq!(decoded.citation_metadata, response.citation_metadata);
459    }
460
461    #[test]
462    fn test_generate_content_config_roundtrip_with_extensions() {
463        let mut extensions = serde_json::Map::new();
464        extensions.insert(
465            "openrouter".to_string(),
466            serde_json::json!({
467                "provider": {
468                    "zdr": true,
469                    "order": ["openai", "anthropic"]
470                },
471                "plugins": [
472                    { "id": "web", "enabled": true }
473                ]
474            }),
475        );
476
477        let config = GenerateContentConfig {
478            temperature: Some(0.4),
479            top_p: Some(0.8),
480            top_k: Some(12),
481            frequency_penalty: Some(0.1),
482            presence_penalty: Some(0.2),
483            max_output_tokens: Some(512),
484            seed: Some(7),
485            top_logprobs: Some(3),
486            stop_sequences: vec!["STOP".to_string(), "DONE".to_string()],
487            response_schema: Some(serde_json::json!({
488                "type": "object",
489                "properties": { "answer": { "type": "string" } },
490                "required": ["answer"]
491            })),
492            cached_content: Some("cachedContents/abc123".to_string()),
493            extensions,
494        };
495
496        let encoded = serde_json::to_string(&config).expect("serialize");
497        let decoded: GenerateContentConfig = serde_json::from_str(&encoded).expect("deserialize");
498
499        assert_eq!(decoded.temperature, config.temperature);
500        assert_eq!(decoded.top_p, config.top_p);
501        assert_eq!(decoded.top_k, config.top_k);
502        assert_eq!(decoded.frequency_penalty, config.frequency_penalty);
503        assert_eq!(decoded.presence_penalty, config.presence_penalty);
504        assert_eq!(decoded.max_output_tokens, config.max_output_tokens);
505        assert_eq!(decoded.seed, config.seed);
506        assert_eq!(decoded.top_logprobs, config.top_logprobs);
507        assert_eq!(decoded.stop_sequences, config.stop_sequences);
508        assert_eq!(decoded.response_schema, config.response_schema);
509        assert_eq!(decoded.cached_content, config.cached_content);
510        assert_eq!(decoded.extensions, config.extensions);
511    }
512
513    #[test]
514    fn test_llm_response_and_usage_roundtrip_with_provider_metadata() {
515        let response = LlmResponse {
516            content: Some(Content::new("model").with_text("hello")),
517            usage_metadata: Some(UsageMetadata {
518                prompt_token_count: 10,
519                candidates_token_count: 20,
520                total_token_count: 30,
521                cache_read_input_token_count: Some(5),
522                cache_creation_input_token_count: Some(2),
523                thinking_token_count: Some(3),
524                audio_input_token_count: Some(4),
525                audio_output_token_count: Some(6),
526                cost: Some(0.0125),
527                is_byok: Some(true),
528                provider_usage: Some(serde_json::json!({
529                    "server_tool_use": {
530                        "web_search_requests": 1
531                    },
532                    "prompt_tokens_details": {
533                        "video_tokens": 8
534                    }
535                })),
536            }),
537            finish_reason: Some(FinishReason::Stop),
538            citation_metadata: None,
539            partial: false,
540            turn_complete: true,
541            interrupted: false,
542            error_code: None,
543            error_message: None,
544            provider_metadata: Some(serde_json::json!({
545                "openrouter": {
546                    "responseId": "resp_123",
547                    "outputItems": 2
548                }
549            })),
550            interaction_id: None,
551        };
552
553        let encoded = serde_json::to_string(&response).expect("serialize");
554        let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
555
556        assert_eq!(decoded.provider_metadata, response.provider_metadata);
557        assert_eq!(
558            decoded.usage_metadata.as_ref().and_then(|u| u.cost),
559            response.usage_metadata.as_ref().and_then(|u| u.cost),
560        );
561        assert_eq!(
562            decoded.usage_metadata.as_ref().and_then(|u| u.is_byok),
563            response.usage_metadata.as_ref().and_then(|u| u.is_byok),
564        );
565        assert_eq!(
566            decoded.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
567            response.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
568        );
569    }
570
571    #[test]
572    fn test_finish_reason() {
573        assert_eq!(FinishReason::Stop, FinishReason::Stop);
574        assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
575    }
576}