aichat 0.30.0

All-in-one LLM CLI Tool
use super::vertexai::*;
use super::*;

use anyhow::{Context, Result};
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};

const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";

#[derive(Debug, Clone, Deserialize, Default)]
pub struct GeminiConfig {
    pub name: Option<String>,
    pub api_key: Option<String>,
    pub api_base: Option<String>,
    #[serde(default)]
    pub models: Vec<ModelData>,
    pub patch: Option<RequestPatch>,
    pub extra: Option<ExtraConfig>,
}

impl GeminiClient {
    config_get_fn!(api_key, get_api_key);
    config_get_fn!(api_base, get_api_base);

    pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)];
}

impl_client_trait!(
    GeminiClient,
    (
        prepare_chat_completions,
        gemini_chat_completions,
        gemini_chat_completions_streaming
    ),
    (prepare_embeddings, embeddings),
    (noop_prepare_rerank, noop_rerank),
);

fn prepare_chat_completions(
    self_: &GeminiClient,
    data: ChatCompletionsData,
) -> Result<RequestData> {
    let api_key = self_.get_api_key()?;
    let api_base = self_
        .get_api_base()
        .unwrap_or_else(|_| API_BASE.to_string());

    let func = match data.stream {
        true => "streamGenerateContent",
        false => "generateContent",
    };

    let url = format!(
        "{}/models/{}:{}?key={}",
        api_base.trim_end_matches('/'),
        self_.model.real_name(),
        func,
        api_key
    );

    let body = gemini_build_chat_completions_body(data, &self_.model)?;

    let request_data = RequestData::new(url, body);

    Ok(request_data)
}

fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result<RequestData> {
    let api_key = self_.get_api_key()?;
    let api_base = self_
        .get_api_base()
        .unwrap_or_else(|_| API_BASE.to_string());

    let url = format!(
        "{}/models/{}:batchEmbedContents?key={}",
        api_base.trim_end_matches('/'),
        self_.model.real_name(),
        api_key
    );

    let model_id = format!("models/{}", self_.model.real_name());

    let requests: Vec<_> = data
        .texts
        .iter()
        .map(|text| {
            json!({
                "model": model_id,
                "content": {
                    "parts": [
                        {
                            "text": text
                        }
                    ]
                },
            })
        })
        .collect();

    let body = json!({
        "requests": requests,
    });

    let request_data = RequestData::new(url, body);

    Ok(request_data)
}

async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
    let res = builder.send().await?;
    let status = res.status();
    let data: Value = res.json().await?;
    if !status.is_success() {
        catch_error(&data, status.as_u16())?;
    }
    let res_body: EmbeddingsResBody =
        serde_json::from_value(data).context("Invalid embeddings data")?;
    let output = res_body
        .embeddings
        .into_iter()
        .map(|embedding| embedding.values)
        .collect();
    Ok(output)
}

#[derive(Deserialize)]
struct EmbeddingsResBody {
    embeddings: Vec<EmbeddingsResBodyEmbedding>,
}

#[derive(Deserialize)]
struct EmbeddingsResBodyEmbedding {
    values: Vec<f32>,
}