use serde::{Deserialize, Serialize};
use thiserror::Error;
pub type DaedraResult<T> = Result<T, DaedraError>;
#[derive(Error, Debug)]
pub enum DaedraError {
#[error("HTTP request failed: {0}")]
HttpError(#[from] reqwest::Error),
#[error("Invalid URL: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Search failed: {0}")]
SearchError(String),
#[error("Failed to fetch page: {0}")]
FetchError(String),
#[error("Invalid arguments: {0}")]
InvalidArguments(String),
#[error("Server error: {0}")]
ServerError(String),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
#[error("Content extraction failed: {0}")]
ExtractionError(String),
#[error("Unsupported content type: {0}")]
UnsupportedContentType(String),
#[error("Rate limit exceeded, please try again later")]
RateLimitExceeded,
#[error("Bot protection detected on target page")]
BotProtectionDetected,
#[error("Operation timed out")]
Timeout,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "UPPERCASE")]
pub enum SafeSearchLevel {
Off,
#[default]
Moderate,
Strict,
}
impl SafeSearchLevel {
pub fn to_ddg_value(&self) -> i32 {
match self {
SafeSearchLevel::Off => -2,
SafeSearchLevel::Moderate => -1,
SafeSearchLevel::Strict => 1,
}
}
}
impl std::fmt::Display for SafeSearchLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SafeSearchLevel::Off => write!(f, "OFF"),
SafeSearchLevel::Moderate => write!(f, "MODERATE"),
SafeSearchLevel::Strict => write!(f, "STRICT"),
}
}
}
impl std::str::FromStr for SafeSearchLevel {
type Err = DaedraError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_uppercase().as_str() {
"OFF" => Ok(SafeSearchLevel::Off),
"MODERATE" => Ok(SafeSearchLevel::Moderate),
"STRICT" => Ok(SafeSearchLevel::Strict),
_ => Err(DaedraError::InvalidArguments(format!(
"Invalid safe search level: {}",
s
))),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchOptions {
#[serde(default = "default_region")]
pub region: String,
#[serde(default)]
pub safe_search: SafeSearchLevel,
#[serde(default = "default_num_results")]
pub num_results: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub time_range: Option<String>,
}
impl Default for SearchOptions {
fn default() -> Self {
Self {
region: "wt-wt".to_string(),
safe_search: SafeSearchLevel::Moderate,
num_results: 10,
time_range: None,
}
}
}
fn default_region() -> String {
"wt-wt".to_string() }
fn default_num_results() -> usize {
10
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchArgs {
pub query: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<SearchOptions>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisitPageArgs {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub selector: Option<String>,
#[serde(default)]
pub include_images: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum ContentType {
Documentation,
Social,
Article,
Forum,
Video,
Shopping,
#[default]
Other,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResultMetadata {
#[serde(rename = "type")]
pub content_type: ContentType,
pub source: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub favicon: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub published_date: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub title: String,
pub url: String,
pub description: String,
pub metadata: ResultMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryAnalysis {
pub language: String,
pub topics: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchContext {
pub region: String,
pub safe_search: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub num_results: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchMetadata {
pub query: String,
pub timestamp: String,
pub result_count: usize,
pub search_context: SearchContext,
pub query_analysis: QueryAnalysis,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResponse {
#[serde(rename = "type")]
pub response_type: String,
pub data: Vec<SearchResult>,
pub metadata: SearchMetadata,
}
impl SearchResponse {
pub fn new(query: String, results: Vec<SearchResult>, options: &SearchOptions) -> Self {
let timestamp = chrono::Utc::now().to_rfc3339();
let result_count = results.len();
let language = detect_language(&query);
let topics = detect_topics(&results);
Self {
response_type: "search_results".to_string(),
data: results,
metadata: SearchMetadata {
query,
timestamp,
result_count,
search_context: SearchContext {
region: options.region.clone(),
safe_search: options.safe_search.to_string(),
num_results: Some(options.num_results),
},
query_analysis: QueryAnalysis { language, topics },
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PageContent {
pub url: String,
pub title: String,
pub content: String,
pub timestamp: String,
pub word_count: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub links: Option<Vec<PageLink>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PageLink {
pub text: String,
pub url: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrawlArgs {
pub root_url: String,
#[serde(default = "default_crawl_max_pages")]
pub max_pages: usize,
#[serde(default = "default_crawl_concurrency")]
pub concurrency: usize,
}
fn default_crawl_max_pages() -> usize { 25 }
fn default_crawl_concurrency() -> usize { 4 }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrawledPage {
pub url: String,
pub title: String,
pub markdown: String,
pub links: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrawlError {
pub url: String,
pub error: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrawlSummary {
pub requested: usize,
pub fetched: usize,
pub failed: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrawlResult {
pub root_url: String,
pub sitemap_found: bool,
pub summary: CrawlSummary,
pub pages: Vec<CrawledPage>,
pub errors: Vec<CrawlError>,
}
struct LangRange {
lang: &'static str,
ranges: &'static [(char, char)],
}
const LANG_RANGES: &[LangRange] = &[
LangRange {
lang: "zh",
ranges: &[('\u{4e00}', '\u{9fff}')],
},
LangRange {
lang: "ja",
ranges: &[('\u{3040}', '\u{30ff}')],
},
LangRange {
lang: "ko",
ranges: &[('\u{ac00}', '\u{d7af}')],
},
LangRange {
lang: "ru",
ranges: &[('\u{0400}', '\u{04ff}')],
},
LangRange {
lang: "ar",
ranges: &[('\u{0600}', '\u{06ff}')],
},
];
fn detect_language(query: &str) -> String {
for range in LANG_RANGES {
if query
.chars()
.any(|c| range.ranges.iter().any(|&(s, e)| c >= s && c <= e))
{
return range.lang.to_string();
}
}
"en".to_string()
}
struct TopicRule {
topic: &'static str,
url_patterns: &'static [&'static str],
title_patterns: &'static [&'static str],
content_type: Option<ContentType>,
}
const TOPIC_RULES: &[TopicRule] = &[
TopicRule {
topic: "technology",
url_patterns: &["github.com", "stackoverflow.com", "gitlab.com"],
title_patterns: &["programming", "code"],
content_type: None,
},
TopicRule {
topic: "documentation",
url_patterns: &["docs.", "/docs/", "/documentation/"],
title_patterns: &["documentation", "api reference"],
content_type: None,
},
TopicRule {
topic: "news",
url_patterns: &["news.", "/news/"],
title_patterns: &[],
content_type: Some(ContentType::Article),
},
TopicRule {
topic: "academic",
url_patterns: &[".edu", "arxiv.org", "scholar.google"],
title_patterns: &["research", "study"],
content_type: None,
},
];
fn detect_topics(results: &[SearchResult]) -> Vec<String> {
let mut topics = std::collections::HashSet::new();
for result in results {
let lower_url = result.url.to_lowercase();
let lower_title = result.title.to_lowercase();
for rule in TOPIC_RULES {
let url_match = rule.url_patterns.iter().any(|p| lower_url.contains(p));
let title_match = rule
.title_patterns
.iter()
.any(|p| lower_title.contains(p));
let type_match = rule
.content_type
.map_or(true, |ct| result.metadata.content_type == ct);
if url_match || title_match || type_match {
topics.insert(rule.topic.to_string());
}
}
}
topics.into_iter().collect()
}
pub fn search_args_schema() -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query string"
},
"options": {
"type": "object",
"description": "Optional search configuration",
"properties": {
"region": {
"type": "string",
"description": "Region for search results (e.g., 'us-en', 'wt-wt' for worldwide)",
"default": "wt-wt"
},
"safe_search": {
"type": "string",
"enum": ["OFF", "MODERATE", "STRICT"],
"description": "Safe search filtering level",
"default": "MODERATE"
},
"num_results": {
"type": "integer",
"description": "Maximum number of results to return",
"default": 10,
"minimum": 1,
"maximum": 50
},
"time_range": {
"type": "string",
"description": "Time range filter (d=day, w=week, m=month, y=year)"
}
}
}
},
"required": ["query"]
})
}
pub fn visit_page_args_schema() -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"format": "uri",
"description": "URL of the page to visit"
},
"selector": {
"type": "string",
"description": "Optional CSS selector to target specific content"
},
"include_images": {
"type": "boolean",
"description": "Whether to include image references in the response",
"default": false
}
},
"required": ["url"]
})
}
pub fn crawl_args_schema() -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"root_url": {
"type": "string",
"format": "uri",
"description": "Root URL of the site to crawl (sitemap or homepage)"
},
"max_pages": {
"type": "integer",
"description": "Maximum number of pages to fetch (default: 25)",
"default": 25
},
"concurrency": {
"type": "integer",
"description": "Maximum concurrent fetches (default: 4)",
"default": 4
}
},
"required": ["root_url"]
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_safe_search_level_parsing() {
assert_eq!(
"OFF".parse::<SafeSearchLevel>().unwrap(),
SafeSearchLevel::Off
);
assert_eq!(
"MODERATE".parse::<SafeSearchLevel>().unwrap(),
SafeSearchLevel::Moderate
);
assert_eq!(
"STRICT".parse::<SafeSearchLevel>().unwrap(),
SafeSearchLevel::Strict
);
assert_eq!(
"moderate".parse::<SafeSearchLevel>().unwrap(),
SafeSearchLevel::Moderate
);
}
#[test]
fn test_safe_search_ddg_value() {
assert_eq!(SafeSearchLevel::Off.to_ddg_value(), -2);
assert_eq!(SafeSearchLevel::Moderate.to_ddg_value(), -1);
assert_eq!(SafeSearchLevel::Strict.to_ddg_value(), 1);
}
#[test]
fn test_language_detection() {
assert_eq!(detect_language("hello world"), "en");
assert_eq!(detect_language("你好世界"), "zh");
assert_eq!(detect_language("こんにちは"), "ja");
assert_eq!(detect_language("안녕하세요"), "ko");
assert_eq!(detect_language("привет"), "ru");
}
#[test]
fn test_search_args_schema() {
let schema = search_args_schema();
assert!(schema["properties"]["query"].is_object());
assert!(schema["properties"]["options"].is_object());
}
#[test]
fn test_search_response_creation() {
let results = vec![SearchResult {
title: "Test".to_string(),
url: "https://example.com".to_string(),
description: "Test description".to_string(),
metadata: ResultMetadata {
content_type: ContentType::Article,
source: "example.com".to_string(),
favicon: None,
published_date: None,
},
}];
let options = SearchOptions::default();
let response = SearchResponse::new("test query".to_string(), results, &options);
assert_eq!(response.response_type, "search_results");
assert_eq!(response.data.len(), 1);
assert_eq!(response.metadata.query, "test query");
}
}