oxirs_chat/rag/
entity_extraction.rs

1//! LLM-powered entity and relationship extraction
2//!
3//! Provides intelligent entity and relationship extraction from natural language queries
4//! using both LLM-based and rule-based approaches.
5
6use super::*;
7
8/// Entity extractor for identifying entities and relationships in queries
9pub struct EntityExtractor;
10
11impl Default for EntityExtractor {
12    fn default() -> Self {
13        Self::new()
14    }
15}
16
17impl EntityExtractor {
18    pub fn new() -> Self {
19        Self
20    }
21
22    /// Extract entities and relationships from query
23    pub async fn extract_entities_and_relationships(
24        &self,
25        query: &str,
26    ) -> Result<(Vec<ExtractedEntity>, Vec<ExtractedRelationship>)> {
27        // Try LLM extraction first, fall back to rule-based if needed
28        match self.llm_extract_entities(query).await {
29            Ok(result) => Ok(result),
30            _ => {
31                warn!("LLM extraction failed, falling back to rule-based extraction");
32                self.rule_based_extraction(query).await
33            }
34        }
35    }
36
37    /// LLM-powered entity and relationship extraction
38    async fn llm_extract_entities(
39        &self,
40        query: &str,
41    ) -> Result<(Vec<ExtractedEntity>, Vec<ExtractedRelationship>)> {
42        use crate::llm::{
43            ChatMessage, ChatRole, LLMConfig, LLMManager, LLMRequest, Priority, UseCase,
44        };
45
46        // Create extraction prompt
47        let prompt = format!(
48            r#"Extract entities and relationships from the following query. Return a JSON response with the following structure:
49
50{{
51  "entities": [
52    {{
53      "text": "entity name",
54      "type": "Person|Organization|Location|Concept|Other",
55      "confidence": 0.95
56    }}
57  ],
58  "relationships": [
59    {{
60      "subject": "entity1",
61      "predicate": "relationship type",
62      "object": "entity2",
63      "confidence": 0.85
64    }}
65  ]
66}}
67
68Query: "{query}"
69
70Focus on:
71- Named entities (people, places, organizations, concepts)
72- Implicit relationships between entities
73- Technical terms and domain-specific concepts
74- Only extract explicit entities mentioned in the query
75
76JSON Response:"#
77        );
78
79        // Initialize LLM manager
80        let llm_config = LLMConfig::default();
81        let mut llm_manager = LLMManager::new(llm_config)?;
82
83        let chat_messages = vec![
84            ChatMessage {
85                role: ChatRole::System,
86                content: "You are an expert at extracting entities and relationships from text. Always respond with valid JSON only.".to_string(),
87                metadata: None,
88            },
89            ChatMessage {
90                role: ChatRole::User,
91                content: prompt,
92                metadata: None,
93            },
94        ];
95
96        let request = LLMRequest {
97            messages: chat_messages,
98            system_prompt: Some("Extract entities and relationships as JSON.".to_string()),
99            use_case: UseCase::SimpleQuery,
100            priority: Priority::Normal,
101            max_tokens: Some(500),
102            temperature: 0.1f32, // Low temperature for consistent extraction
103            timeout: Some(std::time::Duration::from_secs(15)),
104        };
105
106        let response = llm_manager.generate_response(request).await?;
107
108        // Parse JSON response
109        self.parse_extraction_response(&response.content)
110    }
111
112    /// Parse LLM extraction response
113    fn parse_extraction_response(
114        &self,
115        response: &str,
116    ) -> Result<(Vec<ExtractedEntity>, Vec<ExtractedRelationship>)> {
117        // Clean response (remove markdown formatting if present)
118        let cleaned_response = response
119            .trim()
120            .strip_prefix("```json")
121            .unwrap_or(response)
122            .strip_suffix("```")
123            .unwrap_or(response)
124            .trim();
125
126        let parsed: serde_json::Value = serde_json::from_str(cleaned_response)?;
127
128        let mut entities = Vec::new();
129        let mut relationships = Vec::new();
130
131        // Parse entities
132        if let Some(entity_array) = parsed.get("entities").and_then(|e| e.as_array()) {
133            for entity_obj in entity_array {
134                if let (Some(text), Some(entity_type), Some(confidence)) = (
135                    entity_obj.get("text").and_then(|v| v.as_str()),
136                    entity_obj.get("type").and_then(|v| v.as_str()),
137                    entity_obj.get("confidence").and_then(|v| v.as_f64()),
138                ) {
139                    let entity_type = match entity_type {
140                        "Person" => EntityType::Person,
141                        "Organization" => EntityType::Organization,
142                        "Location" => EntityType::Location,
143                        "Concept" => EntityType::Concept,
144                        "Event" => EntityType::Event,
145                        _ => EntityType::Other,
146                    };
147
148                    entities.push(ExtractedEntity {
149                        text: text.to_string(),
150                        entity_type,
151                        iri: None, // Would be resolved separately
152                        confidence: confidence as f32,
153                        aliases: Vec::new(),
154                    });
155                }
156            }
157        }
158
159        // Parse relationships
160        if let Some(relationship_array) = parsed.get("relationships").and_then(|r| r.as_array()) {
161            for rel_obj in relationship_array {
162                if let (Some(subject), Some(predicate), Some(object), Some(confidence)) = (
163                    rel_obj.get("subject").and_then(|v| v.as_str()),
164                    rel_obj.get("predicate").and_then(|v| v.as_str()),
165                    rel_obj.get("object").and_then(|v| v.as_str()),
166                    rel_obj.get("confidence").and_then(|v| v.as_f64()),
167                ) {
168                    relationships.push(ExtractedRelationship {
169                        subject: subject.to_string(),
170                        predicate: predicate.to_string(),
171                        object: object.to_string(),
172                        confidence: confidence as f32,
173                        relation_type: RelationType::Other, // Would be classified separately
174                    });
175                }
176            }
177        }
178
179        debug!(
180            "Extracted {} entities and {} relationships",
181            entities.len(),
182            relationships.len()
183        );
184        Ok((entities, relationships))
185    }
186
187    /// Fallback rule-based extraction
188    async fn rule_based_extraction(
189        &self,
190        query: &str,
191    ) -> Result<(Vec<ExtractedEntity>, Vec<ExtractedRelationship>)> {
192        let mut entities = Vec::new();
193        let mut relationships = Vec::new();
194
195        // Simple pattern-based entity extraction
196        let words: Vec<&str> = query.split_whitespace().collect();
197
198        for (i, word) in words.iter().enumerate() {
199            // Look for capitalized words (potential proper nouns)
200            if word.chars().next().is_some_and(|c| c.is_uppercase()) && word.len() > 2 {
201                // Skip common question words
202                if !self.is_stop_word(&word.to_lowercase()) {
203                    entities.push(ExtractedEntity {
204                        text: word.to_string(),
205                        entity_type: EntityType::Other,
206                        iri: None,
207                        confidence: 0.6, // Lower confidence for rule-based
208                        aliases: Vec::new(),
209                    });
210                }
211            }
212
213            // Look for relationship patterns
214            if i > 0 && i < words.len() - 1 {
215                let prev_word = words[i - 1];
216                let next_word = words[i + 1];
217
218                if word.to_lowercase() == "is" || word.to_lowercase() == "has" {
219                    relationships.push(ExtractedRelationship {
220                        subject: prev_word.to_string(),
221                        predicate: word.to_string(),
222                        object: next_word.to_string(),
223                        confidence: 0.5,
224                        relation_type: RelationType::ConceptualRelation,
225                    });
226                }
227            }
228        }
229
230        debug!(
231            "Rule-based extraction found {} entities and {} relationships",
232            entities.len(),
233            relationships.len()
234        );
235        Ok((entities, relationships))
236    }
237
238    /// Check if a word is a stop word (for entity extraction)
239    fn is_stop_word(&self, word: &str) -> bool {
240        matches!(
241            word,
242            "the"
243                | "and"
244                | "or"
245                | "but"
246                | "in"
247                | "on"
248                | "at"
249                | "to"
250                | "for"
251                | "of"
252                | "with"
253                | "by"
254                | "from"
255                | "up"
256                | "about"
257                | "into"
258                | "through"
259                | "during"
260                | "before"
261                | "after"
262                | "above"
263                | "below"
264                | "between"
265                | "among"
266                | "this"
267                | "that"
268                | "these"
269                | "those"
270                | "i"
271                | "you"
272                | "he"
273                | "she"
274                | "it"
275                | "we"
276                | "they"
277                | "me"
278                | "him"
279                | "her"
280                | "us"
281                | "them"
282                | "my"
283                | "your"
284                | "his"
285                | "its"
286                | "our"
287                | "their"
288                | "am"
289                | "is"
290                | "are"
291                | "was"
292                | "were"
293                | "be"
294                | "been"
295                | "being"
296                | "have"
297                | "has"
298                | "had"
299                | "do"
300                | "does"
301                | "did"
302                | "will"
303                | "would"
304                | "could"
305                | "should"
306                | "may"
307                | "might"
308                | "must"
309                | "can"
310                | "what"
311                | "when"
312                | "where"
313                | "who"
314                | "why"
315                | "how"
316                | "which"
317        )
318    }
319}
320
321/// LLM entity extraction result wrapper
322pub struct LLMEntityExtraction {
323    pub entities: Vec<ExtractedEntity>,
324    pub relationships: Vec<ExtractedRelationship>,
325}
326
327use super::graph_traversal::{EntityType, ExtractedEntity, ExtractedRelationship, RelationType};