use crate::config::{AgentConfig, UsageSnapshot, UsageStats};
#[cfg(feature = "search")]
use crate::config::{ResearchOptions, SearchOptions};
use crate::error::{AgentError, AgentResult};
#[cfg(feature = "search")]
use crate::llm::TokenUsage;
use crate::llm::{CompletionOptions, CompletionResponse, LLMProvider, Message};
use crate::memory::AgentMemory;
use crate::tools::{
CustomTool, CustomToolRegistry, CustomToolResult, SpiderBrowserToolConfig,
SpiderCloudToolConfig,
};
use std::sync::Arc;
fn is_placeholder_api_key(key: &str) -> bool {
let trimmed = key.trim();
trimmed.is_empty()
|| trimmed.eq_ignore_ascii_case("YOUR_API_KEY")
|| trimmed.eq_ignore_ascii_case("YOUR-API-KEY")
|| trimmed.eq_ignore_ascii_case("API_KEY")
|| trimmed.eq_ignore_ascii_case("API-KEY")
}
use tokio::sync::Semaphore;
#[cfg(feature = "search")]
use crate::search::{SearchProvider, SearchResults};
#[cfg(feature = "chrome")]
use crate::browser::BrowserContext;
#[cfg(feature = "webdriver")]
use crate::webdriver::WebDriverContext;
#[cfg(feature = "fs")]
use crate::temp::TempStorage;
pub struct Agent {
llm: Option<Box<dyn LLMProvider>>,
client: reqwest::Client,
#[cfg(feature = "search")]
search_provider: Option<Box<dyn SearchProvider>>,
#[cfg(feature = "chrome")]
browser: Option<BrowserContext>,
#[cfg(feature = "webdriver")]
webdriver: Option<WebDriverContext>,
#[cfg(feature = "fs")]
temp_storage: Option<TempStorage>,
memory: AgentMemory,
llm_semaphore: Arc<Semaphore>,
config: AgentConfig,
usage: Arc<UsageStats>,
custom_tools: CustomToolRegistry,
}
impl Agent {
pub fn builder() -> AgentBuilder {
AgentBuilder::new()
}
pub fn client(&self) -> &reqwest::Client {
&self.client
}
#[cfg(feature = "search")]
pub async fn search(&self, query: &str) -> AgentResult<SearchResults> {
self.search_with_options(query, SearchOptions::default())
.await
}
#[cfg(feature = "search")]
pub async fn search_with_options(
&self,
query: &str,
options: SearchOptions,
) -> AgentResult<SearchResults> {
if let Some(limit) = self.usage.check_search_limit(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
let provider = self
.search_provider
.as_ref()
.ok_or(AgentError::NotConfigured("search provider"))?;
self.usage.increment_search_calls();
provider
.search(query, &options, &self.client)
.await
.map_err(AgentError::Search)
}
pub async fn prompt(&self, messages: Vec<Message>) -> AgentResult<String> {
let response = self.complete(messages).await?;
Ok(response.content)
}
pub async fn complete(&self, messages: Vec<Message>) -> AgentResult<CompletionResponse> {
if let Some(limit) = self.usage.check_llm_limit(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
if let Some(limit) = self.usage.check_token_limits(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
let llm = self
.llm
.as_ref()
.ok_or(AgentError::NotConfigured("LLM provider"))?;
let _permit = self
.llm_semaphore
.acquire()
.await
.map_err(|_| AgentError::Llm("Failed to acquire semaphore".to_string()))?;
let options = CompletionOptions {
temperature: self.config.temperature,
max_tokens: self.config.max_tokens,
json_mode: self.config.json_mode,
};
self.usage.increment_llm_calls();
let response = llm.complete(messages, &options, &self.client).await?;
self.usage.add_tokens(
response.usage.prompt_tokens as u64,
response.usage.completion_tokens as u64,
);
Ok(response)
}
pub async fn extract(&self, html: &str, prompt: &str) -> AgentResult<serde_json::Value> {
let cleaned_html = self.clean_html(html);
let truncated = self.truncate_html(&cleaned_html);
let messages = vec![
Message::system(
"You are a data extraction assistant. Extract the requested information from the HTML and return it as JSON.",
),
Message::user(format!(
"Extract the following from this HTML:\n\n{}\n\nHTML:\n{}",
prompt, truncated
)),
];
let response = self.complete(messages).await?;
self.parse_json(&response.content)
}
pub async fn extract_structured(
&self,
html: &str,
schema: &serde_json::Value,
) -> AgentResult<serde_json::Value> {
let cleaned_html = self.clean_html(html);
let truncated = self.truncate_html(&cleaned_html);
let messages = vec![
Message::system(
"You are a data extraction assistant. Extract data matching the provided JSON schema.",
),
Message::user(format!(
"Extract data matching this schema:\n{}\n\nFrom this HTML:\n{}",
serde_json::to_string_pretty(schema).unwrap_or_default(),
truncated
)),
];
let response = self.complete(messages).await?;
self.parse_json(&response.content)
}
pub async fn fetch(&self, url: &str) -> AgentResult<FetchResult> {
if let Some(limit) = self.usage.check_fetch_limit(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
self.usage.increment_fetch_calls();
let response = self.client.get(url).send().await?;
let status = response.status();
let headers = response.headers().clone();
let content_type = headers
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let html = response.text().await?;
Ok(FetchResult {
url: url.to_string(),
status: status.as_u16(),
content_type,
html,
})
}
#[cfg(feature = "search")]
pub async fn research(
&self,
topic: &str,
options: ResearchOptions,
) -> AgentResult<ResearchResult> {
let search_opts = options
.search_options
.clone()
.unwrap_or_else(|| SearchOptions::new().with_limit(options.max_pages.max(5)));
let search_results = self.search_with_options(topic, search_opts).await?;
if search_results.is_empty() {
return Ok(ResearchResult {
topic: topic.to_string(),
search_results,
extractions: Vec::new(),
summary: None,
usage: TokenUsage::default(),
});
}
let extraction_prompt = options.extraction_prompt.clone().unwrap_or_else(|| {
format!(
"Extract key information relevant to: {}. Include facts, data points, and insights.",
topic
)
});
let mut extractions = Vec::new();
let mut total_usage = TokenUsage::default();
let max_pages = options.max_pages.min(search_results.results.len());
for result in search_results.results.iter().take(max_pages) {
match self.fetch(&result.url).await {
Ok(fetch_result) => {
match self.extract(&fetch_result.html, &extraction_prompt).await {
Ok(extracted) => {
extractions.push(PageExtraction {
url: result.url.clone(),
title: result.title.clone(),
extracted,
});
}
Err(e) => {
log::warn!("Extraction failed for {}: {}", result.url, e);
}
}
}
Err(e) => {
log::warn!("Fetch failed for {}: {}", result.url, e);
}
}
}
let summary = if options.synthesize && !extractions.is_empty() {
match self.synthesize_research(topic, &extractions).await {
Ok((summary, usage)) => {
total_usage.accumulate(&usage);
Some(summary)
}
Err(e) => {
log::warn!("Synthesis failed: {}", e);
None
}
}
} else {
None
};
Ok(ResearchResult {
topic: topic.to_string(),
search_results,
extractions,
summary,
usage: total_usage,
})
}
#[cfg(feature = "search")]
async fn synthesize_research(
&self,
topic: &str,
extractions: &[PageExtraction],
) -> AgentResult<(String, TokenUsage)> {
let mut context = String::new();
for (i, extraction) in extractions.iter().enumerate() {
context.push_str(&format!(
"\n\nSource {} ({}): {}\n{}",
i + 1,
extraction.url,
extraction.title,
serde_json::to_string_pretty(&extraction.extracted).unwrap_or_default()
));
}
let messages = vec![
Message::system(
"You are a research synthesis assistant. Summarize the findings from multiple sources into a coherent response.",
),
Message::user(format!(
"Topic: {}\n\nSources:{}\n\nProvide a comprehensive summary of the findings, citing sources where appropriate. Return as JSON with a 'summary' field.",
topic, context
)),
];
let response = self.complete(messages).await?;
let json = self.parse_json(&response.content)?;
let summary = json
.get("summary")
.and_then(|v| v.as_str())
.unwrap_or(&response.content)
.to_string();
Ok((summary, response.usage))
}
pub fn memory_get(&self, key: &str) -> Option<serde_json::Value> {
self.memory.get(key)
}
pub fn memory_set(&self, key: &str, value: serde_json::Value) {
self.memory.set(key, value);
}
pub fn memory_clear(&self) {
self.memory.clear();
}
pub fn memory(&self) -> &AgentMemory {
&self.memory
}
pub fn usage(&self) -> UsageSnapshot {
self.usage.snapshot()
}
pub fn usage_stats(&self) -> &Arc<UsageStats> {
&self.usage
}
pub fn reset_usage(&self) {
self.usage.reset();
}
pub fn register_custom_tool(&self, tool: CustomTool) {
self.custom_tools.register(tool);
}
pub fn remove_custom_tool(&self, name: &str) -> bool {
self.custom_tools.remove(name).is_some()
}
pub fn list_custom_tools(&self) -> Vec<String> {
self.custom_tools.list()
}
pub fn has_custom_tool(&self, name: &str) -> bool {
self.custom_tools.contains(name)
}
pub async fn execute_custom_tool(
&self,
name: &str,
path: Option<&str>,
query: Option<&[(&str, &str)]>,
body: Option<&str>,
) -> AgentResult<CustomToolResult> {
if let Some(limit) = self.usage.check_custom_tool_limit(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
self.usage.increment_custom_tool_calls(name);
self.custom_tools
.execute(name, &self.client, path, query, body)
.await
}
pub async fn execute_custom_tool_json(
&self,
name: &str,
path: Option<&str>,
query: Option<&[(&str, &str)]>,
body: Option<&str>,
) -> AgentResult<serde_json::Value> {
let result = self.execute_custom_tool(name, path, query, body).await?;
serde_json::from_str(&result.body).map_err(AgentError::Json)
}
pub fn custom_tool_registry(&self) -> &CustomToolRegistry {
&self.custom_tools
}
pub fn register_spider_cloud(&self, config: SpiderCloudToolConfig) -> usize {
self.custom_tools.register_spider_cloud(&config)
}
pub fn register_spider_browser(&self, config: SpiderBrowserToolConfig) -> usize {
self.custom_tools.register_spider_browser(&config)
}
#[cfg(feature = "chrome")]
pub fn browser(&self) -> Option<&BrowserContext> {
self.browser.as_ref()
}
#[cfg(feature = "chrome")]
pub async fn navigate(&self, url: &str) -> AgentResult<()> {
if let Some(limit) = self.usage.check_webbrowser_limit(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
let browser = self
.browser
.as_ref()
.ok_or(AgentError::NotConfigured("browser"))?;
self.usage.increment_webbrowser_calls();
browser
.navigate(url)
.await
.map_err(|e| AgentError::Browser(e.to_string()))
}
#[cfg(feature = "chrome")]
pub async fn browser_html(&self) -> AgentResult<String> {
if let Some(limit) = self.usage.check_webbrowser_limit(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
let browser = self
.browser
.as_ref()
.ok_or(AgentError::NotConfigured("browser"))?;
self.usage.increment_webbrowser_calls();
browser
.html()
.await
.map_err(|e| AgentError::Browser(e.to_string()))
}
#[cfg(feature = "chrome")]
pub async fn screenshot(&self) -> AgentResult<Vec<u8>> {
if let Some(limit) = self.usage.check_webbrowser_limit(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
let browser = self
.browser
.as_ref()
.ok_or(AgentError::NotConfigured("browser"))?;
self.usage.increment_webbrowser_calls();
browser
.screenshot()
.await
.map_err(|e| AgentError::Browser(e.to_string()))
}
#[cfg(feature = "chrome")]
pub async fn new_page(&self) -> AgentResult<crate::browser::BrowserContext> {
if let Some(limit) = self.usage.check_webbrowser_limit(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
let browser = self
.browser
.as_ref()
.ok_or(AgentError::NotConfigured("browser"))?;
self.usage.increment_webbrowser_calls();
browser
.clone_page()
.await
.map_err(|e| AgentError::Browser(e.to_string()))
}
#[cfg(feature = "chrome")]
pub async fn new_page_with_url(
&self,
url: &str,
) -> AgentResult<std::sync::Arc<crate::browser::Page>> {
if let Some(limit) = self.usage.check_webbrowser_limit(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
let browser = self
.browser
.as_ref()
.ok_or(AgentError::NotConfigured("browser"))?;
self.usage.increment_webbrowser_calls();
browser
.new_page_with_url(url)
.await
.map_err(|e| AgentError::Browser(e.to_string()))
}
#[cfg(feature = "chrome")]
pub async fn click(&self, selector: &str) -> AgentResult<()> {
if let Some(limit) = self.usage.check_webbrowser_limit(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
let browser = self
.browser
.as_ref()
.ok_or(AgentError::NotConfigured("browser"))?;
self.usage.increment_webbrowser_calls();
browser
.click(selector)
.await
.map_err(|e| AgentError::Browser(e.to_string()))
}
#[cfg(feature = "chrome")]
pub async fn type_text(&self, selector: &str, text: &str) -> AgentResult<()> {
if let Some(limit) = self.usage.check_webbrowser_limit(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
let browser = self
.browser
.as_ref()
.ok_or(AgentError::NotConfigured("browser"))?;
self.usage.increment_webbrowser_calls();
browser
.type_text(selector, text)
.await
.map_err(|e| AgentError::Browser(e.to_string()))
}
#[cfg(feature = "chrome")]
pub async fn extract_page(&self, prompt: &str) -> AgentResult<serde_json::Value> {
let html = self.browser_html().await?;
self.extract(&html, prompt).await
}
#[cfg(feature = "webdriver")]
pub fn webdriver(&self) -> Option<&WebDriverContext> {
self.webdriver.as_ref()
}
#[cfg(feature = "webdriver")]
pub async fn webdriver_navigate(&self, url: &str) -> AgentResult<()> {
if let Some(limit) = self.usage.check_webbrowser_limit(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
let driver = self
.webdriver
.as_ref()
.ok_or(AgentError::NotConfigured("webdriver"))?;
self.usage.increment_webbrowser_calls();
driver
.navigate(url)
.await
.map_err(|e| AgentError::WebDriver(e.to_string()))
}
#[cfg(feature = "webdriver")]
pub async fn webdriver_html(&self) -> AgentResult<String> {
if let Some(limit) = self.usage.check_webbrowser_limit(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
let driver = self
.webdriver
.as_ref()
.ok_or(AgentError::NotConfigured("webdriver"))?;
self.usage.increment_webbrowser_calls();
driver
.html()
.await
.map_err(|e| AgentError::WebDriver(e.to_string()))
}
#[cfg(feature = "webdriver")]
pub async fn webdriver_screenshot(&self) -> AgentResult<Vec<u8>> {
if let Some(limit) = self.usage.check_webbrowser_limit(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
let driver = self
.webdriver
.as_ref()
.ok_or(AgentError::NotConfigured("webdriver"))?;
self.usage.increment_webbrowser_calls();
driver
.screenshot()
.await
.map_err(|e| AgentError::WebDriver(e.to_string()))
}
#[cfg(feature = "webdriver")]
pub async fn webdriver_click(&self, selector: &str) -> AgentResult<()> {
if let Some(limit) = self.usage.check_webbrowser_limit(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
let driver = self
.webdriver
.as_ref()
.ok_or(AgentError::NotConfigured("webdriver"))?;
self.usage.increment_webbrowser_calls();
driver
.click(selector)
.await
.map_err(|e| AgentError::WebDriver(e.to_string()))
}
#[cfg(feature = "webdriver")]
pub async fn webdriver_type_text(&self, selector: &str, text: &str) -> AgentResult<()> {
if let Some(limit) = self.usage.check_webbrowser_limit(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
let driver = self
.webdriver
.as_ref()
.ok_or(AgentError::NotConfigured("webdriver"))?;
self.usage.increment_webbrowser_calls();
driver
.type_text(selector, text)
.await
.map_err(|e| AgentError::WebDriver(e.to_string()))
}
#[cfg(feature = "webdriver")]
pub async fn webdriver_extract_page(&self, prompt: &str) -> AgentResult<serde_json::Value> {
let html = self.webdriver_html().await?;
self.extract(&html, prompt).await
}
#[cfg(feature = "webdriver")]
pub async fn webdriver_new_tab(&self) -> AgentResult<crate::webdriver::WindowHandle> {
if let Some(limit) = self.usage.check_webbrowser_limit(&self.config.limits) {
return Err(AgentError::LimitExceeded(limit));
}
let driver = self
.webdriver
.as_ref()
.ok_or(AgentError::NotConfigured("webdriver"))?;
self.usage.increment_webbrowser_calls();
driver
.new_tab()
.await
.map_err(|e| AgentError::WebDriver(e.to_string()))
}
#[cfg(feature = "fs")]
pub fn temp_storage(&self) -> Option<&TempStorage> {
self.temp_storage.as_ref()
}
#[cfg(feature = "fs")]
pub fn store_temp(&self, name: &str, data: &[u8]) -> AgentResult<std::path::PathBuf> {
let storage = self
.temp_storage
.as_ref()
.ok_or(AgentError::NotConfigured("temp storage"))?;
storage.store_bytes(name, data).map_err(AgentError::Io)
}
#[cfg(feature = "fs")]
pub fn store_temp_json(
&self,
name: &str,
data: &serde_json::Value,
) -> AgentResult<std::path::PathBuf> {
let storage = self
.temp_storage
.as_ref()
.ok_or(AgentError::NotConfigured("temp storage"))?;
storage.store_json(name, data).map_err(AgentError::Io)
}
#[cfg(feature = "fs")]
pub fn read_temp(&self, name: &str) -> AgentResult<Vec<u8>> {
let storage = self
.temp_storage
.as_ref()
.ok_or(AgentError::NotConfigured("temp storage"))?;
storage.read_bytes(name).map_err(AgentError::Io)
}
#[cfg(feature = "fs")]
pub fn read_temp_json(&self, name: &str) -> AgentResult<serde_json::Value> {
let storage = self
.temp_storage
.as_ref()
.ok_or(AgentError::NotConfigured("temp storage"))?;
storage.read_json(name).map_err(AgentError::Io)
}
fn clean_html(&self, html: &str) -> String {
use crate::config::HtmlCleaningMode;
match self.config.html_cleaning {
HtmlCleaningMode::Raw => html.to_string(),
HtmlCleaningMode::Minimal => {
remove_tags(html, &["script"])
}
HtmlCleaningMode::Default => {
remove_tags(html, &["script", "style", "noscript"])
}
HtmlCleaningMode::Aggressive => {
remove_tags(
html,
&[
"script", "style", "noscript", "svg", "canvas", "video", "audio", "iframe",
],
)
}
}
}
fn truncate_html<'a>(&self, html: &'a str) -> &'a str {
if html.len() <= self.config.html_max_bytes {
html
} else {
let truncated = &html[..self.config.html_max_bytes];
if let Some(pos) = truncated.rfind('<') {
&truncated[..pos]
} else {
truncated
}
}
}
fn parse_json(&self, content: &str) -> AgentResult<serde_json::Value> {
if let Ok(json) = serde_json::from_str(content) {
return Ok(json);
}
if let Some(start) = content.find("```json") {
let start = start + 7;
if let Some(end) = content[start..].find("```") {
let json_str = content[start..start + end].trim();
if let Ok(json) = serde_json::from_str(json_str) {
return Ok(json);
}
}
}
if let Some(start) = content.find("```") {
let start = start + 3;
let start = content[start..]
.find('\n')
.map(|n| start + n + 1)
.unwrap_or(start);
if let Some(end) = content[start..].find("```") {
let json_str = content[start..start + end].trim();
if let Ok(json) = serde_json::from_str(json_str) {
return Ok(json);
}
}
}
if let Some(start) = content.find('{') {
if let Some(end) = content.rfind('}') {
let json_str = &content[start..=end];
if let Ok(json) = serde_json::from_str(json_str) {
return Ok(json);
}
}
}
Err(AgentError::Json(
serde_json::from_str::<serde_json::Value>(content).unwrap_err(),
))
}
}
fn remove_tags(html: &str, tags: &[&str]) -> String {
let mut result = html.to_string();
for tag in tags {
let open = format!("<{}", tag);
let close = format!("</{}>", tag);
let mut out = String::with_capacity(result.len());
let lower = result.to_lowercase();
let mut pos = 0;
while pos < result.len() {
if let Some(rel_start) = lower[pos..].find(&open) {
let start = pos + rel_start;
out.push_str(&result[pos..start]);
if let Some(rel_end) = lower[start..].find(&close) {
pos = start + rel_end + close.len();
} else {
out.push_str(&result[start..]);
pos = result.len();
}
} else {
out.push_str(&result[pos..]);
break;
}
}
result = out;
}
result
}
#[derive(Debug, Clone)]
pub struct FetchResult {
pub url: String,
pub status: u16,
pub content_type: String,
pub html: String,
}
#[cfg(feature = "search")]
#[derive(Debug, Clone)]
pub struct ResearchResult {
pub topic: String,
pub search_results: SearchResults,
pub extractions: Vec<PageExtraction>,
pub summary: Option<String>,
pub usage: TokenUsage,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct PageExtraction {
pub url: String,
pub title: String,
pub extracted: serde_json::Value,
}
pub struct AgentBuilder {
config: AgentConfig,
llm: Option<Box<dyn LLMProvider>>,
spider_cloud: Option<SpiderCloudToolConfig>,
spider_browser: Option<SpiderBrowserToolConfig>,
proxies: Option<Vec<String>>,
client: Option<reqwest::Client>,
#[cfg(feature = "search")]
search_provider: Option<Box<dyn SearchProvider>>,
#[cfg(feature = "chrome")]
browser: Option<BrowserContext>,
#[cfg(feature = "webdriver")]
webdriver: Option<WebDriverContext>,
#[cfg(feature = "fs")]
enable_temp_storage: bool,
}
impl AgentBuilder {
pub fn new() -> Self {
Self {
config: AgentConfig::default(),
llm: None,
spider_cloud: None,
spider_browser: None,
proxies: None,
client: None,
#[cfg(feature = "search")]
search_provider: None,
#[cfg(feature = "chrome")]
browser: None,
#[cfg(feature = "webdriver")]
webdriver: None,
#[cfg(feature = "fs")]
enable_temp_storage: false,
}
}
pub fn with_config(mut self, config: AgentConfig) -> Self {
self.config = config;
self
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.config.system_prompt = Some(prompt.into());
self
}
pub fn with_max_concurrent_llm_calls(mut self, n: usize) -> Self {
self.config.max_concurrent_llm_calls = n;
self
}
#[cfg(feature = "openai")]
pub fn with_openai(mut self, api_key: impl Into<String>, model: impl Into<String>) -> Self {
self.llm = Some(Box::new(crate::llm::OpenAIProvider::new(api_key, model)));
self
}
#[cfg(feature = "openai")]
pub fn with_openai_compatible(
mut self,
api_url: impl Into<String>,
api_key: impl Into<String>,
model: impl Into<String>,
) -> Self {
self.llm = Some(Box::new(
crate::llm::OpenAIProvider::new(api_key, model).with_api_url(api_url),
));
self
}
#[cfg(feature = "openai")]
pub fn with_openai_responses(
mut self,
api_key: impl Into<String>,
model: impl Into<String>,
) -> Self {
self.llm = Some(Box::new(
crate::llm::OpenAIProvider::new(api_key, model).with_responses_api(),
));
self
}
#[cfg(feature = "openai")]
pub fn with_openai_compatible_responses(
mut self,
api_url: impl Into<String>,
api_key: impl Into<String>,
model: impl Into<String>,
) -> Self {
self.llm = Some(Box::new(
crate::llm::OpenAIProvider::new(api_key, model)
.with_responses_api()
.with_api_url(api_url),
));
self
}
pub fn with_spider_cloud(mut self, api_key: impl Into<String>) -> Self {
let key = api_key.into();
if is_placeholder_api_key(&key) {
log::warn!("Spider Cloud API key looks like a placeholder — skipping. Get a real key at https://spider.cloud");
return self;
}
self.spider_cloud = Some(SpiderCloudToolConfig::new(key));
self
}
pub fn with_spider_cloud_config(mut self, config: SpiderCloudToolConfig) -> Self {
self.spider_cloud = Some(config);
self
}
pub fn with_spider_browser(mut self, api_key: impl Into<String>) -> Self {
let key = api_key.into();
if is_placeholder_api_key(&key) {
log::warn!("Spider Browser Cloud API key looks like a placeholder — skipping. Get a real key at https://spider.cloud");
return self;
}
self.spider_browser = Some(SpiderBrowserToolConfig::new(key));
self
}
pub fn with_spider_browser_config(mut self, config: SpiderBrowserToolConfig) -> Self {
self.spider_browser = Some(config);
self
}
pub fn with_timeout(mut self, timeout: Option<std::time::Duration>) -> Self {
match timeout {
Some(d) => self.config.timeout = d,
None => self.config.timeout = std::time::Duration::MAX,
}
self
}
pub fn with_client(mut self, client: reqwest::Client) -> Self {
self.client = Some(client);
self
}
pub fn with_proxies(mut self, proxies: Vec<String>) -> Self {
if !proxies.is_empty() {
self.proxies = Some(proxies);
}
self
}
pub fn with_proxy(mut self, proxy: impl Into<String>) -> Self {
self.proxies = Some(vec![proxy.into()]);
self
}
#[cfg(feature = "search_serper")]
pub fn with_search_serper(mut self, api_key: impl Into<String>) -> Self {
self.search_provider = Some(Box::new(crate::search::SerperProvider::new(api_key)));
self
}
#[cfg(feature = "search_brave")]
pub fn with_search_brave(mut self, api_key: impl Into<String>) -> Self {
self.search_provider = Some(Box::new(crate::search::BraveProvider::new(api_key)));
self
}
#[cfg(feature = "search_bing")]
pub fn with_search_bing(mut self, api_key: impl Into<String>) -> Self {
self.search_provider = Some(Box::new(crate::search::BingProvider::new(api_key)));
self
}
#[cfg(feature = "search_tavily")]
pub fn with_search_tavily(mut self, api_key: impl Into<String>) -> Self {
self.search_provider = Some(Box::new(crate::search::TavilyProvider::new(api_key)));
self
}
#[cfg(feature = "chrome")]
pub fn with_browser(mut self, browser: BrowserContext) -> Self {
self.browser = Some(browser);
self
}
#[cfg(feature = "chrome")]
pub fn with_browser_page(
mut self,
browser: std::sync::Arc<crate::browser::Browser>,
page: std::sync::Arc<crate::browser::Page>,
) -> Self {
self.browser = Some(BrowserContext::new(browser, page));
self
}
#[cfg(feature = "fs")]
pub fn with_temp_storage(mut self) -> Self {
self.enable_temp_storage = true;
self
}
#[cfg(feature = "webdriver")]
pub fn with_webdriver(mut self, webdriver: WebDriverContext) -> Self {
self.webdriver = Some(webdriver);
self
}
#[cfg(feature = "webdriver")]
pub fn with_webdriver_driver(
mut self,
driver: std::sync::Arc<crate::webdriver::WebDriver>,
) -> Self {
self.webdriver = Some(WebDriverContext::new(driver));
self
}
pub fn build(self) -> AgentResult<Agent> {
let client = if let Some(client) = self.client {
client
} else {
let mut builder = reqwest::Client::builder();
if self.config.timeout != std::time::Duration::MAX {
builder = builder.timeout(self.config.timeout);
}
if let Some(proxies) = &self.proxies {
for proxy_url in proxies {
let proxy = reqwest::Proxy::all(proxy_url).map_err(AgentError::Http)?;
builder = builder.proxy(proxy);
}
}
builder.build().map_err(AgentError::Http)?
};
let semaphore = Arc::new(Semaphore::new(self.config.max_concurrent_llm_calls));
#[cfg(feature = "fs")]
let temp_storage = if self.enable_temp_storage {
Some(TempStorage::new().map_err(AgentError::Io)?)
} else {
None
};
let custom_tools = CustomToolRegistry::new();
if let Some(cfg) = self.spider_cloud.as_ref() {
custom_tools.register_spider_cloud(cfg);
}
if let Some(cfg) = self.spider_browser.as_ref() {
custom_tools.register_spider_browser(cfg);
}
Ok(Agent {
llm: self.llm,
client,
#[cfg(feature = "search")]
search_provider: self.search_provider,
#[cfg(feature = "chrome")]
browser: self.browser,
#[cfg(feature = "webdriver")]
webdriver: self.webdriver,
#[cfg(feature = "fs")]
temp_storage,
memory: AgentMemory::new(),
llm_semaphore: semaphore,
config: self.config,
usage: Arc::new(UsageStats::new()),
custom_tools,
})
}
}
impl Default for AgentBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_registers_spider_cloud_default_routes() {
let agent = Agent::builder()
.with_spider_cloud("sk_spider_cloud")
.build()
.expect("agent should build");
let tools = agent.list_custom_tools();
assert!(tools.contains(&"spider_cloud_crawl".to_string()));
assert!(tools.contains(&"spider_cloud_scrape".to_string()));
assert!(tools.contains(&"spider_cloud_search".to_string()));
assert!(tools.contains(&"spider_cloud_links".to_string()));
assert!(tools.contains(&"spider_cloud_transform".to_string()));
assert!(tools.contains(&"spider_cloud_unblocker".to_string()));
assert!(!tools.contains(&"spider_cloud_ai_scrape".to_string()));
}
#[test]
fn test_builder_registers_spider_cloud_ai_routes_when_enabled() {
let cfg = SpiderCloudToolConfig::new("sk_spider_cloud").with_enable_ai_routes(true);
let agent = Agent::builder()
.with_spider_cloud_config(cfg)
.build()
.expect("agent should build");
let tools = agent.list_custom_tools();
assert!(tools.contains(&"spider_cloud_ai_crawl".to_string()));
assert!(tools.contains(&"spider_cloud_ai_scrape".to_string()));
assert!(tools.contains(&"spider_cloud_ai_search".to_string()));
assert!(tools.contains(&"spider_cloud_ai_browser".to_string()));
assert!(tools.contains(&"spider_cloud_ai_links".to_string()));
}
#[test]
fn test_builder_with_single_proxy() {
let agent = Agent::builder()
.with_proxy("http://proxy.example.com:8080")
.build()
.expect("agent with proxy should build");
drop(agent);
}
#[test]
fn test_builder_with_multiple_proxies() {
let agent = Agent::builder()
.with_proxies(vec![
"http://proxy1.example.com:8080".into(),
"http://proxy2.example.com:9090".into(),
])
.build()
.expect("agent with multiple proxies should build");
drop(agent);
}
#[test]
fn test_builder_with_socks5_proxy() {
let agent = Agent::builder()
.with_proxy("socks5://127.0.0.1:1080")
.build()
.expect("agent with socks5 proxy should build");
drop(agent);
}
#[test]
fn test_builder_with_empty_proxies_no_op() {
let agent = Agent::builder()
.with_proxies(vec![])
.build()
.expect("agent with empty proxies should build");
drop(agent);
}
#[test]
fn test_builder_with_invalid_proxy_returns_error() {
let result = Agent::builder()
.with_proxy("not a valid url at all ://")
.build();
assert!(result.is_err(), "invalid proxy URL should fail at build");
}
#[test]
fn test_builder_no_proxies_by_default() {
let agent = Agent::builder()
.build()
.expect("default agent should build");
drop(agent);
}
#[test]
fn test_builder_with_custom_client() {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(120))
.build()
.expect("custom client should build");
let agent = Agent::builder()
.with_client(client)
.build()
.expect("agent with custom client should build");
let _ = agent.client();
}
#[test]
fn test_builder_with_custom_client_ignores_proxy_and_timeout() {
let client = reqwest::Client::new();
let agent = Agent::builder()
.with_client(client)
.with_proxy("http://proxy.example.com:8080")
.with_timeout(Some(std::time::Duration::from_secs(999)))
.build()
.expect("custom client should take precedence");
drop(agent);
}
#[test]
fn test_builder_with_timeout_some() {
let agent = Agent::builder()
.with_timeout(Some(std::time::Duration::from_secs(300)))
.build()
.expect("agent with 300s timeout should build");
drop(agent);
}
#[test]
fn test_builder_with_timeout_none_infinite() {
let agent = Agent::builder()
.with_timeout(None)
.build()
.expect("agent with no timeout should build");
drop(agent);
}
#[test]
fn test_client_accessor_returns_reference() {
let agent = Agent::builder().build().expect("agent should build");
let _c1 = agent.client();
let _c2 = agent.client();
}
#[test]
fn test_builder_registers_spider_browser_default_tools() {
let agent = Agent::builder()
.with_spider_browser("sk_browser_key")
.build()
.expect("agent should build");
let tools = agent.list_custom_tools();
assert!(tools.contains(&"spider_browser_navigate".to_string()));
assert!(tools.contains(&"spider_browser_html".to_string()));
assert!(tools.contains(&"spider_browser_screenshot".to_string()));
assert!(tools.contains(&"spider_browser_evaluate".to_string()));
assert!(tools.contains(&"spider_browser_click".to_string()));
assert!(tools.contains(&"spider_browser_fill".to_string()));
assert!(tools.contains(&"spider_browser_wait".to_string()));
}
#[test]
fn test_builder_spider_browser_with_stealth_and_country() {
let cfg = SpiderBrowserToolConfig::new("sk_key")
.with_stealth(true)
.with_country("us");
assert_eq!(
cfg.connection_url(),
"wss://browser.spider.cloud/v1/browser?token=sk_key&stealth=true&country=us"
);
let agent = Agent::builder()
.with_spider_browser_config(cfg)
.build()
.expect("agent should build");
assert!(agent.has_custom_tool("spider_browser_navigate"));
}
#[test]
fn test_builder_spider_cloud_and_browser_together() {
let agent = Agent::builder()
.with_spider_cloud("cloud-key")
.with_spider_browser("browser-key")
.build()
.expect("agent should build");
assert!(agent.has_custom_tool("spider_cloud_crawl"));
assert!(agent.has_custom_tool("spider_browser_navigate"));
}
}