use async_trait::async_trait;
use rig::OneOrMany;
use rig::completion::{
AssistantContent, CompletionModel, CompletionRequest as RigRequest,
ToolDefinition as RigToolDefinition, Usage as RigUsage,
};
use rig::message::{
Message as RigMessage, ToolChoice as RigToolChoice, ToolFunction, ToolResult as RigToolResult,
ToolResultContent, UserContent,
};
use rust_decimal::Decimal;
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::error::LlmError;
use crate::llm::costs;
use crate::llm::provider::{
ChatMessage, CompletionRequest, CompletionResponse, FinishReason, LlmProvider,
ToolCall as IronToolCall, ToolCompletionRequest, ToolCompletionResponse,
ToolDefinition as IronToolDefinition,
};
pub struct RigAdapter<M: CompletionModel> {
model: M,
model_name: String,
input_cost: Decimal,
output_cost: Decimal,
}
impl<M: CompletionModel> RigAdapter<M> {
pub fn new(model: M, model_name: impl Into<String>) -> Self {
let name = model_name.into();
let (input_cost, output_cost) =
costs::model_cost(&name).unwrap_or_else(costs::default_cost);
Self {
model,
model_name: name,
input_cost,
output_cost,
}
}
}
fn convert_messages(messages: &[ChatMessage]) -> (Option<String>, Vec<RigMessage>) {
let mut preamble: Option<String> = None;
let mut history = Vec::new();
for msg in messages {
match msg.role {
crate::llm::Role::System => {
match preamble {
Some(ref mut p) => {
p.push('\n');
p.push_str(&msg.content);
}
None => preamble = Some(msg.content.clone()),
}
}
crate::llm::Role::User => {
history.push(RigMessage::user(&msg.content));
}
crate::llm::Role::Assistant => {
if let Some(ref tool_calls) = msg.tool_calls {
let mut contents: Vec<AssistantContent> = Vec::new();
if !msg.content.is_empty() {
contents.push(AssistantContent::text(&msg.content));
}
for tc in tool_calls {
contents.push(AssistantContent::ToolCall(rig::message::ToolCall::new(
tc.id.clone(),
ToolFunction::new(tc.name.clone(), tc.arguments.clone()),
)));
}
if let Ok(many) = OneOrMany::many(contents) {
history.push(RigMessage::Assistant {
id: None,
content: many,
});
} else {
history.push(RigMessage::assistant(&msg.content));
}
} else {
history.push(RigMessage::assistant(&msg.content));
}
}
crate::llm::Role::Tool => {
let tool_id = msg.tool_call_id.clone().unwrap_or_default();
history.push(RigMessage::User {
content: OneOrMany::one(UserContent::ToolResult(RigToolResult {
id: tool_id,
call_id: None,
content: OneOrMany::one(ToolResultContent::text(&msg.content)),
})),
});
}
}
}
(preamble, history)
}
fn convert_tools(tools: &[IronToolDefinition]) -> Vec<RigToolDefinition> {
tools
.iter()
.map(|t| RigToolDefinition {
name: t.name.clone(),
description: t.description.clone(),
parameters: t.parameters.clone(),
})
.collect()
}
fn convert_tool_choice(choice: Option<&str>) -> Option<RigToolChoice> {
match choice.map(|s| s.to_lowercase()).as_deref() {
Some("auto") => Some(RigToolChoice::Auto),
Some("required") => Some(RigToolChoice::Required),
Some("none") => Some(RigToolChoice::None),
_ => None,
}
}
fn extract_response(
choice: &OneOrMany<AssistantContent>,
_usage: &RigUsage,
) -> (Option<String>, Vec<IronToolCall>, FinishReason) {
let mut text_parts: Vec<String> = Vec::new();
let mut tool_calls: Vec<IronToolCall> = Vec::new();
for content in choice.iter() {
match content {
AssistantContent::Text(t) => {
if !t.text.is_empty() {
text_parts.push(t.text.clone());
}
}
AssistantContent::ToolCall(tc) => {
tool_calls.push(IronToolCall {
id: tc.id.clone(),
name: tc.function.name.clone(),
arguments: tc.function.arguments.clone(),
});
}
_ => {}
}
}
let text = if text_parts.is_empty() {
None
} else {
Some(text_parts.join(""))
};
let finish = if !tool_calls.is_empty() {
FinishReason::ToolUse
} else {
FinishReason::Stop
};
(text, tool_calls, finish)
}
fn saturate_u32(val: u64) -> u32 {
val.min(u32::MAX as u64) as u32
}
fn build_rig_request(
preamble: Option<String>,
mut history: Vec<RigMessage>,
tools: Vec<RigToolDefinition>,
tool_choice: Option<RigToolChoice>,
temperature: Option<f32>,
max_tokens: Option<u32>,
) -> Result<RigRequest, LlmError> {
if history.is_empty() {
history.push(RigMessage::user("Hello"));
}
let chat_history = OneOrMany::many(history).map_err(|e| LlmError::RequestFailed {
provider: "rig".to_string(),
reason: format!("Failed to build chat history: {}", e),
})?;
Ok(RigRequest {
preamble,
chat_history,
documents: Vec::new(),
tools,
temperature: temperature.map(|t| t as f64),
max_tokens: max_tokens.map(|t| t as u64),
tool_choice,
additional_params: None,
})
}
#[async_trait]
impl<M> LlmProvider for RigAdapter<M>
where
M: CompletionModel + Send + Sync + 'static,
M::Response: Send + Sync + Serialize + DeserializeOwned,
{
fn model_name(&self) -> &str {
&self.model_name
}
fn cost_per_token(&self) -> (Decimal, Decimal) {
(self.input_cost, self.output_cost)
}
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, LlmError> {
let (preamble, history) = convert_messages(&request.messages);
let rig_req = build_rig_request(
preamble,
history,
Vec::new(),
None,
request.temperature,
request.max_tokens,
)?;
let response =
self.model
.completion(rig_req)
.await
.map_err(|e| LlmError::RequestFailed {
provider: self.model_name.clone(),
reason: e.to_string(),
})?;
let (text, _tool_calls, finish) = extract_response(&response.choice, &response.usage);
Ok(CompletionResponse {
content: text.unwrap_or_default(),
input_tokens: saturate_u32(response.usage.input_tokens),
output_tokens: saturate_u32(response.usage.output_tokens),
finish_reason: finish,
response_id: None,
})
}
async fn complete_with_tools(
&self,
request: ToolCompletionRequest,
) -> Result<ToolCompletionResponse, LlmError> {
let (preamble, history) = convert_messages(&request.messages);
let tools = convert_tools(&request.tools);
let tool_choice = convert_tool_choice(request.tool_choice.as_deref());
let rig_req = build_rig_request(
preamble,
history,
tools,
tool_choice,
request.temperature,
request.max_tokens,
)?;
let response =
self.model
.completion(rig_req)
.await
.map_err(|e| LlmError::RequestFailed {
provider: self.model_name.clone(),
reason: e.to_string(),
})?;
let (text, tool_calls, finish) = extract_response(&response.choice, &response.usage);
Ok(ToolCompletionResponse {
content: text,
tool_calls,
input_tokens: saturate_u32(response.usage.input_tokens),
output_tokens: saturate_u32(response.usage.output_tokens),
finish_reason: finish,
response_id: None,
})
}
fn active_model_name(&self) -> String {
self.model_name.clone()
}
fn set_model(&self, _model: &str) -> Result<(), LlmError> {
Err(LlmError::RequestFailed {
provider: self.model_name.clone(),
reason: "Runtime model switching not supported for rig-core providers. \
Restart with a different model configured."
.to_string(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_convert_messages_system_to_preamble() {
let messages = vec![
ChatMessage::system("You are a helpful assistant."),
ChatMessage::user("Hello"),
];
let (preamble, history) = convert_messages(&messages);
assert_eq!(preamble, Some("You are a helpful assistant.".to_string()));
assert_eq!(history.len(), 1);
}
#[test]
fn test_convert_messages_multiple_systems_concatenated() {
let messages = vec![
ChatMessage::system("System 1"),
ChatMessage::system("System 2"),
ChatMessage::user("Hi"),
];
let (preamble, history) = convert_messages(&messages);
assert_eq!(preamble, Some("System 1\nSystem 2".to_string()));
assert_eq!(history.len(), 1);
}
#[test]
fn test_convert_messages_tool_result() {
let messages = vec![ChatMessage::tool_result(
"call_123",
"search",
"result text",
)];
let (preamble, history) = convert_messages(&messages);
assert!(preamble.is_none());
assert_eq!(history.len(), 1);
match &history[0] {
RigMessage::User { .. } => {}
other => panic!("Expected User message, got: {:?}", other),
}
}
#[test]
fn test_convert_messages_assistant_with_tool_calls() {
let tc = IronToolCall {
id: "call_1".to_string(),
name: "search".to_string(),
arguments: serde_json::json!({"query": "test"}),
};
let msg = ChatMessage::assistant_with_tool_calls(Some("thinking".to_string()), vec![tc]);
let messages = vec![msg];
let (_preamble, history) = convert_messages(&messages);
assert_eq!(history.len(), 1);
match &history[0] {
RigMessage::Assistant { content, .. } => {
assert!(content.iter().count() >= 2);
}
other => panic!("Expected Assistant message, got: {:?}", other),
}
}
#[test]
fn test_convert_tools() {
let tools = vec![IronToolDefinition {
name: "search".to_string(),
description: "Search the web".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"query": {"type": "string"}
}
}),
}];
let rig_tools = convert_tools(&tools);
assert_eq!(rig_tools.len(), 1);
assert_eq!(rig_tools[0].name, "search");
assert_eq!(rig_tools[0].description, "Search the web");
}
#[test]
fn test_convert_tool_choice() {
assert!(matches!(
convert_tool_choice(Some("auto")),
Some(RigToolChoice::Auto)
));
assert!(matches!(
convert_tool_choice(Some("required")),
Some(RigToolChoice::Required)
));
assert!(matches!(
convert_tool_choice(Some("none")),
Some(RigToolChoice::None)
));
assert!(matches!(
convert_tool_choice(Some("AUTO")),
Some(RigToolChoice::Auto)
));
assert!(convert_tool_choice(None).is_none());
assert!(convert_tool_choice(Some("unknown")).is_none());
}
#[test]
fn test_extract_response_text_only() {
let content = OneOrMany::one(AssistantContent::text("Hello world"));
let usage = RigUsage::new();
let (text, calls, finish) = extract_response(&content, &usage);
assert_eq!(text, Some("Hello world".to_string()));
assert!(calls.is_empty());
assert_eq!(finish, FinishReason::Stop);
}
#[test]
fn test_extract_response_tool_call() {
let tc = AssistantContent::tool_call("call_1", "search", serde_json::json!({"q": "test"}));
let content = OneOrMany::one(tc);
let usage = RigUsage::new();
let (text, calls, finish) = extract_response(&content, &usage);
assert!(text.is_none());
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "search");
assert_eq!(finish, FinishReason::ToolUse);
}
#[test]
fn test_saturate_u32() {
assert_eq!(saturate_u32(100), 100);
assert_eq!(saturate_u32(u64::MAX), u32::MAX);
assert_eq!(saturate_u32(u32::MAX as u64), u32::MAX);
}
}