Skip to main content

brainwires_reasoning/
entity_enhancer.rs

1//! Entity Enhancer - Semantic Entity Extraction
2//!
3//! Uses a provider to extract entities and relationships beyond
4//! what regex patterns can capture, enriching the knowledge graph.
5
6use std::sync::Arc;
7use tracing::warn;
8
9use brainwires_core::message::Message;
10use brainwires_core::provider::{ChatOptions, Provider};
11
12use crate::InferenceTimer;
13
14/// Enhanced entity type with semantic classification
15#[derive(Clone, Debug, PartialEq, Eq)]
16pub enum SemanticEntityType {
17    /// A source file.
18    File,
19    /// A function or method.
20    Function,
21    /// A type, struct, class, or interface.
22    Type,
23    /// A variable or constant.
24    Variable,
25    /// A module or namespace.
26    Module,
27    /// A package, crate, or library.
28    Package,
29
30    /// A general domain concept.
31    Concept,
32    /// A design or architectural pattern.
33    Pattern,
34    /// An algorithm or computational technique.
35    Algorithm,
36    /// A communication or network protocol.
37    Protocol,
38
39    /// A CLI or shell command.
40    Command,
41    /// A runtime operation or action.
42    Operation,
43    /// A task or work item.
44    Task,
45
46    /// An error or exception.
47    Error,
48    /// A bug or known defect.
49    Bug,
50    /// A fix or patch for a defect.
51    Fix,
52    /// A product or code feature.
53    Feature,
54
55    /// A person or user.
56    Person,
57    /// A role or permission level.
58    Role,
59
60    /// A URL or web link.
61    Url,
62    /// A filesystem path.
63    Path,
64    /// A generic identifier or ID.
65    Identifier,
66}
67
68impl SemanticEntityType {
69    /// Parse from string
70    #[allow(clippy::should_implement_trait)]
71    pub fn from_str(s: &str) -> Option<Self> {
72        let lower = s.to_lowercase();
73        match lower.as_str() {
74            "file" => Some(SemanticEntityType::File),
75            "function" | "func" | "method" => Some(SemanticEntityType::Function),
76            "type" | "struct" | "class" | "interface" => Some(SemanticEntityType::Type),
77            "variable" | "var" | "const" => Some(SemanticEntityType::Variable),
78            "module" | "mod" => Some(SemanticEntityType::Module),
79            "package" | "crate" | "library" | "lib" => Some(SemanticEntityType::Package),
80            "concept" => Some(SemanticEntityType::Concept),
81            "pattern" => Some(SemanticEntityType::Pattern),
82            "algorithm" | "algo" => Some(SemanticEntityType::Algorithm),
83            "protocol" => Some(SemanticEntityType::Protocol),
84            "command" | "cmd" => Some(SemanticEntityType::Command),
85            "operation" | "op" => Some(SemanticEntityType::Operation),
86            "task" => Some(SemanticEntityType::Task),
87            "error" => Some(SemanticEntityType::Error),
88            "bug" => Some(SemanticEntityType::Bug),
89            "fix" => Some(SemanticEntityType::Fix),
90            "feature" => Some(SemanticEntityType::Feature),
91            "person" | "user" | "developer" => Some(SemanticEntityType::Person),
92            "role" => Some(SemanticEntityType::Role),
93            "url" | "link" => Some(SemanticEntityType::Url),
94            "path" => Some(SemanticEntityType::Path),
95            "identifier" | "id" => Some(SemanticEntityType::Identifier),
96            _ => None,
97        }
98    }
99
100    /// Convert to string
101    pub fn as_str(&self) -> &'static str {
102        match self {
103            SemanticEntityType::File => "file",
104            SemanticEntityType::Function => "function",
105            SemanticEntityType::Type => "type",
106            SemanticEntityType::Variable => "variable",
107            SemanticEntityType::Module => "module",
108            SemanticEntityType::Package => "package",
109            SemanticEntityType::Concept => "concept",
110            SemanticEntityType::Pattern => "pattern",
111            SemanticEntityType::Algorithm => "algorithm",
112            SemanticEntityType::Protocol => "protocol",
113            SemanticEntityType::Command => "command",
114            SemanticEntityType::Operation => "operation",
115            SemanticEntityType::Task => "task",
116            SemanticEntityType::Error => "error",
117            SemanticEntityType::Bug => "bug",
118            SemanticEntityType::Fix => "fix",
119            SemanticEntityType::Feature => "feature",
120            SemanticEntityType::Person => "person",
121            SemanticEntityType::Role => "role",
122            SemanticEntityType::Url => "url",
123            SemanticEntityType::Path => "path",
124            SemanticEntityType::Identifier => "identifier",
125        }
126    }
127}
128
129/// An entity extracted by LLM
130#[derive(Clone, Debug)]
131pub struct EnhancedEntity {
132    /// Entity name/value
133    pub name: String,
134    /// Semantic type
135    pub entity_type: SemanticEntityType,
136    /// Confidence score (0.0-1.0)
137    pub confidence: f32,
138    /// Context where found
139    pub context: Option<String>,
140}
141
142impl EnhancedEntity {
143    /// Create a new enhanced entity with the given name, type, and confidence.
144    pub fn new(name: String, entity_type: SemanticEntityType, confidence: f32) -> Self {
145        Self {
146            name,
147            entity_type,
148            confidence,
149            context: None,
150        }
151    }
152
153    /// Attach contextual information describing where the entity was found.
154    pub fn with_context(mut self, context: String) -> Self {
155        self.context = Some(context);
156        self
157    }
158}
159
160/// A semantic relationship between entities
161#[derive(Clone, Debug)]
162pub struct EnhancedRelationship {
163    /// Source entity
164    pub from: String,
165    /// Target entity
166    pub to: String,
167    /// Relationship type
168    pub relation_type: RelationType,
169    /// Confidence score
170    pub confidence: f32,
171}
172
173/// Types of relationships we detect semantically
174#[derive(Clone, Debug, PartialEq, Eq)]
175pub enum RelationType {
176    /// A contains B (parent-child).
177    Contains,
178    /// A is defined inside B.
179    DefinedIn,
180    /// A imports B.
181    Imports,
182    /// A exports B.
183    Exports,
184    /// A extends or inherits from B.
185    Extends,
186    /// A implements the interface or trait B.
187    Implements,
188
189    /// A calls or invokes B.
190    Calls,
191    /// A uses or references B.
192    Uses,
193    /// A modifies B.
194    Modifies,
195    /// A creates or constructs B.
196    Creates,
197    /// A deletes or removes B.
198    Deletes,
199
200    /// A is semantically related to B.
201    RelatedTo,
202    /// A is similar to B.
203    SimilarTo,
204    /// A depends on B.
205    DependsOn,
206    /// A causes B.
207    Causes,
208    /// A fixes or resolves B.
209    Fixes,
210    /// A replaces B.
211    Replaces,
212}
213
214impl RelationType {
215    /// Parse a relation type from a string label.
216    #[allow(clippy::should_implement_trait)]
217    pub fn from_str(s: &str) -> Option<Self> {
218        let lower = s.to_lowercase();
219        match lower.as_str() {
220            "contains" => Some(RelationType::Contains),
221            "defined_in" | "definedin" => Some(RelationType::DefinedIn),
222            "imports" => Some(RelationType::Imports),
223            "exports" => Some(RelationType::Exports),
224            "extends" => Some(RelationType::Extends),
225            "implements" => Some(RelationType::Implements),
226            "calls" => Some(RelationType::Calls),
227            "uses" => Some(RelationType::Uses),
228            "modifies" => Some(RelationType::Modifies),
229            "creates" => Some(RelationType::Creates),
230            "deletes" => Some(RelationType::Deletes),
231            "related_to" | "relatedto" => Some(RelationType::RelatedTo),
232            "similar_to" | "similarto" => Some(RelationType::SimilarTo),
233            "depends_on" | "dependson" => Some(RelationType::DependsOn),
234            "causes" => Some(RelationType::Causes),
235            "fixes" => Some(RelationType::Fixes),
236            "replaces" => Some(RelationType::Replaces),
237            _ => None,
238        }
239    }
240
241    /// Return the canonical string label for this relation type.
242    pub fn as_str(&self) -> &'static str {
243        match self {
244            RelationType::Contains => "contains",
245            RelationType::DefinedIn => "defined_in",
246            RelationType::Imports => "imports",
247            RelationType::Exports => "exports",
248            RelationType::Extends => "extends",
249            RelationType::Implements => "implements",
250            RelationType::Calls => "calls",
251            RelationType::Uses => "uses",
252            RelationType::Modifies => "modifies",
253            RelationType::Creates => "creates",
254            RelationType::Deletes => "deletes",
255            RelationType::RelatedTo => "related_to",
256            RelationType::SimilarTo => "similar_to",
257            RelationType::DependsOn => "depends_on",
258            RelationType::Causes => "causes",
259            RelationType::Fixes => "fixes",
260            RelationType::Replaces => "replaces",
261        }
262    }
263}
264
265/// Result of entity enhancement
266#[derive(Clone, Debug)]
267pub struct EnhancementResult {
268    /// Extracted entities
269    pub entities: Vec<EnhancedEntity>,
270    /// Extracted relationships
271    pub relationships: Vec<EnhancedRelationship>,
272    /// Extracted domain concepts
273    pub concepts: Vec<String>,
274    /// Whether LLM was used
275    pub used_local_llm: bool,
276}
277
278impl EnhancementResult {
279    /// Create an empty enhancement result (no entities, relationships, or concepts).
280    pub fn empty() -> Self {
281        Self {
282            entities: Vec::new(),
283            relationships: Vec::new(),
284            concepts: Vec::new(),
285            used_local_llm: false,
286        }
287    }
288
289    /// Create an enhancement result from LLM-extracted data.
290    pub fn from_local(
291        entities: Vec<EnhancedEntity>,
292        relationships: Vec<EnhancedRelationship>,
293        concepts: Vec<String>,
294    ) -> Self {
295        Self {
296            entities,
297            relationships,
298            concepts,
299            used_local_llm: true,
300        }
301    }
302}
303
304/// Entity enhancer using a provider
305pub struct EntityEnhancer {
306    provider: Arc<dyn Provider>,
307    model_id: String,
308    /// Minimum confidence threshold
309    min_confidence: f32,
310    /// Maximum entities to extract per call
311    max_entities: usize,
312}
313
314impl EntityEnhancer {
315    /// Create a new entity enhancer
316    pub fn new(provider: Arc<dyn Provider>, model_id: impl Into<String>) -> Self {
317        Self {
318            provider,
319            model_id: model_id.into(),
320            min_confidence: 0.6,
321            max_entities: 20,
322        }
323    }
324
325    /// Set minimum confidence threshold
326    pub fn with_min_confidence(mut self, confidence: f32) -> Self {
327        self.min_confidence = confidence.clamp(0.0, 1.0);
328        self
329    }
330
331    /// Set max entities per extraction
332    pub fn with_max_entities(mut self, max: usize) -> Self {
333        self.max_entities = max.max(1);
334        self
335    }
336
337    /// Extract semantic entities from text using the provider
338    pub async fn extract_entities(&self, text: &str) -> Option<Vec<EnhancedEntity>> {
339        let timer = InferenceTimer::new("extract_entities", &self.model_id);
340
341        let prompt = self.build_entity_prompt(text);
342
343        let messages = vec![Message::user(&prompt)];
344        let options = ChatOptions::deterministic(200);
345
346        match self.provider.chat(&messages, None, &options).await {
347            Ok(response) => {
348                let output = response.message.text_or_summary();
349                let entities = self.parse_entities(&output);
350                timer.finish(true);
351                Some(entities)
352            }
353            Err(e) => {
354                warn!(target: "local_llm", "Entity extraction failed: {}", e);
355                timer.finish(false);
356                None
357            }
358        }
359    }
360
361    /// Extract relationships between entities using the provider
362    pub async fn extract_relationships(
363        &self,
364        entities: &[String],
365        context: &str,
366    ) -> Option<Vec<EnhancedRelationship>> {
367        if entities.len() < 2 {
368            return Some(Vec::new());
369        }
370
371        let timer = InferenceTimer::new("extract_relationships", &self.model_id);
372
373        let prompt = self.build_relationship_prompt(entities, context);
374
375        let messages = vec![Message::user(&prompt)];
376        let options = ChatOptions::deterministic(150);
377
378        match self.provider.chat(&messages, None, &options).await {
379            Ok(response) => {
380                let output = response.message.text_or_summary();
381                let relationships = self.parse_relationships(&output);
382                timer.finish(true);
383                Some(relationships)
384            }
385            Err(e) => {
386                warn!(target: "local_llm", "Relationship extraction failed: {}", e);
387                timer.finish(false);
388                None
389            }
390        }
391    }
392
393    /// Extract domain concepts from text using the provider
394    pub async fn extract_concepts(&self, text: &str) -> Option<Vec<String>> {
395        let timer = InferenceTimer::new("extract_concepts", &self.model_id);
396
397        let prompt = self.build_concept_prompt(text);
398
399        let messages = vec![Message::user(&prompt)];
400        let options = ChatOptions::deterministic(100);
401
402        match self.provider.chat(&messages, None, &options).await {
403            Ok(response) => {
404                let output = response.message.text_or_summary();
405                let concepts = self.parse_concepts(&output);
406                timer.finish(true);
407                Some(concepts)
408            }
409            Err(e) => {
410                warn!(target: "local_llm", "Concept extraction failed: {}", e);
411                timer.finish(false);
412                None
413            }
414        }
415    }
416
417    /// Full enhancement - extract entities, relationships, and concepts
418    pub async fn enhance(&self, text: &str) -> EnhancementResult {
419        // Extract entities first
420        let entities = self.extract_entities(text).await.unwrap_or_default();
421
422        // Extract relationships if we have entities
423        let entity_names: Vec<String> = entities.iter().map(|e| e.name.clone()).collect();
424        let relationships = self
425            .extract_relationships(&entity_names, text)
426            .await
427            .unwrap_or_default();
428
429        // Extract concepts
430        let concepts = self.extract_concepts(text).await.unwrap_or_default();
431
432        EnhancementResult::from_local(entities, relationships, concepts)
433    }
434
435    /// Heuristic entity extraction (pattern-based fallback)
436    pub fn extract_heuristic(&self, text: &str) -> Vec<EnhancedEntity> {
437        let mut entities = Vec::new();
438
439        // URL pattern
440        let url_pattern = regex::Regex::new(r#"https?://[^\s<>"']+"#).expect("valid url regex");
441        for cap in url_pattern.find_iter(text) {
442            entities.push(EnhancedEntity::new(
443                cap.as_str().to_string(),
444                SemanticEntityType::Url,
445                0.9,
446            ));
447        }
448
449        // Path-like patterns (beyond file extensions)
450        let path_pattern =
451            regex::Regex::new(r#"(?:^|[\s"'])(/[a-zA-Z0-9_./-]+)"#).expect("valid path regex");
452        for cap in path_pattern.captures_iter(text) {
453            if let Some(m) = cap.get(1) {
454                let path = m.as_str();
455                // Filter common false positives
456                if path.len() > 3 && !path.starts_with("//") {
457                    entities.push(EnhancedEntity::new(
458                        path.to_string(),
459                        SemanticEntityType::Path,
460                        0.7,
461                    ));
462                }
463            }
464        }
465
466        // Package/crate names (Rust-style)
467        let crate_pattern = regex::Regex::new(r"(?:use|extern crate|mod)\s+([a-z_][a-z0-9_]*)")
468            .expect("valid crate regex");
469        for cap in crate_pattern.captures_iter(text) {
470            if let Some(m) = cap.get(1) {
471                entities.push(EnhancedEntity::new(
472                    m.as_str().to_string(),
473                    SemanticEntityType::Module,
474                    0.8,
475                ));
476            }
477        }
478
479        // Problem/fix indicators
480        let lower = text.to_lowercase();
481        if lower.contains("bug") || lower.contains("issue") || lower.contains("problem") {
482            // Look for identifiers near these words
483            let bug_context =
484                regex::Regex::new(r"(?i)(?:bug|issue|problem)\s*(?:#|:)?\s*(\d+|[A-Z]+-\d+)")
485                    .expect("valid bug regex");
486            for cap in bug_context.captures_iter(text) {
487                if let Some(m) = cap.get(1) {
488                    entities.push(EnhancedEntity::new(
489                        m.as_str().to_string(),
490                        SemanticEntityType::Bug,
491                        0.85,
492                    ));
493                }
494            }
495        }
496
497        if lower.contains("fix") || lower.contains("fixed") || lower.contains("resolved") {
498            let fix_context = regex::Regex::new(r"(?i)fix(?:ed|es)?\s+(?:#|:)?\s*(\d+|[A-Z]+-\d+)")
499                .expect("valid fix regex");
500            for cap in fix_context.captures_iter(text) {
501                if let Some(m) = cap.get(1) {
502                    entities.push(EnhancedEntity::new(
503                        m.as_str().to_string(),
504                        SemanticEntityType::Fix,
505                        0.85,
506                    ));
507                }
508            }
509        }
510
511        // Feature indicators
512        if lower.contains("feature") || lower.contains("implement") || lower.contains("add") {
513            let feature_context =
514                regex::Regex::new(r"(?i)(?:feature|implement|add)\s+(\w+(?:\s+\w+)?)")
515                    .expect("valid feature regex");
516            for cap in feature_context.captures_iter(text) {
517                if let Some(m) = cap.get(1) {
518                    let feature = m.as_str().trim();
519                    if feature.len() > 2 && feature.len() < 50 {
520                        entities.push(EnhancedEntity::new(
521                            feature.to_string(),
522                            SemanticEntityType::Feature,
523                            0.6,
524                        ));
525                    }
526                }
527            }
528        }
529
530        entities
531    }
532
533    /// Build the entity extraction prompt
534    fn build_entity_prompt(&self, text: &str) -> String {
535        format!(
536            r#"Extract named entities from this text. Focus on:
537- Code elements: files, functions, types, modules, packages
538- Domain concepts: patterns, algorithms, protocols
539- Problems: errors, bugs, issues
540- Actions: commands, operations, tasks
541
542Text: "{}"
543
544Output format (one per line):
545TYPE: name
546
547Example:
548FUNCTION: process_data
549ERROR: AuthenticationError
550CONCEPT: dependency injection
551
552Entities:"#,
553            if text.len() > 500 { &text[..500] } else { text }
554        )
555    }
556
557    /// Build the relationship extraction prompt
558    fn build_relationship_prompt(&self, entities: &[String], context: &str) -> String {
559        let entity_list = entities
560            .iter()
561            .take(10)
562            .cloned()
563            .collect::<Vec<_>>()
564            .join(", ");
565
566        format!(
567            r#"Given these entities: [{}]
568
569And this context: "{}"
570
571Identify relationships between entities. Types:
572- CONTAINS: A contains B
573- USES: A uses B
574- CALLS: A calls B
575- DEPENDS_ON: A depends on B
576- MODIFIES: A modifies B
577- FIXES: A fixes B
578
579Output format (one per line):
580FROM -> RELATION -> TO
581
582Relationships:"#,
583            entity_list,
584            if context.len() > 300 {
585                &context[..300]
586            } else {
587                context
588            }
589        )
590    }
591
592    /// Build the concept extraction prompt
593    fn build_concept_prompt(&self, text: &str) -> String {
594        format!(
595            r#"Extract domain concepts and technical terms from this text.
596Focus on: frameworks, patterns, methodologies, technologies, abstractions.
597
598Text: "{}"
599
600Output: comma-separated list of concepts
601Example: REST API, dependency injection, authentication
602
603Concepts:"#,
604            if text.len() > 400 { &text[..400] } else { text }
605        )
606    }
607
608    /// Parse entity extraction output
609    fn parse_entities(&self, output: &str) -> Vec<EnhancedEntity> {
610        let mut entities = Vec::new();
611
612        for line in output.lines() {
613            let line = line.trim();
614            if line.is_empty() {
615                continue;
616            }
617
618            // Parse "TYPE: name" format
619            if let Some((type_str, name)) = line.split_once(':') {
620                let type_str = type_str.trim().to_uppercase();
621                let name = name.trim();
622
623                if name.is_empty() {
624                    continue;
625                }
626
627                let entity_type = match type_str.as_str() {
628                    "FILE" => SemanticEntityType::File,
629                    "FUNCTION" | "FUNC" | "FN" => SemanticEntityType::Function,
630                    "TYPE" | "STRUCT" | "CLASS" => SemanticEntityType::Type,
631                    "VARIABLE" | "VAR" => SemanticEntityType::Variable,
632                    "MODULE" | "MOD" => SemanticEntityType::Module,
633                    "PACKAGE" | "CRATE" => SemanticEntityType::Package,
634                    "CONCEPT" => SemanticEntityType::Concept,
635                    "PATTERN" => SemanticEntityType::Pattern,
636                    "ALGORITHM" => SemanticEntityType::Algorithm,
637                    "PROTOCOL" => SemanticEntityType::Protocol,
638                    "COMMAND" | "CMD" => SemanticEntityType::Command,
639                    "OPERATION" => SemanticEntityType::Operation,
640                    "TASK" => SemanticEntityType::Task,
641                    "ERROR" => SemanticEntityType::Error,
642                    "BUG" => SemanticEntityType::Bug,
643                    "FIX" => SemanticEntityType::Fix,
644                    "FEATURE" => SemanticEntityType::Feature,
645                    "PERSON" | "USER" => SemanticEntityType::Person,
646                    "URL" | "LINK" => SemanticEntityType::Url,
647                    "PATH" => SemanticEntityType::Path,
648                    _ => continue,
649                };
650
651                entities.push(EnhancedEntity::new(name.to_string(), entity_type, 0.8));
652
653                if entities.len() >= self.max_entities {
654                    break;
655                }
656            }
657        }
658
659        entities
660    }
661
662    /// Parse relationship extraction output
663    fn parse_relationships(&self, output: &str) -> Vec<EnhancedRelationship> {
664        let mut relationships = Vec::new();
665
666        for line in output.lines() {
667            let line = line.trim();
668            if line.is_empty() {
669                continue;
670            }
671
672            // Parse "FROM -> RELATION -> TO" format
673            let parts: Vec<&str> = line.split("->").map(|s| s.trim()).collect();
674            if parts.len() >= 3 {
675                let from = parts[0].to_string();
676                let relation_str = parts[1].to_uppercase();
677                let to = parts[2].to_string();
678
679                let relation_type = match relation_str.as_str() {
680                    "CONTAINS" => RelationType::Contains,
681                    "DEFINED_IN" | "DEFINEDIN" => RelationType::DefinedIn,
682                    "IMPORTS" => RelationType::Imports,
683                    "EXPORTS" => RelationType::Exports,
684                    "EXTENDS" => RelationType::Extends,
685                    "IMPLEMENTS" => RelationType::Implements,
686                    "CALLS" => RelationType::Calls,
687                    "USES" => RelationType::Uses,
688                    "MODIFIES" => RelationType::Modifies,
689                    "CREATES" => RelationType::Creates,
690                    "DELETES" => RelationType::Deletes,
691                    "RELATED_TO" | "RELATEDTO" => RelationType::RelatedTo,
692                    "SIMILAR_TO" | "SIMILARTO" => RelationType::SimilarTo,
693                    "DEPENDS_ON" | "DEPENDSON" => RelationType::DependsOn,
694                    "CAUSES" => RelationType::Causes,
695                    "FIXES" => RelationType::Fixes,
696                    "REPLACES" => RelationType::Replaces,
697                    _ => RelationType::RelatedTo, // Default
698                };
699
700                relationships.push(EnhancedRelationship {
701                    from,
702                    to,
703                    relation_type,
704                    confidence: 0.75,
705                });
706            }
707        }
708
709        relationships
710    }
711
712    /// Parse concept extraction output
713    fn parse_concepts(&self, output: &str) -> Vec<String> {
714        let mut concepts = Vec::new();
715
716        // Handle comma-separated list
717        for concept in output.split(',') {
718            let concept = concept.trim().to_lowercase();
719            if !concept.is_empty() && concept.len() > 2 && concept.len() < 50 {
720                concepts.push(concept);
721            }
722        }
723
724        // Also handle newline-separated
725        if concepts.is_empty() {
726            for line in output.lines() {
727                let concept = line.trim().to_lowercase();
728                if !concept.is_empty() && concept.len() > 2 && concept.len() < 50 {
729                    concepts.push(concept);
730                }
731            }
732        }
733
734        concepts
735    }
736}
737
738/// Builder for EntityEnhancer
739pub struct EntityEnhancerBuilder {
740    provider: Option<Arc<dyn Provider>>,
741    model_id: String,
742    min_confidence: f32,
743    max_entities: usize,
744}
745
746impl Default for EntityEnhancerBuilder {
747    fn default() -> Self {
748        Self {
749            provider: None,
750            model_id: "lfm2-350m".to_string(), // Fast model for entity extraction
751            min_confidence: 0.6,
752            max_entities: 20,
753        }
754    }
755}
756
757impl EntityEnhancerBuilder {
758    /// Create a new builder with default settings.
759    pub fn new() -> Self {
760        Self::default()
761    }
762
763    /// Set the provider to use for entity extraction.
764    pub fn provider(mut self, provider: Arc<dyn Provider>) -> Self {
765        self.provider = Some(provider);
766        self
767    }
768
769    /// Set the model ID to use for inference.
770    pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
771        self.model_id = model_id.into();
772        self
773    }
774
775    /// Set the minimum confidence threshold for extracted entities.
776    pub fn min_confidence(mut self, confidence: f32) -> Self {
777        self.min_confidence = confidence.clamp(0.0, 1.0);
778        self
779    }
780
781    /// Set the maximum number of entities to extract per call.
782    pub fn max_entities(mut self, max: usize) -> Self {
783        self.max_entities = max.max(1);
784        self
785    }
786
787    /// Build the entity enhancer, returning `None` if no provider was set.
788    pub fn build(self) -> Option<EntityEnhancer> {
789        self.provider.map(|p| {
790            EntityEnhancer::new(p, self.model_id)
791                .with_min_confidence(self.min_confidence)
792                .with_max_entities(self.max_entities)
793        })
794    }
795}
796
797#[cfg(test)]
798mod tests {
799    use super::*;
800
801    #[test]
802    fn test_semantic_entity_type_parsing() {
803        assert_eq!(
804            SemanticEntityType::from_str("function"),
805            Some(SemanticEntityType::Function)
806        );
807        assert_eq!(
808            SemanticEntityType::from_str("STRUCT"),
809            Some(SemanticEntityType::Type)
810        );
811        assert_eq!(
812            SemanticEntityType::from_str("crate"),
813            Some(SemanticEntityType::Package)
814        );
815        assert_eq!(SemanticEntityType::from_str("invalid"), None);
816    }
817
818    #[test]
819    fn test_relation_type_parsing() {
820        assert_eq!(
821            RelationType::from_str("contains"),
822            Some(RelationType::Contains)
823        );
824        assert_eq!(
825            RelationType::from_str("DEPENDS_ON"),
826            Some(RelationType::DependsOn)
827        );
828        assert_eq!(RelationType::from_str("invalid"), None);
829    }
830
831    #[test]
832    fn test_heuristic_extraction_url() {
833        let _enhancer = EntityEnhancerBuilder::default();
834        let result = extract_heuristic_direct("Check https://example.com/docs for more info");
835        assert!(
836            result
837                .iter()
838                .any(|e| e.entity_type == SemanticEntityType::Url)
839        );
840    }
841
842    #[test]
843    fn test_heuristic_extraction_path() {
844        let result = extract_heuristic_direct("Look at /home/user/project/src");
845        assert!(
846            result
847                .iter()
848                .any(|e| e.entity_type == SemanticEntityType::Path)
849        );
850    }
851
852    #[test]
853    fn test_heuristic_extraction_bug() {
854        // "Fixed #123" should match the fix pattern
855        let result = extract_heuristic_direct("Fixed #123 in the parser");
856        assert!(
857            result
858                .iter()
859                .any(|e| e.entity_type == SemanticEntityType::Fix)
860        );
861    }
862
863    fn extract_heuristic_direct(text: &str) -> Vec<EnhancedEntity> {
864        let mut entities = Vec::new();
865
866        // URL pattern
867        let url_pattern = regex::Regex::new(r#"https?://[^\s<>"']+"#).unwrap();
868        for cap in url_pattern.find_iter(text) {
869            entities.push(EnhancedEntity::new(
870                cap.as_str().to_string(),
871                SemanticEntityType::Url,
872                0.9,
873            ));
874        }
875
876        // Path pattern
877        let path_pattern = regex::Regex::new(r#"(?:^|[\s"'])(/[a-zA-Z0-9_./-]+)"#).unwrap();
878        for cap in path_pattern.captures_iter(text) {
879            if let Some(m) = cap.get(1) {
880                let path = m.as_str();
881                if path.len() > 3 && !path.starts_with("//") {
882                    entities.push(EnhancedEntity::new(
883                        path.to_string(),
884                        SemanticEntityType::Path,
885                        0.7,
886                    ));
887                }
888            }
889        }
890
891        // Fix pattern
892        let lower = text.to_lowercase();
893        if lower.contains("fix") {
894            let fix_context =
895                regex::Regex::new(r"(?i)fix(?:ed|es)?\s+(?:#|:)?\s*(\d+|[A-Z]+-\d+)").unwrap();
896            for cap in fix_context.captures_iter(text) {
897                if let Some(m) = cap.get(1) {
898                    entities.push(EnhancedEntity::new(
899                        m.as_str().to_string(),
900                        SemanticEntityType::Fix,
901                        0.85,
902                    ));
903                }
904            }
905        }
906
907        entities
908    }
909
910    #[test]
911    fn test_parse_entities() {
912        let output = r#"FUNCTION: process_data
913ERROR: AuthenticationError
914CONCEPT: dependency injection"#;
915
916        let entities = parse_entities_direct(output);
917        assert_eq!(entities.len(), 3);
918        assert!(entities.iter().any(|e| e.name == "process_data"));
919        assert!(
920            entities
921                .iter()
922                .any(|e| e.entity_type == SemanticEntityType::Error)
923        );
924    }
925
926    fn parse_entities_direct(output: &str) -> Vec<EnhancedEntity> {
927        let mut entities = Vec::new();
928
929        for line in output.lines() {
930            let line = line.trim();
931            if line.is_empty() {
932                continue;
933            }
934
935            if let Some((type_str, name)) = line.split_once(':') {
936                let type_str = type_str.trim().to_uppercase();
937                let name = name.trim();
938
939                if name.is_empty() {
940                    continue;
941                }
942
943                let entity_type = match type_str.as_str() {
944                    "FUNCTION" => SemanticEntityType::Function,
945                    "ERROR" => SemanticEntityType::Error,
946                    "CONCEPT" => SemanticEntityType::Concept,
947                    _ => continue,
948                };
949
950                entities.push(EnhancedEntity::new(name.to_string(), entity_type, 0.8));
951            }
952        }
953
954        entities
955    }
956
957    #[test]
958    fn test_parse_relationships() {
959        let output = "process_data -> CALLS -> validate_input\nModule -> CONTAINS -> Function";
960
961        let relationships = parse_relationships_direct(output);
962        assert_eq!(relationships.len(), 2);
963        assert!(
964            relationships
965                .iter()
966                .any(|r| r.relation_type == RelationType::Calls)
967        );
968    }
969
970    fn parse_relationships_direct(output: &str) -> Vec<EnhancedRelationship> {
971        let mut relationships = Vec::new();
972
973        for line in output.lines() {
974            let parts: Vec<&str> = line.split("->").map(|s| s.trim()).collect();
975            if parts.len() >= 3 {
976                let from = parts[0].to_string();
977                let relation_str = parts[1].to_uppercase();
978                let to = parts[2].to_string();
979
980                let relation_type = match relation_str.as_str() {
981                    "CALLS" => RelationType::Calls,
982                    "CONTAINS" => RelationType::Contains,
983                    _ => RelationType::RelatedTo,
984                };
985
986                relationships.push(EnhancedRelationship {
987                    from,
988                    to,
989                    relation_type,
990                    confidence: 0.75,
991                });
992            }
993        }
994
995        relationships
996    }
997
998    #[test]
999    fn test_parse_concepts() {
1000        let output = "REST API, dependency injection, authentication";
1001        let concepts = parse_concepts_direct(output);
1002        assert_eq!(concepts.len(), 3);
1003        assert!(concepts.contains(&"rest api".to_string()));
1004    }
1005
1006    fn parse_concepts_direct(output: &str) -> Vec<String> {
1007        let mut concepts = Vec::new();
1008
1009        for concept in output.split(',') {
1010            let concept = concept.trim().to_lowercase();
1011            if !concept.is_empty() && concept.len() > 2 && concept.len() < 50 {
1012                concepts.push(concept);
1013            }
1014        }
1015
1016        concepts
1017    }
1018}