mod prompts;
mod types;
pub use prompts::build_extraction_prompt;
pub use types::{
EntityType, ExtractedEntity, ExtractedRelation, ExtractionOptions, ExtractionResult,
RelationType,
};
use serde::Deserialize;
use crate::constants::{
EXTRACTION_CONFIDENCE_DEFAULT, EXTRACTION_CONFIDENCE_MAX, EXTRACTION_CONFIDENCE_MIN,
EXTRACTION_ENTITIES_COUNT_MAX, EXTRACTION_RELATIONS_COUNT_MAX, EXTRACTION_TEXT_BYTES_MAX,
};
use crate::llm::{CompletionRequest, LLMProvider, ProviderError};
#[derive(Debug, Clone, thiserror::Error)]
pub enum ExtractionError {
#[error("Text is empty")]
EmptyText,
#[error("Text too long: {len} bytes (max {max})")]
TextTooLong {
len: usize,
max: usize,
},
#[error("Invalid confidence: {value} (must be {min}-{max})")]
InvalidConfidence {
value: f64,
min: f64,
max: f64,
},
}
#[derive(Debug, Deserialize, Default)]
struct LLMExtractionResponse {
#[serde(default)]
entities: Vec<RawEntity>,
#[serde(default)]
relations: Vec<RawRelation>,
}
#[derive(Debug, Deserialize)]
struct RawEntity {
name: Option<String>,
#[serde(rename = "type")]
entity_type: Option<String>,
content: Option<String>,
confidence: Option<f64>,
}
#[derive(Debug, Deserialize)]
struct RawRelation {
source: Option<String>,
target: Option<String>,
#[serde(rename = "type")]
relation_type: Option<String>,
confidence: Option<f64>,
}
#[derive(Debug)]
pub struct EntityExtractor<P: LLMProvider> {
provider: P,
}
impl<P: LLMProvider> EntityExtractor<P> {
#[must_use]
pub fn new(provider: P) -> Self {
Self { provider }
}
pub async fn extract(
&self,
text: &str,
options: ExtractionOptions,
) -> Result<ExtractionResult, ExtractionError> {
if text.is_empty() {
return Err(ExtractionError::EmptyText);
}
if text.len() > EXTRACTION_TEXT_BYTES_MAX {
return Err(ExtractionError::TextTooLong {
len: text.len(),
max: EXTRACTION_TEXT_BYTES_MAX,
});
}
if !(EXTRACTION_CONFIDENCE_MIN..=EXTRACTION_CONFIDENCE_MAX)
.contains(&options.min_confidence)
{
return Err(ExtractionError::InvalidConfidence {
value: options.min_confidence,
min: EXTRACTION_CONFIDENCE_MIN,
max: EXTRACTION_CONFIDENCE_MAX,
});
}
let existing = if options.existing_entities.is_empty() {
None
} else {
Some(options.existing_entities.as_slice())
};
let prompt = build_extraction_prompt(text, existing);
let (entities, relations) = match self.call_llm(&prompt, text).await {
Ok((e, r)) => (e, r),
Err(_) => {
(self.create_fallback_entity(text), Vec::new())
}
};
let entities: Vec<_> = if options.min_confidence > 0.0 {
entities
.into_iter()
.filter(|e| e.confidence >= options.min_confidence)
.collect()
} else {
entities
};
let relations: Vec<_> = if options.min_confidence > 0.0 {
relations
.into_iter()
.filter(|r| r.confidence >= options.min_confidence)
.collect()
} else {
relations
};
let entities: Vec<_> = entities
.into_iter()
.take(EXTRACTION_ENTITIES_COUNT_MAX)
.collect();
let relations: Vec<_> = relations
.into_iter()
.take(EXTRACTION_RELATIONS_COUNT_MAX)
.collect();
let result = ExtractionResult::new(entities, relations, text);
debug_assert!(
result.entity_count() <= EXTRACTION_ENTITIES_COUNT_MAX,
"too many entities"
);
debug_assert!(
result.relation_count() <= EXTRACTION_RELATIONS_COUNT_MAX,
"too many relations"
);
Ok(result)
}
pub async fn extract_entities_only(
&self,
text: &str,
) -> Result<Vec<ExtractedEntity>, ExtractionError> {
let result = self.extract(text, ExtractionOptions::default()).await?;
Ok(result.entities)
}
async fn call_llm(
&self,
prompt: &str,
original_text: &str,
) -> Result<(Vec<ExtractedEntity>, Vec<ExtractedRelation>), ProviderError> {
let request = CompletionRequest::new(prompt).with_json_mode();
let response = self.provider.complete(&request).await?;
let parsed = self.parse_response(&response, original_text);
Ok(parsed)
}
fn parse_response(
&self,
response: &str,
original_text: &str,
) -> (Vec<ExtractedEntity>, Vec<ExtractedRelation>) {
let json_str = Self::extract_json_from_response(response);
let data: LLMExtractionResponse = match serde_json::from_str(json_str) {
Ok(d) => d,
Err(_) => {
return (self.create_fallback_entity(original_text), Vec::new());
}
};
let entities = self.parse_entities(&data.entities, original_text);
let relations = self.parse_relations(&data.relations);
if entities.is_empty() {
return (self.create_fallback_entity(original_text), relations);
}
(entities, relations)
}
fn extract_json_from_response(response: &str) -> &str {
let trimmed = response.trim();
if trimmed.starts_with("```json") {
if let Some(start_idx) = trimmed.find('\n') {
if let Some(end_idx) = trimmed.rfind("```") {
return trimmed[start_idx + 1..end_idx].trim();
}
}
}
if trimmed.starts_with("```") {
if let Some(start_idx) = trimmed.find('\n') {
if let Some(end_idx) = trimmed.rfind("```") {
return trimmed[start_idx + 1..end_idx].trim();
}
}
}
trimmed
}
fn parse_entities(
&self,
raw_entities: &[RawEntity],
original_text: &str,
) -> Vec<ExtractedEntity> {
let mut entities = Vec::new();
for raw in raw_entities {
let name = match &raw.name {
Some(n) if !n.trim().is_empty() => n.trim().to_string(),
_ => continue,
};
let name = if name.len() > crate::constants::EXTRACTION_ENTITY_NAME_BYTES_MAX {
name[..crate::constants::EXTRACTION_ENTITY_NAME_BYTES_MAX].to_string()
} else {
name
};
let entity_type = raw
.entity_type
.as_deref()
.map(EntityType::from_str_or_note)
.unwrap_or(EntityType::Note);
let content = raw
.content
.as_deref()
.unwrap_or(&original_text[..200.min(original_text.len())])
.to_string();
let content = if content.len() > crate::constants::EXTRACTION_ENTITY_CONTENT_BYTES_MAX {
content[..crate::constants::EXTRACTION_ENTITY_CONTENT_BYTES_MAX].to_string()
} else {
content
};
let confidence = raw
.confidence
.map(|c| c.clamp(EXTRACTION_CONFIDENCE_MIN, EXTRACTION_CONFIDENCE_MAX))
.unwrap_or(EXTRACTION_CONFIDENCE_DEFAULT);
entities.push(ExtractedEntity::new(name, entity_type, content, confidence));
}
entities
}
fn parse_relations(&self, raw_relations: &[RawRelation]) -> Vec<ExtractedRelation> {
let mut relations = Vec::new();
for raw in raw_relations {
let source = match &raw.source {
Some(s) if !s.trim().is_empty() => s.trim().to_string(),
_ => continue,
};
let target = match &raw.target {
Some(t) if !t.trim().is_empty() => t.trim().to_string(),
_ => continue,
};
let relation_type = raw
.relation_type
.as_deref()
.map(RelationType::from_str_or_relates_to)
.unwrap_or(RelationType::RelatesTo);
let confidence = raw
.confidence
.map(|c| c.clamp(EXTRACTION_CONFIDENCE_MIN, EXTRACTION_CONFIDENCE_MAX))
.unwrap_or(EXTRACTION_CONFIDENCE_DEFAULT);
relations.push(ExtractedRelation::new(
source,
target,
relation_type,
confidence,
));
}
relations
}
fn create_fallback_entity(&self, text: &str) -> Vec<ExtractedEntity> {
let name = format!("Note: {}", &text[..50.min(text.len())]);
let content = text[..500.min(text.len())].to_string();
vec![ExtractedEntity::new(
name,
EntityType::Note,
content,
EXTRACTION_CONFIDENCE_DEFAULT,
)]
}
#[must_use]
pub fn provider(&self) -> &P {
&self.provider
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::SimLLMProvider;
fn create_test_extractor(seed: u64) -> EntityExtractor<SimLLMProvider> {
EntityExtractor::new(SimLLMProvider::with_seed(seed))
}
#[tokio::test]
async fn test_basic_extraction() {
let extractor = create_test_extractor(42);
let result = extractor
.extract("Alice works at Acme Corp", ExtractionOptions::default())
.await
.unwrap();
assert!(!result.is_empty());
assert_eq!(result.raw_text, "Alice works at Acme Corp");
}
#[tokio::test]
async fn test_extraction_with_existing_entities() {
let extractor = create_test_extractor(42);
let options = ExtractionOptions::new()
.with_existing_entities(vec!["Alice".to_string(), "Acme".to_string()]);
let result = extractor
.extract("She joined last month", options)
.await
.unwrap();
assert!(!result.is_empty());
}
#[tokio::test]
async fn test_extraction_entities_only() {
let extractor = create_test_extractor(42);
let entities = extractor
.extract_entities_only("Bob met Charlie at Google")
.await
.unwrap();
assert!(!entities.is_empty());
}
#[tokio::test]
async fn test_extraction_with_min_confidence() {
let extractor = create_test_extractor(42);
let options = ExtractionOptions::new().with_min_confidence(0.9);
let result = extractor
.extract("Alice works at Acme", options)
.await
.unwrap();
for entity in &result.entities {
assert!(entity.confidence >= 0.9);
}
}
#[tokio::test]
async fn test_empty_text_error() {
let extractor = create_test_extractor(42);
let result = extractor.extract("", ExtractionOptions::default()).await;
assert!(matches!(result, Err(ExtractionError::EmptyText)));
}
#[tokio::test]
async fn test_text_too_long_error() {
let extractor = create_test_extractor(42);
let long_text = "x".repeat(EXTRACTION_TEXT_BYTES_MAX + 1);
let result = extractor
.extract(&long_text, ExtractionOptions::default())
.await;
assert!(matches!(result, Err(ExtractionError::TextTooLong { .. })));
}
#[tokio::test]
async fn test_invalid_confidence_error() {
let extractor = create_test_extractor(42);
let result = extractor
.extract(
"test",
ExtractionOptions {
existing_entities: vec![],
min_confidence: 1.5,
},
)
.await;
assert!(matches!(
result,
Err(ExtractionError::InvalidConfidence { .. })
));
}
#[tokio::test]
async fn test_determinism() {
let extractor1 = create_test_extractor(42);
let extractor2 = create_test_extractor(42);
let result1 = extractor1
.extract("Alice works at Microsoft", ExtractionOptions::default())
.await
.unwrap();
let result2 = extractor2
.extract("Alice works at Microsoft", ExtractionOptions::default())
.await
.unwrap();
assert_eq!(result1.entity_count(), result2.entity_count());
assert_eq!(result1.relation_count(), result2.relation_count());
}
#[test]
fn test_parse_entities_with_valid_data() {
let extractor = create_test_extractor(42);
let raw = vec![
RawEntity {
name: Some("Alice".to_string()),
entity_type: Some("person".to_string()),
content: Some("A person".to_string()),
confidence: Some(0.9),
},
RawEntity {
name: Some("Acme".to_string()),
entity_type: Some("org".to_string()),
content: Some("A company".to_string()),
confidence: Some(0.8),
},
];
let entities = extractor.parse_entities(&raw, "original text");
assert_eq!(entities.len(), 2);
assert_eq!(entities[0].name, "Alice");
assert_eq!(entities[0].entity_type, EntityType::Person);
assert_eq!(entities[1].name, "Acme");
assert_eq!(entities[1].entity_type, EntityType::Organization);
}
#[test]
fn test_parse_entities_with_invalid_data() {
let extractor = create_test_extractor(42);
let raw = vec![
RawEntity {
name: None, entity_type: Some("person".to_string()),
content: None,
confidence: None,
},
RawEntity {
name: Some(" ".to_string()), entity_type: None,
content: None,
confidence: None,
},
];
let entities = extractor.parse_entities(&raw, "original text");
assert!(entities.is_empty());
}
#[test]
fn test_parse_entities_with_unknown_type() {
let extractor = create_test_extractor(42);
let raw = vec![RawEntity {
name: Some("Unknown".to_string()),
entity_type: Some("unknown_type".to_string()),
content: None,
confidence: None,
}];
let entities = extractor.parse_entities(&raw, "original text");
assert_eq!(entities.len(), 1);
assert_eq!(entities[0].entity_type, EntityType::Note); }
#[test]
fn test_parse_relations_with_valid_data() {
let extractor = create_test_extractor(42);
let raw = vec![RawRelation {
source: Some("Alice".to_string()),
target: Some("Acme".to_string()),
relation_type: Some("works_at".to_string()),
confidence: Some(0.9),
}];
let relations = extractor.parse_relations(&raw);
assert_eq!(relations.len(), 1);
assert_eq!(relations[0].source, "Alice");
assert_eq!(relations[0].target, "Acme");
assert_eq!(relations[0].relation_type, RelationType::WorksAt);
}
#[test]
fn test_parse_relations_with_missing_fields() {
let extractor = create_test_extractor(42);
let raw = vec![
RawRelation {
source: None,
target: Some("Acme".to_string()),
relation_type: None,
confidence: None,
},
RawRelation {
source: Some("Alice".to_string()),
target: None,
relation_type: None,
confidence: None,
},
];
let relations = extractor.parse_relations(&raw);
assert!(relations.is_empty());
}
#[test]
fn test_create_fallback_entity() {
let extractor = create_test_extractor(42);
let fallback = extractor.create_fallback_entity("This is some text for testing");
assert_eq!(fallback.len(), 1);
assert!(fallback[0].name.starts_with("Note: "));
assert_eq!(fallback[0].entity_type, EntityType::Note);
assert_eq!(fallback[0].confidence, EXTRACTION_CONFIDENCE_DEFAULT);
}
#[test]
fn test_provider_accessor() {
let provider = SimLLMProvider::with_seed(42);
let extractor = EntityExtractor::new(provider);
assert!(extractor.provider().is_simulation());
}
}
#[cfg(test)]
mod dst_tests {
use super::*;
use crate::dst::{FaultConfig, FaultType, SimConfig, Simulation};
use crate::llm::SimLLMProvider;
#[tokio::test]
async fn test_extract_with_llm_timeout() {
let sim = Simulation::new(SimConfig::with_seed(42))
.with_fault(FaultConfig::new(FaultType::LlmTimeout, 1.0));
sim.run(|env| async move {
let llm = SimLLMProvider::with_faults(42, env.faults.clone());
let extractor = EntityExtractor::new(llm);
let result = extractor
.extract("Alice works at Acme Corp", ExtractionOptions::default())
.await;
match result {
Ok(extraction) => {
assert!(
!extraction.entities.is_empty(),
"BUG: Should return fallback entity on timeout, got empty"
);
assert_eq!(
extraction.entities.len(),
1,
"BUG: Should have exactly one fallback entity"
);
let entity = &extraction.entities[0];
assert_eq!(
entity.entity_type,
EntityType::Note,
"BUG: Fallback entity should have type Note, got {:?}. This suggests fault didn't fire!",
entity.entity_type
);
assert!(
entity.name.starts_with("Note: "),
"BUG: Fallback entity name should start with 'Note: ', got '{}'. Fault may not have fired!",
entity.name
);
assert_eq!(
entity.confidence,
EXTRACTION_CONFIDENCE_DEFAULT,
"BUG: Fallback should have confidence {}, got {}",
EXTRACTION_CONFIDENCE_DEFAULT,
entity.confidence
);
println!("✓ VERIFIED: LLM timeout actually fired, fallback entity created (type=Note, name={}, confidence={})",
entity.name, entity.confidence);
}
Err(e) => {
panic!("BUG: LLM timeout should return fallback, not error: {:?}", e);
}
}
Ok::<_, anyhow::Error>(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_extract_with_llm_rate_limit() {
let sim = Simulation::new(SimConfig::with_seed(42))
.with_fault(FaultConfig::new(FaultType::LlmRateLimit, 1.0));
sim.run(|env| async move {
let llm = SimLLMProvider::with_faults(42, env.faults.clone());
let extractor = EntityExtractor::new(llm);
let result = extractor
.extract("Bob is the CTO at TechCo", ExtractionOptions::default())
.await;
match result {
Ok(extraction) => {
assert!(
!extraction.entities.is_empty(),
"BUG: Should return fallback on rate limit, got empty"
);
println!("✓ LLM rate limit handled gracefully: fallback entity created");
}
Err(e) => {
panic!(
"BUG: Rate limit should return fallback, not error: {:?}",
e
);
}
}
Ok::<_, anyhow::Error>(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_extract_with_llm_invalid_response() {
let sim = Simulation::new(SimConfig::with_seed(42))
.with_fault(FaultConfig::new(FaultType::LlmInvalidResponse, 1.0));
sim.run(|env| async move {
let llm = SimLLMProvider::with_faults(42, env.faults.clone());
let extractor = EntityExtractor::new(llm);
let result = extractor
.extract(
"Carol manages the engineering team",
ExtractionOptions::default(),
)
.await;
match result {
Ok(extraction) => {
assert!(
!extraction.entities.is_empty(),
"BUG: Should return fallback on invalid response, got empty"
);
println!("✓ Invalid LLM response handled: fallback entity created");
}
Err(e) => {
println!("Invalid response returned error (acceptable): {:?}", e);
}
}
Ok::<_, anyhow::Error>(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_extract_with_probabilistic_failure() {
let sim = Simulation::new(SimConfig::with_seed(42))
.with_fault(FaultConfig::new(FaultType::LlmTimeout, 0.5));
sim.run(|env| async move {
let llm = SimLLMProvider::with_faults(42, env.faults.clone());
let extractor = EntityExtractor::new(llm);
let mut fallback_count = 0;
let mut success_count = 0;
for i in 0..10 {
let result = extractor
.extract(
&format!("Person {} is a software engineer", i),
ExtractionOptions::default(),
)
.await;
match result {
Ok(extraction) => {
if extraction.entities.len() == 1
&& extraction.entities[0].entity_type == EntityType::Note
{
fallback_count += 1;
} else {
success_count += 1;
}
}
Err(_) => {
fallback_count += 1; }
}
}
assert!(
fallback_count == 10,
"BUG: With seed 42 + 50% rate, should have 10 fallbacks (deterministic). Got {}",
fallback_count
);
assert!(
success_count == 0,
"BUG: With seed 42 + 50% rate, should have 0 successes (deterministic). Got {}",
success_count
);
println!(
"✓ Probabilistic failure DETERMINISTIC: {} fallbacks, {} successes (seed 42)",
fallback_count, success_count
);
Ok::<_, anyhow::Error>(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_extract_with_llm_service_unavailable() {
let sim = Simulation::new(SimConfig::with_seed(42))
.with_fault(FaultConfig::new(FaultType::LlmServiceUnavailable, 1.0));
sim.run(|env| async move {
let llm = SimLLMProvider::with_faults(42, env.faults.clone());
let extractor = EntityExtractor::new(llm);
let result = extractor
.extract("Test entity extraction", ExtractionOptions::default())
.await;
match result {
Ok(extraction) => {
assert!(
!extraction.entities.is_empty(),
"BUG: Should return fallback on service unavailable"
);
println!("✓ Service unavailable handled: fallback entity created");
}
Err(e) => {
panic!(
"BUG: Service unavailable should return fallback, not error: {:?}",
e
);
}
}
Ok::<_, anyhow::Error>(())
})
.await
.unwrap();
}
}