use crate::client::{fetch_with_options, FetchOptions};
use crate::dns::DnsPolicy;
use crate::error::FetchError;
use crate::fetchers::FetcherRegistry;
use crate::file_saver::FileSaver;
use crate::types::{FetchRequest, FetchResponse};
use crate::{build_llmtxt, TOOL_DESCRIPTION_BASE, TOOL_DESCRIPTION_SAVE};
use schemars::schema_for;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolStatus {
pub phase: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub percent_complete: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub eta_ms: Option<u64>,
}
impl ToolStatus {
pub fn new(phase: impl Into<String>) -> Self {
Self {
phase: phase.into(),
message: None,
percent_complete: None,
eta_ms: None,
}
}
pub fn with_message(mut self, message: impl Into<String>) -> Self {
self.message = Some(message.into());
self
}
pub fn with_percent(mut self, percent: f32) -> Self {
self.percent_complete = Some(percent);
self
}
pub fn with_eta(mut self, eta_ms: u64) -> Self {
self.eta_ms = Some(eta_ms);
self
}
}
#[derive(Debug, Clone, Default)]
pub struct ToolBuilder {
enable_markdown: bool,
enable_text: bool,
user_agent: Option<String>,
allow_prefixes: Vec<String>,
block_prefixes: Vec<String>,
dns_policy: DnsPolicy,
max_body_size: Option<usize>,
enable_save_to_file: bool,
}
impl ToolBuilder {
pub fn new() -> Self {
Self {
enable_markdown: true,
enable_text: true,
..Default::default()
}
}
pub fn enable_markdown(mut self, enable: bool) -> Self {
self.enable_markdown = enable;
self
}
pub fn enable_text(mut self, enable: bool) -> Self {
self.enable_text = enable;
self
}
pub fn user_agent(mut self, ua: impl Into<String>) -> Self {
self.user_agent = Some(ua.into());
self
}
pub fn allow_prefix(mut self, prefix: impl Into<String>) -> Self {
self.allow_prefixes.push(prefix.into());
self
}
pub fn block_prefix(mut self, prefix: impl Into<String>) -> Self {
self.block_prefixes.push(prefix.into());
self
}
pub fn max_body_size(mut self, size: usize) -> Self {
self.max_body_size = Some(size);
self
}
pub fn enable_save_to_file(mut self, enable: bool) -> Self {
self.enable_save_to_file = enable;
self
}
pub fn block_private_ips(mut self, block: bool) -> Self {
self.dns_policy = if block {
DnsPolicy::block_private_ips()
} else {
DnsPolicy::allow_all()
};
self
}
pub fn build(self) -> Tool {
Tool {
enable_markdown: self.enable_markdown,
enable_text: self.enable_text,
user_agent: self.user_agent,
allow_prefixes: self.allow_prefixes,
block_prefixes: self.block_prefixes,
dns_policy: self.dns_policy,
max_body_size: self.max_body_size,
enable_save_to_file: self.enable_save_to_file,
}
}
}
#[derive(Debug, Clone)]
pub struct Tool {
enable_markdown: bool,
enable_text: bool,
user_agent: Option<String>,
allow_prefixes: Vec<String>,
block_prefixes: Vec<String>,
dns_policy: DnsPolicy,
max_body_size: Option<usize>,
enable_save_to_file: bool,
}
impl Default for Tool {
fn default() -> Self {
ToolBuilder::new().build()
}
}
impl Tool {
pub fn builder() -> ToolBuilder {
ToolBuilder::new()
}
pub fn description(&self) -> String {
let mut s = TOOL_DESCRIPTION_BASE.to_string();
if self.enable_save_to_file {
s.push_str(TOOL_DESCRIPTION_SAVE);
}
s
}
pub fn system_prompt(&self) -> String {
String::new()
}
pub fn llmtxt(&self) -> String {
build_llmtxt(self.enable_save_to_file)
}
pub fn input_schema(&self) -> serde_json::Value {
let schema = schema_for!(FetchRequest);
let mut value = serde_json::to_value(schema).unwrap_or_default();
if let Some(props) = value.get_mut("properties").and_then(|p| p.as_object_mut()) {
if !self.enable_markdown {
props.remove("as_markdown");
}
if !self.enable_text {
props.remove("as_text");
}
if !self.enable_save_to_file {
props.remove("save_to_file");
}
}
value
}
pub fn output_schema(&self) -> serde_json::Value {
let schema = schema_for!(FetchResponse);
serde_json::to_value(schema).unwrap_or_default()
}
pub async fn execute(&self, req: FetchRequest) -> Result<FetchResponse, FetchError> {
fetch_with_options(req, self.build_options()).await
}
pub async fn execute_with_status<F>(
&self,
req: FetchRequest,
mut status_callback: F,
) -> Result<FetchResponse, FetchError>
where
F: FnMut(ToolStatus),
{
status_callback(ToolStatus::new("validate").with_percent(0.0));
if req.url.is_empty() {
return Err(FetchError::MissingUrl);
}
if !req.url.starts_with("http://") && !req.url.starts_with("https://") {
return Err(FetchError::InvalidUrlScheme);
}
status_callback(ToolStatus::new("connect").with_percent(10.0));
status_callback(ToolStatus::new("fetch").with_percent(20.0));
let result = fetch_with_options(req, self.build_options()).await;
status_callback(ToolStatus::new("complete").with_percent(100.0));
result
}
fn build_options(&self) -> FetchOptions {
FetchOptions {
user_agent: self.user_agent.clone(),
allow_prefixes: self.allow_prefixes.clone(),
block_prefixes: self.block_prefixes.clone(),
enable_markdown: self.enable_markdown,
enable_text: self.enable_text,
dns_policy: self.dns_policy.clone(),
max_body_size: self.max_body_size,
enable_save_to_file: self.enable_save_to_file,
}
}
pub async fn execute_with_saver(
&self,
req: FetchRequest,
saver: Option<&dyn FileSaver>,
) -> Result<FetchResponse, FetchError> {
if let Some(path) = &req.save_to_file {
if !self.enable_save_to_file {
return Err(FetchError::SaverNotAvailable);
}
let saver = saver.ok_or(FetchError::SaverNotAvailable)?;
saver
.validate_path(path)
.await
.map_err(|e| FetchError::SaveError(e.to_string()))?;
let options = self.build_options();
let registry = FetcherRegistry::with_defaults();
registry.fetch_to_file(req, options, saver).await
} else {
self.execute(req).await
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_builder() {
let tool = Tool::builder()
.enable_markdown(false)
.enable_text(true)
.user_agent("TestAgent/1.0")
.allow_prefix("https://allowed.com")
.block_prefix("https://blocked.com")
.build();
assert!(!tool.enable_markdown);
assert!(tool.enable_text);
assert_eq!(tool.user_agent, Some("TestAgent/1.0".to_string()));
assert_eq!(tool.allow_prefixes, vec!["https://allowed.com"]);
assert_eq!(tool.block_prefixes, vec!["https://blocked.com"]);
assert!(tool.dns_policy.block_private);
}
#[test]
fn test_tool_builder_opt_out_private_ip_blocking() {
let tool = Tool::builder().block_private_ips(false).build();
assert!(!tool.dns_policy.block_private);
}
#[test]
fn test_tool_description() {
let tool = Tool::default();
assert!(!tool.description().is_empty());
assert!(tool.system_prompt().is_empty());
assert!(!tool.llmtxt().is_empty());
assert!(!tool.description().contains("save_to_file"));
assert!(!tool.llmtxt().contains("save_to_file"));
let tool = Tool::builder().enable_save_to_file(true).build();
assert!(tool.description().contains("save_to_file"));
assert!(tool.llmtxt().contains("save_to_file"));
}
#[test]
fn test_tool_schemas() {
let tool = Tool::default();
let input_schema = tool.input_schema();
let output_schema = tool.output_schema();
assert!(input_schema["properties"]["url"].is_object());
assert!(output_schema["properties"]["url"].is_object());
assert!(output_schema["properties"]["status_code"].is_object());
}
#[test]
fn test_tool_schema_feature_gating() {
let tool = Tool::builder()
.enable_markdown(false)
.enable_text(false)
.build();
let schema = tool.input_schema();
if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
assert!(!props.contains_key("as_markdown"));
assert!(!props.contains_key("as_text"));
}
}
#[test]
fn test_tool_status() {
let status = ToolStatus::new("fetch")
.with_message("Fetching URL")
.with_percent(50.0)
.with_eta(5000);
assert_eq!(status.phase, "fetch");
assert_eq!(status.message, Some("Fetching URL".to_string()));
assert_eq!(status.percent_complete, Some(50.0));
assert_eq!(status.eta_ms, Some(5000));
}
}