#![cfg(feature = "rig")]
use std::sync::Arc;
use crate::provider::Provider;
use crate::types::{
self as llmg_types, ChatCompletionRequest, ChatCompletionResponse, FunctionDefinition, Tool,
};
use rig::completion::{
AssistantContent, CompletionError, CompletionModel, CompletionRequest, CompletionResponse,
GetTokenUsage, Usage,
};
use rig::message::{Message as RigMessage, ToolResultContent, UserContent};
use rig::streaming::StreamingCompletionResponse;
use rig::OneOrMany;
use serde::{Deserialize, Serialize};
use tracing::warn;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PlaceholderStreamingResponse;
impl GetTokenUsage for PlaceholderStreamingResponse {
fn token_usage(&self) -> Option<Usage> {
None
}
}
#[derive(Clone)]
pub struct LlmgClient {
pub provider: Arc<dyn Provider>,
}
impl LlmgClient {
pub fn new(provider: Arc<dyn Provider>) -> Self {
Self { provider }
}
pub fn from_registry(registry: crate::provider::ProviderRegistry) -> Self {
let router = crate::provider::RoutingProvider::new(registry);
Self {
provider: Arc::new(router),
}
}
}
impl rig::client::CompletionClient for LlmgClient {
type CompletionModel = LlmgCompletionModel;
fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
LlmgCompletionModel::make(self, model)
}
}
#[derive(Clone)]
pub struct LlmgCompletionModel {
client: LlmgClient,
model: String,
}
impl CompletionModel for LlmgCompletionModel {
type Response = ChatCompletionResponse;
type StreamingResponse = PlaceholderStreamingResponse;
type Client = LlmgClient;
fn make(client: &Self::Client, model: impl Into<String>) -> Self {
Self {
client: client.clone(),
model: model.into(),
}
}
async fn completion(
&self,
request: CompletionRequest,
) -> Result<CompletionResponse<Self::Response>, CompletionError> {
let llmg_request = build_llmg_request(&self.model, &request);
let response = self
.client
.provider
.chat_completion(llmg_request)
.await
.map_err(|e| CompletionError::ProviderError(e.to_string()))?;
build_rig_response(response)
}
async fn stream(
&self,
_request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
Err(CompletionError::ProviderError(
"streaming not supported yet in LLMG-Rig bridge".to_string(),
))
}
}
fn build_llmg_request(model: &str, request: &CompletionRequest) -> ChatCompletionRequest {
let mut messages: Vec<llmg_types::Message> = Vec::new();
if let Some(ref preamble) = request.preamble {
messages.push(llmg_types::Message::System {
content: preamble.clone(),
name: None,
});
}
for msg in request.chat_history.clone().into_iter() {
convert_rig_message(msg, &mut messages);
}
let tools = if request.tools.is_empty() {
None
} else {
Some(request.tools.iter().map(convert_tool_definition).collect())
};
ChatCompletionRequest {
model: model.to_string(),
messages,
temperature: request.temperature.map(|t| t as f32),
max_tokens: request.max_tokens.map(|t| t as u32),
stream: Some(false),
top_p: None,
frequency_penalty: None,
presence_penalty: None,
stop: None,
user: None,
tools,
tool_choice: None,
}
}
fn convert_rig_message(msg: RigMessage, out: &mut Vec<llmg_types::Message>) {
match msg {
RigMessage::User { content } => {
for item in content.into_iter() {
match item {
UserContent::Text(t) => {
out.push(llmg_types::Message::User {
content: t.text,
name: None,
});
}
UserContent::ToolResult(tr) => {
let text = tr
.content
.into_iter()
.filter_map(|c| match c {
ToolResultContent::Text(t) => Some(t.text),
_ => None,
})
.collect::<Vec<_>>()
.join("\n");
out.push(llmg_types::Message::Tool {
content: text,
tool_call_id: tr.id,
});
}
_ => {}
}
}
}
RigMessage::Assistant { content, .. } => {
let mut text_parts: Vec<String> = Vec::new();
let mut tool_calls: Vec<llmg_types::ToolCall> = Vec::new();
for item in content.into_iter() {
match item {
AssistantContent::Text(t) => {
text_parts.push(t.text);
}
AssistantContent::ToolCall(tc) => {
let arguments = serde_json::to_string(&tc.function.arguments)
.unwrap_or_else(|e| {
warn!("failed to serialize tool call arguments: {e}");
"{}".to_string()
});
tool_calls.push(llmg_types::ToolCall {
id: tc.id,
r#type: "function".to_string(),
function: llmg_types::FunctionCall {
name: tc.function.name,
arguments,
},
});
}
_ => {}
}
}
let content_str = if text_parts.is_empty() {
None
} else {
Some(text_parts.join(""))
};
let tc = if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
};
out.push(llmg_types::Message::Assistant {
content: content_str,
refusal: None,
tool_calls: tc,
});
}
}
}
fn convert_tool_definition(td: &rig::completion::ToolDefinition) -> Tool {
Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
name: td.name.clone(),
description: Some(td.description.clone()),
parameters: td.parameters.clone(),
},
}
}
fn build_rig_response(
response: ChatCompletionResponse,
) -> Result<CompletionResponse<ChatCompletionResponse>, CompletionError> {
let choice = response
.choices
.first()
.ok_or_else(|| CompletionError::ResponseError("no choices in response".to_string()))?;
let assistant_content = match &choice.message {
llmg_types::Message::Assistant {
content,
tool_calls,
..
} => {
let mut items: Vec<AssistantContent> = Vec::new();
if let Some(text) = content {
if !text.is_empty() {
items.push(AssistantContent::text(text));
}
}
if let Some(tcs) = tool_calls {
for tc in tcs {
let args: serde_json::Value = serde_json::from_str(&tc.function.arguments)
.unwrap_or_else(|e| {
warn!("failed to parse tool call arguments: {e}");
serde_json::Value::Object(serde_json::Map::new())
});
items.push(AssistantContent::tool_call(&tc.id, &tc.function.name, args));
}
}
if items.is_empty() {
OneOrMany::one(AssistantContent::text(""))
} else {
OneOrMany::many(items).map_err(|e| CompletionError::ResponseError(e.to_string()))?
}
}
_ => {
return Err(CompletionError::ResponseError(
"expected assistant message in response".to_string(),
));
}
};
let usage = match &response.usage {
Some(u) => Usage {
input_tokens: u.prompt_tokens as u64,
output_tokens: u.completion_tokens as u64,
total_tokens: u.total_tokens as u64,
cached_input_tokens: 0,
},
None => Usage::default(),
};
Ok(CompletionResponse {
choice: assistant_content,
usage,
raw_response: response,
})
}