Skip to main content

engram/
extract.rs

1//! Extraction pipeline — types and configuration.
2//!
3//! The extraction pipeline converts raw conversation messages into structured
4//! facts, entities, and relationships.
5
6use crate::fact::MemoryTier;
7use serde::{Deserialize, Serialize};
8
9/// A single message in a conversation (role + content).
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Message {
12    pub role: String,
13    pub content: String,
14}
15
16impl Message {
17    pub fn user(content: impl Into<String>) -> Self {
18        Self {
19            role: "user".into(),
20            content: content.into(),
21        }
22    }
23    pub fn assistant(content: impl Into<String>) -> Self {
24        Self {
25            role: "assistant".into(),
26            content: content.into(),
27        }
28    }
29}
30
31/// A fact extracted by the LLM (before storage — no ID yet).
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ExtractedFact {
34    /// The human-readable fact text.
35    pub text: String,
36    /// Extracted entities mentioned in this fact.
37    #[serde(default)]
38    pub entities: Vec<ExtractedEntity>,
39    /// Extracted relationships between entities.
40    #[serde(default)]
41    pub relationships: Vec<ExtractedRelationship>,
42    /// Confidence score (0.0-1.0).
43    #[serde(default = "default_confidence")]
44    pub confidence: f64,
45    /// Optional category.
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub category: Option<String>,
48}
49
50fn default_confidence() -> f64 {
51    1.0
52}
53
54/// An entity extracted by the LLM (before storage — no UUID yet).
55///
56/// Deserialization is lenient: the LLM may return either a bare string
57/// `"alice"` (which becomes `{ name: "alice", entity_type: None }`) or a
58/// full object `{ "name": "alice", "entity_type": "person" }`. This makes
59/// the pipeline robust against smaller open models that sometimes drop
60/// structured fields.
61#[derive(Debug, Clone, Serialize, Deserialize)]
62#[serde(from = "ExtractedEntityInput")]
63pub struct ExtractedEntity {
64    /// Canonical name (e.g., "Austin", "Max").
65    pub name: String,
66    /// Entity type (e.g., "person", "place", "pet").
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub entity_type: Option<String>,
69}
70
71/// Internal — lenient deserialization target. Accepts either a bare string
72/// or a struct, then converts into `ExtractedEntity`.
73#[derive(Debug, Clone, Deserialize)]
74#[serde(untagged)]
75enum ExtractedEntityInput {
76    /// Bare string form — `"alice"`.
77    Name(String),
78    /// Full struct form — `{ "name": "alice", "entity_type": "person" }`.
79    Full {
80        name: String,
81        #[serde(default)]
82        entity_type: Option<String>,
83    },
84}
85
86impl From<ExtractedEntityInput> for ExtractedEntity {
87    fn from(input: ExtractedEntityInput) -> Self {
88        match input {
89            ExtractedEntityInput::Name(name) => Self {
90                name,
91                entity_type: None,
92            },
93            ExtractedEntityInput::Full { name, entity_type } => Self { name, entity_type },
94        }
95    }
96}
97
98/// A relationship extracted by the LLM (before storage).
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ExtractedRelationship {
101    /// Source entity name.
102    pub source: String,
103    /// Relationship type (e.g., "lives_in", "owns_pet").
104    pub relation: String,
105    /// Target entity name.
106    pub target: String,
107}
108
109/// The full output of the extraction pipeline.
110#[derive(Debug, Clone, Default, Serialize, Deserialize)]
111pub struct ExtractionResult {
112    /// Extracted facts.
113    pub facts: Vec<ExtractedFact>,
114}
115
116/// Configuration for the extraction pipeline.
117#[derive(Debug, Clone)]
118pub struct ExtractionConfig {
119    /// Custom extraction prompt appended to the system prompt.
120    pub custom_prompt: Option<String>,
121    /// Categories to skip (facts with these categories are discarded).
122    pub skip_categories: Vec<String>,
123    /// Rules for specific categories.
124    pub rules: Vec<ExtractionRule>,
125    /// Similarity threshold for dedup (0.0-1.0). Default: 0.92.
126    pub dedup_threshold: f32,
127    /// Default tier for extracted facts.
128    pub default_tier: MemoryTier,
129}
130
131impl Default for ExtractionConfig {
132    fn default() -> Self {
133        Self {
134            custom_prompt: None,
135            skip_categories: Vec::new(),
136            rules: Vec::new(),
137            dedup_threshold: 0.92,
138            default_tier: MemoryTier::Conversation,
139        }
140    }
141}
142
143/// A rule for a specific category of extracted facts.
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct ExtractionRule {
146    pub category: String,
147    /// Priority weight (higher = more important). Default: 1.0.
148    #[serde(default = "default_priority")]
149    pub priority: f64,
150    /// Time-to-live before decay. None = permanent.
151    #[serde(skip_serializing_if = "Option::is_none")]
152    pub ttl: Option<String>,
153    /// Whether this category contains PII.
154    #[serde(default)]
155    pub pii: bool,
156}
157
158fn default_priority() -> f64 {
159    1.0
160}
161
162/// Result of conflict detection between a new fact and an existing one.
163#[derive(Debug, Clone, PartialEq, Eq)]
164pub enum ConflictVerdict {
165    /// Identical meaning — skip the new fact (dedup).
166    Duplicate,
167    /// Contradicts the existing fact — invalidate old, store new.
168    Contradicts,
169    /// Adds new detail — store alongside existing fact.
170    Refines,
171    /// No conflict — store normally.
172    NoConflict,
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn extraction_config_defaults() {
181        let cfg = ExtractionConfig::default();
182        assert!((cfg.dedup_threshold - 0.92).abs() < f32::EPSILON);
183        assert_eq!(cfg.default_tier, MemoryTier::Conversation);
184        assert!(cfg.skip_categories.is_empty());
185    }
186
187    #[test]
188    fn message_constructors() {
189        let m = Message::user("hello");
190        assert_eq!(m.role, "user");
191        let m = Message::assistant("hi");
192        assert_eq!(m.role, "assistant");
193    }
194
195    #[test]
196    fn extracted_fact_deserializes_with_defaults() {
197        let json = r#"{"text": "User likes pizza"}"#;
198        let fact: ExtractedFact = serde_json::from_str(json).unwrap();
199        assert_eq!(fact.confidence, 1.0);
200        assert!(fact.entities.is_empty());
201    }
202
203    #[test]
204    fn extracted_entity_accepts_bare_string() {
205        let json = r#""alice""#;
206        let entity: ExtractedEntity = serde_json::from_str(json).unwrap();
207        assert_eq!(entity.name, "alice");
208        assert!(entity.entity_type.is_none());
209    }
210
211    #[test]
212    fn extracted_entity_accepts_full_struct() {
213        let json = r#"{"name": "Bangalore", "entity_type": "place"}"#;
214        let entity: ExtractedEntity = serde_json::from_str(json).unwrap();
215        assert_eq!(entity.name, "Bangalore");
216        assert_eq!(entity.entity_type.as_deref(), Some("place"));
217    }
218
219    #[test]
220    fn extracted_fact_accepts_mixed_entity_shapes() {
221        // A real failure case: llama3.2:3b often returns entities as a
222        // mix of bare strings and partial objects in the same list.
223        let json = r#"{
224            "text": "The user is allergic to peanuts",
225            "entities": ["user", {"name": "peanuts", "entity_type": "thing"}],
226            "confidence": 0.95,
227            "category": "health"
228        }"#;
229        let fact: ExtractedFact = serde_json::from_str(json).unwrap();
230        assert_eq!(fact.entities.len(), 2);
231        assert_eq!(fact.entities[0].name, "user");
232        assert!(fact.entities[0].entity_type.is_none());
233        assert_eq!(fact.entities[1].name, "peanuts");
234        assert_eq!(fact.entities[1].entity_type.as_deref(), Some("thing"));
235    }
236}