Skip to main content

phago_llm/
prompt.rs

1//! Prompt templates for LLM concept extraction.
2
3use crate::types::Concept;
4
5/// A prompt template for LLM requests.
6pub trait PromptTemplate {
7    /// Generate the prompt text.
8    fn generate(&self) -> String;
9
10    /// Get the system prompt (if any).
11    fn system_prompt(&self) -> Option<String> {
12        None
13    }
14}
15
16/// Prompt for concept extraction.
17#[derive(Debug, Clone)]
18pub struct ConceptPrompt {
19    /// The text to extract concepts from.
20    pub text: String,
21    /// Maximum number of concepts to extract.
22    pub max_concepts: usize,
23    /// Whether to include descriptions.
24    pub include_descriptions: bool,
25    /// Whether to identify concept types.
26    pub include_types: bool,
27    /// Domain hint for better extraction.
28    pub domain: Option<String>,
29}
30
31impl ConceptPrompt {
32    /// Create a new concept extraction prompt.
33    pub fn new(text: impl Into<String>) -> Self {
34        Self {
35            text: text.into(),
36            max_concepts: 10,
37            include_descriptions: false,
38            include_types: true,
39            domain: None,
40        }
41    }
42
43    /// Set max concepts.
44    pub fn with_max_concepts(mut self, max: usize) -> Self {
45        self.max_concepts = max;
46        self
47    }
48
49    /// Include descriptions in output.
50    pub fn with_descriptions(mut self) -> Self {
51        self.include_descriptions = true;
52        self
53    }
54
55    /// Set domain hint.
56    pub fn with_domain(mut self, domain: impl Into<String>) -> Self {
57        self.domain = Some(domain.into());
58        self
59    }
60}
61
62impl PromptTemplate for ConceptPrompt {
63    fn system_prompt(&self) -> Option<String> {
64        let domain_hint = self.domain.as_ref()
65            .map(|d| format!(" in the domain of {}", d))
66            .unwrap_or_default();
67
68        Some(format!(
69            "You are an expert at extracting key concepts{domain_hint}. \
70             Extract the most important concepts from the given text. \
71             Respond ONLY with a JSON array of concepts, no explanation."
72        ))
73    }
74
75    fn generate(&self) -> String {
76        let type_instruction = if self.include_types {
77            r#", "type": "<entity|concept|process|property>""#
78        } else {
79            ""
80        };
81
82        let desc_instruction = if self.include_descriptions {
83            r#", "description": "<brief description>""#
84        } else {
85            ""
86        };
87
88        format!(
89            r#"Extract up to {} key concepts from this text:
90
91---
92{}
93---
94
95Respond with a JSON array like:
96[{{"label": "<concept>"{}{}, "confidence": <0.0-1.0>}}]
97
98JSON:"#,
99            self.max_concepts,
100            self.text,
101            type_instruction,
102            desc_instruction
103        )
104    }
105}
106
107/// Prompt for relationship identification.
108#[derive(Debug, Clone)]
109pub struct RelationshipPrompt {
110    /// The original text.
111    pub text: String,
112    /// Concepts to find relationships between.
113    pub concepts: Vec<String>,
114    /// Maximum relationships to identify.
115    pub max_relationships: usize,
116}
117
118impl RelationshipPrompt {
119    /// Create a new relationship prompt.
120    pub fn new(text: impl Into<String>, concepts: Vec<String>) -> Self {
121        Self {
122            text: text.into(),
123            concepts,
124            max_relationships: 20,
125        }
126    }
127
128    /// Set max relationships.
129    pub fn with_max_relationships(mut self, max: usize) -> Self {
130        self.max_relationships = max;
131        self
132    }
133}
134
135impl PromptTemplate for RelationshipPrompt {
136    fn system_prompt(&self) -> Option<String> {
137        Some(
138            "You are an expert at identifying relationships between concepts. \
139             Identify meaningful relationships from the given text and concepts. \
140             Respond ONLY with a JSON array, no explanation.".to_string()
141        )
142    }
143
144    fn generate(&self) -> String {
145        let concepts_list = self.concepts.join(", ");
146
147        format!(
148            r#"Given this text and concepts, identify relationships between them.
149
150Text:
151---
152{}
153---
154
155Concepts: {}
156
157Find up to {} relationships. Respond with a JSON array like:
158[{{"source": "<concept>", "target": "<concept>", "relation": "<is_a|part_of|causes|enables|requires|produces|regulates|interacts_with|located_in|related_to>", "label": "<human readable relationship>"}}]
159
160JSON:"#,
161            self.text,
162            concepts_list,
163            self.max_relationships
164        )
165    }
166}
167
168/// Prompt for query expansion.
169#[derive(Debug, Clone)]
170pub struct QueryExpansionPrompt {
171    /// The original query.
172    pub query: String,
173    /// Number of expanded queries to generate.
174    pub num_expansions: usize,
175}
176
177impl QueryExpansionPrompt {
178    /// Create a new query expansion prompt.
179    pub fn new(query: impl Into<String>) -> Self {
180        Self {
181            query: query.into(),
182            num_expansions: 3,
183        }
184    }
185
186    /// Set number of expansions.
187    pub fn with_num_expansions(mut self, num: usize) -> Self {
188        self.num_expansions = num;
189        self
190    }
191}
192
193impl PromptTemplate for QueryExpansionPrompt {
194    fn system_prompt(&self) -> Option<String> {
195        Some(
196            "You are a search query expansion expert. Generate alternative queries \
197             that capture the same intent but use different terminology.".to_string()
198        )
199    }
200
201    fn generate(&self) -> String {
202        format!(
203            r#"Expand this search query into {} alternative queries that capture the same meaning:
204
205Query: {}
206
207Respond with a JSON array of strings:
208["<query1>", "<query2>", ...]
209
210JSON:"#,
211            self.num_expansions,
212            self.query
213        )
214    }
215}
216
217/// Parse concepts from JSON response.
218pub fn parse_concepts_json(json: &str) -> Result<Vec<Concept>, serde_json::Error> {
219    // Try to find JSON array in response
220    let json_str = extract_json_array(json);
221
222    #[derive(serde::Deserialize)]
223    struct RawConcept {
224        label: String,
225        #[serde(rename = "type", default)]
226        concept_type: Option<String>,
227        #[serde(default)]
228        description: Option<String>,
229        #[serde(default = "default_confidence")]
230        confidence: f32,
231    }
232
233    fn default_confidence() -> f32 {
234        1.0
235    }
236
237    let raw: Vec<RawConcept> = serde_json::from_str(json_str)?;
238
239    Ok(raw
240        .into_iter()
241        .map(|r| {
242            let mut concept = Concept::new(r.label).with_confidence(r.confidence);
243            if let Some(desc) = r.description {
244                concept = concept.with_description(desc);
245            }
246            if let Some(t) = r.concept_type {
247                concept = concept.with_type(match t.to_lowercase().as_str() {
248                    "entity" => crate::types::ConceptType::Entity,
249                    "process" => crate::types::ConceptType::Process,
250                    "property" => crate::types::ConceptType::Property,
251                    "relationship" => crate::types::ConceptType::Relationship,
252                    _ => crate::types::ConceptType::Concept,
253                });
254            }
255            concept
256        })
257        .collect())
258}
259
260/// Extract JSON array from text (handles markdown code blocks).
261fn extract_json_array(text: &str) -> &str {
262    // Remove markdown code blocks
263    let text = text.trim();
264    let text = text.strip_prefix("```json").unwrap_or(text);
265    let text = text.strip_prefix("```").unwrap_or(text);
266    let text = text.strip_suffix("```").unwrap_or(text);
267    let text = text.trim();
268
269    // Find array bounds
270    if let (Some(start), Some(end)) = (text.find('['), text.rfind(']')) {
271        &text[start..=end]
272    } else {
273        text
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_concept_prompt() {
283        let prompt = ConceptPrompt::new("The cell membrane controls transport.")
284            .with_max_concepts(5)
285            .with_domain("biology");
286
287        let generated = prompt.generate();
288        assert!(generated.contains("cell membrane"));
289        assert!(generated.contains("5"));
290
291        let system = prompt.system_prompt().unwrap();
292        assert!(system.contains("biology"));
293    }
294
295    #[test]
296    fn test_relationship_prompt() {
297        let prompt = RelationshipPrompt::new(
298            "Mitochondria produce ATP.",
299            vec!["mitochondria".into(), "ATP".into()],
300        );
301
302        let generated = prompt.generate();
303        assert!(generated.contains("mitochondria"));
304        assert!(generated.contains("ATP"));
305    }
306
307    #[test]
308    fn test_parse_concepts_json() {
309        let json = r#"[
310            {"label": "cell", "type": "concept", "confidence": 0.95},
311            {"label": "membrane", "type": "concept", "confidence": 0.9}
312        ]"#;
313
314        let concepts = parse_concepts_json(json).unwrap();
315        assert_eq!(concepts.len(), 2);
316        assert_eq!(concepts[0].label, "cell");
317        assert!((concepts[0].confidence - 0.95).abs() < 0.01);
318    }
319
320    #[test]
321    fn test_parse_concepts_with_code_block() {
322        let json = r#"```json
323        [{"label": "test", "confidence": 1.0}]
324        ```"#;
325
326        let concepts = parse_concepts_json(json).unwrap();
327        assert_eq!(concepts.len(), 1);
328        assert_eq!(concepts[0].label, "test");
329    }
330}