graphify_extract/semantic/
mod.rs1pub mod anthropic;
8pub mod anthropic_oauth;
9pub mod openai_compat;
10pub mod provider;
11
12use std::collections::HashMap;
13use std::path::Path;
14
15use anyhow::{Context, Result};
16use graphify_core::confidence::Confidence;
17use graphify_core::id::make_id;
18use graphify_core::model::{ExtractionResult, GraphEdge, GraphNode, NodeType};
19use serde::Deserialize;
20
21pub use provider::{AuthType, LLMConfigRaw, LLMProvider, LLMProviderConfig};
22
23#[derive(Deserialize, Debug)]
25struct SemanticOutput {
26 #[serde(default)]
27 entities: Vec<SemanticEntity>,
28 #[serde(default)]
29 relationships: Vec<SemanticRelation>,
30}
31
32#[derive(Deserialize, Debug)]
33struct SemanticEntity {
34 name: String,
35 #[serde(default = "default_entity_type")]
36 entity_type: String,
37}
38
39fn default_entity_type() -> String {
40 "concept".to_string()
41}
42
43#[derive(Deserialize, Debug)]
44struct SemanticRelation {
45 source: String,
46 target: String,
47 #[serde(default = "default_relation")]
48 relation: String,
49}
50
51fn default_relation() -> String {
52 "related_to".to_string()
53}
54
55pub async fn extract_semantic(
59 path: &Path,
60 content: &str,
61 file_type: &str,
62 config: &LLMProviderConfig,
63) -> Result<ExtractionResult> {
64 match config.provider {
65 LLMProvider::Anthropic => {
66 anthropic::extract_anthropic(path, content, file_type, config).await
67 }
68 LLMProvider::OpenAI | LLMProvider::Ollama | LLMProvider::OpenAICompatible => {
69 openai_compat::extract_openai_compatible(
70 path,
71 content,
72 file_type,
73 config.provider.clone(),
74 &config.model,
75 config.api_key.as_deref(),
76 &config.base_url,
77 )
78 .await
79 }
80 }
81}
82
83fn build_system_prompt(file_type: &str) -> String {
84 format!(
85 "You are an expert knowledge-graph extraction engine. \
86 Given a {file_type}, extract entities and their relationships. \
87 Respond ONLY with a JSON object having two arrays: \
88 \"entities\" (each with \"name\" and \"entity_type\") and \
89 \"relationships\" (each with \"source\", \"target\", and \"relation\"). \
90 Entity types should be one of: concept, class, function, module, paper, image. \
91 Keep entity names concise and unique."
92 )
93}
94
95fn build_user_prompt(content: &str, file_type: &str) -> String {
96 let max_chars = 100_000;
97 let truncated = if content.len() > max_chars {
98 let mut end = max_chars;
99 while end > 0 && !content.is_char_boundary(end) {
100 end -= 1;
101 }
102 &content[..end]
103 } else {
104 content
105 };
106
107 let is_truncated = content.len() > max_chars;
108 let note = if is_truncated {
109 "\n\n[NOTE: This file was truncated — only the first portion is shown. Extract entities only from the visible portion.]"
110 } else {
111 ""
112 };
113 format!("Extract all entities and relationships from this {file_type}:\n\n{truncated}{note}")
114}
115
116fn parse_semantic_response(text: &str, file_str: &str) -> Result<ExtractionResult> {
117 let json_str = extract_json_block(text);
118
119 let output: SemanticOutput =
120 serde_json::from_str(json_str).context("failed to parse semantic extraction JSON")?;
121
122 let mut nodes = Vec::new();
123 let mut edges = Vec::new();
124
125 let mut name_to_id: HashMap<String, String> = HashMap::new();
126 for entity in &output.entities {
127 let id = make_id(&[file_str, &entity.name]);
128 let node_type = match entity.entity_type.as_str() {
129 "class" => NodeType::Class,
130 "function" => NodeType::Function,
131 "module" => NodeType::Module,
132 "paper" => NodeType::Paper,
133 "image" => NodeType::Image,
134 _ => NodeType::Concept,
135 };
136 name_to_id.insert(entity.name.clone(), id.clone());
137 nodes.push(GraphNode {
138 id,
139 label: entity.name.clone(),
140 source_file: file_str.to_string(),
141 source_location: None,
142 node_type,
143 community: None,
144 extra: HashMap::new(),
145 });
146 }
147
148 for rel in &output.relationships {
149 let source_id = name_to_id
150 .get(&rel.source)
151 .cloned()
152 .unwrap_or_else(|| make_id(&[file_str, &rel.source]));
153 let target_id = name_to_id
154 .get(&rel.target)
155 .cloned()
156 .unwrap_or_else(|| make_id(&[file_str, &rel.target]));
157
158 edges.push(GraphEdge {
159 source: source_id,
160 target: target_id,
161 relation: rel.relation.clone(),
162 confidence: Confidence::Inferred,
163 confidence_score: Confidence::Inferred.default_score(),
164 source_file: file_str.to_string(),
165 source_location: None,
166 weight: 1.0,
167 provenance: Some("llm:semantic".to_string()),
168 extra: HashMap::new(),
169 });
170 }
171
172 Ok(ExtractionResult {
173 nodes,
174 edges,
175 hyperedges: Vec::new(),
176 })
177}
178
179fn extract_json_block(text: &str) -> &str {
181 if let Some(start) = text.find("```json") {
182 let after = &text[start + 7..];
183 if let Some(end) = after.find("```") {
184 return after[..end].trim();
185 }
186 }
187 if let Some(start) = text.find("```") {
188 let after = &text[start + 3..];
189 if let Some(end) = after.find("```") {
190 return after[..end].trim();
191 }
192 }
193 if let Some(start) = text.find('{')
194 && let Some(end) = text.rfind('}')
195 {
196 return &text[start..=end];
197 }
198 text.trim()
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
206 fn parse_semantic_json() {
207 let json = r#"{
208 "entities": [
209 {"name": "Machine Learning", "entity_type": "concept"},
210 {"name": "Neural Network", "entity_type": "concept"},
211 {"name": "Backpropagation", "entity_type": "concept"}
212 ],
213 "relationships": [
214 {"source": "Neural Network", "target": "Machine Learning", "relation": "is_a"},
215 {"source": "Backpropagation", "target": "Neural Network", "relation": "used_by"}
216 ]
217 }"#;
218
219 let result = parse_semantic_response(json, "paper.pdf").unwrap();
220 assert_eq!(result.nodes.len(), 3);
221 assert_eq!(result.edges.len(), 2);
222 assert!(
223 result
224 .nodes
225 .iter()
226 .all(|n| n.node_type == NodeType::Concept)
227 );
228 assert_eq!(result.edges[0].relation, "is_a");
229 }
230
231 #[test]
232 fn parse_markdown_wrapped_json() {
233 let text = r#"Here is the extraction:
234```json
235{
236 "entities": [{"name": "Foo", "entity_type": "class"}],
237 "relationships": []
238}
239```
240"#;
241 let result = parse_semantic_response(text, "doc.md").unwrap();
242 assert_eq!(result.nodes.len(), 1);
243 assert_eq!(result.nodes[0].label, "Foo");
244 assert_eq!(result.nodes[0].node_type, NodeType::Class);
245 }
246
247 #[test]
248 fn parse_empty_response() {
249 let json = r#"{"entities": [], "relationships": []}"#;
250 let result = parse_semantic_response(json, "empty.txt").unwrap();
251 assert!(result.nodes.is_empty());
252 assert!(result.edges.is_empty());
253 }
254
255 #[test]
256 fn extract_json_block_plain() {
257 assert_eq!(extract_json_block(r#"{"a": 1}"#), r#"{"a": 1}"#);
258 }
259
260 #[test]
261 fn extract_json_block_fenced() {
262 let text = "blah\n```json\n{\"a\": 1}\n```\nmore";
263 assert_eq!(extract_json_block(text), r#"{"a": 1}"#);
264 }
265
266 #[test]
267 fn semantic_edges_are_inferred_confidence() {
268 let json = r#"{
269 "entities": [
270 {"name": "A", "entity_type": "concept"},
271 {"name": "B", "entity_type": "concept"}
272 ],
273 "relationships": [
274 {"source": "A", "target": "B", "relation": "depends_on"}
275 ]
276 }"#;
277 let result = parse_semantic_response(json, "test.md").unwrap();
278 assert_eq!(result.edges[0].confidence, Confidence::Inferred);
279 }
280
281 #[test]
282 fn build_prompts_contain_file_type() {
283 let sys = build_system_prompt("paper");
284 assert!(sys.contains("paper"));
285
286 let user = build_user_prompt("hello world", "document");
287 assert!(user.contains("document"));
288 assert!(user.contains("hello world"));
289 }
290}