use std::sync::mpsc::Sender;
use futures::StreamExt;
use reqwest::Client;
use serde::Serialize;
use tokio_util::sync::CancellationToken;
use super::AiError;
use super::sse::{GeminiEventParser, SseParser};
use crate::ai::ai_state::AiResponse;
const GEMINI_API_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
#[derive(Debug, Clone)]
pub struct AsyncGeminiClient {
client: Client,
api_key: String,
model: String,
}
impl AsyncGeminiClient {
pub fn new(api_key: String, model: String) -> Self {
Self {
client: Client::new(),
api_key,
model,
}
}
#[cfg(test)]
pub fn api_key(&self) -> &str {
&self.api_key
}
#[cfg(test)]
pub fn model(&self) -> &str {
&self.model
}
fn build_request_body(&self, prompt: &str) -> Result<String, AiError> {
#[derive(Serialize)]
struct Part {
text: String,
}
#[derive(Serialize)]
struct Content {
role: String,
parts: Vec<Part>,
}
#[derive(Serialize)]
struct RequestBody {
contents: Vec<Content>,
}
let body = RequestBody {
contents: vec![Content {
role: "user".to_string(),
parts: vec![Part {
text: prompt.to_string(),
}],
}],
};
serde_json::to_string(&body).map_err(|e| AiError::Parse {
provider: "Gemini".to_string(),
message: format!("Failed to serialize request body: {}", e),
})
}
fn build_url(&self) -> String {
format!(
"{}/{}:streamGenerateContent?alt=sse&key={}",
GEMINI_API_URL, self.model, self.api_key
)
}
pub async fn stream_with_cancel(
&self,
prompt: &str,
request_id: u64,
cancel_token: CancellationToken,
response_tx: Sender<AiResponse>,
) -> Result<(), AiError> {
if cancel_token.is_cancelled() {
return Err(AiError::Cancelled);
}
let body = self.build_request_body(prompt)?;
let url = self.build_url();
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.body(body)
.send()
.await
.map_err(|e| AiError::Network {
provider: "Gemini".to_string(),
message: e.to_string(),
})?;
if !response.status().is_success() {
let code = response.status().as_u16();
let message = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(AiError::Api {
provider: "Gemini".to_string(),
code,
message,
});
}
let mut stream = response.bytes_stream();
let mut sse_parser = SseParser::new(GeminiEventParser);
loop {
tokio::select! {
biased;
_ = cancel_token.cancelled() => {
return Err(AiError::Cancelled);
}
chunk = stream.next() => {
match chunk {
Some(Ok(bytes)) => {
for text in sse_parser.parse_chunk(&bytes) {
if response_tx
.send(AiResponse::Chunk {
text,
request_id,
})
.is_err()
{
return Ok(());
}
}
}
Some(Err(e)) => {
return Err(AiError::Network {
provider: "Gemini".to_string(),
message: e.to_string(),
});
}
None => {
break;
}
}
}
}
}
Ok(())
}
}
#[cfg(test)]
#[path = "async_gemini_tests.rs"]
mod async_gemini_tests;