use std::sync::Arc;
use async_trait::async_trait;
use serde_json::{Value, json};
use tracing::{debug, info, warn};
use crate::domain::error::{ProviderError, Result, StygianError};
use crate::ports::{AIProvider, ScrapingService, ServiceInput, ServiceOutput};
#[derive(Debug, Clone)]
pub struct ExtractionConfig {
pub max_content_chars: usize,
pub validate_output: bool,
}
impl Default for ExtractionConfig {
fn default() -> Self {
Self {
max_content_chars: 64_000,
validate_output: true,
}
}
}
pub struct LlmExtractionService {
providers: Vec<Arc<dyn AIProvider>>,
config: ExtractionConfig,
}
impl LlmExtractionService {
pub fn new(providers: Vec<Arc<dyn AIProvider>>, config: ExtractionConfig) -> Self {
Self { providers, config }
}
fn resolve_content(input: &ServiceInput) -> &str {
input
.params
.get("content")
.and_then(Value::as_str)
.unwrap_or(&input.url)
}
fn truncate_content<'a>(&self, content: &'a str) -> &'a str {
if content.len() <= self.config.max_content_chars {
content
} else {
warn!(
limit = self.config.max_content_chars,
actual = content.len(),
"Content truncated for LLM extraction"
);
&content[..self.config.max_content_chars]
}
}
fn resolve_schema(input: &ServiceInput) -> Result<Value> {
input.params.get("schema").cloned().ok_or_else(|| {
StygianError::Provider(ProviderError::ApiError(
"LlmExtractionService requires 'schema' in ServiceInput.params".to_string(),
))
})
}
fn validate_output(output: &Value) -> Result<()> {
if output.is_object() || output.is_array() {
Ok(())
} else {
Err(StygianError::Provider(ProviderError::ApiError(format!(
"Provider returned non-object output: {output}"
))))
}
}
}
#[async_trait]
impl ScrapingService for LlmExtractionService {
async fn execute(&self, input: ServiceInput) -> Result<ServiceOutput> {
if self.providers.is_empty() {
return Err(StygianError::Provider(ProviderError::ApiError(
"No AI providers configured in LlmExtractionService".to_string(),
)));
}
let schema = Self::resolve_schema(&input)?;
let raw_content = Self::resolve_content(&input);
let content = self.truncate_content(raw_content).to_string();
let start = std::time::Instant::now();
let mut last_error: Option<StygianError> = None;
for provider in &self.providers {
debug!(provider = provider.name(), "Attempting LLM extraction");
match provider.extract(content.clone(), schema.clone()).await {
Ok(extracted) => {
if self.config.validate_output
&& let Err(e) = Self::validate_output(&extracted)
{
warn!(
provider = provider.name(),
error = %e,
"Provider returned invalid output, trying next"
);
last_error = Some(e);
continue;
}
let elapsed = start.elapsed();
info!(
provider = provider.name(),
elapsed_ms = elapsed.as_millis(),
"LLM extraction succeeded"
);
return Ok(ServiceOutput {
data: extracted.to_string(),
metadata: json!({
"provider": provider.name(),
"elapsed_ms": elapsed.as_millis(),
"content_chars": content.len(),
}),
});
}
Err(e) => {
warn!(
provider = provider.name(),
error = %e,
"Provider failed, trying next in chain"
);
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or_else(|| {
StygianError::Provider(ProviderError::ApiError(
"All AI providers in fallback chain failed".to_string(),
))
}))
}
fn name(&self) -> &'static str {
"llm-extraction"
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::indexing_slicing,
clippy::needless_pass_by_value
)]
mod tests {
use super::*;
use crate::ports::ProviderCapabilities;
use futures::stream::{self, BoxStream};
use serde_json::json;
struct AlwaysSucceed {
response: Value,
}
#[async_trait]
impl AIProvider for AlwaysSucceed {
async fn extract(&self, _content: String, _schema: Value) -> Result<Value> {
Ok(self.response.clone())
}
async fn stream_extract(
&self,
_content: String,
_schema: Value,
) -> Result<BoxStream<'static, Result<Value>>> {
Ok(Box::pin(stream::once(async { Ok(json!({})) })))
}
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities::default()
}
fn name(&self) -> &'static str {
"mock-succeed"
}
}
struct AlwaysFail;
#[async_trait]
impl AIProvider for AlwaysFail {
async fn extract(&self, _content: String, _schema: Value) -> Result<Value> {
Err(StygianError::Provider(ProviderError::ApiError(
"mock failure".to_string(),
)))
}
async fn stream_extract(
&self,
_content: String,
_schema: Value,
) -> Result<BoxStream<'static, Result<Value>>> {
Err(StygianError::Provider(ProviderError::ApiError(
"mock failure".to_string(),
)))
}
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities::default()
}
fn name(&self) -> &'static str {
"mock-fail"
}
}
fn make_input(schema: Value) -> ServiceInput {
ServiceInput {
url: "<h1>Hello</h1>".to_string(),
params: json!({ "schema": schema }),
}
}
#[tokio::test]
async fn test_service_name() {
let svc = LlmExtractionService::new(vec![], ExtractionConfig::default());
assert_eq!(svc.name(), "llm-extraction");
}
#[tokio::test]
async fn test_no_providers_returns_error() {
let svc = LlmExtractionService::new(vec![], ExtractionConfig::default());
let err = svc.execute(make_input(json!({}))).await.unwrap_err();
assert!(err.to_string().contains("No AI providers"));
}
#[tokio::test]
async fn test_missing_schema_returns_error() {
let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysSucceed {
response: json!({}),
})];
let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
let input = ServiceInput {
url: "some content".to_string(),
params: json!({}), };
let err = svc.execute(input).await.unwrap_err();
assert!(err.to_string().contains("schema"));
}
#[tokio::test]
async fn test_single_succeeding_provider() {
let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysSucceed {
response: json!({"title": "Hello"}),
})];
let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
let output = svc.execute(make_input(json!({}))).await.unwrap();
assert_eq!(
output.metadata["provider"].as_str().unwrap(),
"mock-succeed"
);
let data: Value = serde_json::from_str(&output.data).unwrap();
assert_eq!(data["title"].as_str().unwrap(), "Hello");
}
#[tokio::test]
async fn test_fallback_to_second_provider() {
let providers: Vec<Arc<dyn AIProvider>> = vec![
Arc::new(AlwaysFail),
Arc::new(AlwaysSucceed {
response: json!({"score": 42}),
}),
];
let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
let output = svc.execute(make_input(json!({}))).await.unwrap();
assert_eq!(
output.metadata["provider"].as_str().unwrap(),
"mock-succeed"
);
}
#[tokio::test]
async fn test_all_providers_fail() {
let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysFail), Arc::new(AlwaysFail)];
let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
let err = svc.execute(make_input(json!({}))).await.unwrap_err();
assert!(err.to_string().contains("mock failure"));
}
#[tokio::test]
async fn test_content_from_params_overrides_url() {
let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysSucceed {
response: json!({"ok": true}),
})];
let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
let input = ServiceInput {
url: "should-not-be-used".to_string(),
params: json!({
"schema": {"type": "object"},
"content": "actual content here"
}),
};
let output = svc.execute(input).await.unwrap();
assert_eq!(output.metadata["content_chars"].as_u64().unwrap(), 19);
}
#[test]
fn test_truncate_content_short() {
let svc = LlmExtractionService::new(vec![], ExtractionConfig::default());
let s = "hello";
assert_eq!(svc.truncate_content(s), s);
}
#[test]
fn test_truncate_content_long() {
let svc = LlmExtractionService::new(
vec![],
ExtractionConfig {
max_content_chars: 5,
..Default::default()
},
);
assert_eq!(svc.truncate_content("hello world"), "hello");
}
#[test]
fn test_validate_output_object_ok() {
assert!(LlmExtractionService::validate_output(&json!({"k": "v"})).is_ok());
}
#[test]
fn test_validate_output_array_ok() {
assert!(LlmExtractionService::validate_output(&json!([1, 2, 3])).is_ok());
}
#[test]
fn test_validate_output_scalar_err() {
assert!(LlmExtractionService::validate_output(&json!("just a string")).is_err());
}
}