use serde::{Deserialize, Serialize};
use thiserror::Error;
pub type DaedraResult<T> = Result<T, DaedraError>;
#[derive(Error, Debug)]
pub enum DaedraError {
#[error("HTTP request failed: {0}")]
HttpError(#[from] reqwest::Error),
#[error("Invalid URL: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Search failed: {0}")]
SearchError(String),
#[error("Failed to fetch page: {0}")]
FetchError(String),
#[error("Invalid arguments: {0}")]
InvalidArguments(String),
#[error("Server error: {0}")]
ServerError(String),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
#[error("Content extraction failed: {0}")]
ExtractionError(String),
#[error("Rate limit exceeded, please try again later")]
RateLimitExceeded,
#[error("Bot protection detected on target page")]
BotProtectionDetected,
#[error("Operation timed out")]
Timeout,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "UPPERCASE")]
pub enum SafeSearchLevel {
Off,
#[default]
Moderate,
Strict,
}
impl SafeSearchLevel {
pub fn to_ddg_value(&self) -> i32 {
match self {
SafeSearchLevel::Off => -2,
SafeSearchLevel::Moderate => -1,
SafeSearchLevel::Strict => 1,
}
}
}
impl std::fmt::Display for SafeSearchLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SafeSearchLevel::Off => write!(f, "OFF"),
SafeSearchLevel::Moderate => write!(f, "MODERATE"),
SafeSearchLevel::Strict => write!(f, "STRICT"),
}
}
}
impl std::str::FromStr for SafeSearchLevel {
type Err = DaedraError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_uppercase().as_str() {
"OFF" => Ok(SafeSearchLevel::Off),
"MODERATE" => Ok(SafeSearchLevel::Moderate),
"STRICT" => Ok(SafeSearchLevel::Strict),
_ => Err(DaedraError::InvalidArguments(format!(
"Invalid safe search level: {}",
s
))),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchOptions {
#[serde(default = "default_region")]
pub region: String,
#[serde(default)]
pub safe_search: SafeSearchLevel,
#[serde(default = "default_num_results")]
pub num_results: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub time_range: Option<String>,
}
impl Default for SearchOptions {
fn default() -> Self {
Self {
region: "wt-wt".to_string(),
safe_search: SafeSearchLevel::Moderate,
num_results: 10,
time_range: None,
}
}
}
fn default_region() -> String {
"wt-wt".to_string() }
fn default_num_results() -> usize {
10
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchArgs {
pub query: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<SearchOptions>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisitPageArgs {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub selector: Option<String>,
#[serde(default)]
pub include_images: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum ContentType {
Documentation,
Social,
Article,
Forum,
Video,
Shopping,
#[default]
Other,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResultMetadata {
#[serde(rename = "type")]
pub content_type: ContentType,
pub source: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub favicon: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub published_date: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub title: String,
pub url: String,
pub description: String,
pub metadata: ResultMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryAnalysis {
pub language: String,
pub topics: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchContext {
pub region: String,
pub safe_search: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub num_results: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchMetadata {
pub query: String,
pub timestamp: String,
pub result_count: usize,
pub search_context: SearchContext,
pub query_analysis: QueryAnalysis,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResponse {
#[serde(rename = "type")]
pub response_type: String,
pub data: Vec<SearchResult>,
pub metadata: SearchMetadata,
}
impl SearchResponse {
pub fn new(query: String, results: Vec<SearchResult>, options: &SearchOptions) -> Self {
let timestamp = chrono::Utc::now().to_rfc3339();
let result_count = results.len();
let language = detect_language(&query);
let topics = detect_topics(&results);
Self {
response_type: "search_results".to_string(),
data: results,
metadata: SearchMetadata {
query,
timestamp,
result_count,
search_context: SearchContext {
region: options.region.clone(),
safe_search: options.safe_search.to_string(),
num_results: Some(options.num_results),
},
query_analysis: QueryAnalysis { language, topics },
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PageContent {
pub url: String,
pub title: String,
pub content: String,
pub timestamp: String,
pub word_count: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub links: Option<Vec<PageLink>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PageLink {
pub text: String,
pub url: String,
}
fn detect_language(query: &str) -> String {
if query
.chars()
.any(|c| ('\u{4e00}'..='\u{9fff}').contains(&c))
{
return "zh".to_string();
}
if query
.chars()
.any(|c| ('\u{3040}'..='\u{30ff}').contains(&c))
{
return "ja".to_string();
}
if query
.chars()
.any(|c| ('\u{ac00}'..='\u{d7af}').contains(&c))
{
return "ko".to_string();
}
if query
.chars()
.any(|c| ('\u{0400}'..='\u{04ff}').contains(&c))
{
return "ru".to_string();
}
if query
.chars()
.any(|c| ('\u{0600}'..='\u{06ff}').contains(&c))
{
return "ar".to_string();
}
"en".to_string()
}
fn detect_topics(results: &[SearchResult]) -> Vec<String> {
use std::collections::HashSet;
let mut topics = HashSet::new();
for result in results {
let lower_title = result.title.to_lowercase();
let lower_url = result.url.to_lowercase();
if lower_url.contains("github.com")
|| lower_url.contains("stackoverflow.com")
|| lower_url.contains("gitlab.com")
|| lower_title.contains("programming")
|| lower_title.contains("code")
{
topics.insert("technology".to_string());
}
if lower_url.contains("docs.")
|| lower_url.contains("/docs/")
|| lower_url.contains("/documentation/")
|| lower_title.contains("documentation")
|| lower_title.contains("api reference")
{
topics.insert("documentation".to_string());
}
if lower_url.contains("news.")
|| lower_url.contains("/news/")
|| result.metadata.content_type == ContentType::Article
{
topics.insert("news".to_string());
}
if lower_url.contains(".edu")
|| lower_url.contains("arxiv.org")
|| lower_url.contains("scholar.google")
|| lower_title.contains("research")
|| lower_title.contains("study")
{
topics.insert("academic".to_string());
}
}
topics.into_iter().collect()
}
pub fn search_args_schema() -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query string"
},
"options": {
"type": "object",
"description": "Optional search configuration",
"properties": {
"region": {
"type": "string",
"description": "Region for search results (e.g., 'us-en', 'wt-wt' for worldwide)",
"default": "wt-wt"
},
"safe_search": {
"type": "string",
"enum": ["OFF", "MODERATE", "STRICT"],
"description": "Safe search filtering level",
"default": "MODERATE"
},
"num_results": {
"type": "integer",
"description": "Maximum number of results to return",
"default": 10,
"minimum": 1,
"maximum": 50
},
"time_range": {
"type": "string",
"description": "Time range filter (d=day, w=week, m=month, y=year)"
}
}
}
},
"required": ["query"]
})
}
pub fn visit_page_args_schema() -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"format": "uri",
"description": "URL of the page to visit"
},
"selector": {
"type": "string",
"description": "Optional CSS selector to target specific content"
},
"include_images": {
"type": "boolean",
"description": "Whether to include image references in the response",
"default": false
}
},
"required": ["url"]
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_safe_search_level_parsing() {
assert_eq!(
"OFF".parse::<SafeSearchLevel>().unwrap(),
SafeSearchLevel::Off
);
assert_eq!(
"MODERATE".parse::<SafeSearchLevel>().unwrap(),
SafeSearchLevel::Moderate
);
assert_eq!(
"STRICT".parse::<SafeSearchLevel>().unwrap(),
SafeSearchLevel::Strict
);
assert_eq!(
"moderate".parse::<SafeSearchLevel>().unwrap(),
SafeSearchLevel::Moderate
);
}
#[test]
fn test_safe_search_ddg_value() {
assert_eq!(SafeSearchLevel::Off.to_ddg_value(), -2);
assert_eq!(SafeSearchLevel::Moderate.to_ddg_value(), -1);
assert_eq!(SafeSearchLevel::Strict.to_ddg_value(), 1);
}
#[test]
fn test_language_detection() {
assert_eq!(detect_language("hello world"), "en");
assert_eq!(detect_language("你好世界"), "zh");
assert_eq!(detect_language("こんにちは"), "ja");
assert_eq!(detect_language("안녕하세요"), "ko");
assert_eq!(detect_language("привет"), "ru");
}
#[test]
fn test_search_args_schema() {
let schema = search_args_schema();
assert!(schema["properties"]["query"].is_object());
assert!(schema["properties"]["options"].is_object());
}
#[test]
fn test_search_response_creation() {
let results = vec![SearchResult {
title: "Test".to_string(),
url: "https://example.com".to_string(),
description: "Test description".to_string(),
metadata: ResultMetadata {
content_type: ContentType::Article,
source: "example.com".to_string(),
favicon: None,
published_date: None,
},
}];
let options = SearchOptions::default();
let response = SearchResponse::new("test query".to_string(), results, &options);
assert_eq!(response.response_type, "search_results");
assert_eq!(response.data.len(), 1);
assert_eq!(response.metadata.query, "test query");
}
}