use std::path::PathBuf;
use async_trait::async_trait;
use super::base::{BaseRetrievalTool, RetrievalResult};
use crate::error::ToolError;
#[derive(Debug, Clone)]
pub struct FilesRetrievalTool {
files: Vec<PathBuf>,
chunk_size: usize,
chunk_overlap: usize,
}
impl FilesRetrievalTool {
pub fn new(files: Vec<PathBuf>) -> Self {
Self {
files,
chunk_size: 1000,
chunk_overlap: 200,
}
}
pub fn with_chunk_size(mut self, size: usize) -> Self {
self.chunk_size = size;
self
}
pub fn with_chunk_overlap(mut self, overlap: usize) -> Self {
self.chunk_overlap = overlap;
self
}
fn chunk_text(&self, text: &str) -> Vec<String> {
if text.len() <= self.chunk_size {
return vec![text.to_string()];
}
let mut chunks = Vec::new();
let mut start = 0;
while start < text.len() {
let end = (start + self.chunk_size).min(text.len());
chunks.push(text[start..end].to_string());
if end >= text.len() {
break;
}
start += self.chunk_size - self.chunk_overlap;
}
chunks
}
}
#[async_trait]
impl BaseRetrievalTool for FilesRetrievalTool {
fn name(&self) -> &str {
"files_retrieval"
}
async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<RetrievalResult>, ToolError> {
let query_lower = query.to_lowercase();
let mut all_results = Vec::new();
for path in &self.files {
let content = tokio::fs::read_to_string(path).await.map_err(|e| {
ToolError::ExecutionFailed(format!("Failed to read {}: {e}", path.display()))
})?;
let chunks = self.chunk_text(&content);
let source = path.display().to_string();
for chunk in &chunks {
let chunk_lower = chunk.to_lowercase();
let words: Vec<&str> = query_lower.split_whitespace().collect();
let matches = words.iter().filter(|w| chunk_lower.contains(*w)).count();
if matches > 0 {
let score = matches as f64 / words.len().max(1) as f64;
all_results.push(RetrievalResult {
content: chunk.clone(),
source: source.clone(),
score,
metadata: serde_json::Value::Null,
});
}
}
}
all_results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
all_results.truncate(top_k);
Ok(all_results)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chunk_short_text() {
let tool = FilesRetrievalTool::new(vec![]);
let chunks = tool.chunk_text("short text");
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0], "short text");
}
#[test]
fn chunk_long_text() {
let tool = FilesRetrievalTool::new(vec![])
.with_chunk_size(10)
.with_chunk_overlap(3);
let text = "abcdefghijklmnopqrstuvwxyz";
let chunks = tool.chunk_text(text);
assert!(chunks.len() > 1);
assert_eq!(chunks[0].len(), 10);
}
#[tokio::test]
async fn retrieve_from_nonexistent_file() {
let tool = FilesRetrievalTool::new(vec![PathBuf::from("/nonexistent/file.txt")]);
let result = tool.retrieve("test", 5).await;
assert!(result.is_err());
}
}