gemini-client-api 7.4.5

Library to use Google Gemini API. Automatic context management, schema generation, function calling and more.
Documentation
use super::request::*;
use super::sessions::Session;
#[cfg(feature = "reqwest")]
use crate::gemini::error::GeminiResponseStreamError;
use bytes::Bytes;
use derive_new::new;
use futures::Stream;
#[cfg(feature = "reqwest")]
use reqwest::Response;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{
    pin::Pin,
    task::{Context, Poll},
};

#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum FinishReason {
    /// Default value. This value is unused.
    FinishReasonUnspecified,
    /// Natural stop point of the model or provided stop sequence.
    Stop,
    /// The maximum number of tokens as specified in the request was reached.
    MaxTokens,
    /// The response candidate content was flagged for safety reasons.
    Safety,
    /// The response candidate content was flagged for recitation reasons.
    Recitation,
    /// The response candidate content was flagged for using an unsupported language.
    Language,
    /// Unknown reason.
    Other,
    /// Token generation stopped because the content contains forbidden terms.
    Blocklist,
    /// Token generation stopped for potentially containing prohibited content.
    ProhibitedContent,
    /// Token generation stopped because the content potentially contains Sensitive Personally Identifiable Information (SPII).
    Spii,
    /// The function call generated by the model is invalid.
    MalformedFunctionCall,
    /// Token generation stopped because generated images contain safety violations.
    ImageSafety,
}

#[derive(Serialize, Deserialize, Clone, Debug, new)]
#[serde(rename_all = "camelCase")]
pub struct Candidate {
    pub content: Chat,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub finish_reason: Option<FinishReason>,
}

#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "camelCase")]
pub struct GeminiResponse {
    pub candidates: Vec<Candidate>,
    pub usage_metadata: Value,
    pub model_version: String,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub prompt_feedback: Option<Value>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub response_id: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub model_status: Option<Value>,
}
impl GeminiResponse {
    #[cfg(feature = "reqwest")]
    pub(crate) async fn new(response: Response) -> Result<GeminiResponse, reqwest::Error> {
        response.json().await
    }
    pub(crate) fn from_str(string: impl AsRef<str>) -> Result<Self, serde_json::Error> {
        serde_json::from_str(string.as_ref())
    }
    pub fn get_chat(&self) -> &Chat {
        &self.candidates[0].content
    }
    pub fn get_finish_reason(&self) -> Option<&FinishReason> {
        self.candidates[0].finish_reason.as_ref()
    }
    pub fn get_json<T>(&self) -> Result<T, serde_json::Error>
    where
        T: serde::de::DeserializeOwned,
    {
        Self::parse_json(self.get_chat().parts())
    }
    pub fn parse_json<T>(parts: &[Part]) -> Result<T, serde_json::Error>
    where
        T: serde::de::DeserializeOwned,
    {
        let unescaped_str = Chat::extract_text_all(parts, "");
        serde_json::from_str::<T>(&unescaped_str)
    }
}

#[cfg(feature = "reqwest")]
pin_project_lite::pin_project! {
    pub struct ResponseStream<F,T>
        where F:FnMut(&Session, GeminiResponse) -> T{
        #[pin]
        response_stream:Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Unpin + Send + 'static>,
        session: Session,
        data_extractor: F,
        buffer: Vec<u8>,
    }
}
#[cfg(feature = "reqwest")]
impl<F, T> Stream for ResponseStream<F, T>
where
    F: FnMut(&Session, GeminiResponse) -> T,
{
    type Item = Result<T, GeminiResponseStreamError>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let mut this = self.project();

        loop {
            // Look for the delimiter `\r\n\r\n` in the buffer.
            if let Some(end_of_message_pos) = this
                .buffer
                .windows(4)
                .position(|window| window == b"\r\n\r\n")
            {
                let message_bytes = this
                    .buffer
                    .drain(..end_of_message_pos + 4)
                    .collect::<Vec<u8>>();

                // Convert to string and remove the "data: " prefix and extra spaces.
                if let Some(json_str) = String::from_utf8_lossy(&message_bytes)
                    .strip_prefix("data:")
                    .map(|s| s.trim())
                {
                    if !json_str.is_empty() {
                        // Parse JSON.
                        let response = match GeminiResponse::from_str(json_str) {
                            Ok(resp) => resp,
                            Err(e) => {
                                let err = GeminiResponseStreamError::InvalidResposeFormat(format!(
                                    "JSON parsing error [{}]: {}",
                                    e, json_str
                                ));
                                return Poll::Ready(Some(Err(err)));
                            }
                        };

                        // Update the session and return the data.
                        this.session.update(&response);
                        let data = (this.data_extractor)(this.session, response);
                        return Poll::Ready(Some(Ok(data)));
                    }
                }
                continue;
            }

            // If the complete message is not in the buffer, read data from the network.
            match this.response_stream.as_mut().poll_next(cx) {
                Poll::Ready(Some(Ok(bytes))) => {
                    this.buffer.extend_from_slice(&bytes);
                }
                Poll::Pending => {
                    return Poll::Pending;
                }
                Poll::Ready(None) => {
                    if this.buffer.is_empty() {
                        return Poll::Ready(None);
                    } else {
                        // If there's something left in the buffer, it means we received a malformed message.
                        let err_text = String::from_utf8_lossy(this.buffer).into_owned();
                        let err = GeminiResponseStreamError::InvalidResposeFormat(format!(
                            "Stream ended with incomplete data in the buffer: {}",
                            err_text
                        ));
                        // Clear the buffer to avoid triggering the error again.
                        this.buffer.clear();
                        return Poll::Ready(Some(Err(err)));
                    }
                }
                Poll::Ready(Some(Err(e))) => {
                    return Poll::Ready(Some(Err(GeminiResponseStreamError::ReqwestError(e))));
                }
            }
        }
    }
}
#[cfg(feature = "reqwest")]
impl<F, T> ResponseStream<F, T>
where
    F: FnMut(&Session, GeminiResponse) -> T,
{
    pub(crate) fn new(
        response_stream: Box<
            dyn Stream<Item = Result<Bytes, reqwest::Error>> + Unpin + Send + 'static,
        >,
        session: Session,
        data_extractor: F,
    ) -> Self {
        Self {
            response_stream,
            session,
            data_extractor,
            buffer: Vec::new(),
        }
    }
    pub fn get_session(&self) -> &Session {
        &self.session
    }
    pub fn get_session_owned(self) -> Session {
        self.session
    }
}
#[cfg(feature = "reqwest")]
pub type GeminiResponseStream =
    ResponseStream<fn(&Session, GeminiResponse) -> GeminiResponse, GeminiResponse>;