1use crate::fact::MemoryTier;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Message {
12 pub role: String,
13 pub content: String,
14}
15
16impl Message {
17 pub fn user(content: impl Into<String>) -> Self {
18 Self {
19 role: "user".into(),
20 content: content.into(),
21 }
22 }
23 pub fn assistant(content: impl Into<String>) -> Self {
24 Self {
25 role: "assistant".into(),
26 content: content.into(),
27 }
28 }
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ExtractedFact {
34 pub text: String,
36 #[serde(default)]
38 pub entities: Vec<ExtractedEntity>,
39 #[serde(default)]
41 pub relationships: Vec<ExtractedRelationship>,
42 #[serde(default = "default_confidence")]
44 pub confidence: f64,
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub category: Option<String>,
48}
49
50fn default_confidence() -> f64 {
51 1.0
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
62#[serde(from = "ExtractedEntityInput")]
63pub struct ExtractedEntity {
64 pub name: String,
66 #[serde(skip_serializing_if = "Option::is_none")]
68 pub entity_type: Option<String>,
69}
70
71#[derive(Debug, Clone, Deserialize)]
74#[serde(untagged)]
75enum ExtractedEntityInput {
76 Name(String),
78 Full {
80 name: String,
81 #[serde(default)]
82 entity_type: Option<String>,
83 },
84}
85
86impl From<ExtractedEntityInput> for ExtractedEntity {
87 fn from(input: ExtractedEntityInput) -> Self {
88 match input {
89 ExtractedEntityInput::Name(name) => Self {
90 name,
91 entity_type: None,
92 },
93 ExtractedEntityInput::Full { name, entity_type } => Self { name, entity_type },
94 }
95 }
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ExtractedRelationship {
101 pub source: String,
103 pub relation: String,
105 pub target: String,
107}
108
109#[derive(Debug, Clone, Default, Serialize, Deserialize)]
111pub struct ExtractionResult {
112 pub facts: Vec<ExtractedFact>,
114}
115
116#[derive(Debug, Clone)]
118pub struct ExtractionConfig {
119 pub custom_prompt: Option<String>,
121 pub skip_categories: Vec<String>,
123 pub rules: Vec<ExtractionRule>,
125 pub dedup_threshold: f32,
127 pub default_tier: MemoryTier,
129}
130
131impl Default for ExtractionConfig {
132 fn default() -> Self {
133 Self {
134 custom_prompt: None,
135 skip_categories: Vec::new(),
136 rules: Vec::new(),
137 dedup_threshold: 0.92,
138 default_tier: MemoryTier::Conversation,
139 }
140 }
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct ExtractionRule {
146 pub category: String,
147 #[serde(default = "default_priority")]
149 pub priority: f64,
150 #[serde(skip_serializing_if = "Option::is_none")]
152 pub ttl: Option<String>,
153 #[serde(default)]
155 pub pii: bool,
156}
157
158fn default_priority() -> f64 {
159 1.0
160}
161
162#[derive(Debug, Clone, PartialEq, Eq)]
164pub enum ConflictVerdict {
165 Duplicate,
167 Contradicts,
169 Refines,
171 NoConflict,
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn extraction_config_defaults() {
181 let cfg = ExtractionConfig::default();
182 assert!((cfg.dedup_threshold - 0.92).abs() < f32::EPSILON);
183 assert_eq!(cfg.default_tier, MemoryTier::Conversation);
184 assert!(cfg.skip_categories.is_empty());
185 }
186
187 #[test]
188 fn message_constructors() {
189 let m = Message::user("hello");
190 assert_eq!(m.role, "user");
191 let m = Message::assistant("hi");
192 assert_eq!(m.role, "assistant");
193 }
194
195 #[test]
196 fn extracted_fact_deserializes_with_defaults() {
197 let json = r#"{"text": "User likes pizza"}"#;
198 let fact: ExtractedFact = serde_json::from_str(json).unwrap();
199 assert_eq!(fact.confidence, 1.0);
200 assert!(fact.entities.is_empty());
201 }
202
203 #[test]
204 fn extracted_entity_accepts_bare_string() {
205 let json = r#""alice""#;
206 let entity: ExtractedEntity = serde_json::from_str(json).unwrap();
207 assert_eq!(entity.name, "alice");
208 assert!(entity.entity_type.is_none());
209 }
210
211 #[test]
212 fn extracted_entity_accepts_full_struct() {
213 let json = r#"{"name": "Bangalore", "entity_type": "place"}"#;
214 let entity: ExtractedEntity = serde_json::from_str(json).unwrap();
215 assert_eq!(entity.name, "Bangalore");
216 assert_eq!(entity.entity_type.as_deref(), Some("place"));
217 }
218
219 #[test]
220 fn extracted_fact_accepts_mixed_entity_shapes() {
221 let json = r#"{
224 "text": "The user is allergic to peanuts",
225 "entities": ["user", {"name": "peanuts", "entity_type": "thing"}],
226 "confidence": 0.95,
227 "category": "health"
228 }"#;
229 let fact: ExtractedFact = serde_json::from_str(json).unwrap();
230 assert_eq!(fact.entities.len(), 2);
231 assert_eq!(fact.entities[0].name, "user");
232 assert!(fact.entities[0].entity_type.is_none());
233 assert_eq!(fact.entities[1].name, "peanuts");
234 assert_eq!(fact.entities[1].entity_type.as_deref(), Some("thing"));
235 }
236}