use schemars::{JsonSchema, schema_for};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::marker::PhantomData;
use rucora_core::provider::LlmProvider;
use rucora_core::provider::types::{ChatMessage, ChatRequest, LlmParams};
use rucora_core::tool::Tool;
use crate::agent::ToolAgent;
#[derive(Debug, Clone)]
pub struct ExtractionResponse<T> {
pub data: T,
pub usage: Option<TokenUsage>,
}
#[derive(Debug, Clone, Default)]
pub struct TokenUsage {
pub input_tokens: u32,
pub output_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, thiserror::Error)]
pub enum ExtractionError {
#[error("未提取到数据")]
NoData,
#[error("反序列化提取的数据失败:{0}")]
DeserializationError(#[from] serde_json::Error),
#[error("LLM 调用失败:{0}")]
LlmError(String),
#[error("达到最大重试次数")]
MaxRetriesExceeded,
}
pub struct Extractor<T>
where
T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync + 'static,
{
agent: ToolAgent<Box<dyn LlmProvider>>,
_t: PhantomData<T>,
retries: u32,
llm_params: LlmParams,
}
impl<T> Extractor<T>
where
T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync + 'static,
{
pub fn builder<P>(provider: P, model: impl Into<String>) -> ExtractorBuilder<T>
where
P: LlmProvider + Send + Sync + 'static,
{
ExtractorBuilder::new(provider, model)
}
pub async fn extract(&self, text: impl Into<String>) -> Result<T, ExtractionError> {
self._extract_with_chat_history(text.into(), vec![]).await
}
pub async fn extract_with_usage(
&self,
text: impl Into<String>,
) -> Result<ExtractionResponse<T>, ExtractionError> {
self._extract_with_usage_and_chat_history(text.into(), vec![])
.await
}
pub async fn extract_with_chat_history(
&self,
text: impl Into<String>,
chat_history: Vec<String>,
) -> Result<T, ExtractionError> {
self._extract_with_chat_history(text.into(), chat_history)
.await
}
async fn _extract_with_chat_history(
&self,
text: String,
chat_history: Vec<String>,
) -> Result<T, ExtractionError> {
let mut last_error = None;
for i in 0..=self.retries {
tracing::debug!("提取 JSON,剩余重试次数:{}", self.retries - i);
match self._extract_json(text.clone(), chat_history.clone()).await {
Ok(data) => return Ok(data),
Err(e) => {
tracing::warn!("第 {} 次提取失败:{:?},重试中...", i, e);
last_error = Some(e);
}
}
}
match last_error {
Some(_) if self.retries > 0 => Err(ExtractionError::MaxRetriesExceeded),
Some(error) => Err(error),
None => Err(ExtractionError::NoData),
}
}
async fn _extract_with_usage_and_chat_history(
&self,
text: String,
chat_history: Vec<String>,
) -> Result<ExtractionResponse<T>, ExtractionError> {
let mut last_error = None;
for i in 0..=self.retries {
tracing::debug!("提取 JSON,剩余重试次数:{}", self.retries - i);
match self
._extract_json_with_usage(text.clone(), chat_history.clone())
.await
{
Ok((data, usage)) => {
return Ok(ExtractionResponse {
data,
usage: Some(usage),
});
}
Err(e) => {
tracing::warn!("第 {} 次提取失败:{:?},重试中...", i, e);
last_error = Some(e);
}
}
}
match last_error {
Some(_) if self.retries > 0 => Err(ExtractionError::MaxRetriesExceeded),
Some(error) => Err(error),
None => Err(ExtractionError::NoData),
}
}
async fn _extract_json(
&self,
text: String,
chat_history: Vec<String>,
) -> Result<T, ExtractionError> {
let (data, _usage) = self._extract_json_with_usage(text, chat_history).await?;
Ok(data)
}
async fn _extract_json_with_usage(
&self,
text: String,
chat_history: Vec<String>,
) -> Result<(T, TokenUsage), ExtractionError> {
let mut messages = Vec::new();
for msg in chat_history {
messages.push(ChatMessage::user(msg));
}
messages.push(ChatMessage::user(text));
let mut request = ChatRequest {
messages,
model: Some(self.agent.model().to_string()),
tools: Some(self.agent.tool_registry().definitions()),
temperature: None,
max_tokens: None,
response_format: None,
metadata: None,
top_p: None,
top_k: None,
frequency_penalty: None,
presence_penalty: None,
stop: None,
extra: None,
};
self.llm_params.apply_to(&mut request);
let response = self
.agent
.provider()
.chat(request)
.await
.map_err(|e| ExtractionError::LlmError(e.to_string()))?;
let submit_call = response
.tool_calls
.iter()
.find(|call| call.name == SUBMIT_TOOL_NAME);
let arguments = if let Some(call) = submit_call {
call.input.clone()
} else {
tracing::warn!("未找到 submit 工具调用");
return Err(ExtractionError::NoData);
};
let data: T = serde_json::from_value(arguments)?;
let usage = response
.usage
.map(|u| TokenUsage {
input_tokens: u.prompt_tokens,
output_tokens: u.completion_tokens,
total_tokens: u.total_tokens,
})
.unwrap_or_default();
Ok((data, usage))
}
pub fn agent(&self) -> &ToolAgent<Box<dyn LlmProvider>> {
&self.agent
}
}
pub struct ExtractorBuilder<T>
where
T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync + 'static,
{
provider: Box<dyn LlmProvider>,
model: String,
_t: PhantomData<T>,
retries: u32,
preamble: Option<String>,
llm_params: LlmParams,
}
impl<T> ExtractorBuilder<T>
where
T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync + 'static,
{
fn new<P>(provider: P, model: impl Into<String>) -> Self
where
P: LlmProvider + Send + Sync + 'static,
{
Self {
provider: Box::new(provider),
model: model.into(),
preamble: None,
retries: 0,
llm_params: LlmParams::new().temperature(0.0), _t: PhantomData,
}
}
pub fn preamble(mut self, preamble: impl Into<String>) -> Self {
self.preamble = Some(preamble.into());
self
}
pub fn retries(mut self, retries: u32) -> Self {
self.retries = retries;
self
}
pub fn llm_params(mut self, params: LlmParams) -> Self {
self.llm_params = params;
self
}
pub fn temperature(mut self, value: f32) -> Self {
self.llm_params.temperature = Some(value);
self
}
pub fn top_p(mut self, value: f32) -> Self {
self.llm_params.top_p = Some(value);
self
}
pub fn max_tokens(mut self, value: u32) -> Self {
self.llm_params.max_tokens = Some(value);
self
}
pub fn frequency_penalty(mut self, value: f32) -> Self {
self.llm_params.frequency_penalty = Some(value);
self
}
pub fn presence_penalty(mut self, value: f32) -> Self {
self.llm_params.presence_penalty = Some(value);
self
}
pub fn stop(mut self, value: Vec<String>) -> Self {
self.llm_params.stop = Some(value);
self
}
pub fn build(self) -> Extractor<T> {
let mut system_prompt = String::from(
"你是一个 AI 助手,用于从文本中提取结构化数据。\n\
你可以使用 `submit` 工具来提交提取的数据。\n\
务必调用 `submit` 工具,即使使用默认值!\n",
);
if let Some(preamble) = self.preamble {
system_prompt.push_str("\n=============== 额外指令 ===============\n");
system_prompt.push_str(&preamble);
}
let agent = ToolAgent::builder()
.provider(self.provider)
.model(self.model)
.system_prompt(system_prompt)
.tool(SubmitTool::<T>::new())
.max_steps(3)
.build();
Extractor {
agent,
_t: PhantomData,
retries: self.retries,
llm_params: self.llm_params,
}
}
}
const SUBMIT_TOOL_NAME: &str = "submit";
struct SubmitTool<T>
where
T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync + 'static,
{
_t: PhantomData<T>,
}
impl<T> SubmitTool<T>
where
T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync + 'static,
{
fn new() -> Self {
Self { _t: PhantomData }
}
}
impl<T> Default for SubmitTool<T>
where
T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl<T> Tool for SubmitTool<T>
where
T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync + 'static,
{
fn name(&self) -> &str {
SUBMIT_TOOL_NAME
}
fn description(&self) -> Option<&str> {
Some("提交从文本中提取的结构化数据")
}
fn categories(&self) -> &'static [rucora_core::tool::ToolCategory] {
&[rucora_core::tool::ToolCategory::Basic]
}
fn input_schema(&self) -> serde_json::Value {
let schema = schema_for!(T);
serde_json::to_value(&schema).unwrap_or_else(|_| json!({}))
}
async fn call(
&self,
input: serde_json::Value,
) -> Result<serde_json::Value, rucora_core::error::ToolError> {
Ok(input)
}
}
#[cfg(test)]
mod tests {
use super::*;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, Serialize, JsonSchema, PartialEq)]
struct TestPerson {
name: Option<String>,
age: Option<u8>,
profession: Option<String>,
}
#[test]
fn test_submit_tool_schema() {
let tool = SubmitTool::<TestPerson>::new();
let schema = tool.input_schema();
assert!(schema.get("type").is_some());
assert!(schema.get("properties").is_some());
}
}