Skip to main content

graphify_extract/
semantic.rs

1//! Semantic extraction via Claude API (Pass 2).
2//!
3//! Extracts higher-level concepts and relationships from documents, papers, and
4//! images using the Anthropic Messages API. This is the second pass of the
5//! extraction pipeline — it complements the deterministic AST extraction from
6//! Pass 1 by discovering semantic relationships that cannot be inferred from
7//! syntax alone.
8
9use std::collections::HashMap;
10use std::path::Path;
11
12use anyhow::{Context, Result};
13use graphify_core::confidence::Confidence;
14use graphify_core::id::make_id;
15use graphify_core::model::{ExtractionResult, GraphEdge, GraphNode, NodeType};
16use serde::{Deserialize, Serialize};
17use tracing::debug;
18
19// ---------------------------------------------------------------------------
20// Claude API request/response types
21// ---------------------------------------------------------------------------
22
23#[derive(Serialize)]
24struct MessageRequest {
25    model: String,
26    max_tokens: u32,
27    messages: Vec<Message>,
28    system: String,
29}
30
31#[derive(Serialize)]
32struct Message {
33    role: String,
34    content: String,
35}
36
37#[derive(Deserialize)]
38struct MessageResponse {
39    content: Vec<ContentBlock>,
40}
41
42#[derive(Deserialize)]
43struct ContentBlock {
44    text: Option<String>,
45}
46
47/// Entities and relationships extracted by the LLM.
48#[derive(Deserialize, Debug)]
49struct SemanticOutput {
50    #[serde(default)]
51    entities: Vec<SemanticEntity>,
52    #[serde(default)]
53    relationships: Vec<SemanticRelation>,
54}
55
56#[derive(Deserialize, Debug)]
57struct SemanticEntity {
58    name: String,
59    #[serde(default = "default_entity_type")]
60    entity_type: String,
61}
62
63fn default_entity_type() -> String {
64    "concept".to_string()
65}
66
67#[derive(Deserialize, Debug)]
68struct SemanticRelation {
69    source: String,
70    target: String,
71    #[serde(default = "default_relation")]
72    relation: String,
73}
74
75fn default_relation() -> String {
76    "related_to".to_string()
77}
78
79// ---------------------------------------------------------------------------
80// Public API
81// ---------------------------------------------------------------------------
82
83/// Extract semantic concepts from a document, paper, or image using the Claude API.
84///
85/// # Arguments
86/// * `path` — the file path (used for source_file metadata)
87/// * `content` — the text content to analyse
88/// * `file_type` — one of `"document"`, `"paper"`, or `"image"`
89/// * `api_key` — Anthropic API key
90///
91/// # Errors
92/// Returns an error if the HTTP request fails or the response cannot be parsed.
93pub async fn extract_semantic(
94    path: &Path,
95    content: &str,
96    file_type: &str,
97    api_key: &str,
98) -> Result<ExtractionResult> {
99    let file_str = path.to_string_lossy();
100    let system_prompt = build_system_prompt(file_type);
101    let user_prompt = build_user_prompt(content, file_type);
102
103    debug!("sending semantic extraction request for {}", file_str);
104
105    let request_body = MessageRequest {
106        model: "claude-sonnet-4-20250514".to_string(),
107        max_tokens: 4096,
108        messages: vec![Message {
109            role: "user".to_string(),
110            content: user_prompt,
111        }],
112        system: system_prompt,
113    };
114
115    let client = reqwest::Client::new();
116    let response = client
117        .post("https://api.anthropic.com/v1/messages")
118        .header("x-api-key", api_key)
119        .header("anthropic-version", "2023-06-01")
120        .header("content-type", "application/json")
121        .json(&request_body)
122        .send()
123        .await
124        .context("failed to send request to Claude API")?;
125
126    if !response.status().is_success() {
127        let status = response.status();
128        let body = response.text().await.unwrap_or_default();
129        anyhow::bail!("Claude API returned {status}: {body}");
130    }
131
132    let msg: MessageResponse = response
133        .json()
134        .await
135        .context("failed to parse Claude API response")?;
136
137    let text = msg
138        .content
139        .first()
140        .and_then(|b| b.text.as_deref())
141        .unwrap_or("{}");
142
143    parse_semantic_response(text, &file_str)
144}
145
146// ---------------------------------------------------------------------------
147// Prompt construction
148// ---------------------------------------------------------------------------
149
150fn build_system_prompt(file_type: &str) -> String {
151    format!(
152        "You are an expert knowledge-graph extraction engine. \
153         Given a {file_type}, extract entities and their relationships. \
154         Respond ONLY with a JSON object having two arrays: \
155         \"entities\" (each with \"name\" and \"entity_type\") and \
156         \"relationships\" (each with \"source\", \"target\", and \"relation\"). \
157         Entity types should be one of: concept, class, function, module, paper, image. \
158         Keep entity names concise and unique."
159    )
160}
161
162fn build_user_prompt(content: &str, file_type: &str) -> String {
163    // Truncate very long content
164    let max_chars = 100_000;
165    let truncated = if content.len() > max_chars {
166        let mut end = max_chars;
167        while end > 0 && !content.is_char_boundary(end) {
168            end -= 1;
169        }
170        &content[..end]
171    } else {
172        content
173    };
174
175    format!("Extract all entities and relationships from this {file_type}:\n\n{truncated}")
176}
177
178// ---------------------------------------------------------------------------
179// Response parsing
180// ---------------------------------------------------------------------------
181
182fn parse_semantic_response(text: &str, file_str: &str) -> Result<ExtractionResult> {
183    // Try to find JSON in the response (might be wrapped in markdown fences)
184    let json_str = extract_json_block(text);
185
186    let output: SemanticOutput =
187        serde_json::from_str(json_str).context("failed to parse semantic extraction JSON")?;
188
189    let mut nodes = Vec::new();
190    let mut edges = Vec::new();
191
192    // Convert entities to nodes
193    let mut name_to_id: HashMap<String, String> = HashMap::new();
194    for entity in &output.entities {
195        let id = make_id(&[file_str, &entity.name]);
196        let node_type = match entity.entity_type.as_str() {
197            "class" => NodeType::Class,
198            "function" => NodeType::Function,
199            "module" => NodeType::Module,
200            "paper" => NodeType::Paper,
201            "image" => NodeType::Image,
202            _ => NodeType::Concept,
203        };
204        name_to_id.insert(entity.name.clone(), id.clone());
205        nodes.push(GraphNode {
206            id,
207            label: entity.name.clone(),
208            source_file: file_str.to_string(),
209            source_location: None,
210            node_type,
211            community: None,
212            extra: HashMap::new(),
213        });
214    }
215
216    // Convert relationships to edges
217    for rel in &output.relationships {
218        let source_id = name_to_id
219            .get(&rel.source)
220            .cloned()
221            .unwrap_or_else(|| make_id(&[file_str, &rel.source]));
222        let target_id = name_to_id
223            .get(&rel.target)
224            .cloned()
225            .unwrap_or_else(|| make_id(&[file_str, &rel.target]));
226
227        edges.push(GraphEdge {
228            source: source_id,
229            target: target_id,
230            relation: rel.relation.clone(),
231            confidence: Confidence::Inferred,
232            confidence_score: Confidence::Inferred.default_score(),
233            source_file: file_str.to_string(),
234            source_location: None,
235            weight: 1.0,
236            extra: HashMap::new(),
237        });
238    }
239
240    Ok(ExtractionResult {
241        nodes,
242        edges,
243        hyperedges: Vec::new(),
244    })
245}
246
247/// Extract a JSON block from text that might be wrapped in markdown fences.
248fn extract_json_block(text: &str) -> &str {
249    // Try to find ```json ... ``` block
250    if let Some(start) = text.find("```json") {
251        let after = &text[start + 7..];
252        if let Some(end) = after.find("```") {
253            return after[..end].trim();
254        }
255    }
256    // Try to find ``` ... ``` block
257    if let Some(start) = text.find("```") {
258        let after = &text[start + 3..];
259        if let Some(end) = after.find("```") {
260            return after[..end].trim();
261        }
262    }
263    // Try to find { ... } directly
264    if let Some(start) = text.find('{')
265        && let Some(end) = text.rfind('}')
266    {
267        return &text[start..=end];
268    }
269    text.trim()
270}
271
272// ---------------------------------------------------------------------------
273// Tests
274// ---------------------------------------------------------------------------
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn parse_semantic_json() {
282        let json = r#"{
283            "entities": [
284                {"name": "Machine Learning", "entity_type": "concept"},
285                {"name": "Neural Network", "entity_type": "concept"},
286                {"name": "Backpropagation", "entity_type": "concept"}
287            ],
288            "relationships": [
289                {"source": "Neural Network", "target": "Machine Learning", "relation": "is_a"},
290                {"source": "Backpropagation", "target": "Neural Network", "relation": "used_by"}
291            ]
292        }"#;
293
294        let result = parse_semantic_response(json, "paper.pdf").unwrap();
295        assert_eq!(result.nodes.len(), 3);
296        assert_eq!(result.edges.len(), 2);
297        assert!(
298            result
299                .nodes
300                .iter()
301                .all(|n| n.node_type == NodeType::Concept)
302        );
303        assert_eq!(result.edges[0].relation, "is_a");
304    }
305
306    #[test]
307    fn parse_markdown_wrapped_json() {
308        let text = r#"Here is the extraction:
309```json
310{
311    "entities": [{"name": "Foo", "entity_type": "class"}],
312    "relationships": []
313}
314```
315"#;
316        let result = parse_semantic_response(text, "doc.md").unwrap();
317        assert_eq!(result.nodes.len(), 1);
318        assert_eq!(result.nodes[0].label, "Foo");
319        assert_eq!(result.nodes[0].node_type, NodeType::Class);
320    }
321
322    #[test]
323    fn parse_empty_response() {
324        let json = r#"{"entities": [], "relationships": []}"#;
325        let result = parse_semantic_response(json, "empty.txt").unwrap();
326        assert!(result.nodes.is_empty());
327        assert!(result.edges.is_empty());
328    }
329
330    #[test]
331    fn extract_json_block_plain() {
332        assert_eq!(extract_json_block(r#"{"a": 1}"#), r#"{"a": 1}"#);
333    }
334
335    #[test]
336    fn extract_json_block_fenced() {
337        let text = "blah\n```json\n{\"a\": 1}\n```\nmore";
338        assert_eq!(extract_json_block(text), r#"{"a": 1}"#);
339    }
340
341    #[test]
342    fn semantic_edges_are_inferred_confidence() {
343        let json = r#"{
344            "entities": [
345                {"name": "A", "entity_type": "concept"},
346                {"name": "B", "entity_type": "concept"}
347            ],
348            "relationships": [
349                {"source": "A", "target": "B", "relation": "depends_on"}
350            ]
351        }"#;
352        let result = parse_semantic_response(json, "test.md").unwrap();
353        assert_eq!(result.edges[0].confidence, Confidence::Inferred);
354    }
355
356    #[test]
357    fn build_prompts_contain_file_type() {
358        let sys = build_system_prompt("paper");
359        assert!(sys.contains("paper"));
360
361        let user = build_user_prompt("hello world", "document");
362        assert!(user.contains("document"));
363        assert!(user.contains("hello world"));
364    }
365}