use super::request::*;
use super::sessions::Session;
#[cfg(feature = "reqwest")]
use crate::gemini::error::GeminiResponseStreamError;
use bytes::Bytes;
use derive_new::new;
use futures::Stream;
#[cfg(feature = "reqwest")]
use reqwest::Response;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{
pin::Pin,
task::{Context, Poll},
};
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum FinishReason {
FinishReasonUnspecified,
Stop,
MaxTokens,
Safety,
Recitation,
Language,
Other,
Blocklist,
ProhibitedContent,
Spii,
MalformedFunctionCall,
ImageSafety,
}
#[derive(Serialize, Deserialize, Clone, Debug, new)]
#[serde(rename_all = "camelCase")]
pub struct Candidate {
pub content: Chat,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<FinishReason>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "camelCase")]
pub struct GeminiResponse {
pub candidates: Vec<Candidate>,
pub usage_metadata: Value,
pub model_version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_feedback: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model_status: Option<Value>,
}
impl GeminiResponse {
#[cfg(feature = "reqwest")]
pub(crate) async fn new(response: Response) -> Result<GeminiResponse, reqwest::Error> {
response.json().await
}
pub(crate) fn from_str(string: impl AsRef<str>) -> Result<Self, serde_json::Error> {
serde_json::from_str(string.as_ref())
}
pub fn get_chat(&self) -> &Chat {
&self.candidates[0].content
}
pub fn get_finish_reason(&self) -> Option<&FinishReason> {
self.candidates[0].finish_reason.as_ref()
}
pub fn get_json<T>(&self) -> Result<T, serde_json::Error>
where
T: serde::de::DeserializeOwned,
{
Self::parse_json(self.get_chat().parts())
}
pub fn parse_json<T>(parts: &[Part]) -> Result<T, serde_json::Error>
where
T: serde::de::DeserializeOwned,
{
let unescaped_str = Chat::extract_text_all(parts, "");
serde_json::from_str::<T>(&unescaped_str)
}
}
#[cfg(feature = "reqwest")]
pin_project_lite::pin_project! {
pub struct ResponseStream<F,T>
where F:FnMut(&Session, GeminiResponse) -> T{
#[pin]
response_stream:Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Unpin + Send + 'static>,
session: Session,
data_extractor: F,
buffer: Vec<u8>,
}
}
#[cfg(feature = "reqwest")]
impl<F, T> Stream for ResponseStream<F, T>
where
F: FnMut(&Session, GeminiResponse) -> T,
{
type Item = Result<T, GeminiResponseStreamError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
if let Some(end_of_message_pos) = this
.buffer
.windows(4)
.position(|window| window == b"\r\n\r\n")
{
let message_bytes = this
.buffer
.drain(..end_of_message_pos + 4)
.collect::<Vec<u8>>();
if let Some(json_str) = String::from_utf8_lossy(&message_bytes)
.strip_prefix("data:")
.map(|s| s.trim())
{
if !json_str.is_empty() {
let response = match GeminiResponse::from_str(json_str) {
Ok(resp) => resp,
Err(e) => {
let err = GeminiResponseStreamError::InvalidResposeFormat(format!(
"JSON parsing error [{}]: {}",
e, json_str
));
return Poll::Ready(Some(Err(err)));
}
};
this.session.update(&response);
let data = (this.data_extractor)(this.session, response);
return Poll::Ready(Some(Ok(data)));
}
}
continue;
}
match this.response_stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
this.buffer.extend_from_slice(&bytes);
}
Poll::Pending => {
return Poll::Pending;
}
Poll::Ready(None) => {
if this.buffer.is_empty() {
return Poll::Ready(None);
} else {
let err_text = String::from_utf8_lossy(this.buffer).into_owned();
let err = GeminiResponseStreamError::InvalidResposeFormat(format!(
"Stream ended with incomplete data in the buffer: {}",
err_text
));
this.buffer.clear();
return Poll::Ready(Some(Err(err)));
}
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(GeminiResponseStreamError::ReqwestError(e))));
}
}
}
}
}
#[cfg(feature = "reqwest")]
impl<F, T> ResponseStream<F, T>
where
F: FnMut(&Session, GeminiResponse) -> T,
{
pub(crate) fn new(
response_stream: Box<
dyn Stream<Item = Result<Bytes, reqwest::Error>> + Unpin + Send + 'static,
>,
session: Session,
data_extractor: F,
) -> Self {
Self {
response_stream,
session,
data_extractor,
buffer: Vec::new(),
}
}
pub fn get_session(&self) -> &Session {
&self.session
}
pub fn get_session_owned(self) -> Session {
self.session
}
}
#[cfg(feature = "reqwest")]
pub type GeminiResponseStream =
ResponseStream<fn(&Session, GeminiResponse) -> GeminiResponse, GeminiResponse>;