use async_trait::async_trait;
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use crate::error::LlmError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: Role,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
}
impl ChatMessage {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: Role::System,
content: content.into(),
tool_call_id: None,
name: None,
tool_calls: None,
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: Role::User,
content: content.into(),
tool_call_id: None,
name: None,
tool_calls: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: content.into(),
tool_call_id: None,
name: None,
tool_calls: None,
}
}
pub fn assistant_with_tool_calls(content: Option<String>, tool_calls: Vec<ToolCall>) -> Self {
Self {
role: Role::Assistant,
content: content.unwrap_or_default(),
tool_call_id: None,
name: None,
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
}
}
pub fn tool_result(
tool_call_id: impl Into<String>,
name: impl Into<String>,
content: impl Into<String>,
) -> Self {
Self {
role: Role::Tool,
content: content.into(),
tool_call_id: Some(tool_call_id.into()),
name: Some(name.into()),
tool_calls: None,
}
}
}
#[derive(Debug, Clone)]
pub struct CompletionRequest {
pub messages: Vec<ChatMessage>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub stop_sequences: Option<Vec<String>>,
pub metadata: std::collections::HashMap<String, String>,
}
impl CompletionRequest {
pub fn new(messages: Vec<ChatMessage>) -> Self {
Self {
messages,
max_tokens: None,
temperature: None,
stop_sequences: None,
metadata: std::collections::HashMap::new(),
}
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
}
#[derive(Debug, Clone)]
pub struct CompletionResponse {
pub content: String,
pub input_tokens: u32,
pub output_tokens: u32,
pub finish_reason: FinishReason,
pub response_id: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FinishReason {
Stop,
Length,
ToolUse,
ContentFilter,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Debug, Clone)]
pub struct ToolResult {
pub tool_call_id: String,
pub name: String,
pub content: String,
pub is_error: bool,
}
#[derive(Debug, Clone)]
pub struct ToolCompletionRequest {
pub messages: Vec<ChatMessage>,
pub tools: Vec<ToolDefinition>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub tool_choice: Option<String>,
pub metadata: std::collections::HashMap<String, String>,
}
impl ToolCompletionRequest {
pub fn new(messages: Vec<ChatMessage>, tools: Vec<ToolDefinition>) -> Self {
Self {
messages,
tools,
max_tokens: None,
temperature: None,
tool_choice: None,
metadata: std::collections::HashMap::new(),
}
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_tool_choice(mut self, choice: impl Into<String>) -> Self {
self.tool_choice = Some(choice.into());
self
}
}
#[derive(Debug, Clone)]
pub struct ToolCompletionResponse {
pub content: Option<String>,
pub tool_calls: Vec<ToolCall>,
pub input_tokens: u32,
pub output_tokens: u32,
pub finish_reason: FinishReason,
pub response_id: Option<String>,
}
#[derive(Debug, Clone)]
pub struct ModelMetadata {
pub id: String,
pub context_length: Option<u32>,
}
#[async_trait]
pub trait LlmProvider: Send + Sync {
fn model_name(&self) -> &str;
fn cost_per_token(&self) -> (Decimal, Decimal);
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, LlmError>;
async fn complete_with_tools(
&self,
request: ToolCompletionRequest,
) -> Result<ToolCompletionResponse, LlmError>;
async fn list_models(&self) -> Result<Vec<String>, LlmError> {
Ok(Vec::new())
}
async fn model_metadata(&self) -> Result<ModelMetadata, LlmError> {
Ok(ModelMetadata {
id: self.model_name().to_string(),
context_length: None,
})
}
fn active_model_name(&self) -> String {
self.model_name().to_string()
}
fn set_model(&self, _model: &str) -> Result<(), LlmError> {
Err(LlmError::RequestFailed {
provider: "unknown".to_string(),
reason: "Runtime model switching not supported by this provider".to_string(),
})
}
fn seed_response_chain(&self, _thread_id: &str, _response_id: String) {}
fn get_response_chain_id(&self, _thread_id: &str) -> Option<String> {
None
}
fn calculate_cost(&self, input_tokens: u32, output_tokens: u32) -> Decimal {
let (input_cost, output_cost) = self.cost_per_token();
input_cost * Decimal::from(input_tokens) + output_cost * Decimal::from(output_tokens)
}
}