pub mod gemini;
pub mod registry;
pub use gemini::{GeminiLlm, GeminiLlmParams};
pub use registry::LlmRegistry;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use rs_genai::prelude::{Content, Part, Tool};
pub trait TokenProvider: Send + Sync {
fn token(&self) -> String;
}
pub struct EnvTokenProvider;
impl TokenProvider for EnvTokenProvider {
fn token(&self) -> String {
std::env::var("GOOGLE_ACCESS_TOKEN").unwrap_or_default()
}
}
pub struct GcloudTokenProvider {
cache: parking_lot::Mutex<(String, std::time::Instant)>,
ttl: std::time::Duration,
}
impl GcloudTokenProvider {
pub fn new(ttl: std::time::Duration) -> Self {
Self {
cache: parking_lot::Mutex::new((String::new(), std::time::Instant::now())),
ttl,
}
}
}
impl TokenProvider for GcloudTokenProvider {
fn token(&self) -> String {
let mut guard = self.cache.lock();
let (ref mut cached_token, ref mut fetched_at) = *guard;
if !cached_token.is_empty() && fetched_at.elapsed() < self.ttl {
return cached_token.clone();
}
match std::process::Command::new("gcloud")
.args(["auth", "print-access-token"])
.output()
{
Ok(output) if output.status.success() => {
let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
*cached_token = token.clone();
*fetched_at = std::time::Instant::now();
token
}
_ => {
std::env::var("GOOGLE_ACCESS_TOKEN").unwrap_or_default()
}
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct LlmRequest {
pub contents: Vec<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_instruction: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub tools: Vec<Tool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_mime_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_json_schema: Option<serde_json::Value>,
}
impl LlmRequest {
pub fn from_text(text: impl Into<String>) -> Self {
Self {
contents: vec![Content {
role: Some(rs_genai::prelude::Role::User),
parts: vec![Part::Text { text: text.into() }],
}],
..Default::default()
}
}
pub fn from_contents(contents: Vec<Content>) -> Self {
Self {
contents,
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmResponse {
pub content: Content,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<TokenUsage>,
}
impl LlmResponse {
pub fn text(&self) -> String {
self.content
.parts
.iter()
.filter_map(|p| match p {
Part::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
pub fn function_calls(&self) -> Vec<&rs_genai::prelude::FunctionCall> {
self.content
.parts
.iter()
.filter_map(|p| match p {
Part::FunctionCall { function_call } => Some(function_call),
_ => None,
})
.collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, thiserror::Error)]
pub enum LlmError {
#[error("LLM request failed: {0}")]
RequestFailed(String),
#[error("Model not available: {0}")]
ModelNotAvailable(String),
#[error("Rate limited")]
RateLimited,
#[error("Content filtered")]
ContentFiltered,
#[error("{0}")]
Other(String),
}
#[async_trait]
pub trait BaseLlm: Send + Sync {
fn model_id(&self) -> &str;
async fn generate(&self, request: LlmRequest) -> Result<LlmResponse, LlmError>;
async fn warm_up(&self) -> Result<(), LlmError> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn llm_request_from_text() {
let req = LlmRequest::from_text("Hello!");
assert_eq!(req.contents.len(), 1);
assert!(req.system_instruction.is_none());
assert!(req.tools.is_empty());
}
#[test]
fn llm_request_from_contents() {
let contents = vec![Content {
role: Some(rs_genai::prelude::Role::User),
parts: vec![Part::Text {
text: "Hello".into(),
}],
}];
let req = LlmRequest::from_contents(contents);
assert_eq!(req.contents.len(), 1);
}
#[test]
fn llm_response_text() {
let resp = LlmResponse {
content: Content {
role: Some(rs_genai::prelude::Role::Model),
parts: vec![
Part::Text {
text: "Hello ".into(),
},
Part::Text {
text: "world!".into(),
},
],
},
finish_reason: Some("STOP".into()),
usage: None,
};
assert_eq!(resp.text(), "Hello world!");
}
#[test]
fn llm_response_function_calls() {
let resp = LlmResponse {
content: Content {
role: Some(rs_genai::prelude::Role::Model),
parts: vec![Part::FunctionCall {
function_call: rs_genai::prelude::FunctionCall {
name: "get_weather".into(),
args: serde_json::json!({"city": "London"}),
id: None,
},
}],
},
finish_reason: None,
usage: None,
};
let calls = resp.function_calls();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "get_weather");
}
#[test]
fn base_llm_is_object_safe() {
fn _assert(_: &dyn BaseLlm) {}
}
#[test]
fn token_usage() {
let usage = TokenUsage {
prompt_tokens: 10,
completion_tokens: 20,
total_tokens: 30,
};
assert_eq!(usage.total_tokens, 30);
}
}