1use std::{
2 fmt::Write as _,
3 ops::{Deref, DerefMut},
4 sync::{Arc, LazyLock},
5};
6
7use futures::FutureExt as _;
8use reqwest::Method;
9use secrecy::{ExposeSecret as _, SecretString};
10
11use crate::{Chat, Error, Result, StreamGenerateContent, chat, types};
12
13pub(crate) const BASE_URI: &str = "https://generativelanguage.googleapis.com";
14
15pub struct Route<T> {
16 pub(crate) client: Client,
17 pub(crate) kind: T,
18}
19
20impl<T> Route<T> {
21 fn new(client: &Client, kind: T) -> Self {
22 Self {
23 client: client.clone(),
24 kind,
25 }
26 }
27}
28
29impl<T: Request> IntoFuture for Route<T> {
30 type Output = Result<T::Model>;
31 type IntoFuture = futures::future::BoxFuture<'static, Self::Output>;
32
33 fn into_future(self) -> Self::IntoFuture {
34 async move {
35 let mut request = self
36 .client
37 .reqwest
38 .request(T::METHOD, format!("{BASE_URI}/{self}"));
39
40 if let Some(body) = self.kind.body() {
41 request = request.json(&body);
42 };
43
44 let response = request.send().await?;
45 let raw_json = response.text().await?;
46
47 match serde_json::from_str::<types::ApiResponse<T::Model>>(&raw_json)? {
48 types::ApiResponse::Ok(response) => Ok(response),
49 types::ApiResponse::Err(api_error) => Err(Error::Gemini(api_error.error)),
50 }
51 }
52 .boxed()
53 }
54}
55
56impl Deref for Route<GenerateContent> {
57 type Target = GenerateContent;
58
59 fn deref(&self) -> &Self::Target {
60 &self.kind
61 }
62}
63
64impl DerefMut for Route<GenerateContent> {
65 fn deref_mut(&mut self) -> &mut Self::Target {
66 &mut self.kind
67 }
68}
69
70impl<T: Request> std::fmt::Display for Route<T> {
71 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 let mut fmt = Formatter::new(fmt);
73 self.kind.format_uri(&mut fmt)?;
74 fmt.write_query_param("key", &self.client.key.expose_secret())
75 }
76}
77
78#[derive(Clone)]
80pub struct Client {
81 pub(crate) inner: Arc<ClientInner>,
82}
83
84impl Deref for Client {
85 type Target = ClientInner;
86
87 fn deref(&self) -> &Self::Target {
88 &self.inner
89 }
90}
91
92impl Default for Client {
93 fn default() -> Self {
94 Self {
95 inner: ClientInner::new(None),
96 }
97 }
98}
99
100impl Client {
101 pub fn new(key: impl Into<SecretString>) -> Self {
102 Self {
103 inner: ClientInner::new(Some(key.into())),
104 }
105 }
106
107 pub fn chat(&self, model: &str) -> Chat<chat::Text> {
108 Chat::new(self, model)
109 }
110
111 pub fn models(&self) -> Route<Models> {
112 Route::new(self, Models::default())
113 }
114
115 pub fn generate_content(&self, model: &str) -> Route<GenerateContent> {
116 Route::new(self, GenerateContent::new(model.into()))
117 }
118
119 pub fn stream_generate_content(&self, model: &str) -> Route<StreamGenerateContent> {
120 Route::new(self, StreamGenerateContent::new(model))
121 }
122
123 pub fn instance() -> Client {
124 static STATIC_INSTANCE: LazyLock<Client> = LazyLock::new(Client::default);
125 STATIC_INSTANCE.clone()
126 }
127}
128
129pub struct GenerateContent {
130 pub(crate) model: Box<str>,
131 pub body: types::GenerateContent,
132}
133
134impl GenerateContent {
135 pub fn new(model: Box<str>) -> Self {
136 Self {
137 model,
138 body: types::GenerateContent::default(),
139 }
140 }
141
142 pub fn config(&mut self, config: types::GenerationConfig) {
143 self.body.generation_config = Some(config);
144 }
145
146 pub fn safety_settings(&mut self, safety_settings: Vec<types::SafetySettings>) {
147 self.body.safety_settings = safety_settings;
148 }
149
150 pub fn system_instruction(&mut self, instruction: &str) {
151 self.body.system_instruction = Some(types::SystemInstructionContent {
152 parts: vec![types::SystemInstructionPart {
153 text: Some(instruction.into()),
154 }],
155 });
156 }
157 pub fn tool_config(&mut self, conf: types::ToolConfig) {
158 self.body.tool_config = Some(conf);
159 }
160 pub fn contents(&mut self, contents: Vec<types::Content>) {
161 self.body.contents = contents;
162 }
163
164 pub fn message(&mut self, message: &str) {
165 self.body.contents.push(types::Content {
166 role: types::Role::User,
167 parts: vec![types::Part::text(message)],
168 });
169 }
170 pub fn tools(&mut self, tools: Vec<types::Tools>) {
171 self.body.tools = tools;
172 }
173}
174
175impl Request for GenerateContent {
176 type Model = types::Response;
177 type Body = types::GenerateContent;
178
179 const METHOD: Method = Method::POST;
180
181 fn format_uri(&self, fmt: &mut Formatter<'_, '_>) -> std::fmt::Result {
182 fmt.write_str("v1beta/")?;
183 fmt.write_str("models/")?;
184 fmt.write_str(&self.model)?;
185 fmt.write_str(":generateContent")
186 }
187
188 fn body(&self) -> Option<Self::Body> {
189 Some(self.body.clone())
190 }
191}
192
193#[derive(Default)]
194pub struct Models {
195 page_size: Option<usize>,
196 page_token: Option<Box<str>>,
197}
198
199impl Models {
200 pub fn page_size(&mut self, size: usize) {
201 self.page_size = size.into();
202 }
203
204 pub fn page_token(&mut self, token: &str) {
205 self.page_token = Some(Box::from(token));
206 }
207}
208
209impl Request for Models {
210 type Model = types::Models;
211 type Body = ();
212
213 const METHOD: Method = Method::GET;
214
215 fn format_uri(&self, fmt: &mut Formatter<'_, '_>) -> std::fmt::Result {
216 fmt.write_str("v1beta/")?;
217 fmt.write_str("models")?;
218 fmt.write_optional_query_param("page_size", self.page_size.as_ref())?;
219 fmt.write_optional_query_param("page_token", self.page_token.as_ref())
220 }
221}
222
223pub struct Formatter<'me, 'buffer> {
224 formatter: &'me mut std::fmt::Formatter<'buffer>,
225 is_first: bool,
226}
227
228impl<'buffer> Deref for Formatter<'_, 'buffer> {
229 type Target = std::fmt::Formatter<'buffer>;
230
231 fn deref(&self) -> &Self::Target {
232 self.formatter
233 }
234}
235
236impl DerefMut for Formatter<'_, '_> {
237 fn deref_mut(&mut self) -> &mut Self::Target {
238 self.formatter
239 }
240}
241
242impl<'me, 'buffer> Formatter<'me, 'buffer> {
243 pub(crate) fn new(formatter: &'me mut std::fmt::Formatter<'buffer>) -> Self {
244 Self {
245 formatter,
246 is_first: true,
247 }
248 }
249
250 pub(crate) fn write_query_param(
251 &mut self,
252 key: &str,
253 value: &impl std::fmt::Display,
254 ) -> std::fmt::Result {
255 if self.is_first {
256 self.formatter.write_char('?')?;
257 self.is_first = false;
258 } else {
259 self.formatter.write_char('&')?;
260 }
261
262 self.formatter.write_str(key)?;
263 self.formatter.write_char('=')?;
264 std::fmt::Display::fmt(value, self.formatter)
265 }
266
267 fn write_optional_query_param(
268 &mut self,
269 key: &str,
270 value: Option<&impl std::fmt::Display>,
271 ) -> std::fmt::Result {
272 if let Some(value) = value {
273 self.write_query_param(key, value)
274 } else {
275 Ok(())
276 }
277 }
278}
279
280pub struct ClientInner {
281 pub(crate) reqwest: reqwest::Client,
282 key: SecretString,
283}
284
285impl ClientInner {
286 fn new(key: Option<SecretString>) -> Arc<Self> {
287 Self {
288 reqwest: reqwest::Client::new(),
289 key: key
290 .or_else(|| std::env::var("GEMINI_API_KEY").ok().map(Into::into))
291 .expect("API key must be set either via argument or GEMINI_API_KEY environment variable"),
292 }
293 .into()
294 }
295}
296
297pub trait Request: Send + Sized + 'static {
298 type Model: serde::de::DeserializeOwned + Send + 'static;
299 type Body: serde::ser::Serialize;
300
301 const METHOD: Method;
302
303 fn format_uri(&self, fmt: &mut Formatter<'_, '_>) -> std::fmt::Result;
304
305 fn body(&self) -> Option<Self::Body> {
306 None
307 }
308}