1use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16use url::Url;
17
18use crate::extractor::pipeline::extract;
19use crate::fetcher::cached::{ExtractResult, FetchOptions, fetch_with_cache, sha256_hex};
20use crate::mcp::envelope::{
21 CacheStatus, CountEstimates, CountEstimatesResponse, CountResponse, CountSingleResponse,
22 CountSource,
23};
24use crate::mcp::error::McpError;
25use crate::mcp::handler::{RoverHandler, resolve_tokenizer};
26use crate::storage::pages;
27use crate::summarizer::backend::{CompactMode, CompactOpts, Style};
28use crate::summarizer::error::SummarizerError;
29use crate::tokenizer;
30
31#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
35#[serde(rename_all = "snake_case")]
36pub enum CountTokensMode {
37 #[default]
38 Single,
39 Estimates,
40}
41
42#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
49#[serde(deny_unknown_fields)]
50pub struct CountTokensArgs {
51 #[serde(default)]
52 pub text: Option<String>,
53 #[serde(default)]
54 pub url: Option<String>,
55 #[serde(default)]
56 pub tokenizer: Option<String>,
57 #[serde(default)]
58 pub mode: CountTokensMode,
59}
60
61impl RoverHandler {
62 pub async fn count_tokens_inner(
64 &self,
65 args: CountTokensArgs,
66 ) -> Result<CountResponse, McpError> {
67 match args.mode {
68 CountTokensMode::Single => self.count_tokens_single(args).await,
69 CountTokensMode::Estimates => self.count_tokens_estimates(args).await,
70 }
71 }
72
73 async fn count_tokens_single(&self, args: CountTokensArgs) -> Result<CountResponse, McpError> {
74 match (args.text.as_deref(), args.url.as_deref()) {
75 (Some(_), Some(_)) | (None, None) => {
76 return Err(McpError::InvalidArgs(
77 "count_tokens requires exactly one of text or url".into(),
78 ));
79 }
80 _ => {}
81 }
82
83 let family = resolve_tokenizer(args.tokenizer.as_deref(), &self.config)?;
84 tokenizer::ensure_loaded(family).await?;
85
86 if let Some(text) = args.text {
87 let tokens = tokenizer::count(&text, family)?;
88 return Ok(CountResponse::Single(CountSingleResponse {
89 tokens,
90 tokenizer: family.as_str().to_string(),
91 source: CountSource::Text,
92 url: None,
93 content_hash: None,
94 fetched_at: None,
95 cache_status: None,
96 }));
97 }
98
99 let url_str = args.url.expect("validated above");
101 let url = Url::parse(&url_str).map_err(|e| McpError::InvalidUrl(e.to_string()))?;
102
103 let result = fetch_with_cache(
104 &self.db,
105 &self.client,
106 &self.pacer,
107 &self.config.rate_limit,
108 &self.config.robots,
109 &url,
110 &self.config.cache,
111 FetchOptions {
112 force_refresh: false,
113 ssrf_level: self.ssrf_level,
114 ssrf_project_root: self.ssrf_project_root.clone(),
115 har_recorder: self.har_recorder.clone(),
116 ignore_robots: false,
117 user_agent: self.config.fetch.user_agent.clone(),
118 #[cfg(feature = "headless")]
119 headless: None,
120 headless_mode: crate::fetcher::HeadlessMode::Off,
121 synchronous_revalidation: false,
122 },
123 |body, base| {
124 let extracted =
125 extract(body, Some(base)).map_err(crate::fetcher::FetcherError::Extract)?;
126 let content_hash = format!("sha256:{}", sha256_hex(extracted.body_md.as_bytes()));
127 Ok(ExtractResult {
128 title: extracted.title,
129 body_md: extracted.body_md,
130 content_hash,
131 metadata: extracted.metadata,
132 })
133 },
134 )
135 .await?;
136
137 let tokens = tokenizer::count(&result.page.extracted_md, family)?;
138 let cache_status: CacheStatus = result.cache_status.into();
139
140 Ok(CountResponse::Single(CountSingleResponse {
141 tokens,
142 tokenizer: family.as_str().to_string(),
143 source: CountSource::Url,
144 url: Some(url.as_str().to_string()),
145 content_hash: Some(result.page.content_hash.clone()),
146 fetched_at: Some(
147 jiff::Timestamp::from_second(result.page.fetched_at)
148 .map(|t| t.to_string())
149 .unwrap_or_default(),
150 ),
151 cache_status: Some(cache_status),
152 }))
153 }
154
155 async fn count_tokens_estimates(
159 &self,
160 args: CountTokensArgs,
161 ) -> Result<CountResponse, McpError> {
162 if args.text.is_some() {
163 return Err(McpError::InvalidArgs(
164 "count_tokens mode=\"estimates\" does not accept text; provide url".into(),
165 ));
166 }
167 let url_str = args.url.ok_or_else(|| {
168 McpError::InvalidArgs("count_tokens mode=\"estimates\" requires url".into())
169 })?;
170 let url = Url::parse(&url_str).map_err(|e| McpError::InvalidUrl(e.to_string()))?;
171
172 let extractive_name = self
176 .summarizer
177 .registry()
178 .extractive_fallback_name()
179 .ok_or(SummarizerError::NoExtractiveBackendForFallback)?
180 .to_string();
181
182 let family = resolve_tokenizer(args.tokenizer.as_deref(), &self.config)?;
183 tokenizer::ensure_loaded(family).await?;
184
185 let result = fetch_with_cache(
186 &self.db,
187 &self.client,
188 &self.pacer,
189 &self.config.rate_limit,
190 &self.config.robots,
191 &url,
192 &self.config.cache,
193 FetchOptions {
194 force_refresh: false,
195 ssrf_level: self.ssrf_level,
196 ssrf_project_root: self.ssrf_project_root.clone(),
197 har_recorder: self.har_recorder.clone(),
198 ignore_robots: false,
199 user_agent: self.config.fetch.user_agent.clone(),
200 #[cfg(feature = "headless")]
201 headless: None,
202 headless_mode: crate::fetcher::HeadlessMode::Off,
203 synchronous_revalidation: false,
204 },
205 |body, base| {
206 let extracted =
207 extract(body, Some(base)).map_err(crate::fetcher::FetcherError::Extract)?;
208 let content_hash = format!("sha256:{}", sha256_hex(extracted.body_md.as_bytes()));
209 Ok(ExtractResult {
210 title: extracted.title,
211 body_md: extracted.body_md,
212 content_hash,
213 metadata: extracted.metadata,
214 })
215 },
216 )
217 .await?;
218
219 let extracted_md_tokens = tokenizer::count(&result.page.extracted_md, family)?;
220
221 let url_hash = pages::url_hash(url.as_str());
225 let raw_html_tokens: Option<usize> = match pages::raw_html_bytes(&self.db, &url_hash).await
226 {
227 Ok(Some(blob)) => zstd::stream::decode_all(blob.as_slice())
228 .ok()
229 .and_then(|bytes| String::from_utf8(bytes).ok())
230 .and_then(|s| tokenizer::count(&s, family).ok()),
231 Ok(None) => None,
232 Err(_) => None,
233 };
234
235 let content_hash = &result.page.content_hash;
239 let extracted_md = &result.page.extracted_md;
240 let short_opts = CompactOpts {
241 mode: CompactMode::Extractive,
242 style: Style::Bullet,
243 target_tokens: Some(250),
244 focus: None,
245 preserve: vec![],
246 backend_name: extractive_name.clone(),
247 };
248 let medium_opts = CompactOpts {
249 mode: CompactMode::Extractive,
250 style: Style::Bullet,
251 target_tokens: Some(750),
252 focus: None,
253 preserve: vec![],
254 backend_name: extractive_name,
255 };
256 let short = self
257 .summarizer
258 .compact(content_hash, extracted_md, &short_opts)
259 .await?;
260 let medium = self
261 .summarizer
262 .compact(content_hash, extracted_md, &medium_opts)
263 .await?;
264 let summary_short_tokens = tokenizer::count(&short.summary_md, family)?;
265 let summary_medium_tokens = tokenizer::count(&medium.summary_md, family)?;
266
267 Ok(CountResponse::Estimates(CountEstimatesResponse {
268 url: url.as_str().to_string(),
269 tokenizer: family.as_str().to_string(),
270 estimates: CountEstimates {
271 raw_html: raw_html_tokens,
272 extracted_md: extracted_md_tokens,
273 summary_short: summary_short_tokens,
274 summary_medium: summary_medium_tokens,
275 },
276 }))
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283 use crate::mcp::envelope::RoverError;
284
285 async fn fake_handler() -> (RoverHandler, tempfile::TempDir) {
291 let cfg = std::sync::Arc::new(crate::config::Config::default());
292 crate::fetcher::client::install_ring_provider();
293 let client = reqwest::Client::new();
294 let tmp = tempfile::tempdir().unwrap();
295 let path = tmp.path().join("rover.db");
296 let db = crate::storage::Db::open(&path).await.unwrap();
297 let pacer = std::sync::Arc::new(crate::fetcher::concurrency::Pacer::new(&cfg.rate_limit));
298 let summarizer = {
299 let mut map: std::collections::HashMap<
300 String,
301 std::sync::Arc<dyn crate::summarizer::backend::SummarizerBackend>,
302 > = Default::default();
303 map.insert(
304 "default".into(),
305 std::sync::Arc::new(crate::summarizer::extractive::ExtractiveBackend::new(
306 "default",
307 crate::tokenizer::Tokenizer::O200k,
308 )),
309 );
310 let reg = std::sync::Arc::new(
311 crate::summarizer::registry::SummarizerRegistry::__test_construct(
312 map,
313 "default".into(),
314 Some("default".into()),
315 ),
316 );
317 std::sync::Arc::new(crate::summarizer::SummarizerService::new(
318 db.clone(),
319 reg,
320 true,
321 ))
322 };
323 let captioners = std::sync::Arc::new(crate::vlm::CaptionerRegistry::empty());
324 (
325 RoverHandler::new(
326 db,
327 cfg,
328 client,
329 crate::fetcher::ssrf::SsrfLevel::Strict,
330 None,
331 None,
332 pacer,
333 summarizer,
334 captioners,
335 std::sync::Arc::new(
336 crate::guard::Guard::from_config(
337 &crate::config::Config::default().prompt_injection,
338 )
339 .unwrap(),
340 ),
341 #[cfg(feature = "headless")]
342 std::sync::Arc::new(tokio::sync::OnceCell::new()),
343 ),
344 tmp,
345 )
346 }
347
348 #[tokio::test]
349 async fn rejects_both_text_and_url() {
350 let (h, _tmp) = fake_handler().await;
351 let err = h
352 .count_tokens_inner(CountTokensArgs {
353 text: Some("hi".into()),
354 url: Some("https://example.com".into()),
355 tokenizer: None,
356 mode: CountTokensMode::Single,
357 })
358 .await
359 .unwrap_err();
360 let r = err.into_rover_error();
361 assert_eq!(r.code, RoverError::INVALID_ARGS);
362 }
363
364 #[tokio::test]
365 async fn rejects_neither() {
366 let (h, _tmp) = fake_handler().await;
367 let err = h
368 .count_tokens_inner(CountTokensArgs::default())
369 .await
370 .unwrap_err();
371 let r = err.into_rover_error();
372 assert_eq!(r.code, RoverError::INVALID_ARGS);
373 }
374
375 #[tokio::test]
376 async fn rejects_unknown_tokenizer() {
377 let (h, _tmp) = fake_handler().await;
378 let err = h
379 .count_tokens_inner(CountTokensArgs {
380 text: Some("hi".into()),
381 url: None,
382 tokenizer: Some("gpt-5".into()),
383 mode: CountTokensMode::Single,
384 })
385 .await
386 .unwrap_err();
387 let r = err.into_rover_error();
388 assert_eq!(r.code, RoverError::INVALID_ARGS);
389 }
390
391 #[tokio::test]
392 async fn estimates_mode_rejects_text_arg() {
393 let (h, _tmp) = fake_handler().await;
394 let err = h
395 .count_tokens_inner(CountTokensArgs {
396 text: Some("hi".into()),
397 url: None,
398 tokenizer: None,
399 mode: CountTokensMode::Estimates,
400 })
401 .await
402 .unwrap_err();
403 let r = err.into_rover_error();
404 assert_eq!(r.code, RoverError::INVALID_ARGS);
405 }
406
407 #[tokio::test]
408 async fn estimates_mode_requires_url() {
409 let (h, _tmp) = fake_handler().await;
410 let err = h
411 .count_tokens_inner(CountTokensArgs {
412 text: None,
413 url: None,
414 tokenizer: None,
415 mode: CountTokensMode::Estimates,
416 })
417 .await
418 .unwrap_err();
419 let r = err.into_rover_error();
420 assert_eq!(r.code, RoverError::INVALID_ARGS);
421 }
422
423 #[test]
424 fn schema_contains_all_fields() {
425 let schema = schemars::schema_for!(CountTokensArgs);
426 let json = serde_json::to_string(&schema).unwrap();
427 for f in ["text", "url", "tokenizer", "mode"] {
428 assert!(json.contains(f), "missing {f}");
429 }
430 }
431}