pub mod anthropic;
pub mod embedding;
pub mod embedding_factory;
pub mod factory;
pub mod ollama_embedding;
pub mod openai_compat;
pub mod openai_embedding;
pub mod pricing;
use crate::error::LlmError;
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct TokenUsage {
pub input_tokens: u32,
pub output_tokens: u32,
}
impl TokenUsage {
pub fn accumulate(&mut self, other: &TokenUsage) {
self.input_tokens += other.input_tokens;
self.output_tokens += other.output_tokens;
}
}
#[derive(Debug, Clone)]
pub struct LlmResponse {
pub text: String,
pub usage: TokenUsage,
pub model: String,
}
#[derive(Debug, Clone)]
pub struct GenerationParams {
pub max_tokens: u32,
pub temperature: f32,
pub system_prompt: Option<String>,
}
impl Default for GenerationParams {
fn default() -> Self {
Self {
max_tokens: 512,
temperature: 0.7,
system_prompt: None,
}
}
}
#[async_trait::async_trait]
pub trait LlmProvider: Send + Sync {
fn name(&self) -> &str;
async fn complete(
&self,
system: &str,
user_message: &str,
params: &GenerationParams,
) -> Result<LlmResponse, LlmError>;
async fn health_check(&self) -> Result<(), LlmError>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn token_usage_default_is_zero() {
let usage = TokenUsage::default();
assert_eq!(usage.input_tokens, 0);
assert_eq!(usage.output_tokens, 0);
}
#[test]
fn token_usage_accumulate() {
let mut total = TokenUsage {
input_tokens: 100,
output_tokens: 50,
};
let other = TokenUsage {
input_tokens: 200,
output_tokens: 80,
};
total.accumulate(&other);
assert_eq!(total.input_tokens, 300);
assert_eq!(total.output_tokens, 130);
}
#[test]
fn token_usage_accumulate_multiple() {
let mut total = TokenUsage::default();
for i in 1..=5 {
total.accumulate(&TokenUsage {
input_tokens: i * 10,
output_tokens: i * 5,
});
}
assert_eq!(total.input_tokens, 150);
assert_eq!(total.output_tokens, 75);
}
#[test]
fn token_usage_accumulate_zero() {
let mut total = TokenUsage {
input_tokens: 42,
output_tokens: 17,
};
total.accumulate(&TokenUsage::default());
assert_eq!(total.input_tokens, 42);
assert_eq!(total.output_tokens, 17);
}
#[test]
fn generation_params_default() {
let params = GenerationParams::default();
assert_eq!(params.max_tokens, 512);
assert!((params.temperature - 0.7).abs() < f32::EPSILON);
assert!(params.system_prompt.is_none());
}
#[test]
fn generation_params_with_system_prompt() {
let params = GenerationParams {
system_prompt: Some("You are a helpful assistant.".to_string()),
..Default::default()
};
assert_eq!(
params.system_prompt.as_deref(),
Some("You are a helpful assistant.")
);
assert_eq!(params.max_tokens, 512);
}
#[test]
fn llm_response_fields() {
let response = LlmResponse {
text: "Hello, world!".to_string(),
usage: TokenUsage {
input_tokens: 10,
output_tokens: 3,
},
model: "gpt-4o-mini".to_string(),
};
assert_eq!(response.text, "Hello, world!");
assert_eq!(response.usage.input_tokens, 10);
assert_eq!(response.usage.output_tokens, 3);
assert_eq!(response.model, "gpt-4o-mini");
}
#[test]
fn token_usage_serde_roundtrip() {
let usage = TokenUsage {
input_tokens: 100,
output_tokens: 50,
};
let json = serde_json::to_string(&usage).expect("serialize");
let deserialized: TokenUsage = serde_json::from_str(&json).expect("deserialize");
assert_eq!(deserialized.input_tokens, 100);
assert_eq!(deserialized.output_tokens, 50);
}
#[test]
fn token_usage_clone() {
let usage = TokenUsage {
input_tokens: 42,
output_tokens: 17,
};
let cloned = usage.clone();
assert_eq!(cloned.input_tokens, 42);
assert_eq!(cloned.output_tokens, 17);
}
#[test]
fn generation_params_clone() {
let params = GenerationParams {
max_tokens: 1024,
temperature: 0.5,
system_prompt: Some("test".to_string()),
};
let cloned = params.clone();
assert_eq!(cloned.max_tokens, 1024);
assert!((cloned.temperature - 0.5).abs() < f32::EPSILON);
assert_eq!(cloned.system_prompt.as_deref(), Some("test"));
}
}