Skip to main content

mem7_graph/
extraction.rs

1use mem7_error::{Mem7Error, Result};
2use mem7_llm::{LlmClient, LlmMessage, ResponseFormat};
3use serde::Deserialize;
4use tracing::debug;
5
6use crate::prompts::{
7    DELETE_RELATIONS_PROMPT, ENTITY_EXTRACTION_PROMPT, RELATION_EXTRACTION_PROMPT,
8};
9use crate::types::{Entity, GraphSearchResult, Relation};
10
11#[derive(Debug, Deserialize)]
12struct EntityExtractionOutput {
13    #[serde(default)]
14    entities: Vec<RawEntity>,
15}
16
17#[derive(Debug, Deserialize)]
18struct RawEntity {
19    entity: String,
20    entity_type: String,
21}
22
23#[derive(Debug, Deserialize)]
24struct RelationExtractionOutput {
25    #[serde(default)]
26    relations: Vec<RawRelation>,
27}
28
29#[derive(Debug, Deserialize)]
30struct RawRelation {
31    source: String,
32    relationship: String,
33    destination: String,
34}
35
36/// Extract entities from a conversation using the LLM.
37pub async fn extract_entities(
38    llm: &dyn LlmClient,
39    conversation: &str,
40    custom_prompt: Option<&str>,
41) -> Result<Vec<Entity>> {
42    let prompt = custom_prompt.unwrap_or(ENTITY_EXTRACTION_PROMPT);
43
44    let messages = vec![LlmMessage::system(prompt), LlmMessage::user(conversation)];
45
46    let response = llm
47        .chat_completion(&messages, Some(&ResponseFormat::json()))
48        .await?;
49
50    debug!(raw = %response.content, "entity extraction response");
51
52    let output: EntityExtractionOutput = parse_json_response(&response.content)?;
53
54    Ok(output
55        .entities
56        .into_iter()
57        .map(|e| Entity {
58            name: e.entity,
59            entity_type: e.entity_type,
60            embedding: None,
61            created_at: None,
62            mentions: 0,
63        })
64        .collect())
65}
66
67/// Extract relations between entities from a conversation using the LLM.
68pub async fn extract_relations(
69    llm: &dyn LlmClient,
70    conversation: &str,
71    entities: &[Entity],
72    custom_prompt: Option<&str>,
73) -> Result<Vec<Relation>> {
74    if entities.is_empty() {
75        return Ok(Vec::new());
76    }
77
78    let prompt = custom_prompt.unwrap_or(RELATION_EXTRACTION_PROMPT);
79
80    let entity_names: Vec<&str> = entities.iter().map(|e| e.name.as_str()).collect();
81    let user_input = format!(
82        "Entities: {}\nText: {}",
83        entity_names.join(", "),
84        conversation
85    );
86
87    let messages = vec![LlmMessage::system(prompt), LlmMessage::user(user_input)];
88
89    let response = llm
90        .chat_completion(&messages, Some(&ResponseFormat::json()))
91        .await?;
92
93    debug!(raw = %response.content, "relation extraction response");
94
95    let output: RelationExtractionOutput = parse_json_response(&response.content)?;
96
97    Ok(output
98        .relations
99        .into_iter()
100        .map(|r| Relation {
101            source: r.source,
102            relationship: r.relationship,
103            destination: r.destination,
104            created_at: None,
105            mentions: 0,
106            valid: true,
107        })
108        .collect())
109}
110
111#[derive(Debug, Deserialize)]
112struct DeletionOutput {
113    #[serde(default)]
114    deletions: Vec<RawRelation>,
115}
116
117/// Ask the LLM which existing relations should be invalidated given new data.
118/// Returns triples `(source, relationship, destination)` to soft-delete.
119pub async fn extract_deletions(
120    llm: &dyn LlmClient,
121    existing: &[GraphSearchResult],
122    new_data: &str,
123) -> Result<Vec<(String, String, String)>> {
124    if existing.is_empty() {
125        return Ok(Vec::new());
126    }
127
128    let existing_str = existing
129        .iter()
130        .map(|r| format!("{} -- {} -- {}", r.source, r.relationship, r.destination))
131        .collect::<Vec<_>>()
132        .join("\n");
133
134    let user_msg =
135        format!("Here are the existing memories:\n{existing_str}\n\nNew Information:\n{new_data}");
136
137    let messages = vec![
138        LlmMessage::system(DELETE_RELATIONS_PROMPT),
139        LlmMessage::user(user_msg),
140    ];
141
142    let response = llm
143        .chat_completion(&messages, Some(&ResponseFormat::json()))
144        .await?;
145
146    debug!(raw = %response.content, "deletion extraction response");
147
148    let output: DeletionOutput = parse_json_response(&response.content)?;
149
150    Ok(output
151        .deletions
152        .into_iter()
153        .map(|d| (d.source, d.relationship, d.destination))
154        .collect())
155}
156
157fn parse_json_response<T: serde::de::DeserializeOwned>(raw: &str) -> Result<T> {
158    mem7_core::parse_json_response(raw).map_err(Mem7Error::Graph)
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn parse_entity_response() {
167        let raw = r#"{"entities": [{"entity": "Alice", "entity_type": "Person"}, {"entity": "tennis", "entity_type": "Activity"}]}"#;
168        let output: EntityExtractionOutput = parse_json_response(raw).unwrap();
169        assert_eq!(output.entities.len(), 2);
170        assert_eq!(output.entities[0].entity, "Alice");
171        assert_eq!(output.entities[1].entity_type, "Activity");
172    }
173
174    #[test]
175    fn parse_entity_response_with_code_fence() {
176        let raw =
177            "```json\n{\"entities\": [{\"entity\": \"Bob\", \"entity_type\": \"Person\"}]}\n```";
178        let output: EntityExtractionOutput = parse_json_response(raw).unwrap();
179        assert_eq!(output.entities.len(), 1);
180        assert_eq!(output.entities[0].entity, "Bob");
181    }
182
183    #[test]
184    fn parse_empty_entities() {
185        let raw = r#"{"entities": []}"#;
186        let output: EntityExtractionOutput = parse_json_response(raw).unwrap();
187        assert!(output.entities.is_empty());
188    }
189
190    #[test]
191    fn parse_relation_response() {
192        let raw = r#"{"relations": [{"source": "USER", "relationship": "works_at", "destination": "Google"}]}"#;
193        let output: RelationExtractionOutput = parse_json_response(raw).unwrap();
194        assert_eq!(output.relations.len(), 1);
195        assert_eq!(output.relations[0].source, "USER");
196        assert_eq!(output.relations[0].relationship, "works_at");
197        assert_eq!(output.relations[0].destination, "Google");
198    }
199
200    #[test]
201    fn parse_empty_relations() {
202        let raw = r#"{"relations": []}"#;
203        let output: RelationExtractionOutput = parse_json_response(raw).unwrap();
204        assert!(output.relations.is_empty());
205    }
206}