Skip to main content

rover/mcp/tools/
count_tokens.rs

1//! MCP `count_tokens` tool.
2//!
3//! Two input modes: `text` (in-process tokenization only) or `url` (shares
4//! the cached fetch pipeline with the `fetch` tool). Exactly one of the
5//! two must be provided. The body lives on [`RoverHandler`] as
6//! [`RoverHandler::count_tokens_inner`]; Task 11 wires it into the
7//! `#[tool_router]` surface.
8//!
9//! Task 12b adds an alternate `mode = "estimates"` shape (URL-only) that
10//! returns four token counts in one round-trip: `raw_html` (when stored),
11//! `extracted_md`, and two extractive summary estimates at ~250 and ~750
12//! target tokens.
13
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16use url::Url;
17
18use crate::extractor::pipeline::extract;
19use crate::fetcher::cached::{ExtractResult, FetchOptions, fetch_with_cache, sha256_hex};
20use crate::mcp::envelope::{
21    CacheStatus, CountEstimates, CountEstimatesResponse, CountResponse, CountSingleResponse,
22    CountSource,
23};
24use crate::mcp::error::McpError;
25use crate::mcp::handler::{RoverHandler, resolve_tokenizer};
26use crate::storage::pages;
27use crate::summarizer::backend::{CompactMode, CompactOpts, Style};
28use crate::summarizer::error::SummarizerError;
29use crate::tokenizer;
30
31/// `mode` arg for `count_tokens`. `Single` is the historical M2/M3 shape
32/// (one token count); `Estimates` is the M7 four-count shape that requires
33/// a URL and uses the extractive backend to estimate summary sizes.
34#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
35#[serde(rename_all = "snake_case")]
36pub enum CountTokensMode {
37    #[default]
38    Single,
39    Estimates,
40}
41
42/// Wire-side `count_tokens` arguments.
43///
44/// `tokenizer` is exposed as a string (Option 2, matching `FetchArgs`) so
45/// the JSON schema doesn't have to mirror the [`crate::tokenizer::Tokenizer`]
46/// enum's manual serde impls. Parsing happens via
47/// [`crate::mcp::handler::resolve_tokenizer`].
48#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
49#[serde(deny_unknown_fields)]
50pub struct CountTokensArgs {
51    #[serde(default)]
52    pub text: Option<String>,
53    #[serde(default)]
54    pub url: Option<String>,
55    #[serde(default)]
56    pub tokenizer: Option<String>,
57    #[serde(default)]
58    pub mode: CountTokensMode,
59}
60
61impl RoverHandler {
62    /// Tool body, decoupled from the `#[tool]` macro for unit testing.
63    pub async fn count_tokens_inner(
64        &self,
65        args: CountTokensArgs,
66    ) -> Result<CountResponse, McpError> {
67        match args.mode {
68            CountTokensMode::Single => self.count_tokens_single(args).await,
69            CountTokensMode::Estimates => self.count_tokens_estimates(args).await,
70        }
71    }
72
73    async fn count_tokens_single(&self, args: CountTokensArgs) -> Result<CountResponse, McpError> {
74        match (args.text.as_deref(), args.url.as_deref()) {
75            (Some(_), Some(_)) | (None, None) => {
76                return Err(McpError::InvalidArgs(
77                    "count_tokens requires exactly one of text or url".into(),
78                ));
79            }
80            _ => {}
81        }
82
83        let family = resolve_tokenizer(args.tokenizer.as_deref(), &self.config)?;
84        tokenizer::ensure_loaded(family).await?;
85
86        if let Some(text) = args.text {
87            let tokens = tokenizer::count(&text, family)?;
88            return Ok(CountResponse::Single(CountSingleResponse {
89                tokens,
90                tokenizer: family.as_str().to_string(),
91                source: CountSource::Text,
92                url: None,
93                content_hash: None,
94                fetched_at: None,
95                cache_status: None,
96            }));
97        }
98
99        // URL mode: share the cached fetch + extract pipeline.
100        let url_str = args.url.expect("validated above");
101        let url = Url::parse(&url_str).map_err(|e| McpError::InvalidUrl(e.to_string()))?;
102
103        let result = fetch_with_cache(
104            &self.db,
105            &self.client,
106            &self.pacer,
107            &self.config.rate_limit,
108            &self.config.robots,
109            &url,
110            &self.config.cache,
111            FetchOptions {
112                force_refresh: false,
113                ssrf_level: self.ssrf_level,
114                ssrf_project_root: self.ssrf_project_root.clone(),
115                har_recorder: self.har_recorder.clone(),
116                ignore_robots: false,
117                user_agent: self.config.fetch.user_agent.clone(),
118                #[cfg(feature = "headless")]
119                headless: None,
120                headless_mode: crate::fetcher::HeadlessMode::Off,
121                synchronous_revalidation: false,
122            },
123            |body, base| {
124                let extracted =
125                    extract(body, Some(base)).map_err(crate::fetcher::FetcherError::Extract)?;
126                let content_hash = format!("sha256:{}", sha256_hex(extracted.body_md.as_bytes()));
127                Ok(ExtractResult {
128                    title: extracted.title,
129                    body_md: extracted.body_md,
130                    content_hash,
131                    metadata: extracted.metadata,
132                })
133            },
134        )
135        .await?;
136
137        let tokens = tokenizer::count(&result.page.extracted_md, family)?;
138        let cache_status: CacheStatus = result.cache_status.into();
139
140        Ok(CountResponse::Single(CountSingleResponse {
141            tokens,
142            tokenizer: family.as_str().to_string(),
143            source: CountSource::Url,
144            url: Some(url.as_str().to_string()),
145            content_hash: Some(result.page.content_hash.clone()),
146            fetched_at: Some(
147                jiff::Timestamp::from_second(result.page.fetched_at)
148                    .map(|t| t.to_string())
149                    .unwrap_or_default(),
150            ),
151            cache_status: Some(cache_status),
152        }))
153    }
154
155    /// `mode = "estimates"` arm. Requires `url`; tokenizes the extracted
156    /// markdown, the (optional) cached raw HTML, and two extractive-summary
157    /// estimates at ~250 and ~750 target tokens.
158    async fn count_tokens_estimates(
159        &self,
160        args: CountTokensArgs,
161    ) -> Result<CountResponse, McpError> {
162        if args.text.is_some() {
163            return Err(McpError::InvalidArgs(
164                "count_tokens mode=\"estimates\" does not accept text; provide url".into(),
165            ));
166        }
167        let url_str = args.url.ok_or_else(|| {
168            McpError::InvalidArgs("count_tokens mode=\"estimates\" requires url".into())
169        })?;
170        let url = Url::parse(&url_str).map_err(|e| McpError::InvalidUrl(e.to_string()))?;
171
172        // Resolve the extractive backend strictly: the spec says estimates
173        // ALWAYS run on extractive (never cloud), so if no extractive fallback
174        // is configured we surface the dedicated error code.
175        let extractive_name = self
176            .summarizer
177            .registry()
178            .extractive_fallback_name()
179            .ok_or(SummarizerError::NoExtractiveBackendForFallback)?
180            .to_string();
181
182        let family = resolve_tokenizer(args.tokenizer.as_deref(), &self.config)?;
183        tokenizer::ensure_loaded(family).await?;
184
185        let result = fetch_with_cache(
186            &self.db,
187            &self.client,
188            &self.pacer,
189            &self.config.rate_limit,
190            &self.config.robots,
191            &url,
192            &self.config.cache,
193            FetchOptions {
194                force_refresh: false,
195                ssrf_level: self.ssrf_level,
196                ssrf_project_root: self.ssrf_project_root.clone(),
197                har_recorder: self.har_recorder.clone(),
198                ignore_robots: false,
199                user_agent: self.config.fetch.user_agent.clone(),
200                #[cfg(feature = "headless")]
201                headless: None,
202                headless_mode: crate::fetcher::HeadlessMode::Off,
203                synchronous_revalidation: false,
204            },
205            |body, base| {
206                let extracted =
207                    extract(body, Some(base)).map_err(crate::fetcher::FetcherError::Extract)?;
208                let content_hash = format!("sha256:{}", sha256_hex(extracted.body_md.as_bytes()));
209                Ok(ExtractResult {
210                    title: extracted.title,
211                    body_md: extracted.body_md,
212                    content_hash,
213                    metadata: extracted.metadata,
214                })
215            },
216        )
217        .await?;
218
219        let extracted_md_tokens = tokenizer::count(&result.page.extracted_md, family)?;
220
221        // Optional raw-html token count: only present when the row carries a
222        // populated `raw_html_zstd` blob AND it decodes cleanly. Any decode
223        // error degrades silently to `None`.
224        let url_hash = pages::url_hash(url.as_str());
225        let raw_html_tokens: Option<usize> = match pages::raw_html_bytes(&self.db, &url_hash).await
226        {
227            Ok(Some(blob)) => zstd::stream::decode_all(blob.as_slice())
228                .ok()
229                .and_then(|bytes| String::from_utf8(bytes).ok())
230                .and_then(|s| tokenizer::count(&s, family).ok()),
231            Ok(None) => None,
232            Err(_) => None,
233        };
234
235        // Two summary-token estimates via the extractive backend, both keyed
236        // on the same content_hash. Run sequentially: the extractive backend
237        // is CPU-only and bounded; parallelism gives no real win here.
238        let content_hash = &result.page.content_hash;
239        let extracted_md = &result.page.extracted_md;
240        let short_opts = CompactOpts {
241            mode: CompactMode::Extractive,
242            style: Style::Bullet,
243            target_tokens: Some(250),
244            focus: None,
245            preserve: vec![],
246            backend_name: extractive_name.clone(),
247        };
248        let medium_opts = CompactOpts {
249            mode: CompactMode::Extractive,
250            style: Style::Bullet,
251            target_tokens: Some(750),
252            focus: None,
253            preserve: vec![],
254            backend_name: extractive_name,
255        };
256        let short = self
257            .summarizer
258            .compact(content_hash, extracted_md, &short_opts)
259            .await?;
260        let medium = self
261            .summarizer
262            .compact(content_hash, extracted_md, &medium_opts)
263            .await?;
264        let summary_short_tokens = tokenizer::count(&short.summary_md, family)?;
265        let summary_medium_tokens = tokenizer::count(&medium.summary_md, family)?;
266
267        Ok(CountResponse::Estimates(CountEstimatesResponse {
268            url: url.as_str().to_string(),
269            tokenizer: family.as_str().to_string(),
270            estimates: CountEstimates {
271                raw_html: raw_html_tokens,
272                extracted_md: extracted_md_tokens,
273                summary_short: summary_short_tokens,
274                summary_medium: summary_medium_tokens,
275            },
276        }))
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use crate::mcp::envelope::RoverError;
284
285    /// Build a handler whose db/client suffice for the synchronous
286    /// validation paths. URL-mode happy-path coverage lives in
287    /// `tests/mcp_smoke.rs` (Task 13). The unit tests below all error
288    /// before `ensure_loaded` runs, so the global tokenizer registry is
289    /// never touched and no network I/O happens.
290    async fn fake_handler() -> (RoverHandler, tempfile::TempDir) {
291        let cfg = std::sync::Arc::new(crate::config::Config::default());
292        crate::fetcher::client::install_ring_provider();
293        let client = reqwest::Client::new();
294        let tmp = tempfile::tempdir().unwrap();
295        let path = tmp.path().join("rover.db");
296        let db = crate::storage::Db::open(&path).await.unwrap();
297        let pacer = std::sync::Arc::new(crate::fetcher::concurrency::Pacer::new(&cfg.rate_limit));
298        let summarizer = {
299            let mut map: std::collections::HashMap<
300                String,
301                std::sync::Arc<dyn crate::summarizer::backend::SummarizerBackend>,
302            > = Default::default();
303            map.insert(
304                "default".into(),
305                std::sync::Arc::new(crate::summarizer::extractive::ExtractiveBackend::new(
306                    "default",
307                    crate::tokenizer::Tokenizer::O200k,
308                )),
309            );
310            let reg = std::sync::Arc::new(
311                crate::summarizer::registry::SummarizerRegistry::__test_construct(
312                    map,
313                    "default".into(),
314                    Some("default".into()),
315                ),
316            );
317            std::sync::Arc::new(crate::summarizer::SummarizerService::new(
318                db.clone(),
319                reg,
320                true,
321            ))
322        };
323        let captioners = std::sync::Arc::new(crate::vlm::CaptionerRegistry::empty());
324        (
325            RoverHandler::new(
326                db,
327                cfg,
328                client,
329                crate::fetcher::ssrf::SsrfLevel::Strict,
330                None,
331                None,
332                pacer,
333                summarizer,
334                captioners,
335                std::sync::Arc::new(
336                    crate::guard::Guard::from_config(
337                        &crate::config::Config::default().prompt_injection,
338                    )
339                    .unwrap(),
340                ),
341                #[cfg(feature = "headless")]
342                std::sync::Arc::new(tokio::sync::OnceCell::new()),
343            ),
344            tmp,
345        )
346    }
347
348    #[tokio::test]
349    async fn rejects_both_text_and_url() {
350        let (h, _tmp) = fake_handler().await;
351        let err = h
352            .count_tokens_inner(CountTokensArgs {
353                text: Some("hi".into()),
354                url: Some("https://example.com".into()),
355                tokenizer: None,
356                mode: CountTokensMode::Single,
357            })
358            .await
359            .unwrap_err();
360        let r = err.into_rover_error();
361        assert_eq!(r.code, RoverError::INVALID_ARGS);
362    }
363
364    #[tokio::test]
365    async fn rejects_neither() {
366        let (h, _tmp) = fake_handler().await;
367        let err = h
368            .count_tokens_inner(CountTokensArgs::default())
369            .await
370            .unwrap_err();
371        let r = err.into_rover_error();
372        assert_eq!(r.code, RoverError::INVALID_ARGS);
373    }
374
375    #[tokio::test]
376    async fn rejects_unknown_tokenizer() {
377        let (h, _tmp) = fake_handler().await;
378        let err = h
379            .count_tokens_inner(CountTokensArgs {
380                text: Some("hi".into()),
381                url: None,
382                tokenizer: Some("gpt-5".into()),
383                mode: CountTokensMode::Single,
384            })
385            .await
386            .unwrap_err();
387        let r = err.into_rover_error();
388        assert_eq!(r.code, RoverError::INVALID_ARGS);
389    }
390
391    #[tokio::test]
392    async fn estimates_mode_rejects_text_arg() {
393        let (h, _tmp) = fake_handler().await;
394        let err = h
395            .count_tokens_inner(CountTokensArgs {
396                text: Some("hi".into()),
397                url: None,
398                tokenizer: None,
399                mode: CountTokensMode::Estimates,
400            })
401            .await
402            .unwrap_err();
403        let r = err.into_rover_error();
404        assert_eq!(r.code, RoverError::INVALID_ARGS);
405    }
406
407    #[tokio::test]
408    async fn estimates_mode_requires_url() {
409        let (h, _tmp) = fake_handler().await;
410        let err = h
411            .count_tokens_inner(CountTokensArgs {
412                text: None,
413                url: None,
414                tokenizer: None,
415                mode: CountTokensMode::Estimates,
416            })
417            .await
418            .unwrap_err();
419        let r = err.into_rover_error();
420        assert_eq!(r.code, RoverError::INVALID_ARGS);
421    }
422
423    #[test]
424    fn schema_contains_all_fields() {
425        let schema = schemars::schema_for!(CountTokensArgs);
426        let json = serde_json::to_string(&schema).unwrap();
427        for f in ["text", "url", "tokenizer", "mode"] {
428            assert!(json.contains(f), "missing {f}");
429        }
430    }
431}