use async_trait::async_trait;
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use crate::llm::error::LlmError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: Role,
pub content: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub content_parts: Vec<ContentPart>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
}
impl ChatMessage {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: Role::System,
content: content.into(),
content_parts: Vec::new(),
tool_call_id: None,
name: None,
tool_calls: None,
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: Role::User,
content: content.into(),
content_parts: Vec::new(),
tool_call_id: None,
name: None,
tool_calls: None,
}
}
pub fn user_with_parts(content: impl Into<String>, parts: Vec<ContentPart>) -> Self {
Self {
role: Role::User,
content: content.into(),
content_parts: parts,
tool_call_id: None,
name: None,
tool_calls: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: content.into(),
content_parts: Vec::new(),
tool_call_id: None,
name: None,
tool_calls: None,
}
}
pub fn assistant_with_tool_calls(content: Option<String>, tool_calls: Vec<ToolCall>) -> Self {
Self {
role: Role::Assistant,
content: content.unwrap_or_default(),
content_parts: Vec::new(),
tool_call_id: None,
name: None,
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
}
}
pub fn tool_result(
tool_call_id: impl Into<String>,
name: impl Into<String>,
content: impl Into<String>,
) -> Self {
Self {
role: Role::Tool,
content: content.into(),
content_parts: Vec::new(),
tool_call_id: Some(tool_call_id.into()),
name: Some(name.into()),
tool_calls: None,
}
}
}
#[derive(Debug, Clone)]
pub struct CompletionRequest {
pub messages: Vec<ChatMessage>,
pub model: Option<String>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub stop_sequences: Option<Vec<String>>,
pub metadata: std::collections::HashMap<String, String>,
}
impl CompletionRequest {
pub fn new(messages: Vec<ChatMessage>) -> Self {
Self {
messages,
model: None,
max_tokens: None,
temperature: None,
stop_sequences: None,
metadata: std::collections::HashMap::new(),
}
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
}
#[derive(Debug, Clone)]
pub struct CompletionResponse {
pub content: String,
pub input_tokens: u32,
pub output_tokens: u32,
pub finish_reason: FinishReason,
pub cache_read_input_tokens: u32,
pub cache_creation_input_tokens: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FinishReason {
Stop,
Length,
ToolUse,
ContentFilter,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reasoning: Option<String>,
}
pub fn generate_tool_call_id(seed_a: usize, seed_b: usize) -> String {
let combined = (seed_a as u64)
.wrapping_mul(6364136223846793005)
.wrapping_add(seed_b as u64);
let mut buf = [b'0'; 9];
let mut val = combined;
for b in buf.iter_mut().rev() {
let digit = (val % 62) as u8;
*b = match digit {
0..=9 => b'0' + digit,
10..=35 => b'a' + (digit - 10),
_ => b'A' + (digit - 36),
};
val /= 62;
}
buf.iter().map(|&b| b as char).collect::<String>()
}
#[derive(Debug, Clone)]
pub struct ToolResult {
pub tool_call_id: String,
pub name: String,
pub content: String,
pub is_error: bool,
}
#[derive(Debug, Clone)]
pub struct ToolCompletionRequest {
pub messages: Vec<ChatMessage>,
pub tools: Vec<ToolDefinition>,
pub model: Option<String>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub stop_sequences: Option<Vec<String>>,
pub tool_choice: Option<String>,
pub metadata: std::collections::HashMap<String, String>,
}
impl ToolCompletionRequest {
pub fn new(messages: Vec<ChatMessage>, tools: Vec<ToolDefinition>) -> Self {
Self {
messages,
tools,
model: None,
max_tokens: None,
temperature: None,
stop_sequences: None,
tool_choice: None,
metadata: std::collections::HashMap::new(),
}
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
self.stop_sequences = Some(stop_sequences);
self
}
pub fn with_tool_choice(mut self, choice: impl Into<String>) -> Self {
self.tool_choice = Some(choice.into());
self
}
}
#[derive(Debug, Clone)]
pub struct ToolCompletionResponse {
pub content: Option<String>,
pub tool_calls: Vec<ToolCall>,
pub input_tokens: u32,
pub output_tokens: u32,
pub finish_reason: FinishReason,
pub cache_read_input_tokens: u32,
pub cache_creation_input_tokens: u32,
}
#[derive(Debug, Clone)]
pub struct ModelMetadata {
pub id: String,
pub context_length: Option<u32>,
}
#[async_trait]
pub trait LlmProvider: Send + Sync {
fn model_name(&self) -> &str;
fn cost_per_token(&self) -> (Decimal, Decimal);
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, LlmError>;
async fn complete_with_tools(
&self,
request: ToolCompletionRequest,
) -> Result<ToolCompletionResponse, LlmError>;
async fn list_models(&self) -> Result<Vec<String>, LlmError> {
Ok(Vec::new())
}
async fn model_metadata(&self) -> Result<ModelMetadata, LlmError> {
Ok(ModelMetadata {
id: self.model_name().to_string(),
context_length: None,
})
}
fn effective_model_name(&self, requested_model: Option<&str>) -> String {
requested_model
.map(std::borrow::ToOwned::to_owned)
.unwrap_or_else(|| self.active_model_name())
}
fn active_model_name(&self) -> String {
self.model_name().to_string()
}
fn set_model(&self, _model: &str) -> Result<(), LlmError> {
Err(LlmError::RequestFailed {
provider: "unknown".to_string(),
reason: "Runtime model switching not supported by this provider".to_string(),
})
}
fn calculate_cost(&self, input_tokens: u32, output_tokens: u32) -> Decimal {
let (input_cost, output_cost) = self.cost_per_token();
input_cost * Decimal::from(input_tokens) + output_cost * Decimal::from(output_tokens)
}
fn cache_write_multiplier(&self) -> Decimal {
Decimal::ONE
}
fn cache_read_discount(&self) -> Decimal {
Decimal::ONE
}
}
pub fn sanitize_tool_messages(messages: &mut [ChatMessage]) {
use std::collections::HashSet;
let mut known_ids: HashSet<String> = HashSet::new();
for msg in messages.iter() {
if msg.role == Role::Assistant
&& let Some(ref calls) = msg.tool_calls
{
for tc in calls {
known_ids.insert(tc.id.clone());
}
}
}
for msg in messages.iter_mut() {
if msg.role != Role::Tool {
continue;
}
let is_orphaned = match &msg.tool_call_id {
Some(id) => !known_ids.contains(id),
None => true,
};
if is_orphaned {
let tool_name = msg.name.as_deref().unwrap_or("unknown");
tracing::debug!(
tool_call_id = ?msg.tool_call_id,
tool_name,
"Rewriting orphaned tool_result as user message",
);
msg.role = Role::User;
msg.content = format!("[Tool `{}` returned: {}]", tool_name, msg.content);
msg.tool_call_id = None;
msg.name = None;
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum UnsupportedParam {
Temperature,
MaxTokens,
StopSequences,
}
impl UnsupportedParam {
pub fn name(&self) -> &'static str {
match self {
UnsupportedParam::Temperature => "temperature",
UnsupportedParam::MaxTokens => "max_tokens",
UnsupportedParam::StopSequences => "stop_sequences",
}
}
}
pub fn strip_unsupported_completion_params(
unsupported: &std::collections::HashSet<String>,
req: &mut CompletionRequest,
) {
if unsupported.is_empty() {
return;
}
if unsupported.contains(UnsupportedParam::Temperature.name()) {
req.temperature = None;
}
if unsupported.contains(UnsupportedParam::MaxTokens.name()) {
req.max_tokens = None;
}
if unsupported.contains(UnsupportedParam::StopSequences.name()) {
req.stop_sequences = None;
}
}
pub fn strip_unsupported_tool_params(
unsupported: &std::collections::HashSet<String>,
req: &mut ToolCompletionRequest,
) {
if unsupported.is_empty() {
return;
}
if unsupported.contains(UnsupportedParam::Temperature.name()) {
req.temperature = None;
}
if unsupported.contains(UnsupportedParam::MaxTokens.name()) {
req.max_tokens = None;
}
if unsupported.contains(UnsupportedParam::StopSequences.name()) {
req.stop_sequences = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[test]
fn generate_tool_call_id_has_valid_format() {
let samples = [
(0usize, 0usize),
(1usize, 2usize),
(42usize, 999usize),
(usize::MAX, usize::MAX),
];
for (a, b) in samples {
let id = generate_tool_call_id(a, b);
assert_eq!(
id.len(),
9,
"tool-call ID must be exactly 9 characters for seeds ({a}, {b})"
);
assert!(
id.chars().all(|c| c.is_ascii_alphanumeric()),
"tool-call ID must be ASCII alphanumeric for seeds ({a}, {b}), got: {id}"
);
}
}
#[test]
fn generate_tool_call_id_is_deterministic_for_same_seeds() {
let pairs = [
(0usize, 0usize),
(1usize, 2usize),
(123usize, 456usize),
(usize::MAX, 0usize),
];
for (a, b) in pairs {
let id1 = generate_tool_call_id(a, b);
let id2 = generate_tool_call_id(a, b);
let id3 = generate_tool_call_id(a, b);
assert_eq!(
id1, id2,
"tool-call ID must be deterministic for seeds ({a}, {b})"
);
assert_eq!(
id2, id3,
"tool-call ID must be deterministic across multiple calls for seeds ({a}, {b})"
);
}
}
#[test]
fn generate_tool_call_id_differs_for_different_seeds_in_small_sample() {
let seed_pairs = [
(0usize, 1usize),
(1usize, 0usize),
(1usize, 2usize),
(2usize, 3usize),
(10usize, 20usize),
(100usize, 200usize),
];
let mut ids = HashSet::new();
for (a, b) in seed_pairs {
let id = generate_tool_call_id(a, b);
let inserted = ids.insert(id.clone());
assert!(
inserted,
"expected distinct tool-call IDs for different seeds, \
but duplicate ID '{id}' found for seeds ({a}, {b})"
);
}
}
#[test]
fn test_sanitize_preserves_valid_pairs() {
let tc = ToolCall {
id: "call_1".to_string(),
name: "echo".to_string(),
arguments: serde_json::json!({}),
reasoning: None,
};
let mut messages = vec![
ChatMessage::user("hello"),
ChatMessage::assistant_with_tool_calls(None, vec![tc]),
ChatMessage::tool_result("call_1", "echo", "result"),
];
sanitize_tool_messages(&mut messages);
assert_eq!(messages[2].role, Role::Tool);
assert_eq!(messages[2].tool_call_id, Some("call_1".to_string()));
}
#[test]
fn test_sanitize_rewrites_orphaned_tool_result() {
let mut messages = vec![
ChatMessage::user("hello"),
ChatMessage::assistant("I'll use a tool"),
ChatMessage::tool_result("call_missing", "search", "some result"),
];
sanitize_tool_messages(&mut messages);
assert_eq!(messages[2].role, Role::User);
assert!(messages[2].content.contains("[Tool `search` returned:"));
assert!(messages[2].tool_call_id.is_none());
assert!(messages[2].name.is_none());
}
#[test]
fn test_sanitize_handles_no_tool_messages() {
let mut messages = vec![
ChatMessage::system("prompt"),
ChatMessage::user("hello"),
ChatMessage::assistant("hi"),
];
let original_len = messages.len();
sanitize_tool_messages(&mut messages);
assert_eq!(messages.len(), original_len);
}
#[test]
fn test_sanitize_multiple_orphaned() {
let tc = ToolCall {
id: "call_1".to_string(),
name: "echo".to_string(),
arguments: serde_json::json!({}),
reasoning: None,
};
let mut messages = vec![
ChatMessage::user("test"),
ChatMessage::assistant_with_tool_calls(None, vec![tc]),
ChatMessage::tool_result("call_1", "echo", "ok"),
ChatMessage::tool_result("call_2", "search", "orphan 1"),
ChatMessage::tool_result("call_3", "http", "orphan 2"),
];
sanitize_tool_messages(&mut messages);
assert_eq!(messages[2].role, Role::Tool); assert_eq!(messages[3].role, Role::User); assert_eq!(messages[4].role, Role::User); }
#[test]
fn test_sanitize_preserves_tool_results_with_matching_assistant() {
let tc1 = ToolCall {
id: "call_sel_1".to_string(),
name: "search".to_string(),
arguments: serde_json::json!({"q": "test"}),
reasoning: None,
};
let tc2 = ToolCall {
id: "call_sel_2".to_string(),
name: "http".to_string(),
arguments: serde_json::json!({"url": "https://example.com"}),
reasoning: None,
};
let mut messages = vec![
ChatMessage::system("You are a helpful assistant."),
ChatMessage::assistant_with_tool_calls(None, vec![tc1, tc2]),
ChatMessage::tool_result("call_sel_1", "search", "found 3 results"),
ChatMessage::tool_result("call_sel_2", "http", "200 OK"),
];
sanitize_tool_messages(&mut messages);
assert_eq!(messages[2].role, Role::Tool);
assert_eq!(messages[2].tool_call_id, Some("call_sel_1".to_string()));
assert_eq!(messages[2].content, "found 3 results");
assert_eq!(messages[3].role, Role::Tool);
assert_eq!(messages[3].tool_call_id, Some("call_sel_2".to_string()));
assert_eq!(messages[3].content, "200 OK");
}
#[test]
fn test_sanitize_rewrites_orphaned_tool_results() {
let mut messages = vec![
ChatMessage::system("You are a helpful assistant."),
ChatMessage::tool_result("call_bug_1", "search", "found 3 results"),
ChatMessage::tool_result("call_bug_2", "http", "200 OK"),
];
sanitize_tool_messages(&mut messages);
assert_eq!(messages[1].role, Role::User);
assert!(messages[1].content.contains("[Tool `search` returned:"));
assert!(messages[1].content.contains("found 3 results"));
assert!(messages[1].tool_call_id.is_none());
assert!(messages[1].name.is_none());
assert_eq!(messages[2].role, Role::User);
assert!(messages[2].content.contains("[Tool `http` returned:"));
assert!(messages[2].content.contains("200 OK"));
assert!(messages[2].tool_call_id.is_none());
assert!(messages[2].name.is_none());
}
#[test]
fn test_strip_unsupported_tool_params_strips_stop_sequences() {
let mut unsupported = std::collections::HashSet::new();
unsupported.insert(UnsupportedParam::StopSequences.name().to_string());
let mut req = ToolCompletionRequest::new(vec![ChatMessage::user("hello")], vec![]);
req.stop_sequences = Some(vec!["STOP".to_string()]);
strip_unsupported_tool_params(&unsupported, &mut req);
assert!(req.stop_sequences.is_none()); }
}