agent_core_runtime/client/providers/gemini/
mod.rs1mod sse;
4mod types;
5
6use async_stream::stream;
7use futures::Stream;
8
9use crate::client::error::LlmError;
10use crate::client::http::HttpClient;
11use crate::client::models::{Message, MessageOptions, StreamEvent};
12use crate::client::traits::LlmProvider;
13use std::future::Future;
14use std::pin::Pin;
15
16const ERROR_SSE_DECODE: &str = "SSE_DECODE_ERROR";
22
23const MSG_INVALID_UTF8: &str = "Invalid UTF-8 in stream";
25
26pub struct GeminiProvider {
32 api_key: String,
34 model: String,
36}
37
38impl GeminiProvider {
39 pub fn new(api_key: String, model: String) -> Self {
41 Self { api_key, model }
42 }
43
44 pub fn model(&self) -> &str {
46 &self.model
47 }
48}
49
50impl LlmProvider for GeminiProvider {
51 fn send_msg(
52 &self,
53 client: &HttpClient,
54 messages: &[Message],
55 options: &MessageOptions,
56 ) -> Pin<Box<dyn Future<Output = Result<Message, LlmError>> + Send>> {
57 let client = client.clone();
59 let api_key = self.api_key.clone();
60 let model = options.model.as_deref().unwrap_or(&self.model).to_string();
61 let messages = messages.to_vec();
62 let options = options.clone();
63
64 Box::pin(async move {
65 let body = types::build_request_body(&messages, &options)?;
67
68 let headers = types::get_request_headers(&api_key)?;
70 let headers_ref: Vec<(&str, &str)> = headers
71 .iter()
72 .map(|(k, v)| (*k, v.as_str()))
73 .collect();
74
75 let url = types::get_api_url(&model);
77
78 let response = client.post(&url, &headers_ref, &body).await?;
80
81 types::parse_response(&response)
83 })
84 }
85
86 fn send_msg_stream(
87 &self,
88 client: &HttpClient,
89 messages: &[Message],
90 options: &MessageOptions,
91 ) -> Pin<Box<dyn Future<Output = Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>, LlmError>> + Send>> {
92 let client = client.clone();
94 let api_key = self.api_key.clone();
95 let model = options.model.as_deref().unwrap_or(&self.model).to_string();
96 let messages = messages.to_vec();
97 let options = options.clone();
98
99 Box::pin(async move {
100 let body = types::build_request_body(&messages, &options)?;
102
103 let headers = types::get_request_headers(&api_key)?;
105 let headers_ref: Vec<(&str, &str)> = headers
106 .iter()
107 .map(|(k, v)| (*k, v.as_str()))
108 .collect();
109
110 let url = types::get_streaming_api_url(&model);
112
113 let byte_stream = client.post_stream(&url, &headers_ref, &body).await?;
115
116 use futures::StreamExt;
118 let event_stream = stream! {
119 let mut buffer = String::new();
120 let mut byte_stream = byte_stream;
121 let mut message_started = false;
122 let mut stream_state = sse::StreamState::default();
123
124 while let Some(chunk_result) = byte_stream.next().await {
125 match chunk_result {
126 Ok(bytes) => {
127 if let Ok(text) = std::str::from_utf8(&bytes) {
129 buffer.push_str(text);
130 } else {
131 yield Err(LlmError::new(ERROR_SSE_DECODE, MSG_INVALID_UTF8));
132 break;
133 }
134
135 let (events, remaining) = sse::parse_sse_chunk(&buffer);
137 buffer = remaining;
138
139 for sse_event in events {
141 match sse::parse_stream_event(&sse_event, &mut stream_state) {
142 Ok(stream_events) => {
143 if !message_started && !stream_events.is_empty() {
145 message_started = true;
146 yield Ok(StreamEvent::MessageStart {
147 message_id: String::new(),
148 model: model.clone(),
149 });
150 }
151
152 for stream_event in stream_events {
153 yield Ok(stream_event);
154 }
155 }
156 Err(e) => {
157 yield Err(e);
158 return;
159 }
160 }
161 }
162 }
163 Err(e) => {
164 yield Err(e);
165 break;
166 }
167 }
168 }
169
170 if message_started {
172 yield Ok(StreamEvent::MessageStop);
173 }
174 };
175
176 Ok(Box::pin(event_stream) as Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>)
177 })
178 }
179}