1use crate::{Result, types::Content};
2use async_trait::async_trait;
3use futures::stream::Stream;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::pin::Pin;
7
8pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmResponse>> + Send>>;
9
10#[async_trait]
11pub trait Llm: Send + Sync {
12 fn name(&self) -> &str;
13 async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream>;
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct LlmRequest {
18 pub model: String,
19 pub contents: Vec<Content>,
20 pub config: Option<GenerateContentConfig>,
21 #[serde(skip)]
22 pub tools: HashMap<String, serde_json::Value>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct GenerateContentConfig {
27 pub temperature: Option<f32>,
28 pub top_p: Option<f32>,
29 pub top_k: Option<i32>,
30 pub max_output_tokens: Option<i32>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub response_schema: Option<serde_json::Value>,
33}
34
35#[derive(Debug, Clone, Default, Serialize, Deserialize)]
36pub struct LlmResponse {
37 pub content: Option<Content>,
38 pub usage_metadata: Option<UsageMetadata>,
39 pub finish_reason: Option<FinishReason>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub citation_metadata: Option<CitationMetadata>,
42 pub partial: bool,
43 pub turn_complete: bool,
44 pub interrupted: bool,
45 pub error_code: Option<String>,
46 pub error_message: Option<String>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct UsageMetadata {
51 pub prompt_token_count: i32,
52 pub candidates_token_count: i32,
53 pub total_token_count: i32,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58#[serde(rename_all = "camelCase")]
59pub struct CitationMetadata {
60 pub citation_sources: Vec<CitationSource>,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
65#[serde(rename_all = "camelCase")]
66pub struct CitationSource {
67 pub uri: Option<String>,
68 pub title: Option<String>,
69 pub start_index: Option<i32>,
70 pub end_index: Option<i32>,
71 pub license: Option<String>,
72 pub publication_date: Option<String>,
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
76pub enum FinishReason {
77 Stop,
78 MaxTokens,
79 Safety,
80 Recitation,
81 Other,
82}
83
84impl LlmRequest {
85 pub fn new(model: impl Into<String>, contents: Vec<Content>) -> Self {
86 Self { model: model.into(), contents, config: None, tools: HashMap::new() }
87 }
88
89 pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
91 let config = self.config.get_or_insert(GenerateContentConfig {
92 temperature: None,
93 top_p: None,
94 top_k: None,
95 max_output_tokens: None,
96 response_schema: None,
97 });
98 config.response_schema = Some(schema);
99 self
100 }
101
102 pub fn with_config(mut self, config: GenerateContentConfig) -> Self {
104 self.config = Some(config);
105 self
106 }
107}
108
109impl LlmResponse {
110 pub fn new(content: Content) -> Self {
111 Self {
112 content: Some(content),
113 usage_metadata: None,
114 finish_reason: Some(FinishReason::Stop),
115 citation_metadata: None,
116 partial: false,
117 turn_complete: true,
118 interrupted: false,
119 error_code: None,
120 error_message: None,
121 }
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn test_llm_request_creation() {
131 let req = LlmRequest::new("test-model", vec![]);
132 assert_eq!(req.model, "test-model");
133 assert!(req.contents.is_empty());
134 }
135
136 #[test]
137 fn test_llm_request_with_response_schema() {
138 let schema = serde_json::json!({
139 "type": "object",
140 "properties": {
141 "name": { "type": "string" }
142 }
143 });
144 let req = LlmRequest::new("test-model", vec![]).with_response_schema(schema.clone());
145
146 assert!(req.config.is_some());
147 let config = req.config.unwrap();
148 assert!(config.response_schema.is_some());
149 assert_eq!(config.response_schema.unwrap(), schema);
150 }
151
152 #[test]
153 fn test_llm_request_with_config() {
154 let config = GenerateContentConfig {
155 temperature: Some(0.7),
156 top_p: Some(0.9),
157 top_k: Some(40),
158 max_output_tokens: Some(1024),
159 response_schema: None,
160 };
161 let req = LlmRequest::new("test-model", vec![]).with_config(config);
162
163 assert!(req.config.is_some());
164 let config = req.config.unwrap();
165 assert_eq!(config.temperature, Some(0.7));
166 assert_eq!(config.max_output_tokens, Some(1024));
167 }
168
169 #[test]
170 fn test_llm_response_creation() {
171 let content = Content::new("assistant");
172 let resp = LlmResponse::new(content);
173 assert!(resp.content.is_some());
174 assert!(resp.turn_complete);
175 assert!(!resp.partial);
176 assert_eq!(resp.finish_reason, Some(FinishReason::Stop));
177 assert!(resp.citation_metadata.is_none());
178 }
179
180 #[test]
181 fn test_llm_response_deserialize_without_citations() {
182 let json = serde_json::json!({
183 "content": {
184 "role": "model",
185 "parts": [{"text": "hello"}]
186 },
187 "partial": false,
188 "turn_complete": true,
189 "interrupted": false
190 });
191
192 let response: LlmResponse = serde_json::from_value(json).expect("should deserialize");
193 assert!(response.citation_metadata.is_none());
194 }
195
196 #[test]
197 fn test_llm_response_roundtrip_with_citations() {
198 let response = LlmResponse {
199 content: Some(Content::new("model").with_text("hello")),
200 usage_metadata: None,
201 finish_reason: Some(FinishReason::Stop),
202 citation_metadata: Some(CitationMetadata {
203 citation_sources: vec![CitationSource {
204 uri: Some("https://example.com".to_string()),
205 title: Some("Example".to_string()),
206 start_index: Some(0),
207 end_index: Some(5),
208 license: None,
209 publication_date: Some("2026-01-01T00:00:00Z".to_string()),
210 }],
211 }),
212 partial: false,
213 turn_complete: true,
214 interrupted: false,
215 error_code: None,
216 error_message: None,
217 };
218
219 let encoded = serde_json::to_string(&response).expect("serialize");
220 let decoded: LlmResponse = serde_json::from_str(&encoded).expect("deserialize");
221 assert_eq!(decoded.citation_metadata, response.citation_metadata);
222 }
223
224 #[test]
225 fn test_finish_reason() {
226 assert_eq!(FinishReason::Stop, FinishReason::Stop);
227 assert_ne!(FinishReason::Stop, FinishReason::MaxTokens);
228 }
229}