gemini_chat_api/
client.rs

1//! Async client for Google Gemini Chat API.
2
3use crate::enums::{gemini_headers, rotate_cookies_headers, Endpoint, Model};
4use crate::error::{Error, Result};
5use crate::utils::upload_file;
6
7use rand::Rng;
8use regex::Regex;
9use reqwest::cookie::Jar;
10use reqwest::{Client, Url};
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13
14use std::path::Path;
15use std::sync::Arc;
16use std::time::Duration;
17
18const SNLM0E_PATTERN: &str = r#"["']SNlM0e["']\s*:\s*["']([^"']+)["']"#;
19
20/// Response from a chat request.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ChatResponse {
23    /// The main text content of the response.
24    pub content: String,
25    /// Current conversation ID.
26    pub conversation_id: String,
27    /// Current response ID.
28    pub response_id: String,
29    /// Query used for factuality checking.
30    pub factuality_queries: Option<Value>,
31    /// Original text query.
32    pub text_query: String,
33    /// Alternative response choices.
34    pub choices: Vec<Choice>,
35    /// Whether an error occurred.
36    pub error: bool,
37}
38
39/// An alternative response choice.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Choice {
42    /// Choice identifier.
43    pub id: String,
44    /// Choice content text.
45    pub content: String,
46}
47
48/// Saved conversation data for persistence.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct SavedConversation {
51    pub conversation_name: String,
52    #[serde(rename = "_reqid")]
53    pub reqid: u32,
54    pub conversation_id: String,
55    pub response_id: String,
56    pub choice_id: String,
57    #[serde(rename = "SNlM0e")]
58    pub snlm0e: String,
59    pub model_name: String,
60    pub timestamp: String,
61}
62
63/// Async chatbot client for interacting with Google Gemini.
64///
65/// # Example
66/// ```no_run
67/// use gemini_chat_api::{AsyncChatbot, Model};
68///
69/// #[tokio::main]
70/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
71///     let chatbot = AsyncChatbot::new(
72///         "your_psid",
73///         "your_psidts",
74///         Model::default(),
75///         None,
76///         30,
77///     ).await?;
78///
79///     let response = chatbot.ask("Hello!", None).await?;
80///     println!("{}", response.content);
81///     Ok(())
82/// }
83/// ```
84pub struct AsyncChatbot {
85    client: Client,
86    snlm0e: String,
87    conversation_id: String,
88    response_id: String,
89    choice_id: String,
90    reqid: u32,
91    secure_1psidts: String,
92    model: Model,
93    proxy: Option<String>,
94}
95
96impl AsyncChatbot {
97    /// Creates a new AsyncChatbot instance.
98    ///
99    /// # Arguments
100    /// * `secure_1psid` - The __Secure-1PSID cookie value
101    /// * `secure_1psidts` - The __Secure-1PSIDTS cookie value
102    /// * `model` - The Gemini model to use
103    /// * `proxy` - Optional proxy URL
104    /// * `timeout` - Request timeout in seconds
105    ///
106    /// # Returns
107    /// A new initialized AsyncChatbot
108    ///
109    /// # Errors
110    /// Returns an error if authentication fails or network is unavailable.
111    pub async fn new(
112        secure_1psid: &str,
113        secure_1psidts: &str,
114        model: Model,
115        proxy: Option<&str>,
116        timeout: u64,
117    ) -> Result<Self> {
118        if secure_1psid.is_empty() {
119            return Err(Error::Authentication(
120                "__Secure-1PSID cookie is required".to_string(),
121            ));
122        }
123
124        // Build cookie jar with proper Secure cookie attributes
125        let jar = Jar::default();
126        let url: Url = "https://gemini.google.com".parse().unwrap();
127        // Secure cookies need proper attributes in the cookie string
128        jar.add_cookie_str(
129            &format!(
130                "__Secure-1PSID={}; Domain=.google.com; Path=/; Secure; SameSite=None",
131                secure_1psid
132            ),
133            &url,
134        );
135        jar.add_cookie_str(
136            &format!(
137                "__Secure-1PSIDTS={}; Domain=.google.com; Path=/; Secure; SameSite=None",
138                secure_1psidts
139            ),
140            &url,
141        );
142
143        // Build headers
144        let mut headers = gemini_headers();
145        if let Some(model_headers) = model.headers() {
146            headers.extend(model_headers);
147        }
148
149        // Build client
150        let mut builder = Client::builder()
151            .cookie_provider(Arc::new(jar))
152            .default_headers(headers)
153            .timeout(Duration::from_secs(timeout));
154
155        if let Some(proxy_url) = proxy {
156            builder = builder.proxy(reqwest::Proxy::all(proxy_url)?);
157        }
158
159        let client = builder.build()?;
160
161        let mut chatbot = Self {
162            client,
163            snlm0e: String::new(),
164            conversation_id: String::new(),
165            response_id: String::new(),
166            choice_id: String::new(),
167            reqid: rand::thread_rng().gen_range(1000000..9999999),
168            secure_1psidts: secure_1psidts.to_string(),
169            model,
170            proxy: proxy.map(|s| s.to_string()),
171        };
172
173        // Fetch the SNlM0e token
174        chatbot.snlm0e = chatbot.get_snlm0e().await?;
175
176        Ok(chatbot)
177    }
178
179    /// Fetches the SNlM0e value required for API requests.
180    async fn get_snlm0e(&mut self) -> Result<String> {
181        // Proactively try to rotate cookies if PSIDTS is missing
182        if self.secure_1psidts.is_empty() {
183            let _ = self.rotate_cookies().await;
184        }
185
186        let response = self.client.get(Endpoint::Init.url()).send().await?;
187
188        let status = response.status();
189        let text = response.text().await?;
190
191        if !status.is_success() {
192            if status.as_u16() == 401 || status.as_u16() == 403 {
193                return Err(Error::Authentication(format!(
194                    "Authentication failed (status {}). Check cookies.",
195                    status
196                )));
197            }
198            return Err(Error::Parse(format!("HTTP error: {}", status)));
199        }
200
201        // Check for authentication redirect - be precise to avoid false positives
202        // Only trigger if it's an actual login page, not just any page with google.com links
203        if text.contains("\"identifier-shown\"")
204            || text.contains("SignIn?continue")
205            || text.contains("Sign in - Google Accounts")
206        {
207            return Err(Error::Authentication(
208                "Authentication failed. Cookies might be invalid or expired.".to_string(),
209            ));
210        }
211
212        // Extract SNlM0e using regex
213        let re = Regex::new(SNLM0E_PATTERN).unwrap();
214        match re.captures(&text) {
215            Some(caps) => Ok(caps.get(1).unwrap().as_str().to_string()),
216            None => {
217                if text.contains("429") {
218                    Err(Error::Parse(
219                        "SNlM0e not found. Rate limit likely exceeded.".to_string(),
220                    ))
221                } else {
222                    Err(Error::Parse(
223                        "SNlM0e value not found in response. Check cookie validity.".to_string(),
224                    ))
225                }
226            }
227        }
228    }
229
230    /// Rotates the __Secure-1PSIDTS cookie.
231    async fn rotate_cookies(&mut self) -> Result<Option<String>> {
232        let response = self
233            .client
234            .post(Endpoint::RotateCookies.url())
235            .headers(rotate_cookies_headers())
236            .body(r#"[000,"-0000000000000000000"]"#)
237            .send()
238            .await?;
239
240        if !response.status().is_success() {
241            return Ok(None);
242        }
243
244        // Check for new cookie in response
245        // Note: Reqwest's cookie store automatically handles Set-Cookie headers for the client
246        // But we want to update our struct field too
247        for cookie in response.cookies() {
248            if cookie.name() == "__Secure-1PSIDTS" {
249                let new_value = cookie.value().to_string();
250                self.secure_1psidts = new_value.clone();
251                return Ok(Some(new_value));
252            }
253        }
254
255        Ok(None)
256    }
257
258    /// Sends a message to Gemini and returns the response.
259    ///
260    /// # Arguments
261    /// * `message` - The message text to send
262    /// * `image` - Optional image data to include
263    ///
264    /// # Returns
265    /// A ChatResponse containing the Gemini reply and metadata
266    pub async fn ask(&mut self, message: &str, image: Option<&[u8]>) -> Result<ChatResponse> {
267        if self.snlm0e.is_empty() {
268            return Err(Error::NotInitialized(
269                "AsyncChatbot not properly initialized. SNlM0e is missing.".to_string(),
270            ));
271        }
272
273        // Handle image upload if provided
274        let image_upload_id = if let Some(img_data) = image {
275            Some(upload_file(img_data, self.proxy.as_deref()).await?)
276        } else {
277            None
278        };
279
280        // Prepare message structure
281        let message_struct: Value = if let Some(ref upload_id) = image_upload_id {
282            serde_json::json!([
283                [message],
284                [[[upload_id, 1]]],
285                [&self.conversation_id, &self.response_id, &self.choice_id]
286            ])
287        } else {
288            serde_json::json!([
289                [message],
290                null,
291                [&self.conversation_id, &self.response_id, &self.choice_id]
292            ])
293        };
294
295        // Prepare request
296        let freq_value = serde_json::json!([null, serde_json::to_string(&message_struct)?]);
297        let params = [
298            ("bl", "boq_assistant-bard-web-server_20240625.13_p0"),
299            ("_reqid", &self.reqid.to_string()),
300            ("rt", "c"),
301        ];
302
303        let form_data = [
304            ("f.req", serde_json::to_string(&freq_value)?),
305            ("at", self.snlm0e.clone()),
306        ];
307
308        let response = self
309            .client
310            .post(Endpoint::Generate.url())
311            .query(&params)
312            .form(&form_data)
313            .send()
314            .await?;
315
316        if !response.status().is_success() {
317            return Err(Error::Network(response.error_for_status().unwrap_err()));
318        }
319
320        let text = response.text().await?;
321        self.parse_response(&text)
322    }
323
324    /// Parses the Gemini API response text.
325    fn parse_response(&mut self, text: &str) -> Result<ChatResponse> {
326        let lines: Vec<&str> = text.lines().collect();
327        if lines.len() < 3 {
328            return Err(Error::Parse(format!(
329                "Unexpected response format. Content: {}...",
330                &text[..text.len().min(200)]
331            )));
332        }
333
334        // Find the main response body
335        let mut body: Option<Value> = None;
336
337        for line in &lines {
338            // Skip empty lines and security prefix
339            if line.is_empty() || *line == ")]}" {
340                continue;
341            }
342
343            let mut clean_line = *line;
344            if clean_line.starts_with(")]}") {
345                clean_line = clean_line.get(4..).unwrap_or("").trim();
346            }
347
348            if !clean_line.starts_with('[') {
349                continue;
350            }
351
352            if let Ok(response_json) = serde_json::from_str::<Value>(clean_line) {
353                if let Some(arr) = response_json.as_array() {
354                    for part in arr {
355                        if let Some(part_arr) = part.as_array() {
356                            if part_arr.len() > 2
357                                && part_arr.first().and_then(|v| v.as_str()) == Some("wrb.fr")
358                            {
359                                if let Some(inner_str) = part_arr.get(2).and_then(|v| v.as_str()) {
360                                    if let Ok(main_part) = serde_json::from_str::<Value>(inner_str)
361                                    {
362                                        if main_part
363                                            .as_array()
364                                            .map(|a| a.len() > 4 && !a[4].is_null())
365                                            .unwrap_or(false)
366                                        {
367                                            body = Some(main_part);
368                                            break;
369                                        }
370                                    }
371                                }
372                            }
373                        }
374                    }
375                }
376
377                if body.is_some() {
378                    break;
379                }
380            }
381        }
382
383        let body = body.ok_or_else(|| {
384            Error::Parse("Failed to parse response body. No valid data found.".to_string())
385        })?;
386
387        // Extract data
388        let body_arr = body.as_array().unwrap();
389
390        // Extract content
391        // Structure: body[4][0][1][0] -> content
392        let content = body_arr
393            .get(4)
394            .and_then(|v| v.as_array())
395            .and_then(|a| a.first())
396            .and_then(|v| v.as_array())
397            .and_then(|a| a.get(1))
398            .and_then(|v| v.as_array())
399            .and_then(|a| a.first())
400            .and_then(|v| v.as_str())
401            .unwrap_or("")
402            .to_string();
403
404        // Extract conversation metadata
405        let conversation_id = body_arr
406            .get(1)
407            .and_then(|v| v.as_array())
408            .and_then(|a| a.first())
409            .and_then(|v| v.as_str())
410            .unwrap_or(&self.conversation_id)
411            .to_string();
412
413        let response_id = body_arr
414            .get(1)
415            .and_then(|v| v.as_array())
416            .and_then(|a| a.get(1))
417            .and_then(|v| v.as_str())
418            .unwrap_or(&self.response_id)
419            .to_string();
420
421        // Extract other data
422        let factuality_queries = body_arr.get(3).cloned();
423        let text_query = body_arr
424            .get(2)
425            .and_then(|v| v.as_array())
426            .and_then(|a| a.first())
427            .and_then(|v| v.as_str())
428            .unwrap_or("")
429            .to_string();
430
431        // Extract choices
432        let mut choices = Vec::new();
433        if let Some(candidates) = body_arr.get(4).and_then(|v| v.as_array()) {
434            for candidate in candidates {
435                if let Some(cand_arr) = candidate.as_array() {
436                    if cand_arr.len() > 1 {
437                        let id = cand_arr
438                            .first()
439                            .and_then(|v| v.as_str())
440                            .unwrap_or("")
441                            .to_string();
442                        let choice_content = cand_arr
443                            .get(1)
444                            .and_then(|v| v.as_array())
445                            .and_then(|a| a.first())
446                            .and_then(|v| v.as_str())
447                            .unwrap_or("")
448                            .to_string();
449                        choices.push(Choice {
450                            id,
451                            content: choice_content,
452                        });
453                    }
454                }
455            }
456        }
457
458        let choice_id = choices
459            .first()
460            .map(|c| c.id.clone())
461            .unwrap_or_else(|| self.choice_id.clone());
462
463        // Update state
464        self.conversation_id = conversation_id.clone();
465        self.response_id = response_id.clone();
466        self.choice_id = choice_id;
467        self.reqid += rand::thread_rng().gen_range(1000..9000);
468
469        Ok(ChatResponse {
470            content,
471            conversation_id,
472            response_id,
473            factuality_queries,
474            text_query,
475            choices,
476            error: false,
477        })
478    }
479
480    /// Saves the current conversation to a file.
481    pub async fn save_conversation(&self, file_path: &str, conversation_name: &str) -> Result<()> {
482        let mut conversations = self.load_conversations(file_path).await?;
483
484        let conversation_data = SavedConversation {
485            conversation_name: conversation_name.to_string(),
486            reqid: self.reqid,
487            conversation_id: self.conversation_id.clone(),
488            response_id: self.response_id.clone(),
489            choice_id: self.choice_id.clone(),
490            snlm0e: self.snlm0e.clone(),
491            model_name: self.model.name().to_string(),
492            timestamp: chrono_now(),
493        };
494
495        // Update or add conversation
496        let mut found = false;
497        for conv in &mut conversations {
498            if conv.conversation_name == conversation_name {
499                *conv = conversation_data.clone();
500                found = true;
501                break;
502            }
503        }
504        if !found {
505            conversations.push(conversation_data);
506        }
507
508        // Ensure parent directory exists
509        if let Some(parent) = Path::new(file_path).parent() {
510            std::fs::create_dir_all(parent)?;
511        }
512
513        let json = serde_json::to_string_pretty(&conversations)?;
514        std::fs::write(file_path, json)?;
515
516        Ok(())
517    }
518
519    /// Loads all saved conversations from a file.
520    pub async fn load_conversations(&self, file_path: &str) -> Result<Vec<SavedConversation>> {
521        if !Path::new(file_path).exists() {
522            return Ok(Vec::new());
523        }
524
525        let content = std::fs::read_to_string(file_path)?;
526        let conversations: Vec<SavedConversation> = serde_json::from_str(&content)?;
527        Ok(conversations)
528    }
529
530    /// Loads a specific conversation by name.
531    pub async fn load_conversation(
532        &mut self,
533        file_path: &str,
534        conversation_name: &str,
535    ) -> Result<bool> {
536        let conversations = self.load_conversations(file_path).await?;
537
538        for conv in conversations {
539            if conv.conversation_name == conversation_name {
540                self.reqid = conv.reqid;
541                self.conversation_id = conv.conversation_id;
542                self.response_id = conv.response_id;
543                self.choice_id = conv.choice_id;
544                self.snlm0e = conv.snlm0e;
545
546                if let Some(model) = Model::from_name(&conv.model_name) {
547                    self.model = model;
548                }
549
550                return Ok(true);
551            }
552        }
553
554        Ok(false)
555    }
556
557    /// Gets the current conversation ID.
558    pub fn conversation_id(&self) -> &str {
559        &self.conversation_id
560    }
561
562    /// Gets the current model.
563    pub fn model(&self) -> &Model {
564        &self.model
565    }
566
567    /// Resets the conversation state (IDs) to start a fresh conversation session.
568    /// This keeps authentication valid (SNlM0e, cookies) but generates new conversation IDs.
569    pub fn reset(&mut self) {
570        self.conversation_id.clear();
571        self.response_id.clear();
572        self.choice_id.clear();
573        self.reqid = rand::thread_rng().gen_range(1000000..9999999);
574    }
575}
576
577/// Simple timestamp function (avoids adding chrono dependency).
578fn chrono_now() -> String {
579    use std::time::{SystemTime, UNIX_EPOCH};
580    let duration = SystemTime::now()
581        .duration_since(UNIX_EPOCH)
582        .unwrap_or_default();
583    format!("{}", duration.as_secs())
584}