Skip to main content

adk_core/
model.rs

1use crate::schema_adapter::{GenericSchemaAdapter, SchemaAdapter};
2use crate::{Result, types::Content};
3use async_trait::async_trait;
4use futures::stream::Stream;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::pin::Pin;
8
9/// A pinned, boxed stream of [`LlmResponse`] results from a model.
10pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmResponse>> + Send>>;
11
12/// The core trait for all LLM providers.
13///
14/// Implementations wrap a specific model API (Gemini, OpenAI, Anthropic, etc.)
15/// and produce a stream of responses for a given request.
16#[async_trait]
17pub trait Llm: Send + Sync {
18    /// Returns the model identifier (e.g., "gemini-2.5-flash").
19    fn name(&self) -> &str;
20    /// Generates content from the given request, optionally streaming.
21    async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream>;
22
23    /// Returns the schema adapter for this provider.
24    ///
25    /// The schema adapter normalizes raw JSON Schema from MCP tools into the
26    /// format accepted by this provider's function-calling API.
27    ///
28    /// Default implementation returns [`GenericSchemaAdapter`], which applies
29    /// safe transforms suitable for most providers. Override this method to
30    /// return a provider-specific adapter (e.g., `GeminiSchemaAdapter`,
31    /// `OpenAiStrictSchemaAdapter`).
32    fn schema_adapter(&self) -> &dyn SchemaAdapter {
33        &GenericSchemaAdapter
34    }
35}
36
37/// A request to an LLM provider.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct LlmRequest {
40    /// The model identifier to use for generation.
41    pub model: String,
42    /// The conversation contents (system, user, model messages).
43    pub contents: Vec<Content>,
44    /// Optional generation configuration (temperature, tokens, etc.).
45    pub config: Option<GenerateContentConfig>,
46    /// Tool declarations keyed by tool name.
47    #[serde(skip)]
48    pub tools: HashMap<String, serde_json::Value>,
49}
50
51/// Configuration for LLM content generation.
52#[derive(Debug, Clone, Default, Serialize, Deserialize)]
53pub struct GenerateContentConfig {
54    /// Sampling temperature (0.0 = deterministic, higher = more random).
55    pub temperature: Option<f32>,
56    /// Nucleus sampling threshold.
57    pub top_p: Option<f32>,
58    /// Top-k sampling parameter.
59    pub top_k: Option<i32>,
60    /// Frequency penalty to reduce repetition.
61    #[serde(skip_serializing_if = "Option::is_none", default)]
62    pub frequency_penalty: Option<f32>,
63    /// Presence penalty to encourage topic diversity.
64    #[serde(skip_serializing_if = "Option::is_none", default)]
65    pub presence_penalty: Option<f32>,
66    /// Maximum number of output tokens to generate.
67    pub max_output_tokens: Option<i32>,
68    /// Random seed for reproducible generation.
69    #[serde(skip_serializing_if = "Option::is_none", default)]
70    pub seed: Option<i64>,
71    /// Number of top log probabilities to return per token.
72    #[serde(skip_serializing_if = "Option::is_none", default)]
73    pub top_logprobs: Option<u8>,
74    /// Sequences that stop generation when encountered.
75    #[serde(default, skip_serializing_if = "Vec::is_empty")]
76    pub stop_sequences: Vec<String>,
77    /// JSON Schema for structured output.
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub response_schema: Option<serde_json::Value>,
80
81    /// Optional cached content name for Gemini provider.
82    /// When set, the Gemini provider attaches this to the generation request.
83    #[serde(skip_serializing_if = "Option::is_none", default)]
84    pub cached_content: Option<String>,
85
86    /// Provider-specific request options keyed by provider namespace.
87    #[serde(default, skip_serializing_if = "serde_json::Map::is_empty")]
88    pub extensions: serde_json::Map<String, serde_json::Value>,
89}
90
91/// A response from an LLM provider.
92#[derive(Debug, Clone, Default, Serialize, Deserialize)]
93pub struct LlmResponse {
94    /// The generated content (text, function calls, etc.).
95    pub content: Option<Content>,
96    /// Token usage statistics.
97    pub usage_metadata: Option<UsageMetadata>,
98    /// Reason the model stopped generating.
99    pub finish_reason: Option<FinishReason>,
100    /// Citation sources referenced in the response.
101    #[serde(skip_serializing_if = "Option::is_none")]
102    pub citation_metadata: Option<CitationMetadata>,
103    /// Whether this is a partial streaming chunk.
104    pub partial: bool,
105    /// Whether the model has finished its turn.
106    pub turn_complete: bool,
107    /// Whether the response was interrupted.
108    pub interrupted: bool,
109    /// Error code from the provider, if any.
110    pub error_code: Option<String>,
111    /// Error message from the provider, if any.
112    pub error_message: Option<String>,
113    /// Provider-specific metadata (e.g., response IDs, routing info).
114    #[serde(skip_serializing_if = "Option::is_none", default)]
115    pub provider_metadata: Option<serde_json::Value>,
116}
117
118/// Trait for LLM providers that support prompt caching.
119///
120/// Providers implementing this trait can create and delete cached content
121/// resources, enabling automatic prompt caching lifecycle management by the
122/// runner. The runner stores an `Option<Arc<dyn CacheCapable>>` alongside the
123/// primary `Arc<dyn Llm>` and calls these methods when [`ContextCacheConfig`]
124/// is active.
125///
126/// # Example
127///
128/// ```rust,ignore
129/// use adk_core::CacheCapable;
130///
131/// let cache_name = model
132///     .create_cache("You are a helpful assistant.", &tools, 600)
133///     .await?;
134/// // ... use cache_name in generation requests ...
135/// model.delete_cache(&cache_name).await?;
136/// ```
137#[async_trait]
138pub trait CacheCapable: Send + Sync {
139    /// Create a cached content resource from the given system instruction,
140    /// tool definitions, and TTL.
141    ///
142    /// Returns the provider-specific cache name (e.g. `"cachedContents/abc123"`
143    /// for Gemini) that can be attached to subsequent generation requests via
144    /// [`GenerateContentConfig::cached_content`].
145    async fn create_cache(
146        &self,
147        system_instruction: &str,
148        tools: &HashMap<String, serde_json::Value>,
149        ttl_seconds: u32,
150    ) -> Result<String>;
151
152    /// Delete a previously created cached content resource by name.
153    async fn delete_cache(&self, name: &str) -> Result<()>;
154}
155
156/// Configuration for automatic prompt caching lifecycle management.
157///
158/// When set on runner configuration, the runner will automatically create and manage
159/// cached content resources for supported providers (currently Gemini).
160///
161/// # Example
162///
163/// ```rust
164/// use adk_core::ContextCacheConfig;
165///
166/// let config = ContextCacheConfig {
167///     min_tokens: 4096,
168///     ttl_seconds: 600,
169///     cache_intervals: 3,
170/// };
171/// assert_eq!(config.min_tokens, 4096);
172/// ```
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct ContextCacheConfig {
175    /// Minimum system instruction + tool token count to trigger caching.
176    /// Set to 0 to disable caching.
177    pub min_tokens: u32,
178
179    /// Cache time-to-live in seconds.
180    /// Set to 0 to disable caching.
181    pub ttl_seconds: u32,
182
183    /// Maximum number of LLM invocations before cache refresh.
184    /// After this many invocations, the runner creates a new cache
185    /// and deletes the old one.
186    pub cache_intervals: u32,
187}
188
189impl Default for ContextCacheConfig {
190    fn default() -> Self {
191        Self { min_tokens: 4096, ttl_seconds: 600, cache_intervals: 3 }
192    }
193}
194
195/// Token usage statistics from an LLM response.
196#[derive(Debug, Clone, Default, Serialize, Deserialize)]
197pub struct UsageMetadata {
198    /// Number of tokens in the prompt.
199    pub prompt_token_count: i32,
200    /// Number of tokens in the generated response.
201    pub candidates_token_count: i32,
202    /// Total tokens (prompt + response).
203    pub total_token_count: i32,
204
205    /// Tokens read from cache (Gemini/Anthropic).
206    #[serde(skip_serializing_if = "Option::is_none", default)]
207    pub cache_read_input_token_count: Option<i32>,
208
209    /// Tokens written to cache during this request.
210    #[serde(skip_serializing_if = "Option::is_none", default)]
211    pub cache_creation_input_token_count: Option<i32>,
212
213    /// Tokens used for thinking/reasoning (thinking models).
214    #[serde(skip_serializing_if = "Option::is_none", default)]
215    pub thinking_token_count: Option<i32>,
216
217    /// Audio input tokens (multimodal models).
218    #[serde(skip_serializing_if = "Option::is_none", default)]
219    pub audio_input_token_count: Option<i32>,
220
221    /// Audio output tokens (multimodal models).
222    #[serde(skip_serializing_if = "Option::is_none", default)]
223    pub audio_output_token_count: Option<i32>,
224
225    /// Estimated cost in USD for this request.
226    #[serde(skip_serializing_if = "Option::is_none", default)]
227    pub cost: Option<f64>,
228
229    /// Whether this request used a bring-your-own-key provider.
230    #[serde(skip_serializing_if = "Option::is_none", default)]
231    pub is_byok: Option<bool>,
232
233    /// Provider-specific usage details (e.g., server tool use, video tokens).
234    #[serde(skip_serializing_if = "Option::is_none", default)]
235    pub provider_usage: Option<serde_json::Value>,
236}
237
238/// Citation metadata emitted by model providers for source attribution.
239#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
240#[serde(rename_all = "camelCase")]
241pub struct CitationMetadata {
242    /// The list of citation sources in this response.
243    #[serde(default)]
244    pub citation_sources: Vec<CitationSource>,
245}
246
247/// One citation source with optional offsets.
248#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
249#[serde(rename_all = "camelCase")]
250pub struct CitationSource {
251    /// URI of the cited source.
252    pub uri: Option<String>,
253    /// Title of the cited source.
254    pub title: Option<String>,
255    /// Start character index in the response text.
256    pub start_index: Option<i32>,
257    /// End character index in the response text.
258    pub end_index: Option<i32>,
259    /// License of the cited source.
260    pub license: Option<String>,
261    /// Publication date of the cited source.
262    pub publication_date: Option<String>,
263}
264
265/// Reason the model stopped generating content.
266#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
267pub enum FinishReason {
268    /// Natural stop (end of response).
269    Stop,
270    /// Hit the maximum token limit.
271    MaxTokens,
272    /// Content filtered for safety.
273    Safety,
274    /// Content blocked due to recitation/copyright.
275    Recitation,
276    /// Other/unknown reason.
277    Other,
278}
279
280impl LlmRequest {
281    /// Creates a new request with the given model and contents.
282    pub fn new(model: impl Into<String>, contents: Vec<Content>) -> Self {
283        Self { model: model.into(), contents, config: None, tools: HashMap::new() }
284    }
285
286    /// Set the response schema for structured output.
287    pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
288        let config = self.config.get_or_insert(GenerateContentConfig::default());
289        config.response_schema = Some(schema);
290        self
291    }
292
293    /// Set the generation config.
294    pub fn with_config(mut self, config: GenerateContentConfig) -> Self {
295        self.config = Some(config);
296        self
297    }
298}
299
300impl LlmResponse {
301    /// Creates a complete (non-streaming) response with the given content.
302    pub fn new(content: Content) -> Self {
303        Self {
304            content: Some(content),
305            usage_metadata: None,
306            finish_reason: Some(FinishReason::Stop),
307            citation_metadata: None,
308            partial: false,
309            turn_complete: true,
310            interrupted: false,
311            error_code: None,
312            error_message: None,
313            provider_metadata: None,
314        }
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn test_llm_request_creation() {
324        let req = LlmRequest::new("test-model", vec![]);
325        assert_eq!(req.model, "test-model");
326        assert!(req.contents.is_empty());
327    }
328
329    #[test]
330    fn test_llm_request_with_response_schema() {
331        let schema = serde_json::json!({
332            "type": "object",
333            "properties": {
334                "name": { "type": "string" }
335            }
336        });
337        let req = LlmRequest::new("test-model", vec![]).with_response_schema(schema.clone());
338
339        assert!(req.config.is_some());
340        let config = req.config.unwrap();
341        assert!(config.response_schema.is_some());
342        assert_eq!(config.response_schema.unwrap(), schema);
343    }
344
345    #[test]
346    fn test_llm_request_with_config() {
347        let config = GenerateContentConfig {
348            temperature: Some(0.7),
349            top_p: Some(0.9),
350            top_k: Some(40),
351            frequency_penalty: Some(0.2),
352            presence_penalty: Some(-0.3),
353            max_output_tokens: Some(1024),
354            seed: Some(42),
355            top_logprobs: Some(5),
356            stop_sequences: vec!["END".to_string()],
357            ..Default::default()
358        };
359        let req = LlmRequest::new("test-model", vec![]).with_config(config);
360
361        assert!(req.config.is_some());
362        let config = req.config.unwrap();
363        assert_eq!(config.temperature, Some(0.7));
364        assert_eq!(config.max_output_tokens, Some(1024));
365        assert_eq!(config.frequency_penalty, Some(0.2));
366        assert_eq!(config.presence_penalty, Some(-0.3));
367        assert_eq!(config.seed, Some(42));
368        assert_eq!(config.top_logprobs, Some(5));
369        assert_eq!(config.stop_sequences, vec!["END"]);
370    }
371
372    #[test]
373    fn test_llm_response_creation() {
374        let content = Content::new("assistant");
375        let resp = LlmResponse::new(content);
376        assert!(resp.content.is_some());
377        assert!(resp.turn_complete);
378        assert!(!resp.partial);
379        assert_eq!(resp.finish_reason, Some(FinishReason::Stop));
380        assert!(resp.citation_metadata.is_none());
381        assert!(resp.provider_metadata.is_none());
382    }
383
384    #[test]
385    fn test_llm_response_deserialize_without_citations() {
386        let json = serde_json::json!({
387            "content": {
388                "role": "model",
389                "parts": [{"text": "hello"}]
390            },
391            "partial": false,
392            "turn_complete": true,
393            "interrupted": false
394        });
395
396        let response: LlmResponse = serde_json::from_value(json).expect("should deserialize");
397        assert!(response.citation_metadata.is_none());
398    }
399
400    #[test]
401    fn test_llm_response_roundtrip_with_citations() {
402        let response = LlmResponse {
403            content: Some(Content::new("model").with_text("hello")),
404            usage_metadata: None,
405            finish_reason: Some(FinishReason::Stop),
406            citation_metadata: Some(CitationMetadata {
407                citation_sources: vec![CitationSource {
408                    uri: Some("https://example.com".to_string()),
409                    title: Some("Example".to_string()),
410                    start_index: Some(0),
411                    end_index: Some(5),
412                    license: None,
413                    publication_date: Some("2026-01-01T00:00:00Z".to_string()),
414                }],
415            }),
416            partial: false,
417            turn_complete: true,
418            interrupted: false,
419            error_code: None,
420            error_message: None,
421            provider_metadata: None,
422        };
423
424        let encoded = serde_json::to_string(&response).expect("serialize");
425        let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
426        assert_eq!(decoded.citation_metadata, response.citation_metadata);
427    }
428
429    #[test]
430    fn test_generate_content_config_roundtrip_with_extensions() {
431        let mut extensions = serde_json::Map::new();
432        extensions.insert(
433            "openrouter".to_string(),
434            serde_json::json!({
435                "provider": {
436                    "zdr": true,
437                    "order": ["openai", "anthropic"]
438                },
439                "plugins": [
440                    { "id": "web", "enabled": true }
441                ]
442            }),
443        );
444
445        let config = GenerateContentConfig {
446            temperature: Some(0.4),
447            top_p: Some(0.8),
448            top_k: Some(12),
449            frequency_penalty: Some(0.1),
450            presence_penalty: Some(0.2),
451            max_output_tokens: Some(512),
452            seed: Some(7),
453            top_logprobs: Some(3),
454            stop_sequences: vec!["STOP".to_string(), "DONE".to_string()],
455            response_schema: Some(serde_json::json!({
456                "type": "object",
457                "properties": { "answer": { "type": "string" } },
458                "required": ["answer"]
459            })),
460            cached_content: Some("cachedContents/abc123".to_string()),
461            extensions,
462        };
463
464        let encoded = serde_json::to_string(&config).expect("serialize");
465        let decoded: GenerateContentConfig = serde_json::from_str(&encoded).expect("deserialize");
466
467        assert_eq!(decoded.temperature, config.temperature);
468        assert_eq!(decoded.top_p, config.top_p);
469        assert_eq!(decoded.top_k, config.top_k);
470        assert_eq!(decoded.frequency_penalty, config.frequency_penalty);
471        assert_eq!(decoded.presence_penalty, config.presence_penalty);
472        assert_eq!(decoded.max_output_tokens, config.max_output_tokens);
473        assert_eq!(decoded.seed, config.seed);
474        assert_eq!(decoded.top_logprobs, config.top_logprobs);
475        assert_eq!(decoded.stop_sequences, config.stop_sequences);
476        assert_eq!(decoded.response_schema, config.response_schema);
477        assert_eq!(decoded.cached_content, config.cached_content);
478        assert_eq!(decoded.extensions, config.extensions);
479    }
480
481    #[test]
482    fn test_llm_response_and_usage_roundtrip_with_provider_metadata() {
483        let response = LlmResponse {
484            content: Some(Content::new("model").with_text("hello")),
485            usage_metadata: Some(UsageMetadata {
486                prompt_token_count: 10,
487                candidates_token_count: 20,
488                total_token_count: 30,
489                cache_read_input_token_count: Some(5),
490                cache_creation_input_token_count: Some(2),
491                thinking_token_count: Some(3),
492                audio_input_token_count: Some(4),
493                audio_output_token_count: Some(6),
494                cost: Some(0.0125),
495                is_byok: Some(true),
496                provider_usage: Some(serde_json::json!({
497                    "server_tool_use": {
498                        "web_search_requests": 1
499                    },
500                    "prompt_tokens_details": {
501                        "video_tokens": 8
502                    }
503                })),
504            }),
505            finish_reason: Some(FinishReason::Stop),
506            citation_metadata: None,
507            partial: false,
508            turn_complete: true,
509            interrupted: false,
510            error_code: None,
511            error_message: None,
512            provider_metadata: Some(serde_json::json!({
513                "openrouter": {
514                    "responseId": "resp_123",
515                    "outputItems": 2
516                }
517            })),
518        };
519
520        let encoded = serde_json::to_string(&response).expect("serialize");
521        let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
522
523        assert_eq!(decoded.provider_metadata, response.provider_metadata);
524        assert_eq!(
525            decoded.usage_metadata.as_ref().and_then(|u| u.cost),
526            response.usage_metadata.as_ref().and_then(|u| u.cost),
527        );
528        assert_eq!(
529            decoded.usage_metadata.as_ref().and_then(|u| u.is_byok),
530            response.usage_metadata.as_ref().and_then(|u| u.is_byok),
531        );
532        assert_eq!(
533            decoded.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
534            response.usage_metadata.as_ref().and_then(|u| u.provider_usage.clone()),
535        );
536    }
537
538    #[test]
539    fn test_finish_reason() {
540        assert_eq!(FinishReason::Stop, FinishReason::Stop);
541        assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
542    }
543}