1use std::collections::HashMap;
10use std::path::Path;
11
12use anyhow::{Context, Result};
13use graphify_core::confidence::Confidence;
14use graphify_core::id::make_id;
15use graphify_core::model::{ExtractionResult, GraphEdge, GraphNode, NodeType};
16use serde::{Deserialize, Serialize};
17use tracing::debug;
18
19#[derive(Serialize)]
24struct MessageRequest {
25 model: String,
26 max_tokens: u32,
27 messages: Vec<Message>,
28 system: String,
29}
30
31#[derive(Serialize)]
32struct Message {
33 role: String,
34 content: String,
35}
36
37#[derive(Deserialize)]
38struct MessageResponse {
39 content: Vec<ContentBlock>,
40}
41
42#[derive(Deserialize)]
43struct ContentBlock {
44 text: Option<String>,
45}
46
47#[derive(Deserialize, Debug)]
49struct SemanticOutput {
50 #[serde(default)]
51 entities: Vec<SemanticEntity>,
52 #[serde(default)]
53 relationships: Vec<SemanticRelation>,
54}
55
56#[derive(Deserialize, Debug)]
57struct SemanticEntity {
58 name: String,
59 #[serde(default = "default_entity_type")]
60 entity_type: String,
61}
62
63fn default_entity_type() -> String {
64 "concept".to_string()
65}
66
67#[derive(Deserialize, Debug)]
68struct SemanticRelation {
69 source: String,
70 target: String,
71 #[serde(default = "default_relation")]
72 relation: String,
73}
74
75fn default_relation() -> String {
76 "related_to".to_string()
77}
78
79pub async fn extract_semantic(
94 path: &Path,
95 content: &str,
96 file_type: &str,
97 api_key: &str,
98) -> Result<ExtractionResult> {
99 let file_str = path.to_string_lossy();
100 let system_prompt = build_system_prompt(file_type);
101 let user_prompt = build_user_prompt(content, file_type);
102
103 debug!("sending semantic extraction request for {}", file_str);
104
105 let request_body = MessageRequest {
106 model: "claude-sonnet-4-20250514".to_string(),
107 max_tokens: 4096,
108 messages: vec![Message {
109 role: "user".to_string(),
110 content: user_prompt,
111 }],
112 system: system_prompt,
113 };
114
115 let client = reqwest::Client::new();
116 let response = client
117 .post("https://api.anthropic.com/v1/messages")
118 .header("x-api-key", api_key)
119 .header("anthropic-version", "2023-06-01")
120 .header("content-type", "application/json")
121 .json(&request_body)
122 .send()
123 .await
124 .context("failed to send request to Claude API")?;
125
126 if !response.status().is_success() {
127 let status = response.status();
128 let body = response.text().await.unwrap_or_default();
129 anyhow::bail!("Claude API returned {status}: {body}");
130 }
131
132 let msg: MessageResponse = response
133 .json()
134 .await
135 .context("failed to parse Claude API response")?;
136
137 let text = msg
138 .content
139 .first()
140 .and_then(|b| b.text.as_deref())
141 .unwrap_or("{}");
142
143 parse_semantic_response(text, &file_str)
144}
145
146fn build_system_prompt(file_type: &str) -> String {
151 format!(
152 "You are an expert knowledge-graph extraction engine. \
153 Given a {file_type}, extract entities and their relationships. \
154 Respond ONLY with a JSON object having two arrays: \
155 \"entities\" (each with \"name\" and \"entity_type\") and \
156 \"relationships\" (each with \"source\", \"target\", and \"relation\"). \
157 Entity types should be one of: concept, class, function, module, paper, image. \
158 Keep entity names concise and unique."
159 )
160}
161
162fn build_user_prompt(content: &str, file_type: &str) -> String {
163 let max_chars = 100_000;
165 let truncated = if content.len() > max_chars {
166 let mut end = max_chars;
167 while end > 0 && !content.is_char_boundary(end) {
168 end -= 1;
169 }
170 &content[..end]
171 } else {
172 content
173 };
174
175 format!("Extract all entities and relationships from this {file_type}:\n\n{truncated}")
176}
177
178fn parse_semantic_response(text: &str, file_str: &str) -> Result<ExtractionResult> {
183 let json_str = extract_json_block(text);
185
186 let output: SemanticOutput =
187 serde_json::from_str(json_str).context("failed to parse semantic extraction JSON")?;
188
189 let mut nodes = Vec::new();
190 let mut edges = Vec::new();
191
192 let mut name_to_id: HashMap<String, String> = HashMap::new();
194 for entity in &output.entities {
195 let id = make_id(&[file_str, &entity.name]);
196 let node_type = match entity.entity_type.as_str() {
197 "class" => NodeType::Class,
198 "function" => NodeType::Function,
199 "module" => NodeType::Module,
200 "paper" => NodeType::Paper,
201 "image" => NodeType::Image,
202 _ => NodeType::Concept,
203 };
204 name_to_id.insert(entity.name.clone(), id.clone());
205 nodes.push(GraphNode {
206 id,
207 label: entity.name.clone(),
208 source_file: file_str.to_string(),
209 source_location: None,
210 node_type,
211 community: None,
212 extra: HashMap::new(),
213 });
214 }
215
216 for rel in &output.relationships {
218 let source_id = name_to_id
219 .get(&rel.source)
220 .cloned()
221 .unwrap_or_else(|| make_id(&[file_str, &rel.source]));
222 let target_id = name_to_id
223 .get(&rel.target)
224 .cloned()
225 .unwrap_or_else(|| make_id(&[file_str, &rel.target]));
226
227 edges.push(GraphEdge {
228 source: source_id,
229 target: target_id,
230 relation: rel.relation.clone(),
231 confidence: Confidence::Inferred,
232 confidence_score: Confidence::Inferred.default_score(),
233 source_file: file_str.to_string(),
234 source_location: None,
235 weight: 1.0,
236 extra: HashMap::new(),
237 });
238 }
239
240 Ok(ExtractionResult {
241 nodes,
242 edges,
243 hyperedges: Vec::new(),
244 })
245}
246
247fn extract_json_block(text: &str) -> &str {
249 if let Some(start) = text.find("```json") {
251 let after = &text[start + 7..];
252 if let Some(end) = after.find("```") {
253 return after[..end].trim();
254 }
255 }
256 if let Some(start) = text.find("```") {
258 let after = &text[start + 3..];
259 if let Some(end) = after.find("```") {
260 return after[..end].trim();
261 }
262 }
263 if let Some(start) = text.find('{')
265 && let Some(end) = text.rfind('}')
266 {
267 return &text[start..=end];
268 }
269 text.trim()
270}
271
272#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn parse_semantic_json() {
282 let json = r#"{
283 "entities": [
284 {"name": "Machine Learning", "entity_type": "concept"},
285 {"name": "Neural Network", "entity_type": "concept"},
286 {"name": "Backpropagation", "entity_type": "concept"}
287 ],
288 "relationships": [
289 {"source": "Neural Network", "target": "Machine Learning", "relation": "is_a"},
290 {"source": "Backpropagation", "target": "Neural Network", "relation": "used_by"}
291 ]
292 }"#;
293
294 let result = parse_semantic_response(json, "paper.pdf").unwrap();
295 assert_eq!(result.nodes.len(), 3);
296 assert_eq!(result.edges.len(), 2);
297 assert!(
298 result
299 .nodes
300 .iter()
301 .all(|n| n.node_type == NodeType::Concept)
302 );
303 assert_eq!(result.edges[0].relation, "is_a");
304 }
305
306 #[test]
307 fn parse_markdown_wrapped_json() {
308 let text = r#"Here is the extraction:
309```json
310{
311 "entities": [{"name": "Foo", "entity_type": "class"}],
312 "relationships": []
313}
314```
315"#;
316 let result = parse_semantic_response(text, "doc.md").unwrap();
317 assert_eq!(result.nodes.len(), 1);
318 assert_eq!(result.nodes[0].label, "Foo");
319 assert_eq!(result.nodes[0].node_type, NodeType::Class);
320 }
321
322 #[test]
323 fn parse_empty_response() {
324 let json = r#"{"entities": [], "relationships": []}"#;
325 let result = parse_semantic_response(json, "empty.txt").unwrap();
326 assert!(result.nodes.is_empty());
327 assert!(result.edges.is_empty());
328 }
329
330 #[test]
331 fn extract_json_block_plain() {
332 assert_eq!(extract_json_block(r#"{"a": 1}"#), r#"{"a": 1}"#);
333 }
334
335 #[test]
336 fn extract_json_block_fenced() {
337 let text = "blah\n```json\n{\"a\": 1}\n```\nmore";
338 assert_eq!(extract_json_block(text), r#"{"a": 1}"#);
339 }
340
341 #[test]
342 fn semantic_edges_are_inferred_confidence() {
343 let json = r#"{
344 "entities": [
345 {"name": "A", "entity_type": "concept"},
346 {"name": "B", "entity_type": "concept"}
347 ],
348 "relationships": [
349 {"source": "A", "target": "B", "relation": "depends_on"}
350 ]
351 }"#;
352 let result = parse_semantic_response(json, "test.md").unwrap();
353 assert_eq!(result.edges[0].confidence, Confidence::Inferred);
354 }
355
356 #[test]
357 fn build_prompts_contain_file_type() {
358 let sys = build_system_prompt("paper");
359 assert!(sys.contains("paper"));
360
361 let user = build_user_prompt("hello world", "document");
362 assert!(user.contains("document"));
363 assert!(user.contains("hello world"));
364 }
365}