gemini_rs/
client.rs

1use std::{
2    fmt::Write as _,
3    ops::{Deref, DerefMut},
4    sync::{Arc, LazyLock},
5};
6
7use futures::FutureExt as _;
8use reqwest::Method;
9use secrecy::{ExposeSecret as _, SecretString};
10
11use crate::{Chat, Error, Result, StreamGenerateContent, chat, types};
12
13pub(crate) const BASE_URI: &str = "https://generativelanguage.googleapis.com";
14
15pub struct Route<T> {
16    pub(crate) client: Client,
17    pub(crate) kind: T,
18}
19
20impl<T> Route<T> {
21    fn new(client: &Client, kind: T) -> Self {
22        Self {
23            client: client.clone(),
24            kind,
25        }
26    }
27}
28
29impl<T: Request> IntoFuture for Route<T> {
30    type Output = Result<T::Model>;
31    type IntoFuture = futures::future::BoxFuture<'static, Self::Output>;
32
33    fn into_future(self) -> Self::IntoFuture {
34        async move {
35            let mut request = self
36                .client
37                .reqwest
38                .request(T::METHOD, format!("{BASE_URI}/{self}"));
39
40            if let Some(body) = self.kind.body() {
41                request = request.json(&body);
42            };
43
44            let response = request.send().await?;
45            let raw_json = response.text().await?;
46
47            match serde_json::from_str::<types::ApiResponse<T::Model>>(&raw_json)? {
48                types::ApiResponse::Ok(response) => Ok(response),
49                types::ApiResponse::Err(api_error) => Err(Error::Gemini(api_error.error)),
50            }
51        }
52        .boxed()
53    }
54}
55
56impl Deref for Route<GenerateContent> {
57    type Target = GenerateContent;
58
59    fn deref(&self) -> &Self::Target {
60        &self.kind
61    }
62}
63
64impl DerefMut for Route<GenerateContent> {
65    fn deref_mut(&mut self) -> &mut Self::Target {
66        &mut self.kind
67    }
68}
69
70impl<T: Request> std::fmt::Display for Route<T> {
71    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        let mut fmt = Formatter::new(fmt);
73        self.kind.format_uri(&mut fmt)?;
74        fmt.write_query_param("key", &self.client.key.expose_secret())
75    }
76}
77
78/// Covers the 20% of use cases that [Chat] doesn't
79#[derive(Clone)]
80pub struct Client {
81    pub(crate) inner: Arc<ClientInner>,
82}
83
84impl Deref for Client {
85    type Target = ClientInner;
86
87    fn deref(&self) -> &Self::Target {
88        &self.inner
89    }
90}
91
92impl Default for Client {
93    fn default() -> Self {
94        Self {
95            inner: ClientInner::new(None),
96        }
97    }
98}
99
100impl Client {
101    pub fn new(key: impl Into<SecretString>) -> Self {
102        Self {
103            inner: ClientInner::new(Some(key.into())),
104        }
105    }
106
107    pub fn chat(&self, model: &str) -> Chat<chat::Text> {
108        Chat::new(self, model)
109    }
110
111    pub fn models(&self) -> Route<Models> {
112        Route::new(self, Models::default())
113    }
114
115    pub fn generate_content(&self, model: &str) -> Route<GenerateContent> {
116        Route::new(self, GenerateContent::new(model.into()))
117    }
118
119    pub fn stream_generate_content(&self, model: &str) -> Route<StreamGenerateContent> {
120        Route::new(self, StreamGenerateContent::new(model))
121    }
122
123    pub fn instance() -> Client {
124        static STATIC_INSTANCE: LazyLock<Client> = LazyLock::new(Client::default);
125        STATIC_INSTANCE.clone()
126    }
127}
128
129pub struct GenerateContent {
130    pub(crate) model: Box<str>,
131    pub body: types::GenerateContent,
132}
133
134impl GenerateContent {
135    pub fn new(model: Box<str>) -> Self {
136        Self {
137            model,
138            body: types::GenerateContent::default(),
139        }
140    }
141
142    pub fn config(&mut self, config: types::GenerationConfig) {
143        self.body.generation_config = Some(config);
144    }
145
146    pub fn safety_settings(&mut self, safety_settings: Vec<types::SafetySettings>) {
147        self.body.safety_settings = safety_settings;
148    }
149
150    pub fn system_instruction(&mut self, instruction: &str) {
151        self.body.system_instruction = Some(types::SystemInstructionContent {
152            parts: vec![types::SystemInstructionPart {
153                text: Some(instruction.into()),
154            }],
155        });
156    }
157    pub fn tool_config(&mut self, conf: types::ToolConfig) {
158        self.body.tool_config = Some(conf);
159    }
160    pub fn contents(&mut self, contents: Vec<types::Content>) {
161        self.body.contents = contents;
162    }
163
164    pub fn message(&mut self, message: &str) {
165        self.body.contents.push(types::Content {
166            role: types::Role::User,
167            parts: vec![types::Part::text(message)],
168        });
169    }
170    pub fn tools(&mut self, tools: Vec<types::Tools>) {
171        self.body.tools = tools;
172    }
173}
174
175impl Request for GenerateContent {
176    type Model = types::Response;
177    type Body = types::GenerateContent;
178
179    const METHOD: Method = Method::POST;
180
181    fn format_uri(&self, fmt: &mut Formatter<'_, '_>) -> std::fmt::Result {
182        fmt.write_str("v1beta/")?;
183        fmt.write_str("models/")?;
184        fmt.write_str(&self.model)?;
185        fmt.write_str(":generateContent")
186    }
187
188    fn body(&self) -> Option<Self::Body> {
189        Some(self.body.clone())
190    }
191}
192
193#[derive(Default)]
194pub struct Models {
195    page_size: Option<usize>,
196    page_token: Option<Box<str>>,
197}
198
199impl Models {
200    pub fn page_size(&mut self, size: usize) {
201        self.page_size = size.into();
202    }
203
204    pub fn page_token(&mut self, token: &str) {
205        self.page_token = Some(Box::from(token));
206    }
207}
208
209impl Request for Models {
210    type Model = types::Models;
211    type Body = ();
212
213    const METHOD: Method = Method::GET;
214
215    fn format_uri(&self, fmt: &mut Formatter<'_, '_>) -> std::fmt::Result {
216        fmt.write_str("v1beta/")?;
217        fmt.write_str("models")?;
218        fmt.write_optional_query_param("page_size", self.page_size.as_ref())?;
219        fmt.write_optional_query_param("page_token", self.page_token.as_ref())
220    }
221}
222
223pub struct Formatter<'me, 'buffer> {
224    formatter: &'me mut std::fmt::Formatter<'buffer>,
225    is_first: bool,
226}
227
228impl<'buffer> Deref for Formatter<'_, 'buffer> {
229    type Target = std::fmt::Formatter<'buffer>;
230
231    fn deref(&self) -> &Self::Target {
232        self.formatter
233    }
234}
235
236impl DerefMut for Formatter<'_, '_> {
237    fn deref_mut(&mut self) -> &mut Self::Target {
238        self.formatter
239    }
240}
241
242impl<'me, 'buffer> Formatter<'me, 'buffer> {
243    pub(crate) fn new(formatter: &'me mut std::fmt::Formatter<'buffer>) -> Self {
244        Self {
245            formatter,
246            is_first: true,
247        }
248    }
249
250    pub(crate) fn write_query_param(
251        &mut self,
252        key: &str,
253        value: &impl std::fmt::Display,
254    ) -> std::fmt::Result {
255        if self.is_first {
256            self.formatter.write_char('?')?;
257            self.is_first = false;
258        } else {
259            self.formatter.write_char('&')?;
260        }
261
262        self.formatter.write_str(key)?;
263        self.formatter.write_char('=')?;
264        std::fmt::Display::fmt(value, self.formatter)
265    }
266
267    fn write_optional_query_param(
268        &mut self,
269        key: &str,
270        value: Option<&impl std::fmt::Display>,
271    ) -> std::fmt::Result {
272        if let Some(value) = value {
273            self.write_query_param(key, value)
274        } else {
275            Ok(())
276        }
277    }
278}
279
280pub struct ClientInner {
281    pub(crate) reqwest: reqwest::Client,
282    key: SecretString,
283}
284
285impl ClientInner {
286    fn new(key: Option<SecretString>) -> Arc<Self> {
287        Self {
288            reqwest: reqwest::Client::new(),
289            key: key
290                .or_else(|| std::env::var("GEMINI_API_KEY").ok().map(Into::into))
291                .expect("API key must be set either via argument or GEMINI_API_KEY environment variable"),
292        }
293        .into()
294    }
295}
296
297pub trait Request: Send + Sized + 'static {
298    type Model: serde::de::DeserializeOwned + Send + 'static;
299    type Body: serde::ser::Serialize;
300
301    const METHOD: Method;
302
303    fn format_uri(&self, fmt: &mut Formatter<'_, '_>) -> std::fmt::Result;
304
305    fn body(&self) -> Option<Self::Body> {
306        None
307    }
308}