Skip to main content

llm_stack/tool/
extractor.rs

1//! Async semantic extraction for large tool results.
2//!
3//! The [`ToolResultExtractor`] trait provides an async post-processing stage
4//! that runs **after** structural pruning ([`ToolResultProcessor`]) but
5//! **before** out-of-context caching ([`ToolResultCacher`]).
6//!
7//! Use this for heavyweight transformations that require async work, such as
8//! calling a fast/cheap LLM (Haiku-class) to extract task-relevant information
9//! from large tool results.
10//!
11//! The extractor receives the tool name, the (already-pruned) output, and the
12//! last user message for relevance-guided extraction. It returns an
13//! [`ExtractedResult`] with the condensed content.
14//!
15//! # Pipeline position
16//!
17//! ```text
18//!   Tool executes
19//!        │
20//!   Stage 1: ToolResultProcessor::process()   (sync, structural)
21//!        │
22//!   Stage 2: ToolResultExtractor::extract()   (async, semantic)
23//!        │
24//!   Stage 3: ToolResultCacher::cache()        (sync, overflow)
25//!        │
26//!   Result enters conversation context
27//! ```
28//!
29//! # Example
30//!
31//! ```rust
32//! use llm_stack::tool::{ToolResultExtractor, ExtractedResult};
33//! use llm_stack::context::estimate_tokens;
34//! use std::future::Future;
35//! use std::pin::Pin;
36//!
37//! struct KeywordExtractor;
38//!
39//! impl ToolResultExtractor for KeywordExtractor {
40//!     fn extract<'a>(
41//!         &'a self,
42//!         tool_name: &'a str,
43//!         output: &'a str,
44//!         user_query: &'a str,
45//!     ) -> Pin<Box<dyn Future<Output = Option<ExtractedResult>> + Send + 'a>> {
46//!         Box::pin(async move {
47//!             if tool_name != "web_search" || output.len() < 10_000 {
48//!                 return None;
49//!             }
50//!             let extracted = format!("Extracted relevant info about: {user_query}");
51//!             Some(ExtractedResult {
52//!                 content: extracted.clone(),
53//!                 original_tokens_est: estimate_tokens(output),
54//!                 extracted_tokens_est: estimate_tokens(&extracted),
55//!             })
56//!         })
57//!     }
58//!
59//!     fn extraction_threshold(&self) -> u32 {
60//!         15_000
61//!     }
62//! }
63//! ```
64
65use std::future::Future;
66use std::pin::Pin;
67
68/// Async semantic extractor for oversized tool results.
69///
70/// Implementations run after structural pruning and can perform heavyweight
71/// async work (e.g., calling a fast LLM) to condense large results into
72/// task-relevant summaries.
73///
74/// The extractor receives the last user message as context to guide
75/// relevance-based extraction.
76pub trait ToolResultExtractor: Send + Sync {
77    /// Extract task-relevant information from a tool result.
78    ///
79    /// # Arguments
80    ///
81    /// * `tool_name` — The name of the tool that produced this result.
82    /// * `output` — The (already structurally pruned) output string.
83    /// * `user_query` — The most recent user message, for relevance guidance.
84    ///
85    /// Return `None` to skip extraction (keep the structurally-pruned content).
86    /// The extractor is only called for results exceeding
87    /// [`extraction_threshold`](Self::extraction_threshold) tokens.
88    fn extract<'a>(
89        &'a self,
90        tool_name: &'a str,
91        output: &'a str,
92        user_query: &'a str,
93    ) -> Pin<Box<dyn Future<Output = Option<ExtractedResult>> + Send + 'a>>;
94
95    /// Token threshold above which results are offered to the extractor.
96    ///
97    /// Results at or below this size skip semantic extraction entirely.
98    /// Default: 15 000 tokens (~60 000 chars).
99    fn extraction_threshold(&self) -> u32 {
100        15_000
101    }
102}
103
104/// The result of semantic extraction.
105#[derive(Debug, Clone)]
106pub struct ExtractedResult {
107    /// The condensed, task-relevant content.
108    pub content: String,
109    /// Estimated token count of the pre-extraction content.
110    pub original_tokens_est: u32,
111    /// Estimated token count of the extracted content.
112    pub extracted_tokens_est: u32,
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::context::estimate_tokens;
119
120    struct NoopExtractor;
121
122    impl ToolResultExtractor for NoopExtractor {
123        fn extract<'a>(
124            &'a self,
125            _tool_name: &'a str,
126            _output: &'a str,
127            _user_query: &'a str,
128        ) -> Pin<Box<dyn Future<Output = Option<ExtractedResult>> + Send + 'a>> {
129            Box::pin(async { None })
130        }
131    }
132
133    struct TestExtractor {
134        threshold: u32,
135    }
136
137    impl ToolResultExtractor for TestExtractor {
138        fn extract<'a>(
139            &'a self,
140            tool_name: &'a str,
141            output: &'a str,
142            user_query: &'a str,
143        ) -> Pin<Box<dyn Future<Output = Option<ExtractedResult>> + Send + 'a>> {
144            Box::pin(async move {
145                let extracted = format!(
146                    "[Extracted from {tool_name} for query: {user_query}] \
147                     Summary of {} chars",
148                    output.len()
149                );
150                Some(ExtractedResult {
151                    content: extracted.clone(),
152                    original_tokens_est: estimate_tokens(output),
153                    extracted_tokens_est: estimate_tokens(&extracted),
154                })
155            })
156        }
157
158        fn extraction_threshold(&self) -> u32 {
159            self.threshold
160        }
161    }
162
163    #[test]
164    fn test_extracted_result_debug_clone() {
165        let result = ExtractedResult {
166            content: "test".into(),
167            original_tokens_est: 100,
168            extracted_tokens_est: 10,
169        };
170        let cloned = result.clone();
171        assert_eq!(cloned.content, "test");
172        assert_eq!(format!("{result:?}").len(), format!("{cloned:?}").len());
173    }
174
175    #[test]
176    fn test_default_threshold() {
177        let extractor = NoopExtractor;
178        assert_eq!(extractor.extraction_threshold(), 15_000);
179    }
180
181    #[test]
182    fn test_custom_threshold() {
183        let extractor = TestExtractor { threshold: 5_000 };
184        assert_eq!(extractor.extraction_threshold(), 5_000);
185    }
186
187    #[tokio::test]
188    async fn test_noop_extractor_returns_none() {
189        let extractor = NoopExtractor;
190        let result = extractor.extract("web_search", "content", "query").await;
191        assert!(result.is_none());
192    }
193
194    #[tokio::test]
195    async fn test_extractor_returns_condensed_content() {
196        let extractor = TestExtractor { threshold: 10 };
197        let output = "a".repeat(1000);
198        let result = extractor
199            .extract("web_search", &output, "weather in Tybee")
200            .await;
201        assert!(result.is_some());
202        let extracted = result.unwrap();
203        assert!(extracted.content.contains("web_search"));
204        assert!(extracted.content.contains("weather in Tybee"));
205        assert!(extracted.extracted_tokens_est < extracted.original_tokens_est);
206    }
207
208    #[test]
209    fn test_extractor_is_object_safe() {
210        let extractor: Box<dyn ToolResultExtractor> = Box::new(NoopExtractor);
211        assert_eq!(extractor.extraction_threshold(), 15_000);
212    }
213
214    #[tokio::test]
215    async fn test_extractor_object_safe_extract() {
216        let extractor: Box<dyn ToolResultExtractor> = Box::new(TestExtractor { threshold: 100 });
217        let result = extractor.extract("tool", "data", "query").await;
218        assert!(result.is_some());
219    }
220}