Skip to main content

adk_core/
model.rs

1use crate::{Result, types::Content};
2use async_trait::async_trait;
3use futures::stream::Stream;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::pin::Pin;
7
8pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmResponse>> + Send>>;
9
10#[async_trait]
11pub trait Llm: Send + Sync {
12    fn name(&self) -> &str;
13    async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream>;
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct LlmRequest {
18    pub model: String,
19    pub contents: Vec<Content>,
20    pub config: Option<GenerateContentConfig>,
21    #[serde(skip)]
22    pub tools: HashMap<String, serde_json::Value>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct GenerateContentConfig {
27    pub temperature: Option<f32>,
28    pub top_p: Option<f32>,
29    pub top_k: Option<i32>,
30    pub max_output_tokens: Option<i32>,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub response_schema: Option<serde_json::Value>,
33}
34
35#[derive(Debug, Clone, Default, Serialize, Deserialize)]
36pub struct LlmResponse {
37    pub content: Option<Content>,
38    pub usage_metadata: Option<UsageMetadata>,
39    pub finish_reason: Option<FinishReason>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub citation_metadata: Option<CitationMetadata>,
42    pub partial: bool,
43    pub turn_complete: bool,
44    pub interrupted: bool,
45    pub error_code: Option<String>,
46    pub error_message: Option<String>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct UsageMetadata {
51    pub prompt_token_count: i32,
52    pub candidates_token_count: i32,
53    pub total_token_count: i32,
54}
55
56/// Citation metadata emitted by model providers for source attribution.
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58#[serde(rename_all = "camelCase")]
59pub struct CitationMetadata {
60    pub citation_sources: Vec<CitationSource>,
61}
62
63/// One citation source with optional offsets.
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
65#[serde(rename_all = "camelCase")]
66pub struct CitationSource {
67    pub uri: Option<String>,
68    pub title: Option<String>,
69    pub start_index: Option<i32>,
70    pub end_index: Option<i32>,
71    pub license: Option<String>,
72    pub publication_date: Option<String>,
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
76pub enum FinishReason {
77    Stop,
78    MaxTokens,
79    Safety,
80    Recitation,
81    Other,
82}
83
84impl LlmRequest {
85    pub fn new(model: impl Into<String>, contents: Vec<Content>) -> Self {
86        Self { model: model.into(), contents, config: None, tools: HashMap::new() }
87    }
88
89    /// Set the response schema for structured output.
90    pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
91        let config = self.config.get_or_insert(GenerateContentConfig {
92            temperature: None,
93            top_p: None,
94            top_k: None,
95            max_output_tokens: None,
96            response_schema: None,
97        });
98        config.response_schema = Some(schema);
99        self
100    }
101
102    /// Set the generation config.
103    pub fn with_config(mut self, config: GenerateContentConfig) -> Self {
104        self.config = Some(config);
105        self
106    }
107}
108
109impl LlmResponse {
110    pub fn new(content: Content) -> Self {
111        Self {
112            content: Some(content),
113            usage_metadata: None,
114            finish_reason: Some(FinishReason::Stop),
115            citation_metadata: None,
116            partial: false,
117            turn_complete: true,
118            interrupted: false,
119            error_code: None,
120            error_message: None,
121        }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn test_llm_request_creation() {
131        let req = LlmRequest::new("test-model", vec![]);
132        assert_eq!(req.model, "test-model");
133        assert!(req.contents.is_empty());
134    }
135
136    #[test]
137    fn test_llm_request_with_response_schema() {
138        let schema = serde_json::json!({
139            "type": "object",
140            "properties": {
141                "name": { "type": "string" }
142            }
143        });
144        let req = LlmRequest::new("test-model", vec![]).with_response_schema(schema.clone());
145
146        assert!(req.config.is_some());
147        let config = req.config.unwrap();
148        assert!(config.response_schema.is_some());
149        assert_eq!(config.response_schema.unwrap(), schema);
150    }
151
152    #[test]
153    fn test_llm_request_with_config() {
154        let config = GenerateContentConfig {
155            temperature: Some(0.7),
156            top_p: Some(0.9),
157            top_k: Some(40),
158            max_output_tokens: Some(1024),
159            response_schema: None,
160        };
161        let req = LlmRequest::new("test-model", vec![]).with_config(config);
162
163        assert!(req.config.is_some());
164        let config = req.config.unwrap();
165        assert_eq!(config.temperature, Some(0.7));
166        assert_eq!(config.max_output_tokens, Some(1024));
167    }
168
169    #[test]
170    fn test_llm_response_creation() {
171        let content = Content::new("assistant");
172        let resp = LlmResponse::new(content);
173        assert!(resp.content.is_some());
174        assert!(resp.turn_complete);
175        assert!(!resp.partial);
176        assert_eq!(resp.finish_reason, Some(FinishReason::Stop));
177        assert!(resp.citation_metadata.is_none());
178    }
179
180    #[test]
181    fn test_llm_response_deserialize_without_citations() {
182        let json = serde_json::json!({
183            "content": {
184                "role": "model",
185                "parts": [{"text": "hello"}]
186            },
187            "partial": false,
188            "turn_complete": true,
189            "interrupted": false
190        });
191
192        let response: LlmResponse = serde_json::from_value(json).expect("should deserialize");
193        assert!(response.citation_metadata.is_none());
194    }
195
196    #[test]
197    fn test_llm_response_roundtrip_with_citations() {
198        let response = LlmResponse {
199            content: Some(Content::new("model").with_text("hello")),
200            usage_metadata: None,
201            finish_reason: Some(FinishReason::Stop),
202            citation_metadata: Some(CitationMetadata {
203                citation_sources: vec![CitationSource {
204                    uri: Some("https://example.com".to_string()),
205                    title: Some("Example".to_string()),
206                    start_index: Some(0),
207                    end_index: Some(5),
208                    license: None,
209                    publication_date: Some("2026-01-01T00:00:00Z".to_string()),
210                }],
211            }),
212            partial: false,
213            turn_complete: true,
214            interrupted: false,
215            error_code: None,
216            error_message: None,
217        };
218
219        let encoded = serde_json::to_string(&response).expect("serialize");
220        let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
221        assert_eq!(decoded.citation_metadata, response.citation_metadata);
222    }
223
224    #[test]
225    fn test_finish_reason() {
226        assert_eq!(FinishReason::Stop, FinishReason::Stop);
227        assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
228    }
229}