use std::sync::Arc;
use std::sync::OnceLock;
use rust_mcp_sdk::macros::{JsonSchema, mcp_tool};
use rust_mcp_sdk::schema::{CallToolResult, TextContent, schema_utils::CallToolError};
use serde::{Deserialize, Serialize};
use nab::watch::{AddOptions, WatchManager, WatchOptions};
static WATCH_MANAGER: OnceLock<Arc<WatchManager>> = OnceLock::new();
pub fn init_watch_manager(mgr: Arc<WatchManager>) {
WATCH_MANAGER.set(mgr).ok(); }
pub fn get_watch_manager() -> Arc<WatchManager> {
WATCH_MANAGER
.get()
.expect("WatchManager not initialized — call init_watch_manager() first")
.clone()
}
#[mcp_tool(
name = "watch_create",
description = "Create a URL watch that emits MCP resource notifications when content changes.\n\
\n\
The watch appears as a subscribable resource at `nab://watch/<id>`. Subscribe to it via\n\
`resources/subscribe` to receive `notifications/resources/updated` when the page changes.\n\
\n\
**interval**: duration string — `30s`, `5m`, `1h`, `24h` (default: `1h`).\n\
**selector**: CSS selector to watch only a specific element (e.g. `#price`, `table.pricing`).\n\
**diff_kind**: `text` (default) | `semantic` | `dom` — comparison algorithm."
)]
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
pub struct WatchCreateTool {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub selector: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub interval: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub diff_kind: Option<String>,
}
impl WatchCreateTool {
pub async fn run(self) -> Result<CallToolResult, CallToolError> {
let interval_secs = parse_interval(self.interval.as_deref())
.map_err(|e| CallToolError::from_message(format!("Invalid interval: {e}")))?;
let diff_kind = parse_diff_kind(self.diff_kind.as_deref())
.map_err(|e| CallToolError::from_message(format!("Invalid diff_kind: {e}")))?;
let opts = AddOptions {
selector: self.selector.clone(),
interval_secs,
options: WatchOptions {
diff_kind,
..WatchOptions::default()
},
};
let mgr = get_watch_manager();
let id = mgr
.add(&self.url, opts)
.await
.map_err(|e| CallToolError::from_message(format!("Failed to create watch: {e}")))?;
let text = format!(
"Watch created.\n\n\
- **ID**: `{id}`\n\
- **URL**: {url}\n\
- **Resource URI**: `nab://watch/{id}`\n\
- **Interval**: {interval_secs}s\n\
{selector}\
\n\
Subscribe with `resources/subscribe` and URI `nab://watch/{id}` to receive \
`notifications/resources/updated` when content changes.",
url = self.url,
selector = self
.selector
.as_deref()
.map(|s| format!("- **Selector**: `{s}`\n"))
.unwrap_or_default(),
);
Ok(CallToolResult::text_content(vec![TextContent::from(text)]))
}
}
#[mcp_tool(
name = "watch_list",
description = "List all active URL watches with their IDs, URLs, and polling status."
)]
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
pub struct WatchListTool {}
impl WatchListTool {
pub async fn run(self) -> Result<CallToolResult, CallToolError> {
let mgr = get_watch_manager();
let watches = mgr.list().await;
if watches.is_empty() {
return Ok(CallToolResult::text_content(vec![TextContent::from(
"No watches registered. Use `watch_create` to add one.".to_owned(),
)]));
}
let mut lines = vec![format!("## Active watches ({})\n", watches.len())];
for w in &watches {
lines.push(format!(
"### `{id}` — {url}\n\
- **Resource URI**: `nab://watch/{id}`\n\
- **Interval**: {interval}s{muted}\n\
- **Last checked**: {last_check}\n\
- **Last changed**: {last_change}\n\
- **Snapshots**: {snaps}\n",
id = w.id,
url = w.url,
interval = w.interval_secs,
muted = if w.interval_secs == 0 {
" **(muted)**"
} else {
""
},
last_check = w.last_check_at.map_or_else(
|| "never".into(),
|t| t.format("%Y-%m-%dT%H:%M:%SZ").to_string()
),
last_change = w.last_change_at.map_or_else(
|| "never".into(),
|t| t.format("%Y-%m-%dT%H:%M:%SZ").to_string()
),
snaps = w.snapshots.len(),
));
}
Ok(CallToolResult::text_content(vec![TextContent::from(
lines.join("\n"),
)]))
}
}
#[mcp_tool(
name = "watch_remove",
description = "Remove a URL watch by ID. Use `watch_list` to find the ID."
)]
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
pub struct WatchRemoveTool {
pub id: String,
}
impl WatchRemoveTool {
pub async fn run(self) -> Result<CallToolResult, CallToolError> {
let mgr = get_watch_manager();
mgr.remove(&self.id).await.map_err(|e| {
CallToolError::from_message(format!("Failed to remove watch '{}': {e}", self.id))
})?;
Ok(CallToolResult::text_content(vec![TextContent::from(
format!("Watch `{}` removed.", self.id),
)]))
}
}
fn parse_interval(s: Option<&str>) -> Result<u64, String> {
let s = match s {
None | Some("") => return Ok(3600),
Some(s) => s.trim(),
};
if let Some(rest) = s.strip_suffix('s') {
return rest
.parse::<u64>()
.map_err(|_| format!("bad seconds value: '{rest}'"));
}
if let Some(rest) = s.strip_suffix('m') {
return rest
.parse::<u64>()
.map(|v| v * 60)
.map_err(|_| format!("bad minutes value: '{rest}'"));
}
if let Some(rest) = s.strip_suffix('h') {
return rest
.parse::<u64>()
.map(|v| v * 3600)
.map_err(|_| format!("bad hours value: '{rest}'"));
}
s.parse::<u64>()
.map_err(|_| format!("unrecognised duration: '{s}'"))
}
fn parse_diff_kind(s: Option<&str>) -> Result<nab::watch::DiffKind, String> {
use nab::watch::DiffKind;
match s.unwrap_or("text") {
"text" | "" => Ok(DiffKind::Text),
"semantic" => Ok(DiffKind::Semantic),
"dom" => Ok(DiffKind::Dom),
other => Err(format!(
"unknown diff_kind '{other}'; use text | semantic | dom"
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_interval_seconds() {
assert_eq!(parse_interval(Some("30s")).unwrap(), 30);
}
#[test]
fn parse_interval_minutes() {
assert_eq!(parse_interval(Some("5m")).unwrap(), 300);
}
#[test]
fn parse_interval_hours() {
assert_eq!(parse_interval(Some("1h")).unwrap(), 3600);
}
#[test]
fn parse_interval_none_defaults_to_1h() {
assert_eq!(parse_interval(None).unwrap(), 3600);
}
#[test]
fn parse_interval_plain_int() {
assert_eq!(parse_interval(Some("120")).unwrap(), 120);
}
#[test]
fn parse_interval_bad_value_errors() {
assert!(parse_interval(Some("abc")).is_err());
}
#[test]
fn parse_diff_kind_defaults_to_text() {
use nab::watch::DiffKind;
assert!(matches!(parse_diff_kind(None).unwrap(), DiffKind::Text));
assert!(matches!(
parse_diff_kind(Some("text")).unwrap(),
DiffKind::Text
));
}
#[test]
fn parse_diff_kind_semantic() {
use nab::watch::DiffKind;
assert!(matches!(
parse_diff_kind(Some("semantic")).unwrap(),
DiffKind::Semantic
));
}
#[test]
fn parse_diff_kind_unknown_errors() {
assert!(parse_diff_kind(Some("fuzzy")).is_err());
}
}