Skip to main content

engram/
extract.rs

1//! Extraction pipeline — types and configuration.
2//!
3//! The extraction pipeline converts raw conversation messages into structured
4//! facts, entities, and relationships.
5
6use crate::fact::MemoryTier;
7use serde::{Deserialize, Serialize};
8
9/// A single message in a conversation (role + content).
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Message {
12    pub role: String,
13    pub content: String,
14}
15
16impl Message {
17    pub fn user(content: impl Into<String>) -> Self {
18        Self {
19            role: "user".into(),
20            content: content.into(),
21        }
22    }
23    pub fn assistant(content: impl Into<String>) -> Self {
24        Self {
25            role: "assistant".into(),
26            content: content.into(),
27        }
28    }
29}
30
31/// A fact extracted by the LLM (before storage — no ID yet).
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ExtractedFact {
34    /// The human-readable fact text.
35    pub text: String,
36    /// Extracted entities mentioned in this fact.
37    #[serde(default)]
38    pub entities: Vec<ExtractedEntity>,
39    /// Extracted relationships between entities.
40    #[serde(default)]
41    pub relationships: Vec<ExtractedRelationship>,
42    /// Confidence score (0.0-1.0).
43    #[serde(default = "default_confidence")]
44    pub confidence: f64,
45    /// Optional category.
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub category: Option<String>,
48}
49
50fn default_confidence() -> f64 {
51    1.0
52}
53
54/// An entity extracted by the LLM (before storage — no UUID yet).
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct ExtractedEntity {
57    /// Canonical name (e.g., "Austin", "Max").
58    pub name: String,
59    /// Entity type (e.g., "person", "place", "pet").
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub entity_type: Option<String>,
62}
63
64/// A relationship extracted by the LLM (before storage).
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ExtractedRelationship {
67    /// Source entity name.
68    pub source: String,
69    /// Relationship type (e.g., "lives_in", "owns_pet").
70    pub relation: String,
71    /// Target entity name.
72    pub target: String,
73}
74
75/// The full output of the extraction pipeline.
76#[derive(Debug, Clone, Default, Serialize, Deserialize)]
77pub struct ExtractionResult {
78    /// Extracted facts.
79    pub facts: Vec<ExtractedFact>,
80}
81
82/// Configuration for the extraction pipeline.
83#[derive(Debug, Clone)]
84pub struct ExtractionConfig {
85    /// Custom extraction prompt appended to the system prompt.
86    pub custom_prompt: Option<String>,
87    /// Categories to skip (facts with these categories are discarded).
88    pub skip_categories: Vec<String>,
89    /// Rules for specific categories.
90    pub rules: Vec<ExtractionRule>,
91    /// Similarity threshold for dedup (0.0-1.0). Default: 0.92.
92    pub dedup_threshold: f32,
93    /// Default tier for extracted facts.
94    pub default_tier: MemoryTier,
95}
96
97impl Default for ExtractionConfig {
98    fn default() -> Self {
99        Self {
100            custom_prompt: None,
101            skip_categories: Vec::new(),
102            rules: Vec::new(),
103            dedup_threshold: 0.92,
104            default_tier: MemoryTier::Conversation,
105        }
106    }
107}
108
109/// A rule for a specific category of extracted facts.
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ExtractionRule {
112    pub category: String,
113    /// Priority weight (higher = more important). Default: 1.0.
114    #[serde(default = "default_priority")]
115    pub priority: f64,
116    /// Time-to-live before decay. None = permanent.
117    #[serde(skip_serializing_if = "Option::is_none")]
118    pub ttl: Option<String>,
119    /// Whether this category contains PII.
120    #[serde(default)]
121    pub pii: bool,
122}
123
124fn default_priority() -> f64 {
125    1.0
126}
127
128/// Result of conflict detection between a new fact and an existing one.
129#[derive(Debug, Clone, PartialEq, Eq)]
130pub enum ConflictVerdict {
131    /// Identical meaning — skip the new fact (dedup).
132    Duplicate,
133    /// Contradicts the existing fact — invalidate old, store new.
134    Contradicts,
135    /// Adds new detail — store alongside existing fact.
136    Refines,
137    /// No conflict — store normally.
138    NoConflict,
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn extraction_config_defaults() {
147        let cfg = ExtractionConfig::default();
148        assert!((cfg.dedup_threshold - 0.92).abs() < f32::EPSILON);
149        assert_eq!(cfg.default_tier, MemoryTier::Conversation);
150        assert!(cfg.skip_categories.is_empty());
151    }
152
153    #[test]
154    fn message_constructors() {
155        let m = Message::user("hello");
156        assert_eq!(m.role, "user");
157        let m = Message::assistant("hi");
158        assert_eq!(m.role, "assistant");
159    }
160
161    #[test]
162    fn extracted_fact_deserializes_with_defaults() {
163        let json = r#"{"text": "User likes pizza"}"#;
164        let fact: ExtractedFact = serde_json::from_str(json).unwrap();
165        assert_eq!(fact.confidence, 1.0);
166        assert!(fact.entities.is_empty());
167    }
168}