use anyhow::{Context, Result};
use async_openai::config::OpenAIConfig;
use async_openai::types::chat::{
ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionToolChoiceOption,
ChatCompletionTools, CreateChatCompletionRequestArgs, ToolChoiceOptions,
};
use async_openai::types::chat::{FunctionCall, FunctionObject};
use async_openai::Client;
use serde_json::Value;
#[derive(Clone)]
pub struct OpenAIClient {
client: Client<OpenAIConfig>,
model: String,
}
impl OpenAIClient {
pub fn new_with_base_url(api_key: String, model: String, base_url: String) -> Self {
let base_url = base_url.trim_end_matches('/').to_string();
let config = OpenAIConfig::new()
.with_api_key(api_key)
.with_api_base(base_url);
let client = Client::with_config(config);
Self { client, model }
}
pub async fn chat(&self, messages: Vec<Message>) -> Result<String> {
let messages: Vec<ChatCompletionRequestMessage> = messages
.into_iter()
.map(|msg| match msg.role.as_str() {
"system" => ChatCompletionRequestSystemMessageArgs::default()
.content(msg.content)
.build()
.map(ChatCompletionRequestMessage::System),
_ => ChatCompletionRequestUserMessageArgs::default()
.content(msg.content)
.build()
.map(ChatCompletionRequestMessage::User),
})
.collect::<Result<Vec<_>, _>>()
.context("构建消息失败")?;
let request = CreateChatCompletionRequestArgs::default()
.model(&self.model)
.messages(messages)
.build()
.context("构建请求失败")?;
let response = self
.client
.chat()
.create(request)
.await
.context("调用 LLM API 失败")?;
let content = response
.choices
.into_iter()
.next()
.and_then(|c| c.message.content)
.unwrap_or_default();
Ok(content)
}
pub async fn complete(&self, prompt: &str) -> Result<String> {
let messages = vec![Message::user(prompt)];
self.chat(messages).await
}
pub async fn chat_with_tools(
&self,
messages: Vec<Message>,
tools: Vec<Tool>,
) -> Result<ToolResponse> {
let request_messages = self.build_request_messages(messages)?;
let chat_tools = self.build_chat_tools(tools)?;
let request = CreateChatCompletionRequestArgs::default()
.model(&self.model)
.messages(request_messages)
.tools(chat_tools)
.tool_choice(ChatCompletionToolChoiceOption::Mode(
ToolChoiceOptions::Auto,
))
.build()
.context("构建请求失败")?;
let response = self
.client
.chat()
.create(request)
.await
.context("调用 LLM API 失败")?;
let choice = response
.choices
.into_iter()
.next()
.context("API 返回空响应")?;
let message = choice.message;
if let Some(tool_calls) = message.tool_calls {
let calls: Vec<ToolCall> = tool_calls
.into_iter()
.filter_map(|tc| match tc {
ChatCompletionMessageToolCalls::Function(func_call) => {
let args = serde_json::from_str(&func_call.function.arguments)
.unwrap_or(Value::Null);
Some(ToolCall {
id: func_call.id,
name: func_call.function.name,
arguments: args,
})
}
_ => None,
})
.collect();
if !calls.is_empty() {
return Ok(ToolResponse::ToolCalls {
content: message.content.unwrap_or_default(),
tool_calls: calls,
});
}
}
Ok(ToolResponse::Message(message.content.unwrap_or_default()))
}
fn build_request_messages(
&self,
messages: Vec<Message>,
) -> Result<Vec<ChatCompletionRequestMessage>> {
messages
.into_iter()
.map(|msg| match msg.role.as_str() {
"system" => ChatCompletionRequestSystemMessageArgs::default()
.content(msg.content)
.build()
.map(ChatCompletionRequestMessage::System)
.context("构建系统消息失败"),
"user" => ChatCompletionRequestUserMessageArgs::default()
.content(msg.content)
.build()
.map(ChatCompletionRequestMessage::User)
.context("构建用户消息失败"),
"assistant" => {
let mut args = ChatCompletionRequestAssistantMessageArgs::default();
args.content(msg.content);
if let Some(calls) = msg.tool_calls {
let tool_calls: Vec<ChatCompletionMessageToolCalls> = calls
.into_iter()
.map(|c| {
ChatCompletionMessageToolCalls::Function(
ChatCompletionMessageToolCall {
id: c.id,
function: FunctionCall {
name: c.name,
arguments: c.arguments.to_string(),
},
},
)
})
.collect();
args.tool_calls(tool_calls);
}
args.build()
.map(ChatCompletionRequestMessage::Assistant)
.context("构建助手消息失败")
}
"tool" => ChatCompletionRequestToolMessageArgs::default()
.content(msg.content)
.tool_call_id(msg.tool_call_id.unwrap_or_default())
.build()
.map(ChatCompletionRequestMessage::Tool)
.context("构建工具消息失败"),
_ => ChatCompletionRequestUserMessageArgs::default()
.content(msg.content)
.build()
.map(ChatCompletionRequestMessage::User)
.context("构建用户消息失败"),
})
.collect::<Result<Vec<_>>>()
}
fn build_chat_tools(&self, tools: Vec<Tool>) -> Result<Vec<ChatCompletionTools>> {
Ok(tools
.into_iter()
.map(|tool| {
ChatCompletionTools::Function(ChatCompletionTool {
function: FunctionObject {
name: tool.id,
description: Some(tool.description),
parameters: Some(tool.parameters),
strict: None,
},
})
})
.collect())
}
}
#[derive(Clone, Debug)]
pub enum ToolResponse {
Message(String),
ToolCalls {
content: String,
tool_calls: Vec<ToolCall>,
},
}
impl ToolResponse {
pub fn is_tool_calls(&self) -> bool {
matches!(self, Self::ToolCalls { .. })
}
pub fn content(&self) -> &str {
match self {
Self::Message(c) => c,
Self::ToolCalls { content, .. } => content,
}
}
pub fn tool_calls(&self) -> Option<&Vec<ToolCall>> {
match self {
Self::ToolCalls { tool_calls, .. } => Some(tool_calls),
_ => None,
}
}
}
#[derive(Clone, Debug)]
pub struct Message {
pub role: String,
pub content: String,
pub tool_calls: Option<Vec<ToolCall>>,
pub tool_call_id: Option<String>,
}
impl Message {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".to_string(),
content: content.into(),
tool_calls: None,
tool_call_id: None,
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: content.into(),
tool_calls: None,
tool_call_id: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".to_string(),
content: content.into(),
tool_calls: None,
tool_call_id: None,
}
}
pub fn assistant_with_tools(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
Self {
role: "assistant".to_string(),
content: content.into(),
tool_calls: Some(tool_calls),
tool_call_id: None,
}
}
pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
Self {
role: "tool".to_string(),
content: content.into(),
tool_calls: None,
tool_call_id: Some(tool_call_id.into()),
}
}
}
#[derive(Clone, Debug)]
pub struct Tool {
pub id: String,
pub name: String,
pub description: String,
pub parameters: Value,
}
impl Tool {
pub fn from_domain_tool(tool: &crate::domain::tool::Tool) -> Self {
Self {
id: tool.id.clone(),
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.parameters.clone(),
}
}
}
#[derive(Clone, Debug)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: Value,
}