1use mem7_error::{Mem7Error, Result};
2use mem7_llm::{LlmClient, LlmMessage, ResponseFormat};
3use serde::Deserialize;
4use tracing::debug;
5
6use crate::prompts::{
7 DELETE_RELATIONS_PROMPT, ENTITY_EXTRACTION_PROMPT, RELATION_EXTRACTION_PROMPT,
8};
9use crate::types::{Entity, GraphSearchResult, Relation};
10
11#[derive(Debug, Deserialize)]
12struct EntityExtractionOutput {
13 #[serde(default)]
14 entities: Vec<RawEntity>,
15}
16
17#[derive(Debug, Deserialize)]
18struct RawEntity {
19 entity: String,
20 entity_type: String,
21}
22
23#[derive(Debug, Deserialize)]
24struct RelationExtractionOutput {
25 #[serde(default)]
26 relations: Vec<RawRelation>,
27}
28
29#[derive(Debug, Deserialize)]
30struct RawRelation {
31 source: String,
32 relationship: String,
33 destination: String,
34}
35
36pub async fn extract_entities(
38 llm: &dyn LlmClient,
39 conversation: &str,
40 custom_prompt: Option<&str>,
41) -> Result<Vec<Entity>> {
42 let prompt = custom_prompt.unwrap_or(ENTITY_EXTRACTION_PROMPT);
43
44 let messages = vec![LlmMessage::system(prompt), LlmMessage::user(conversation)];
45
46 let response = llm
47 .chat_completion(&messages, Some(&ResponseFormat::json()))
48 .await?;
49
50 debug!(raw = %response.content, "entity extraction response");
51
52 let output: EntityExtractionOutput = parse_json_response(&response.content)?;
53
54 Ok(output
55 .entities
56 .into_iter()
57 .map(|e| Entity {
58 name: e.entity,
59 entity_type: e.entity_type,
60 embedding: None,
61 created_at: None,
62 mentions: 0,
63 })
64 .collect())
65}
66
67pub async fn extract_relations(
69 llm: &dyn LlmClient,
70 conversation: &str,
71 entities: &[Entity],
72 custom_prompt: Option<&str>,
73) -> Result<Vec<Relation>> {
74 if entities.is_empty() {
75 return Ok(Vec::new());
76 }
77
78 let prompt = custom_prompt.unwrap_or(RELATION_EXTRACTION_PROMPT);
79
80 let entity_names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
81 let user_input = format!(
82 "Entities: {}\nText: {}",
83 entity_names.join(", "),
84 conversation
85 );
86
87 let messages = vec![LlmMessage::system(prompt), LlmMessage::user(user_input)];
88
89 let response = llm
90 .chat_completion(&messages, Some(&ResponseFormat::json()))
91 .await?;
92
93 debug!(raw = %response.content, "relation extraction response");
94
95 let output: RelationExtractionOutput = parse_json_response(&response.content)?;
96
97 Ok(output
98 .relations
99 .into_iter()
100 .map(|r| Relation {
101 source: r.source,
102 relationship: r.relationship,
103 destination: r.destination,
104 created_at: None,
105 mentions: 0,
106 valid: true,
107 })
108 .collect())
109}
110
111#[derive(Debug, Deserialize)]
112struct DeletionOutput {
113 #[serde(default)]
114 deletions: Vec<RawRelation>,
115}
116
117pub async fn extract_deletions(
120 llm: &dyn LlmClient,
121 existing: &[GraphSearchResult],
122 new_data: &str,
123) -> Result<Vec<(String, String, String)>> {
124 if existing.is_empty() {
125 return Ok(Vec::new());
126 }
127
128 let existing_str = existing
129 .iter()
130 .map(|r| format!("{} -- {} -- {}", r.source, r.relationship, r.destination))
131 .collect::<Vec<_>>()
132 .join("\n");
133
134 let user_msg =
135 format!("Here are the existing memories:\n{existing_str}\n\nNew Information:\n{new_data}");
136
137 let messages = vec![
138 LlmMessage::system(DELETE_RELATIONS_PROMPT),
139 LlmMessage::user(user_msg),
140 ];
141
142 let response = llm
143 .chat_completion(&messages, Some(&ResponseFormat::json()))
144 .await?;
145
146 debug!(raw = %response.content, "deletion extraction response");
147
148 let output: DeletionOutput = parse_json_response(&response.content)?;
149
150 Ok(output
151 .deletions
152 .into_iter()
153 .map(|d| (d.source, d.relationship, d.destination))
154 .collect())
155}
156
157fn parse_json_response<T: serde::de::DeserializeOwned>(raw: &str) -> Result<T> {
158 mem7_core::parse_json_response(raw).map_err(Mem7Error::Graph)
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 #[test]
166 fn parse_entity_response() {
167 let raw = r#"{"entities": [{"entity": "Alice", "entity_type": "Person"}, {"entity": "tennis", "entity_type": "Activity"}]}"#;
168 let output: EntityExtractionOutput = parse_json_response(raw).unwrap();
169 assert_eq!(output.entities.len(), 2);
170 assert_eq!(output.entities[0].entity, "Alice");
171 assert_eq!(output.entities[1].entity_type, "Activity");
172 }
173
174 #[test]
175 fn parse_entity_response_with_code_fence() {
176 let raw =
177 "```json\n{\"entities\": [{\"entity\": \"Bob\", \"entity_type\": \"Person\"}]}\n```";
178 let output: EntityExtractionOutput = parse_json_response(raw).unwrap();
179 assert_eq!(output.entities.len(), 1);
180 assert_eq!(output.entities[0].entity, "Bob");
181 }
182
183 #[test]
184 fn parse_empty_entities() {
185 let raw = r#"{"entities": []}"#;
186 let output: EntityExtractionOutput = parse_json_response(raw).unwrap();
187 assert!(output.entities.is_empty());
188 }
189
190 #[test]
191 fn parse_relation_response() {
192 let raw = r#"{"relations": [{"source": "USER", "relationship": "works_at", "destination": "Google"}]}"#;
193 let output: RelationExtractionOutput = parse_json_response(raw).unwrap();
194 assert_eq!(output.relations.len(), 1);
195 assert_eq!(output.relations[0].source, "USER");
196 assert_eq!(output.relations[0].relationship, "works_at");
197 assert_eq!(output.relations[0].destination, "Google");
198 }
199
200 #[test]
201 fn parse_empty_relations() {
202 let raw = r#"{"relations": []}"#;
203 let output: RelationExtractionOutput = parse_json_response(raw).unwrap();
204 assert!(output.relations.is_empty());
205 }
206}