adk_core/
model.rs

1use crate::{Result, types::Content};
2use async_trait::async_trait;
3use futures::stream::Stream;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::pin::Pin;
7
8pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmResponse>> + Send>>;
9
10#[async_trait]
11pub trait Llm: Send + Sync {
12    fn name(&self) -> &str;
13    async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream>;
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct LlmRequest {
18    pub model: String,
19    pub contents: Vec<Content>,
20    pub config: Option<GenerateContentConfig>,
21    #[serde(skip)]
22    pub tools: HashMap<String, serde_json::Value>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct GenerateContentConfig {
27    pub temperature: Option<f32>,
28    pub top_p: Option<f32>,
29    pub top_k: Option<i32>,
30    pub max_output_tokens: Option<i32>,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub response_schema: Option<serde_json::Value>,
33}
34
35#[derive(Debug, Clone, Default, Serialize, Deserialize)]
36pub struct LlmResponse {
37    pub content: Option<Content>,
38    pub usage_metadata: Option<UsageMetadata>,
39    pub finish_reason: Option<FinishReason>,
40    pub partial: bool,
41    pub turn_complete: bool,
42    pub interrupted: bool,
43    pub error_code: Option<String>,
44    pub error_message: Option<String>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct UsageMetadata {
49    pub prompt_token_count: i32,
50    pub candidates_token_count: i32,
51    pub total_token_count: i32,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
55pub enum FinishReason {
56    Stop,
57    MaxTokens,
58    Safety,
59    Recitation,
60    Other,
61}
62
63impl LlmRequest {
64    pub fn new(model: impl Into<String>, contents: Vec<Content>) -> Self {
65        Self { model: model.into(), contents, config: None, tools: HashMap::new() }
66    }
67
68    /// Set the response schema for structured output.
69    pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
70        let config = self.config.get_or_insert(GenerateContentConfig {
71            temperature: None,
72            top_p: None,
73            top_k: None,
74            max_output_tokens: None,
75            response_schema: None,
76        });
77        config.response_schema = Some(schema);
78        self
79    }
80
81    /// Set the generation config.
82    pub fn with_config(mut self, config: GenerateContentConfig) -> Self {
83        self.config = Some(config);
84        self
85    }
86}
87
88impl LlmResponse {
89    pub fn new(content: Content) -> Self {
90        Self {
91            content: Some(content),
92            usage_metadata: None,
93            finish_reason: Some(FinishReason::Stop),
94            partial: false,
95            turn_complete: true,
96            interrupted: false,
97            error_code: None,
98            error_message: None,
99        }
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn test_llm_request_creation() {
109        let req = LlmRequest::new("test-model", vec![]);
110        assert_eq!(req.model, "test-model");
111        assert!(req.contents.is_empty());
112    }
113
114    #[test]
115    fn test_llm_request_with_response_schema() {
116        let schema = serde_json::json!({
117            "type": "object",
118            "properties": {
119                "name": { "type": "string" }
120            }
121        });
122        let req = LlmRequest::new("test-model", vec![]).with_response_schema(schema.clone());
123
124        assert!(req.config.is_some());
125        let config = req.config.unwrap();
126        assert!(config.response_schema.is_some());
127        assert_eq!(config.response_schema.unwrap(), schema);
128    }
129
130    #[test]
131    fn test_llm_request_with_config() {
132        let config = GenerateContentConfig {
133            temperature: Some(0.7),
134            top_p: Some(0.9),
135            top_k: Some(40),
136            max_output_tokens: Some(1024),
137            response_schema: None,
138        };
139        let req = LlmRequest::new("test-model", vec![]).with_config(config);
140
141        assert!(req.config.is_some());
142        let config = req.config.unwrap();
143        assert_eq!(config.temperature, Some(0.7));
144        assert_eq!(config.max_output_tokens, Some(1024));
145    }
146
147    #[test]
148    fn test_llm_response_creation() {
149        let content = Content::new("assistant");
150        let resp = LlmResponse::new(content);
151        assert!(resp.content.is_some());
152        assert!(resp.turn_complete);
153        assert!(!resp.partial);
154        assert_eq!(resp.finish_reason, Some(FinishReason::Stop));
155    }
156
157    #[test]
158    fn test_finish_reason() {
159        assert_eq!(FinishReason::Stop, FinishReason::Stop);
160        assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
161    }
162}