use futures::future::BoxFuture;
use serde_json::Value;
use std::sync::{Arc, OnceLock};
use tokio::sync::RwLock;
use crate::error::{Result, ToolError};
use crate::tools::{Tool, ToolParameters, ToolResult};
#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct DocumentChunk {
id: String,
content: String,
embedding: Vec<f32>,
source: Option<String>,
chunk_index: usize,
total_chunks: usize,
}
struct VectorStore {
chunks: Vec<DocumentChunk>,
}
impl VectorStore {
fn new() -> Self {
Self { chunks: Vec::new() }
}
fn add_chunks(&mut self, mut chunks: Vec<DocumentChunk>) {
self.chunks.append(&mut chunks);
}
fn search(&self, query_embedding: &[f32], top_k: usize) -> Vec<(f32, DocumentChunk)> {
let mut scored: Vec<(f32, DocumentChunk)> = self
.chunks
.iter()
.map(|chunk| {
let sim = cosine_similarity(query_embedding, &chunk.embedding);
(sim, chunk.clone())
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
scored
}
}
fn global_vector_store() -> Arc<RwLock<VectorStore>> {
static STORE: OnceLock<Arc<RwLock<VectorStore>>> = OnceLock::new();
STORE
.get_or_init(|| Arc::new(RwLock::new(VectorStore::new())))
.clone()
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
fn chunk_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
let paragraphs: Vec<&str> = text.split("\n\n").collect();
let mut chunks: Vec<String> = Vec::new();
let mut current = String::new();
for para in paragraphs {
let para = para.trim();
if para.is_empty() {
continue;
}
if current.len() + para.len() + 2 > chunk_size && !current.is_empty() {
chunks.push(current.trim().to_string());
current = String::new();
}
if para.len() > chunk_size {
if !current.is_empty() {
chunks.push(current.trim().to_string());
current = String::new();
}
let sub_chunks = chunk_by_sentences(para, chunk_size, overlap);
chunks.extend(sub_chunks);
} else {
if !current.is_empty() {
current.push_str("\n\n");
}
current.push_str(para);
}
}
if !current.trim().is_empty() {
chunks.push(current.trim().to_string());
}
chunks
}
fn chunk_by_sentences(text: &str, chunk_size: usize, _overlap: usize) -> Vec<String> {
let mut chunks = Vec::new();
let mut current = String::new();
for ch in text.chars() {
current.push(ch);
if current.len() >= chunk_size
&& (ch == '.' || ch == '!' || ch == '?' || ch == '。' || ch == '!' || ch == '?')
{
chunks.push(current.trim().to_string());
current = String::new();
}
}
if !current.trim().is_empty() {
if !chunks.is_empty() && current.len() < chunk_size / 3 {
let last = chunks.pop().unwrap();
chunks.push(format!("{} {}", last, current.trim()));
} else {
chunks.push(current.trim().to_string());
}
}
chunks
}
async fn generate_embeddings(texts: &[String]) -> Result<Vec<Vec<f32>>> {
use echo_state::memory::{Embedder, HttpEmbedder};
let embedder = HttpEmbedder::from_env();
let mut embeddings = Vec::with_capacity(texts.len());
for text in texts {
let vec = embedder
.embed(text)
.await
.map_err(|e| ToolError::ExecutionFailed {
tool: "rag".to_string(),
message: format!("嵌入生成失败: {}", e),
})?;
embeddings.push(vec);
}
Ok(embeddings)
}
pub struct RagIndexTool;
impl Tool for RagIndexTool {
fn name(&self) -> &str {
"rag_index"
}
fn description(&self) -> &str {
"将文档分块并建立向量索引,支持后续语义检索。\
需要配置 EMBEDDING_API_KEY 环境变量(或兼容的 OPENAI_API_KEY / EMBEDDING_APIKEY)。\
分块大小默认 1000 字符,重叠默认 100 字符。"
}
fn parameters(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "要索引的文档文本内容"
},
"source": {
"type": "string",
"description": "文档来源标识(如文件名、URL),用于结果溯源"
},
"chunk_size": {
"type": "integer",
"description": "分块大小(字符数,默认 1000)"
},
"overlap": {
"type": "integer",
"description": "块间重叠字符数(默认 100)"
}
},
"required": ["content"]
})
}
fn execute(&self, parameters: ToolParameters) -> BoxFuture<'_, Result<ToolResult>> {
Box::pin(async move {
let content = parameters
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::MissingParameter("content".to_string()))?;
let source = parameters
.get("source")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let chunk_size = parameters
.get("chunk_size")
.and_then(|v| v.as_u64())
.unwrap_or(1000) as usize;
let overlap = parameters
.get("overlap")
.and_then(|v| v.as_u64())
.unwrap_or(100) as usize;
let texts = chunk_text(content, chunk_size, overlap);
let total_chunks = texts.len();
if texts.is_empty() {
return Ok(ToolResult::success("文档内容为空,已跳过索引".to_string()));
}
let embeddings = match generate_embeddings(&texts).await {
Ok(e) => e,
Err(e) => return Ok(ToolResult::error(format!("嵌入生成失败: {}", e))),
};
let chunks: Vec<DocumentChunk> = texts
.into_iter()
.zip(embeddings)
.enumerate()
.map(|(i, (text, embedding))| DocumentChunk {
id: uuid::Uuid::new_v4().to_string(),
content: text,
embedding,
source: source.clone(),
chunk_index: i,
total_chunks,
})
.collect();
let store = global_vector_store();
store.write().await.add_chunks(chunks);
Ok(ToolResult::success(format!(
"成功索引 {} 个文档块{}",
total_chunks,
source
.map(|s| format!("(来源: {})", s))
.unwrap_or_default()
)))
})
}
}
pub struct RagSearchTool;
impl Tool for RagSearchTool {
fn name(&self) -> &str {
"rag_search"
}
fn description(&self) -> &str {
"对已索引的文档进行语义搜索,返回最相关的 top_k 个片段及其相似度分数。\
需要先使用 rag_index 建好索引。"
}
fn parameters(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "搜索查询"
},
"top_k": {
"type": "integer",
"description": "返回结果数量(默认 5)"
}
},
"required": ["query"]
})
}
fn execute(&self, parameters: ToolParameters) -> BoxFuture<'_, Result<ToolResult>> {
Box::pin(async move {
let query = parameters
.get("query")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::MissingParameter("query".to_string()))?;
let top_k = parameters
.get("top_k")
.and_then(|v| v.as_u64())
.unwrap_or(5) as usize;
let query_embedding = match generate_embeddings(&[query.to_string()]).await {
Ok(mut e) if !e.is_empty() => e.remove(0),
Ok(_) => return Ok(ToolResult::error("嵌入生成返回空结果".to_string())),
Err(e) => return Ok(ToolResult::error(format!("嵌入生成失败: {}", e))),
};
let store = global_vector_store();
let results = store.read().await.search(&query_embedding, top_k);
if results.is_empty() {
return Ok(ToolResult::success(
"未找到相关文档。请先使用 rag_index 索引文档。".to_string(),
));
}
let items: Vec<Value> = results
.iter()
.enumerate()
.map(|(rank, (score, chunk))| {
serde_json::json!({
"rank": rank + 1,
"similarity_pct": format!("{:.1}", score * 100.0),
"chunk_index": chunk.chunk_index,
"total_chunks": chunk.total_chunks,
"source": chunk.source,
"content": chunk.content,
})
})
.collect();
let result = serde_json::json!({
"query": query,
"total_results": results.len(),
"results": items,
});
Ok(ToolResult::success_json(result))
})
}
}
pub struct RagChunkDocumentTool;
impl Tool for RagChunkDocumentTool {
fn name(&self) -> &str {
"rag_chunk_document"
}
fn description(&self) -> &str {
"预览文档分块结果(不会建立索引)。用于检查分块策略和调整参数。"
}
fn parameters(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "要分块的文本内容"
},
"chunk_size": {
"type": "integer",
"description": "分块大小(字符数,默认 1000)"
},
"overlap": {
"type": "integer",
"description": "块间重叠字符数(默认 100)"
}
},
"required": ["content"]
})
}
fn execute(&self, parameters: ToolParameters) -> BoxFuture<'_, Result<ToolResult>> {
Box::pin(async move {
let content = parameters
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::MissingParameter("content".to_string()))?;
let chunk_size = parameters
.get("chunk_size")
.and_then(|v| v.as_u64())
.unwrap_or(1000) as usize;
let overlap = parameters
.get("overlap")
.and_then(|v| v.as_u64())
.unwrap_or(100) as usize;
let chunks = chunk_text(content, chunk_size, overlap);
if chunks.is_empty() {
return Ok(ToolResult::success("文档内容为空".to_string()));
}
let items: Vec<Value> = chunks
.iter()
.enumerate()
.map(|(i, chunk)| {
let preview: String = chunk.chars().take(200).collect();
serde_json::json!({
"index": i + 1,
"char_count": chunk.len(),
"preview": preview,
"truncated": chunk.len() > 200,
})
})
.collect();
let result = serde_json::json!({
"chunk_size": chunk_size,
"overlap": overlap,
"total_chunks": chunks.len(),
"chunks": items,
});
Ok(ToolResult::success_json(result))
})
}
}