use std::future::Future;
use std::pin::Pin;
pub trait ToolResultExtractor: Send + Sync {
fn extract<'a>(
&'a self,
tool_name: &'a str,
output: &'a str,
user_query: &'a str,
) -> Pin<Box<dyn Future<Output = Option<ExtractedResult>> + Send + 'a>>;
fn extraction_threshold(&self) -> u32 {
15_000
}
}
#[derive(Debug, Clone)]
pub struct ExtractedResult {
pub content: String,
pub original_tokens_est: u32,
pub extracted_tokens_est: u32,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::estimate_tokens;
struct NoopExtractor;
impl ToolResultExtractor for NoopExtractor {
fn extract<'a>(
&'a self,
_tool_name: &'a str,
_output: &'a str,
_user_query: &'a str,
) -> Pin<Box<dyn Future<Output = Option<ExtractedResult>> + Send + 'a>> {
Box::pin(async { None })
}
}
struct TestExtractor {
threshold: u32,
}
impl ToolResultExtractor for TestExtractor {
fn extract<'a>(
&'a self,
tool_name: &'a str,
output: &'a str,
user_query: &'a str,
) -> Pin<Box<dyn Future<Output = Option<ExtractedResult>> + Send + 'a>> {
Box::pin(async move {
let extracted = format!(
"[Extracted from {tool_name} for query: {user_query}] \
Summary of {} chars",
output.len()
);
Some(ExtractedResult {
content: extracted.clone(),
original_tokens_est: estimate_tokens(output),
extracted_tokens_est: estimate_tokens(&extracted),
})
})
}
fn extraction_threshold(&self) -> u32 {
self.threshold
}
}
#[test]
fn test_extracted_result_debug_clone() {
let result = ExtractedResult {
content: "test".into(),
original_tokens_est: 100,
extracted_tokens_est: 10,
};
let cloned = result.clone();
assert_eq!(cloned.content, "test");
assert_eq!(format!("{result:?}").len(), format!("{cloned:?}").len());
}
#[test]
fn test_default_threshold() {
let extractor = NoopExtractor;
assert_eq!(extractor.extraction_threshold(), 15_000);
}
#[test]
fn test_custom_threshold() {
let extractor = TestExtractor { threshold: 5_000 };
assert_eq!(extractor.extraction_threshold(), 5_000);
}
#[tokio::test]
async fn test_noop_extractor_returns_none() {
let extractor = NoopExtractor;
let result = extractor.extract("web_search", "content", "query").await;
assert!(result.is_none());
}
#[tokio::test]
async fn test_extractor_returns_condensed_content() {
let extractor = TestExtractor { threshold: 10 };
let output = "a".repeat(1000);
let result = extractor
.extract("web_search", &output, "weather in Tybee")
.await;
assert!(result.is_some());
let extracted = result.unwrap();
assert!(extracted.content.contains("web_search"));
assert!(extracted.content.contains("weather in Tybee"));
assert!(extracted.extracted_tokens_est < extracted.original_tokens_est);
}
#[test]
fn test_extractor_is_object_safe() {
let extractor: Box<dyn ToolResultExtractor> = Box::new(NoopExtractor);
assert_eq!(extractor.extraction_threshold(), 15_000);
}
#[tokio::test]
async fn test_extractor_object_safe_extract() {
let extractor: Box<dyn ToolResultExtractor> = Box::new(TestExtractor { threshold: 100 });
let result = extractor.extract("tool", "data", "query").await;
assert!(result.is_some());
}
}