use futures::future::BoxFuture;
use serde_json::Value;
use super::security::{ResourceLimits, SecurityConfig};
use crate::error::{Result, ToolError};
use crate::tools::{Tool, ToolParameters, ToolResult};
const TOOL_NAME: &str = "pdf_tools";
pub struct PdfExtractTool;
impl Tool for PdfExtractTool {
fn name(&self) -> &str {
"extract_pdf"
}
fn description(&self) -> &str {
"从 PDF 文档中提取文本内容。支持提取全部文本、指定页面范围或获取文档元数据。"
}
fn parameters(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "PDF 文件的绝对路径"
},
"pages": {
"type": "string",
"description": "要提取的页面范围(可选),如 '1-5'、'1,3,7' 或 'all'(默认)"
},
"extract_metadata": {
"type": "boolean",
"description": "是否同时提取文档元数据(默认 false)"
}
},
"required": ["file_path"]
})
}
fn execute(&self, parameters: ToolParameters) -> BoxFuture<'_, Result<ToolResult>> {
Box::pin(async move {
let file_path = parameters
.get("file_path")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::MissingParameter("file_path".to_string()))?;
let pages = parameters
.get("pages")
.and_then(|v| v.as_str())
.unwrap_or("all");
let extract_metadata = parameters
.get("extract_metadata")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let security = SecurityConfig::global();
let path = security.validate_file(file_path)?;
let pdf = lopdf::Document::load(&path).map_err(|e| ToolError::ExecutionFailed {
tool: TOOL_NAME.to_string(),
message: format!("打开 PDF 失败: {}", e),
})?;
let total_pages = pdf.get_pages().len();
let metadata_str = if extract_metadata {
extract_pdf_metadata(&pdf)?
} else {
String::new()
};
let page_numbers = parse_page_range(pages, total_pages, &security.limits)?;
let text_content = extract_pages_text(&pdf, &page_numbers, &security.limits)?;
let result = if extract_metadata {
format!(
"=== PDF 元数据 ===\n{}\n\n=== 文本内容 (第 {} 页,共 {} 页) ===\n{}",
metadata_str, pages, total_pages, text_content
)
} else {
format!(
"=== 文本内容 (第 {} 页,共 {} 页) ===\n{}",
pages, total_pages, text_content
)
};
Ok(ToolResult::success(result))
})
}
}
pub struct PdfInfoTool;
impl Tool for PdfInfoTool {
fn name(&self) -> &str {
"pdf_info"
}
fn description(&self) -> &str {
"获取 PDF 文档的基本信息:页数、标题、作者、创建时间等元数据,不提取文本内容。"
}
fn parameters(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "PDF 文件的绝对路径"
}
},
"required": ["file_path"]
})
}
fn execute(&self, parameters: ToolParameters) -> BoxFuture<'_, Result<ToolResult>> {
Box::pin(async move {
let file_path = parameters
.get("file_path")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::MissingParameter("file_path".to_string()))?;
let security = SecurityConfig::global();
let path = security.validate_file(file_path)?;
let pdf = lopdf::Document::load(&path).map_err(|e| ToolError::ExecutionFailed {
tool: TOOL_NAME.to_string(),
message: format!("打开 PDF 失败: {}", e),
})?;
let metadata = extract_pdf_metadata(&pdf)?;
Ok(ToolResult::success(metadata))
})
}
}
fn extract_pdf_metadata(pdf: &lopdf::Document) -> Result<String> {
use lopdf::Object;
let mut info = Vec::new();
info.push(format!("页数: {}", pdf.get_pages().len()));
if let Ok(trailer) = pdf.trailer.get(b"Info")
&& let Object::Dictionary(dict) = trailer
{
for (key, value) in dict.iter() {
let key_str = match key.as_slice() {
b"Title" => "标题",
b"Author" => "作者",
b"Subject" => "主题",
b"Creator" => "创建工具",
b"Producer" => "PDF 生成器",
b"CreationDate" => "创建时间",
b"ModDate" => "修改时间",
other => std::str::from_utf8(other).unwrap_or("未知"),
};
let value_str = match value {
Object::String(s, _) => {
if key.as_slice() == b"CreationDate" || key.as_slice() == b"ModDate" {
parse_pdf_date(s)
} else {
String::from_utf8_lossy(s).to_string()
}
}
Object::Name(n) => String::from_utf8_lossy(n).to_string(),
Object::Integer(i) => i.to_string(),
Object::Real(f) => f.to_string(),
Object::Boolean(b) => b.to_string(),
_ => "未知".to_string(),
};
info.push(format!("{}: {}", key_str, value_str));
}
}
info.push(format!("总页数: {}", pdf.get_pages().len()));
Ok(info.join("\n"))
}
fn parse_pdf_date(date: &[u8]) -> String {
let date_str = String::from_utf8_lossy(date);
if let Some(rest) = date_str.strip_prefix("D:")
&& rest.len() >= 8
{
let year = &rest[0..4];
let month = &rest[4..6];
let day = &rest[6..8];
return format!("{}-{}-{}", year, month, day);
}
date_str.to_string()
}
fn parse_page_range(range: &str, total_pages: usize, limits: &ResourceLimits) -> Result<Vec<u32>> {
if range == "all" {
let max_pages = limits.max_preview_pages.min(total_pages);
return Ok((1..=max_pages as u32).collect());
}
let mut pages = Vec::new();
for part in range.split(',') {
if part.contains('-') {
let bounds: Vec<&str> = part.split('-').collect();
if bounds.len() != 2 {
return Err(ToolError::InvalidParameter {
name: "pages".to_string(),
message: format!("无效的页面范围: {}", part),
}
.into());
}
let start: u32 = bounds[0].parse().map_err(|_| ToolError::InvalidParameter {
name: "pages".to_string(),
message: format!("无效的起始页: {}", bounds[0]),
})?;
let end: u32 = bounds[1].parse().map_err(|_| ToolError::InvalidParameter {
name: "pages".to_string(),
message: format!("无效的结束页: {}", bounds[1]),
})?;
if start > end || end > total_pages as u32 {
return Err(ToolError::InvalidParameter {
name: "pages".to_string(),
message: format!("页面范围无效或超出文档页数 ({} 页)", total_pages),
}
.into());
}
let limited_end = (end - start + 1).min(limits.max_preview_pages as u32);
for p in start..(start + limited_end) {
if !pages.contains(&p) {
pages.push(p);
}
}
} else {
let page: u32 = part.parse().map_err(|_| ToolError::InvalidParameter {
name: "pages".to_string(),
message: format!("无效的页码: {}", part),
})?;
if page > total_pages as u32 {
return Err(ToolError::InvalidParameter {
name: "pages".to_string(),
message: format!("页码 {} 超出文档页数 ({} 页)", page, total_pages),
}
.into());
}
if !pages.contains(&page) {
pages.push(page);
}
}
}
if pages.len() > limits.max_preview_pages {
pages = pages.into_iter().take(limits.max_preview_pages).collect();
}
pages.sort();
Ok(pages)
}
fn extract_pages_text(
pdf: &lopdf::Document,
page_numbers: &[u32],
limits: &ResourceLimits,
) -> Result<String> {
use lopdf::Object;
let mut all_text = Vec::new();
let mut total_chars = 0;
for page_num in page_numbers {
if total_chars >= limits.max_preview_chars {
all_text.push(format!(
"... (已达到最大预览字符数 {})",
limits.max_preview_chars
));
break;
}
let page_id = *pdf.get_pages().get(page_num).unwrap_or(&(0, 0));
if let Ok(page_obj) = pdf.get_object(page_id)
&& let Object::Dictionary(dict) = page_obj
{
if let Ok(contents_ref) = dict.get(b"Contents") {
let content_stream: Option<lopdf::Stream> = match contents_ref {
Object::Reference(id) => pdf.get_object(*id).ok().and_then(|obj| {
if let Object::Stream(stream) = obj {
Some(stream.clone())
} else {
None
}
}),
Object::Array(arr) => {
let mut combined = Vec::new();
for obj_ref in arr.iter() {
if let Object::Reference(id) = obj_ref
&& let Ok(obj) = pdf.get_object(*id)
&& let Object::Stream(stream) = obj
{
combined.extend_from_slice(&stream.content);
}
}
let text = extract_text_from_stream(&combined, limits);
total_chars += text.len();
all_text.push(format!("--- 第 {} 页 ---\n{}", page_num, text));
continue;
}
Object::Stream(stream) => Some(stream.clone()),
_ => None,
};
if let Some(stream) = content_stream {
let text = extract_text_from_stream(&stream.content, limits);
total_chars += text.len();
all_text.push(format!("--- 第 {} 页 ---\n{}", page_num, text));
}
}
}
}
Ok(all_text.join("\n\n"))
}
fn extract_text_from_stream(content: &[u8], limits: &ResourceLimits) -> String {
let content_str = String::from_utf8_lossy(content);
let mut text_parts = Vec::new();
let tj_regex = regex::RegexBuilder::new(r"\(([^)]*)\)\s*Tj")
.size_limit(limits.regex_max_size)
.dfa_size_limit(limits.regex_max_size)
.build()
.unwrap();
for cap in tj_regex.captures_iter(&content_str) {
if let Some(text) = cap.get(1) {
text_parts.push(text.as_str().to_string());
}
}
let hex_regex = regex::RegexBuilder::new(r"<([0-9a-fA-F]*)>\s*Tj")
.size_limit(limits.regex_max_size)
.dfa_size_limit(limits.regex_max_size)
.build()
.unwrap();
for cap in hex_regex.captures_iter(&content_str) {
if let Some(hex) = cap.get(1) {
if let Ok(decoded) = hex_decode(hex.as_str()) {
text_parts.push(decoded);
}
}
}
let tj_array_regex = regex::RegexBuilder::new(r"\[(.*?)\]\s*TJ")
.size_limit(limits.regex_max_size)
.dfa_size_limit(limits.regex_max_size)
.build()
.unwrap();
let str_regex = regex::RegexBuilder::new(r"\(([^)]*)\)")
.size_limit(limits.regex_max_size)
.dfa_size_limit(limits.regex_max_size)
.build()
.unwrap();
for cap in tj_array_regex.captures_iter(&content_str) {
if let Some(arr_content) = cap.get(1) {
for str_cap in str_regex.captures_iter(arr_content.as_str()) {
if let Some(s) = str_cap.get(1) {
text_parts.push(s.as_str().to_string());
}
}
}
}
let result = text_parts.join(" ");
if result.len() > limits.max_preview_chars {
result.chars().take(limits.max_preview_chars).collect()
} else {
result
}
}
fn hex_decode(hex: &str) -> Result<String> {
let bytes: Vec<u8> = (0..hex.len())
.step_by(2)
.map(|i| u8::from_str_radix(&hex[i..i.min(i + 2)], 16).unwrap_or(0))
.collect();
Ok(String::from_utf8_lossy(&bytes).to_string())
}