graphify_extract/semantic/
mod.rs1pub mod anthropic;
8pub mod anthropic_oauth;
9pub mod openai_compat;
10pub mod provider;
11
12use std::collections::HashMap;
13use std::path::Path;
14
15use anyhow::{Context, Result};
16use graphify_core::confidence::Confidence;
17use graphify_core::id::make_id;
18use graphify_core::model::{ExtractionResult, GraphEdge, GraphNode, NodeType};
19use serde::Deserialize;
20
21pub use provider::{AuthType, LLMConfigRaw, LLMProvider, LLMProviderConfig};
22
23#[derive(Deserialize, Debug)]
25struct SemanticOutput {
26 #[serde(default)]
27 entities: Vec<SemanticEntity>,
28 #[serde(default)]
29 relationships: Vec<SemanticRelation>,
30}
31
32#[derive(Deserialize, Debug)]
33struct SemanticEntity {
34 name: String,
35 #[serde(default = "default_entity_type")]
36 entity_type: String,
37}
38
39fn default_entity_type() -> String {
40 "concept".to_string()
41}
42
43#[derive(Deserialize, Debug)]
44struct SemanticRelation {
45 source: String,
46 target: String,
47 #[serde(default = "default_relation")]
48 relation: String,
49}
50
51fn default_relation() -> String {
52 "related_to".to_string()
53}
54
55pub async fn extract_semantic(
59 path: &Path,
60 content: &str,
61 file_type: &str,
62 config: &LLMProviderConfig,
63) -> Result<ExtractionResult> {
64 match config.provider {
65 LLMProvider::Anthropic => {
66 anthropic::extract_anthropic(path, content, file_type, config).await
67 }
68 LLMProvider::OpenAI | LLMProvider::Ollama | LLMProvider::OpenAICompatible => {
69 openai_compat::extract_openai_compatible(
70 path,
71 content,
72 file_type,
73 config.provider.clone(),
74 &config.model,
75 config.api_key.as_deref(),
76 &config.base_url,
77 )
78 .await
79 }
80 }
81}
82
83fn build_system_prompt(file_type: &str) -> String {
84 format!(
85 "You are an expert knowledge-graph extraction engine. \
86 Given a {file_type}, extract entities and their relationships. \
87 Respond ONLY with a JSON object having two arrays: \
88 \"entities\" (each with \"name\" and \"entity_type\") and \
89 \"relationships\" (each with \"source\", \"target\", and \"relation\"). \
90 Entity types should be one of: concept, class, function, module, paper, image. \
91 Keep entity names concise and unique."
92 )
93}
94
95fn build_user_prompt(content: &str, file_type: &str) -> String {
96 let max_chars = 100_000;
97 let truncated = if content.len() > max_chars {
98 let mut end = max_chars;
99 while end > 0 && !content.is_char_boundary(end) {
100 end -= 1;
101 }
102 &content[..end]
103 } else {
104 content
105 };
106
107 let is_truncated = content.len() > max_chars;
108 let note = if is_truncated {
109 "\n\n[NOTE: This file was truncated — only the first portion is shown. Extract entities only from the visible portion.]"
110 } else {
111 ""
112 };
113 format!("Extract all entities and relationships from this {file_type}:\n\n{truncated}{note}")
114}
115
116fn parse_semantic_response(text: &str, file_str: &str) -> Result<ExtractionResult> {
117 let json_str = extract_json_block(text);
118
119 let output: SemanticOutput =
120 serde_json::from_str(json_str).context("failed to parse semantic extraction JSON")?;
121
122 let mut nodes = Vec::new();
123 let mut edges = Vec::new();
124
125 let mut name_to_id: HashMap<String, String> = HashMap::new();
126 for entity in &output.entities {
127 let id = make_id(&[file_str, &entity.name]);
128 let node_type = match entity.entity_type.as_str() {
129 "class" => NodeType::Class,
130 "function" => NodeType::Function,
131 "module" => NodeType::Module,
132 "paper" => NodeType::Paper,
133 "image" => NodeType::Image,
134 _ => NodeType::Concept,
135 };
136 name_to_id.insert(entity.name.clone(), id.clone());
137 nodes.push(GraphNode {
138 id,
139 label: entity.name.clone(),
140 source_file: file_str.to_string(),
141 source_location: None,
142 node_type,
143 community: None,
144 extra: HashMap::new(),
145 });
146 }
147
148 for rel in &output.relationships {
149 let source_id = name_to_id
150 .get(&rel.source)
151 .cloned()
152 .unwrap_or_else(|| make_id(&[file_str, &rel.source]));
153 let target_id = name_to_id
154 .get(&rel.target)
155 .cloned()
156 .unwrap_or_else(|| make_id(&[file_str, &rel.target]));
157
158 edges.push(GraphEdge {
159 source: source_id,
160 target: target_id,
161 relation: rel.relation.clone(),
162 confidence: Confidence::Inferred,
163 confidence_score: Confidence::Inferred.default_score(),
164 source_file: file_str.to_string(),
165 source_location: None,
166 weight: 1.0,
167 extra: HashMap::new(),
168 });
169 }
170
171 Ok(ExtractionResult {
172 nodes,
173 edges,
174 hyperedges: Vec::new(),
175 })
176}
177
178fn extract_json_block(text: &str) -> &str {
180 if let Some(start) = text.find("```json") {
181 let after = &text[start + 7..];
182 if let Some(end) = after.find("```") {
183 return after[..end].trim();
184 }
185 }
186 if let Some(start) = text.find("```") {
187 let after = &text[start + 3..];
188 if let Some(end) = after.find("```") {
189 return after[..end].trim();
190 }
191 }
192 if let Some(start) = text.find('{')
193 && let Some(end) = text.rfind('}')
194 {
195 return &text[start..=end];
196 }
197 text.trim()
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn parse_semantic_json() {
206 let json = r#"{
207 "entities": [
208 {"name": "Machine Learning", "entity_type": "concept"},
209 {"name": "Neural Network", "entity_type": "concept"},
210 {"name": "Backpropagation", "entity_type": "concept"}
211 ],
212 "relationships": [
213 {"source": "Neural Network", "target": "Machine Learning", "relation": "is_a"},
214 {"source": "Backpropagation", "target": "Neural Network", "relation": "used_by"}
215 ]
216 }"#;
217
218 let result = parse_semantic_response(json, "paper.pdf").unwrap();
219 assert_eq!(result.nodes.len(), 3);
220 assert_eq!(result.edges.len(), 2);
221 assert!(
222 result
223 .nodes
224 .iter()
225 .all(|n| n.node_type == NodeType::Concept)
226 );
227 assert_eq!(result.edges[0].relation, "is_a");
228 }
229
230 #[test]
231 fn parse_markdown_wrapped_json() {
232 let text = r#"Here is the extraction:
233```json
234{
235 "entities": [{"name": "Foo", "entity_type": "class"}],
236 "relationships": []
237}
238```
239"#;
240 let result = parse_semantic_response(text, "doc.md").unwrap();
241 assert_eq!(result.nodes.len(), 1);
242 assert_eq!(result.nodes[0].label, "Foo");
243 assert_eq!(result.nodes[0].node_type, NodeType::Class);
244 }
245
246 #[test]
247 fn parse_empty_response() {
248 let json = r#"{"entities": [], "relationships": []}"#;
249 let result = parse_semantic_response(json, "empty.txt").unwrap();
250 assert!(result.nodes.is_empty());
251 assert!(result.edges.is_empty());
252 }
253
254 #[test]
255 fn extract_json_block_plain() {
256 assert_eq!(extract_json_block(r#"{"a": 1}"#), r#"{"a": 1}"#);
257 }
258
259 #[test]
260 fn extract_json_block_fenced() {
261 let text = "blah\n```json\n{\"a\": 1}\n```\nmore";
262 assert_eq!(extract_json_block(text), r#"{"a": 1}"#);
263 }
264
265 #[test]
266 fn semantic_edges_are_inferred_confidence() {
267 let json = r#"{
268 "entities": [
269 {"name": "A", "entity_type": "concept"},
270 {"name": "B", "entity_type": "concept"}
271 ],
272 "relationships": [
273 {"source": "A", "target": "B", "relation": "depends_on"}
274 ]
275 }"#;
276 let result = parse_semantic_response(json, "test.md").unwrap();
277 assert_eq!(result.edges[0].confidence, Confidence::Inferred);
278 }
279
280 #[test]
281 fn build_prompts_contain_file_type() {
282 let sys = build_system_prompt("paper");
283 assert!(sys.contains("paper"));
284
285 let user = build_user_prompt("hello world", "document");
286 assert!(user.contains("document"));
287 assert!(user.contains("hello world"));
288 }
289}