1use serde::{Deserialize, Serialize};
6
7pub const ENTITY_EXTRACTION_PROMPT: &str = r#"-Goal-
9Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
10
11-Steps-
121. Identify all entities. For each identified entity, extract the following information:
13- entity_name: Name of the entity, capitalized
14- entity_type: One of the following types: [{entity_types}]
15- entity_description: Comprehensive description of the entity's attributes and activities
16Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>)
17
182. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
19For each pair of related entities, extract the following information:
20- source_entity: name of the source entity, as identified in step 1
21- target_entity: name of the target entity, as identified in step 1
22- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
23- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
24Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_strength>)
25
263. Return output in JSON format with the following structure:
27{{
28 "entities": [
29 {{
30 "name": "entity name",
31 "type": "entity type",
32 "description": "entity description"
33 }}
34 ],
35 "relationships": [
36 {{
37 "source": "source entity name",
38 "target": "target entity name",
39 "description": "relationship description",
40 "strength": 0.8
41 }}
42 ]
43}}
44
45-Real Data-
46######################
47Entity Types: {entity_types}
48Text: {input_text}
49######################
50Output:
51"#;
52
53pub const GLEANING_CONTINUATION_PROMPT: &str = r#"-Goal-
55You previously extracted entities and relationships from a text document. Review your previous extraction and the original text to identify any additional entities or relationships you may have missed in the first pass.
56
57-Steps-
581. Review the entities you previously identified:
59{previous_entities}
60
612. Review the relationships you previously identified:
62{previous_relationships}
63
643. Carefully review the original text again and identify:
65- Any entities you may have missed
66- Any relationships between entities you may have overlooked
67- Any entities that need better descriptions
68
694. Return ONLY the NEW entities and relationships you discovered in this pass, using the same JSON format:
70{{
71 "entities": [
72 {{
73 "name": "entity name",
74 "type": "entity type",
75 "description": "entity description"
76 }}
77 ],
78 "relationships": [
79 {{
80 "source": "source entity name",
81 "target": "target entity name",
82 "description": "relationship description",
83 "strength": 0.8
84 }}
85 ]
86}}
87
88If you found no additional entities or relationships, return empty arrays.
89
90-Real Data-
91######################
92Entity Types: {entity_types}
93Text: {input_text}
94######################
95Output:
96"#;
97
98pub const COMPLETION_CHECK_PROMPT: &str = r#"Based on the text below and the entities/relationships already extracted, are there any significant entities or relationships that have been missed?
100
101Text:
102{input_text}
103
104Current Entities ({entity_count}):
105{entities_summary}
106
107Current Relationships ({relationship_count}):
108{relationships_summary}
109
110Think carefully about:
1111. Are all important characters, people, organizations mentioned in the text captured?
1122. Are all significant locations, places, settings identified?
1133. Are all key events, objects, concepts extracted?
1144. Are all meaningful relationships between entities documented?
115
116Respond with ONLY "YES" if the extraction is complete and thorough, or "NO" if there are still significant entities or relationships missing.
117
118Answer (YES or NO):"#;
119
120pub const ENTITY_EXTRACTION_JSON_SCHEMA: &str = r#"{
122 "type": "object",
123 "properties": {
124 "entities": {
125 "type": "array",
126 "items": {
127 "type": "object",
128 "properties": {
129 "name": {"type": "string"},
130 "type": {"type": "string"},
131 "description": {"type": "string"}
132 },
133 "required": ["name", "type", "description"]
134 }
135 },
136 "relationships": {
137 "type": "array",
138 "items": {
139 "type": "object",
140 "properties": {
141 "source": {"type": "string"},
142 "target": {"type": "string"},
143 "description": {"type": "string"},
144 "strength": {"type": "number"}
145 },
146 "required": ["source", "target", "description", "strength"]
147 }
148 }
149 },
150 "required": ["entities", "relationships"]
151}"#;
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct ExtractionOutput {
159 pub entities: Vec<EntityData>,
161 pub relationships: Vec<RelationshipData>,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct EntityData {
171 pub name: String,
173 #[serde(rename = "type")]
175 pub entity_type: String,
176 #[serde(default)]
178 pub description: String,
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct RelationshipData {
187 pub source: String,
189 pub target: String,
191 pub description: String,
193 pub strength: f64,
195}
196
197pub struct PromptBuilder {
199 entity_types: Vec<String>,
200 tuple_delimiter: String,
201}
202
203impl PromptBuilder {
204 pub fn new(entity_types: Vec<String>) -> Self {
206 Self {
207 entity_types,
208 tuple_delimiter: "|".to_string(),
209 }
210 }
211
212 pub fn build_extraction_prompt(&self, text: &str) -> String {
214 let entity_types_str = self.entity_types.join(", ");
215
216 ENTITY_EXTRACTION_PROMPT
217 .replace("{entity_types}", &entity_types_str)
218 .replace("{tuple_delimiter}", &self.tuple_delimiter)
219 .replace("{input_text}", text)
220 }
221
222 pub fn build_continuation_prompt(
224 &self,
225 text: &str,
226 previous_entities: &[EntityData],
227 previous_relationships: &[RelationshipData],
228 ) -> String {
229 let entity_types_str = self.entity_types.join(", ");
230
231 let entities_summary = previous_entities
233 .iter()
234 .map(|e| format!("- {} ({}): {}", e.name, e.entity_type, e.description))
235 .collect::<Vec<_>>()
236 .join("\n");
237
238 let relationships_summary = previous_relationships
240 .iter()
241 .map(|r| {
242 format!(
243 "- {} -> {}: {} (strength: {:.2})",
244 r.source, r.target, r.description, r.strength
245 )
246 })
247 .collect::<Vec<_>>()
248 .join("\n");
249
250 GLEANING_CONTINUATION_PROMPT
251 .replace("{entity_types}", &entity_types_str)
252 .replace("{input_text}", text)
253 .replace("{previous_entities}", &entities_summary)
254 .replace("{previous_relationships}", &relationships_summary)
255 }
256
257 pub fn build_completion_prompt(
259 &self,
260 text: &str,
261 entities: &[EntityData],
262 relationships: &[RelationshipData],
263 ) -> String {
264 let entities_summary = entities
266 .iter()
267 .take(20) .map(|e| format!("- {} ({})", e.name, e.entity_type))
269 .collect::<Vec<_>>()
270 .join("\n");
271
272 let entities_summary = if entities.len() > 20 {
273 format!(
274 "{}...\n(showing 20 of {} entities)",
275 entities_summary,
276 entities.len()
277 )
278 } else {
279 entities_summary
280 };
281
282 let relationships_summary = relationships
284 .iter()
285 .take(20) .map(|r| format!("- {} -> {}", r.source, r.target))
287 .collect::<Vec<_>>()
288 .join("\n");
289
290 let relationships_summary = if relationships.len() > 20 {
291 format!(
292 "{}...\n(showing 20 of {} relationships)",
293 relationships_summary,
294 relationships.len()
295 )
296 } else {
297 relationships_summary
298 };
299
300 COMPLETION_CHECK_PROMPT
301 .replace("{input_text}", text)
302 .replace("{entity_count}", &entities.len().to_string())
303 .replace("{entities_summary}", &entities_summary)
304 .replace("{relationship_count}", &relationships.len().to_string())
305 .replace("{relationships_summary}", &relationships_summary)
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn test_build_extraction_prompt() {
315 let builder = PromptBuilder::new(vec![
316 "PERSON".to_string(),
317 "LOCATION".to_string(),
318 "ORGANIZATION".to_string(),
319 ]);
320
321 let prompt = builder.build_extraction_prompt("Tom and Huck went to the cave.");
322
323 assert!(prompt.contains("PERSON"));
324 assert!(prompt.contains("LOCATION"));
325 assert!(prompt.contains("ORGANIZATION"));
326 assert!(prompt.contains("Tom and Huck went to the cave."));
327 }
328
329 #[test]
330 fn test_build_continuation_prompt() {
331 let builder = PromptBuilder::new(vec!["PERSON".to_string()]);
332
333 let previous_entities = vec![EntityData {
334 name: "Tom".to_string(),
335 entity_type: "PERSON".to_string(),
336 description: "A young boy".to_string(),
337 }];
338
339 let previous_relationships = vec![RelationshipData {
340 source: "Tom".to_string(),
341 target: "Huck".to_string(),
342 description: "friends".to_string(),
343 strength: 0.9,
344 }];
345
346 let prompt = builder.build_continuation_prompt(
347 "Tom and Huck are best friends.",
348 &previous_entities,
349 &previous_relationships,
350 );
351
352 assert!(prompt.contains("Tom"));
353 assert!(prompt.contains("Huck"));
354 assert!(prompt.contains("friends"));
355 }
356
357 #[test]
358 fn test_build_completion_prompt() {
359 let builder = PromptBuilder::new(vec!["PERSON".to_string()]);
360
361 let entities = vec![EntityData {
362 name: "Tom".to_string(),
363 entity_type: "PERSON".to_string(),
364 description: "A young boy".to_string(),
365 }];
366
367 let relationships = vec![RelationshipData {
368 source: "Tom".to_string(),
369 target: "Huck".to_string(),
370 description: "friends".to_string(),
371 strength: 0.9,
372 }];
373
374 let prompt = builder.build_completion_prompt("Test text", &entities, &relationships);
375
376 assert!(prompt.contains("Tom"));
377 assert!(prompt.contains("YES or NO"));
378 }
379
380 #[test]
381 fn test_extraction_output_serialization() {
382 let output = ExtractionOutput {
383 entities: vec![EntityData {
384 name: "Tom Sawyer".to_string(),
385 entity_type: "PERSON".to_string(),
386 description: "The protagonist".to_string(),
387 }],
388 relationships: vec![RelationshipData {
389 source: "Tom Sawyer".to_string(),
390 target: "Huck Finn".to_string(),
391 description: "best friends".to_string(),
392 strength: 0.95,
393 }],
394 };
395
396 let json = serde_json::to_string(&output).unwrap();
397 assert!(json.contains("Tom Sawyer"));
398 assert!(json.contains("PERSON"));
399
400 let deserialized: ExtractionOutput = serde_json::from_str(&json).unwrap();
401 assert_eq!(deserialized.entities.len(), 1);
402 assert_eq!(deserialized.relationships.len(), 1);
403 }
404}