use crate::{PreviewError, Fetcher};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
#[cfg(feature = "cache")]
use crate::Cache;
#[derive(Clone, Debug)]
pub struct LLMExtractorConfig {
pub format: ContentFormat,
pub clean_html: bool,
pub max_content_length: usize,
pub model_params: HashMap<String, Value>,
}
impl Default for LLMExtractorConfig {
fn default() -> Self {
Self {
format: ContentFormat::Html,
clean_html: true,
max_content_length: 50_000, model_params: HashMap::new(),
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum ContentFormat {
Html,
Markdown,
Text,
Image,
}
#[derive(Clone, Debug)]
pub struct ProcessedContent {
pub content: String,
pub format: ContentFormat,
pub metadata: HashMap<String, String>,
}
#[derive(Clone, Debug)]
pub struct ExtractionResult<T> {
pub data: T,
pub model: String,
pub usage: Option<TokenUsage>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[async_trait]
pub trait LLMProvider: Send + Sync {
fn name(&self) -> &str;
async fn generate(
&self,
prompt: String,
schema: Value,
config: &LLMExtractorConfig,
) -> Result<Value, PreviewError>;
async fn stream(
&self,
_prompt: String,
_schema: Value,
_config: &LLMExtractorConfig,
) -> Result<Box<dyn futures::Stream<Item = Result<Value, PreviewError>> + Send + Unpin>, PreviewError> {
Err(PreviewError::UnsupportedOperation("Streaming not supported by this provider".into()))
}
}
pub struct LLMExtractor {
provider: Arc<dyn LLMProvider>,
preprocessor: ContentPreprocessor,
config: LLMExtractorConfig,
#[cfg(feature = "cache")]
cache: Option<Arc<Cache>>,
}
impl LLMExtractor {
pub fn new(provider: Arc<dyn LLMProvider>) -> Self {
Self {
provider,
preprocessor: ContentPreprocessor::new(),
config: LLMExtractorConfig::default(),
#[cfg(feature = "cache")]
cache: None,
}
}
pub fn with_config(provider: Arc<dyn LLMProvider>, config: LLMExtractorConfig) -> Self {
Self {
provider,
preprocessor: ContentPreprocessor::new(),
config,
#[cfg(feature = "cache")]
cache: None,
}
}
#[cfg(feature = "cache")]
pub fn with_cache(mut self, cache: Arc<Cache>) -> Self {
self.cache = Some(cache);
self
}
pub async fn extract<T>(&self, url: &str, fetcher: &Fetcher) -> Result<ExtractionResult<T>, PreviewError>
where
T: serde::de::DeserializeOwned + serde::Serialize + schemars::JsonSchema,
{
#[cfg(feature = "cache")]
if let Some(cache) = &self.cache {
let cache_key = format!("llm:{}:{}", url, std::any::type_name::<T>());
if let Some(cached) = cache.get(&cache_key).await {
if let Ok(result) = serde_json::from_str::<T>(&cached.description.unwrap_or_default()) {
return Ok(ExtractionResult {
data: result,
model: "cached".to_string(),
usage: None,
});
}
}
}
let fetch_result = fetcher.fetch(url).await?;
let html = match fetch_result {
crate::FetchResult::Html(h) => h,
_ => return Err(PreviewError::InvalidContentType("Expected HTML".to_string())),
};
let processed = self.preprocessor.preprocess(&html, &self.config).await?;
let schema = schemars::schema_for!(T);
let schema_json = serde_json::to_value(&schema)?;
let prompt = self.build_prompt(&processed, &schema_json)?;
let result = self.provider.generate(prompt, schema_json, &self.config).await?;
let extracted: T = serde_json::from_value(result)?;
#[cfg(feature = "cache")]
if let Some(cache) = &self.cache {
let cache_key = format!("llm:{}:{}", url, std::any::type_name::<T>());
let preview = crate::Preview {
url: url.to_string(),
title: Some(format!("LLM Extraction: {}", std::any::type_name::<T>())),
description: Some(serde_json::to_string(&extracted)?),
image_url: None,
site_name: None,
favicon: None,
};
cache.set(cache_key, preview).await;
}
Ok(ExtractionResult {
data: extracted,
model: self.provider.name().to_string(),
usage: None, })
}
fn build_prompt(&self, content: &ProcessedContent, schema: &Value) -> Result<String, PreviewError> {
let schema_str = serde_json::to_string_pretty(schema)?;
let format_hint = match content.format {
ContentFormat::Html => "HTML",
ContentFormat::Markdown => "Markdown",
ContentFormat::Text => "plain text",
ContentFormat::Image => "image",
};
Ok(format!(
"Extract structured data from the following {} content according to this schema:\n\n\
Schema:\n```json\n{}\n```\n\n\
Content:\n{}\n\n\
Extract the data and return it as a valid JSON object matching the schema.",
format_hint,
schema_str,
content.content
))
}
}
pub struct ContentPreprocessor {
html_cleaner: HtmlCleaner,
}
impl ContentPreprocessor {
pub fn new() -> Self {
Self {
html_cleaner: HtmlCleaner::new(),
}
}
pub async fn preprocess(&self, html: &str, config: &LLMExtractorConfig) -> Result<ProcessedContent, PreviewError> {
let processed_html = if config.clean_html {
self.html_cleaner.clean(html)?
} else {
html.to_string()
};
let content = match config.format {
ContentFormat::Html => processed_html,
ContentFormat::Markdown => self.convert_to_markdown(&processed_html)?,
ContentFormat::Text => self.extract_text(&processed_html)?,
ContentFormat::Image => {
return Err(PreviewError::UnsupportedOperation("Image format not yet implemented".into()));
}
};
let content = if content.len() > config.max_content_length {
content.chars().take(config.max_content_length).collect()
} else {
content
};
Ok(ProcessedContent {
content,
format: config.format.clone(),
metadata: HashMap::new(),
})
}
fn convert_to_markdown(&self, html: &str) -> Result<String, PreviewError> {
use scraper::{Html, Selector};
let document = Html::parse_document(html);
let mut markdown = String::new();
if let Ok(title_selector) = Selector::parse("title") {
if let Some(title) = document.select(&title_selector).next() {
markdown.push_str(&format!("# {}\n\n", title.text().collect::<String>()));
}
}
for i in 1..=6 {
if let Ok(selector) = Selector::parse(&format!("h{}", i)) {
for element in document.select(&selector) {
let heading_level = "#".repeat(i);
markdown.push_str(&format!("{} {}\n\n", heading_level, element.text().collect::<String>()));
}
}
}
if let Ok(p_selector) = Selector::parse("p") {
for element in document.select(&p_selector) {
markdown.push_str(&format!("{}\n\n", element.text().collect::<String>()));
}
}
Ok(markdown)
}
fn extract_text(&self, html: &str) -> Result<String, PreviewError> {
use scraper::Html;
let document = Html::parse_document(html);
let text = document.root_element().text().collect::<Vec<_>>().join(" ");
let text = text
.split_whitespace()
.collect::<Vec<_>>()
.join(" ");
Ok(text)
}
}
struct HtmlCleaner;
impl HtmlCleaner {
fn new() -> Self {
Self
}
fn clean(&self, html: &str) -> Result<String, PreviewError> {
use scraper::{Html, Selector};
use std::collections::HashSet;
let document = Html::parse_document(html);
let remove_selectors = vec![
"script", "style", "noscript", "iframe", "object", "embed",
"form", "input", "button", "select", "textarea", "option",
"nav", "header", "footer", "aside", "menu", "menuitem",
"audio", "video", "source", "track", "canvas", "svg",
"meta", "link", "base", "title"
];
let _content_selectors = vec![
"h1", "h2", "h3", "h4", "h5", "h6",
"p", "div", "span", "section", "article", "main",
"ul", "ol", "li", "dl", "dt", "dd",
"table", "thead", "tbody", "tr", "th", "td",
"blockquote", "pre", "code",
"strong", "b", "em", "i", "u", "mark",
"a", "img", "br", "hr"
];
let mut cleaned_html = String::new();
let mut removed_tags: HashSet<String> = HashSet::new();
if let Ok(body_selector) = Selector::parse("body") {
if let Some(body) = document.select(&body_selector).next() {
cleaned_html = self.extract_clean_content(body, &remove_selectors, &mut removed_tags);
} else {
cleaned_html = self.extract_clean_content(document.root_element(), &remove_selectors, &mut removed_tags);
}
}
if cleaned_html.trim().is_empty() {
cleaned_html = document.root_element().text().collect::<Vec<_>>().join(" ");
cleaned_html = cleaned_html
.split_whitespace()
.collect::<Vec<_>>()
.join(" ");
}
Ok(cleaned_html)
}
fn extract_clean_content(
&self,
element: scraper::ElementRef,
remove_selectors: &[&str],
removed_tags: &mut std::collections::HashSet<String>
) -> String {
use scraper::{Node, ElementRef};
let mut content = String::new();
let tag_name = element.value().name();
if remove_selectors.contains(&tag_name) {
removed_tags.insert(tag_name.to_string());
return content;
}
for child in element.children() {
match child.value() {
Node::Text(text) => {
let text_content = text.text.trim();
if !text_content.is_empty() {
content.push_str(text_content);
content.push(' ');
}
}
Node::Element(_) => {
if let Some(child_element) = ElementRef::wrap(child) {
let child_content = self.extract_clean_content(child_element, remove_selectors, removed_tags);
if !child_content.trim().is_empty() {
match child_element.value().name() {
"h1" | "h2" | "h3" | "h4" | "h5" | "h6" => {
content.push_str(&format!("\n\n{}\n", child_content.trim()));
}
"p" | "div" | "section" | "article" => {
content.push_str(&format!("\n{}\n", child_content.trim()));
}
"li" => {
content.push_str(&format!("• {}\n", child_content.trim()));
}
"br" => {
content.push('\n');
}
"hr" => {
content.push_str("\n---\n");
}
_ => {
content.push_str(&child_content);
}
}
}
}
}
_ => {} }
}
content
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_content_format() {
assert_eq!(ContentFormat::Html, ContentFormat::Html);
assert_ne!(ContentFormat::Html, ContentFormat::Markdown);
}
#[test]
fn test_default_config() {
let config = LLMExtractorConfig::default();
assert_eq!(config.format, ContentFormat::Html);
assert!(config.clean_html);
assert_eq!(config.max_content_length, 50_000);
}
}