use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub use llm;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LlmBackend {
OpenAI,
Anthropic,
XAI,
Google,
Ollama,
OpenRouter,
}
impl LlmBackend {
pub fn api_key_env_var(&self) -> &'static str {
match self {
Self::OpenAI => "OPENAI_API_KEY",
Self::Anthropic => "ANTHROPIC_API_KEY",
Self::XAI => "XAI_API_KEY",
Self::Google => "GOOGLE_API_KEY",
Self::Ollama => "OLLAMA_HOST",
Self::OpenRouter => "OPENROUTER_API_KEY",
}
}
pub fn default_model(&self) -> &'static str {
match self {
Self::OpenAI => "gpt-4o",
Self::Anthropic => "claude-sonnet-4-20250514",
Self::XAI => "grok-2",
Self::Google => "gemini-2.0-flash",
Self::Ollama => "llama3.2",
Self::OpenRouter => "anthropic/claude-sonnet-4",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiLlmConfig {
pub primary_backend: LlmBackend,
pub fallback_backends: Vec<LlmBackend>,
pub model_overrides: HashMap<String, String>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub streaming: bool,
pub timeout_secs: u64,
}
impl Default for MultiLlmConfig {
fn default() -> Self {
Self {
primary_backend: LlmBackend::Anthropic,
fallback_backends: vec![LlmBackend::OpenAI, LlmBackend::OpenRouter],
model_overrides: HashMap::new(),
max_tokens: Some(4096),
temperature: Some(0.7),
streaming: true,
timeout_secs: 120,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
pub name: Option<String>,
}
impl Message {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: content.into(),
name: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".to_string(),
content: content.into(),
name: None,
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".to_string(),
content: content.into(),
name: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionResponse {
pub content: String,
pub tool_calls: Vec<ToolCall>,
pub finish_reason: String,
pub usage: TokenUsage,
pub backend: LlmBackend,
pub model: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ConversationMemory {
pub id: String,
pub messages: Vec<Message>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl ConversationMemory {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
messages: Vec::new(),
metadata: HashMap::new(),
}
}
pub fn add_message(&mut self, message: Message) {
self.messages.push(message);
}
pub fn last_messages(&self, n: usize) -> &[Message] {
let start = self.messages.len().saturating_sub(n);
&self.messages[start..]
}
pub fn clear(&mut self) {
self.messages.clear();
}
pub fn len(&self) -> usize {
self.messages.len()
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
}
pub struct ToolBuilder {
name: String,
description: String,
properties: serde_json::Map<String, serde_json::Value>,
required: Vec<String>,
}
impl ToolBuilder {
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
properties: serde_json::Map::new(),
required: Vec::new(),
}
}
pub fn string_param(
mut self,
name: impl Into<String>,
description: impl Into<String>,
required: bool,
) -> Self {
let name = name.into();
self.properties.insert(
name.clone(),
serde_json::json!({
"type": "string",
"description": description.into()
}),
);
if required {
self.required.push(name);
}
self
}
pub fn number_param(
mut self,
name: impl Into<String>,
description: impl Into<String>,
required: bool,
) -> Self {
let name = name.into();
self.properties.insert(
name.clone(),
serde_json::json!({
"type": "number",
"description": description.into()
}),
);
if required {
self.required.push(name);
}
self
}
pub fn bool_param(
mut self,
name: impl Into<String>,
description: impl Into<String>,
required: bool,
) -> Self {
let name = name.into();
self.properties.insert(
name.clone(),
serde_json::json!({
"type": "boolean",
"description": description.into()
}),
);
if required {
self.required.push(name);
}
self
}
pub fn array_param(
mut self,
name: impl Into<String>,
description: impl Into<String>,
item_type: &str,
required: bool,
) -> Self {
let name = name.into();
self.properties.insert(
name.clone(),
serde_json::json!({
"type": "array",
"items": { "type": item_type },
"description": description.into()
}),
);
if required {
self.required.push(name);
}
self
}
pub fn enum_param(
mut self,
name: impl Into<String>,
description: impl Into<String>,
values: &[&str],
required: bool,
) -> Self {
let name = name.into();
self.properties.insert(
name.clone(),
serde_json::json!({
"type": "string",
"enum": values,
"description": description.into()
}),
);
if required {
self.required.push(name);
}
self
}
pub fn build(self) -> ToolDefinition {
ToolDefinition {
name: self.name,
description: self.description,
parameters: serde_json::json!({
"type": "object",
"properties": self.properties,
"required": self.required
}),
}
}
}
pub fn is_backend_available(backend: LlmBackend) -> bool {
std::env::var(backend.api_key_env_var()).is_ok()
}
pub fn available_backends() -> Vec<LlmBackend> {
[
LlmBackend::OpenAI,
LlmBackend::Anthropic,
LlmBackend::XAI,
LlmBackend::Google,
LlmBackend::Ollama,
LlmBackend::OpenRouter,
]
.into_iter()
.filter(|b| is_backend_available(*b))
.collect()
}
pub fn estimate_tokens(text: &str) -> u32 {
(text.len() / 4) as u32
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = MultiLlmConfig::default();
assert_eq!(config.primary_backend, LlmBackend::Anthropic);
assert!(config.streaming);
}
#[test]
fn test_message_construction() {
let msg = Message::user("Hello");
assert_eq!(msg.role, "user");
assert_eq!(msg.content, "Hello");
let msg = Message::system("You are helpful");
assert_eq!(msg.role, "system");
}
#[test]
fn test_tool_builder() {
let tool = ToolBuilder::new("search", "Search for information")
.string_param("query", "Search query", true)
.number_param("limit", "Max results", false)
.enum_param("type", "Search type", &["web", "images", "news"], false)
.build();
assert_eq!(tool.name, "search");
assert!(tool.parameters["required"]
.as_array()
.unwrap()
.contains(&serde_json::json!("query")));
}
#[test]
fn test_conversation_memory() {
let mut memory = ConversationMemory::new("test-conv");
memory.add_message(Message::user("Hello"));
memory.add_message(Message::assistant("Hi there!"));
assert_eq!(memory.len(), 2);
assert_eq!(memory.last_messages(1).len(), 1);
assert_eq!(memory.last_messages(1)[0].role, "assistant");
}
#[test]
fn test_backend_env_vars() {
assert_eq!(LlmBackend::OpenAI.api_key_env_var(), "OPENAI_API_KEY");
assert_eq!(LlmBackend::Anthropic.api_key_env_var(), "ANTHROPIC_API_KEY");
}
#[test]
fn test_token_estimate() {
let text = "Hello, world! This is a test.";
let tokens = estimate_tokens(text);
assert!(tokens > 0);
assert!(tokens < 100);
}
}