hanzo-engine 0.6.1

Hanzo Engine - fast, flexible LLM inference engine written in Rust.
Documentation
use std::collections::{HashMap, HashSet};
use std::time::Duration;

pub mod rag;

use anyhow::Result;
use html2text::{config, render::PlainDecorator};
use scraper::{Html, Selector};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::env::consts::{ARCH, FAMILY, OS};
use tokenizers::Tokenizer;

use crate::{Function, Tool, ToolType, WebSearchOptions, WebSearchUserLocation};

const SEARCH_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
const SEARCH_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
const MAX_SEARCH_RESULTS: usize = 10;

/// Callback used to override how search results are gathered. The returned
/// vector must be sorted in decreasing order of relevance.
pub type SearchCallback =
    dyn Fn(&SearchFunctionParameters) -> Result<Vec<SearchResult>> + Send + Sync;

pub(crate) fn search_tool_called(name: &str) -> bool {
    name == SEARCH_TOOL_NAME || name == EXTRACT_TOOL_NAME
}

pub(crate) const SEARCH_TOOL_NAME: &str = "hanzo_search_the_web";
pub(crate) const EXTRACT_TOOL_NAME: &str = "hanzo_website_content_extractor";

const APP_VERSION: &str = env!("CARGO_PKG_VERSION");
pub(crate) const SEARCH_DESCRIPTION: &str = r#"This tool is used to search the web given a query.
If the user wants up-to-date information or you want to retrieve new information, call this tool.
If you call this tool, then you MUST complete your answer using the output.
The input can be a query. It should not be a URL. Either is fine.
Additionally, if you have any questions that require a follow-up, you can call this tool repeatedly.

You should expect output like this:
{
    "sources": ["example.com", ...],
    "output": [
        {
            "title": "...",
            "description": "...",
            "url": "...",
            "content": "...",
        },
        ...
    ]
}
"#;
pub(crate) const EXTRACT_DESCRIPTION: &str = r#"This tool is used to extract the content of a website.
If the user wants information about a specific site or you want to extract the content of a specific site, call this tool.
The input must be a URL.
Additionally, if you have any questions that require a follow-up, you can call this tool repeatedly.

You should expect output like this:
{
    "output": [
        {
            "url": "...",
            "content": "...",
        },
        ...
    ]
}
"#;

#[derive(Debug, Serialize, Deserialize, Default, Clone)]
pub struct SearchResult {
    pub title: String,
    pub description: String,
    pub url: String,
    pub content: String,
}

pub(crate) fn source_domain(url: &str) -> Option<String> {
    let host = reqwest::Url::parse(url)
        .ok()?
        .host_str()?
        .trim_end_matches('.')
        .to_ascii_lowercase();
    let host = host.strip_prefix("www.").unwrap_or(&host).to_string();
    if host.is_empty() {
        None
    } else {
        Some(host)
    }
}

pub(crate) fn source_domains<'a>(urls: impl IntoIterator<Item = &'a str>) -> Vec<String> {
    let mut seen = HashSet::new();
    let mut domains = Vec::new();
    for url in urls {
        let Some(domain) = source_domain(url) else {
            continue;
        };
        if seen.insert(domain.clone()) {
            domains.push(domain);
        }
    }
    domains
}

#[derive(Debug, Serialize, Deserialize, Default, Clone)]
pub struct ExtractResult {
    pub url: String,
    pub content: String,
}

impl SearchResult {
    pub fn cap_content_len(self, tokenizer: &Tokenizer, size: usize) -> Result<Self> {
        let tokenized_content = tokenizer
            .encode_fast(self.content, false)
            .map_err(anyhow::Error::msg)?;
        let ids = tokenized_content.get_ids();
        let content = tokenizer
            .decode(&ids[..size.min(ids.len())], false)
            .map_err(anyhow::Error::msg)?;

        Ok(Self {
            title: self.title,
            description: self.description,
            url: self.url,
            content,
        })
    }
}

impl ExtractResult {
    pub fn cap_content_len(self, tokenizer: &Tokenizer, size: usize) -> Result<Self> {
        let tokenized_content = tokenizer
            .encode_fast(self.content, false)
            .map_err(anyhow::Error::msg)?;
        let ids = tokenized_content.get_ids();
        let content = tokenizer
            .decode(&ids[..size.min(ids.len())], false)
            .map_err(anyhow::Error::msg)?;

        Ok(Self {
            url: self.url,
            content,
        })
    }
}

#[derive(Debug, Serialize, Deserialize)]
pub struct SearchFunctionParameters {
    pub query: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ExtractFunctionParameters {
    pub url: String,
}

pub fn get_search_tools(web_search_options: &WebSearchOptions) -> Result<Vec<Tool>> {
    let search_tool = {
        let parameters: HashMap<String, Value> = serde_json::from_value(json!({
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": "A query for web searching.",
                },
            },
            "required": ["query"],
        }))?;

        let location_details = match &web_search_options.user_location {
            Some(WebSearchUserLocation::Approximate { approximate }) => {
                format!(
                    "\nThe user's location is: {}, {}, {}, {}.",
                    approximate.city, approximate.region, approximate.country, approximate.timezone
                )
            }
            None => "".to_string(),
        };
        let description = web_search_options
            .search_description
            .as_deref()
            .unwrap_or(SEARCH_DESCRIPTION);
        Tool {
            tp: ToolType::Function,
            function: Function {
                description: Some(format!("{description}{location_details}")),
                name: SEARCH_TOOL_NAME.to_string(),
                parameters: Some(parameters),
                strict: Some(true),
            },
        }
    };

    let extract_tool = {
        let parameters: HashMap<String, Value> = serde_json::from_value(json!({
            "type": "object",
            "properties": {
                "url": {
                    "type": "string",
                    "description": "A URL to extract the content of the website from.",
                },
            },
            "required": ["url"],
        }))?;

        let description = web_search_options
            .extract_description
            .as_deref()
            .unwrap_or(EXTRACT_DESCRIPTION);
        Tool {
            tp: ToolType::Function,
            function: Function {
                description: Some(description.to_string()),
                name: EXTRACT_TOOL_NAME.to_string(),
                parameters: Some(parameters),
                strict: Some(true),
            },
        }
    };

    Ok(vec![search_tool, extract_tool])
}

fn build_client() -> Result<reqwest::Client> {
    Ok(reqwest::Client::builder()
        .connect_timeout(SEARCH_CONNECT_TIMEOUT)
        .timeout(SEARCH_REQUEST_TIMEOUT)
        .build()?)
}

fn html_to_text(html: &str) -> Option<String> {
    config::with_decorator(PlainDecorator::new())
        .do_decorate()
        .string_from_read(html.as_bytes(), 80)
        .ok()
}

pub async fn run_search_tool(params: &SearchFunctionParameters) -> Result<Vec<SearchResult>> {
    let client = build_client()?;
    let user_agent = format!("hanzo/{APP_VERSION} ({OS}; {ARCH}; {FAMILY})");

    // If the model passed a URL instead of a search query, fetch it directly
    // rather than searching DuckDuckGo (which returns 0 results for raw URLs).
    let trimmed = params.query.trim().trim_matches('"');
    if trimmed.starts_with("http://") || trimmed.starts_with("https://") {
        let response = client
            .get(trimmed)
            .header("User-Agent", &user_agent)
            .send()
            .await?;
        let html = response.text().await?;
        let content = html_to_text(&html).unwrap_or_default();
        return Ok(vec![SearchResult {
            title: trimmed.to_string(),
            description: String::new(),
            url: trimmed.to_string(),
            content,
        }]);
    }

    let encoded_query = urlencoding::encode(&params.query);
    let url = format!("https://html.duckduckgo.com/html/?q={encoded_query}");

    let t0 = std::time::Instant::now();
    let response = client
        .get(&url)
        .header("User-Agent", &user_agent)
        .send()
        .await?;

    if !response.status().is_success() {
        anyhow::bail!("Failed to fetch search results: {}", response.status())
    }

    let html = response.text().await?;
    tracing::debug!(
        "Search: DuckDuckGo query completed in {:.2}s",
        t0.elapsed().as_secs_f32()
    );

    // Parse DDG HTML in a block so `document` (non-Send) is dropped before the
    // async content fetches below.
    let partials: Vec<(String, String, String)> = {
        let document = Html::parse_document(&html);

        let result_selector = Selector::parse(".result").unwrap();
        let title_selector = Selector::parse(".result__title").unwrap();
        let snippet_selector = Selector::parse(".result__snippet").unwrap();
        let url_selector = Selector::parse(".result__url").unwrap();

        document
            .select(&result_selector)
            .filter_map(|element| {
                let title = element
                    .select(&title_selector)
                    .next()
                    .map(|e| e.text().collect::<String>().trim().to_string())
                    .unwrap_or_default();
                let description = element
                    .select(&snippet_selector)
                    .next()
                    .map(|e| e.text().collect::<String>().trim().to_string())
                    .unwrap_or_default();
                let mut url = element
                    .select(&url_selector)
                    .next()
                    .map(|e| e.text().collect::<String>().trim().to_string())
                    .unwrap_or_default();
                if title.is_empty() || description.is_empty() || url.is_empty() {
                    return None;
                }
                if !url.starts_with("http") {
                    url = format!("https://{url}");
                }
                Some((title, description, url))
            })
            .take(MAX_SEARCH_RESULTS)
            .collect()
    };
    tracing::debug!("Search: fetching content for {} pages", partials.len());

    // Fetch all pages concurrently with async I/O (not Rayon thread pool rounds).
    let t1 = std::time::Instant::now();
    let fetches = partials.into_iter().map(|(title, description, url)| {
        let client = client.clone();
        let user_agent = user_agent.clone();
        async move {
            let resp = client
                .get(&url)
                .header("User-Agent", &user_agent)
                .send()
                .await
                .ok()?;
            let html = resp.text().await.ok()?;
            let content = html_to_text(&html)?;
            Some(SearchResult {
                title,
                description,
                url,
                content,
            })
        }
    });
    let results: Vec<SearchResult> = futures::future::join_all(fetches)
        .await
        .into_iter()
        .flatten()
        .collect();
    tracing::debug!(
        "Search: fetched {} pages in {:.2}s",
        results.len(),
        t1.elapsed().as_secs_f32()
    );

    Ok(results)
}

pub async fn run_extract_tool(params: &ExtractFunctionParameters) -> Result<ExtractResult> {
    let client = build_client()?;
    let user_agent = format!("hanzo/{APP_VERSION} ({OS}; {ARCH}; {FAMILY})");

    let content = match client
        .get(&params.url)
        .header("User-Agent", &user_agent)
        .send()
        .await
    {
        Ok(response) => response
            .text()
            .await
            .ok()
            .and_then(|html| html_to_text(&html)),
        Err(_) => None,
    };
    Ok(ExtractResult {
        url: params.url.clone(),
        content: content.unwrap_or("ERROR: failed to extract content".to_string()),
    })
}