Skip to main content

graphify_extract/semantic/
mod.rs

1//! Semantic extraction via LLM APIs (Pass 2).
2//!
3//! Supports multiple LLM providers through a dual-path architecture:
4//! - Anthropic (Messages API + OAuth token support)
5//! - OpenAI-compatible (Chat Completions API: OpenAI, Ollama, vLLM, etc.)
6
7pub mod anthropic;
8pub mod anthropic_oauth;
9pub mod openai_compat;
10pub mod provider;
11
12use std::collections::HashMap;
13use std::path::Path;
14
15use anyhow::{Context, Result};
16use graphify_core::confidence::Confidence;
17use graphify_core::id::make_id;
18use graphify_core::model::{ExtractionResult, GraphEdge, GraphNode, NodeType};
19use serde::Deserialize;
20
21pub use provider::{AuthType, LLMConfigRaw, LLMProvider, LLMProviderConfig};
22
23/// Entities and relationships extracted by the LLM.
24#[derive(Deserialize, Debug)]
25struct SemanticOutput {
26    #[serde(default)]
27    entities: Vec<SemanticEntity>,
28    #[serde(default)]
29    relationships: Vec<SemanticRelation>,
30}
31
32#[derive(Deserialize, Debug)]
33struct SemanticEntity {
34    name: String,
35    #[serde(default = "default_entity_type")]
36    entity_type: String,
37}
38
39fn default_entity_type() -> String {
40    "concept".to_string()
41}
42
43#[derive(Deserialize, Debug)]
44struct SemanticRelation {
45    source: String,
46    target: String,
47    #[serde(default = "default_relation")]
48    relation: String,
49}
50
51fn default_relation() -> String {
52    "related_to".to_string()
53}
54
55/// Extract semantic concepts from a document, paper, or image using an LLM.
56///
57/// Dispatches to the appropriate provider based on `config.provider`.
58pub async fn extract_semantic(
59    path: &Path,
60    content: &str,
61    file_type: &str,
62    config: &LLMProviderConfig,
63) -> Result<ExtractionResult> {
64    match config.provider {
65        LLMProvider::Anthropic => {
66            anthropic::extract_anthropic(path, content, file_type, config).await
67        }
68        LLMProvider::OpenAI | LLMProvider::Ollama | LLMProvider::OpenAICompatible => {
69            openai_compat::extract_openai_compatible(
70                path,
71                content,
72                file_type,
73                config.provider.clone(),
74                &config.model,
75                config.api_key.as_deref(),
76                &config.base_url,
77            )
78            .await
79        }
80    }
81}
82
83fn build_system_prompt(file_type: &str) -> String {
84    format!(
85        "You are an expert knowledge-graph extraction engine. \
86         Given a {file_type}, extract entities and their relationships. \
87         Respond ONLY with a JSON object having two arrays: \
88         \"entities\" (each with \"name\" and \"entity_type\") and \
89         \"relationships\" (each with \"source\", \"target\", and \"relation\"). \
90         Entity types should be one of: concept, class, function, module, paper, image. \
91         Keep entity names concise and unique."
92    )
93}
94
95fn build_user_prompt(content: &str, file_type: &str) -> String {
96    let max_chars = 100_000;
97    let truncated = if content.len() > max_chars {
98        let mut end = max_chars;
99        while end > 0 && !content.is_char_boundary(end) {
100            end -= 1;
101        }
102        &content[..end]
103    } else {
104        content
105    };
106
107    let is_truncated = content.len() > max_chars;
108    let note = if is_truncated {
109        "\n\n[NOTE: This file was truncated — only the first portion is shown. Extract entities only from the visible portion.]"
110    } else {
111        ""
112    };
113    format!("Extract all entities and relationships from this {file_type}:\n\n{truncated}{note}")
114}
115
116fn parse_semantic_response(text: &str, file_str: &str) -> Result<ExtractionResult> {
117    let json_str = extract_json_block(text);
118
119    let output: SemanticOutput =
120        serde_json::from_str(json_str).context("failed to parse semantic extraction JSON")?;
121
122    let mut nodes = Vec::new();
123    let mut edges = Vec::new();
124
125    let mut name_to_id: HashMap<String, String> = HashMap::new();
126    for entity in &output.entities {
127        let id = make_id(&[file_str, &entity.name]);
128        let node_type = match entity.entity_type.as_str() {
129            "class" => NodeType::Class,
130            "function" => NodeType::Function,
131            "module" => NodeType::Module,
132            "paper" => NodeType::Paper,
133            "image" => NodeType::Image,
134            _ => NodeType::Concept,
135        };
136        name_to_id.insert(entity.name.clone(), id.clone());
137        nodes.push(GraphNode {
138            id,
139            label: entity.name.clone(),
140            source_file: file_str.to_string(),
141            source_location: None,
142            node_type,
143            community: None,
144            extra: HashMap::new(),
145        });
146    }
147
148    for rel in &output.relationships {
149        let source_id = name_to_id
150            .get(&rel.source)
151            .cloned()
152            .unwrap_or_else(|| make_id(&[file_str, &rel.source]));
153        let target_id = name_to_id
154            .get(&rel.target)
155            .cloned()
156            .unwrap_or_else(|| make_id(&[file_str, &rel.target]));
157
158        edges.push(GraphEdge {
159            source: source_id,
160            target: target_id,
161            relation: rel.relation.clone(),
162            confidence: Confidence::Inferred,
163            confidence_score: Confidence::Inferred.default_score(),
164            source_file: file_str.to_string(),
165            source_location: None,
166            weight: 1.0,
167            provenance: Some("llm:semantic".to_string()),
168            extra: HashMap::new(),
169        });
170    }
171
172    Ok(ExtractionResult {
173        nodes,
174        edges,
175        hyperedges: Vec::new(),
176    })
177}
178
179/// Extract a JSON block from text that might be wrapped in markdown fences.
180fn extract_json_block(text: &str) -> &str {
181    if let Some(start) = text.find("```json") {
182        let after = &text[start + 7..];
183        if let Some(end) = after.find("```") {
184            return after[..end].trim();
185        }
186    }
187    if let Some(start) = text.find("```") {
188        let after = &text[start + 3..];
189        if let Some(end) = after.find("```") {
190            return after[..end].trim();
191        }
192    }
193    if let Some(start) = text.find('{')
194        && let Some(end) = text.rfind('}')
195    {
196        return &text[start..=end];
197    }
198    text.trim()
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    #[test]
206    fn parse_semantic_json() {
207        let json = r#"{
208            "entities": [
209                {"name": "Machine Learning", "entity_type": "concept"},
210                {"name": "Neural Network", "entity_type": "concept"},
211                {"name": "Backpropagation", "entity_type": "concept"}
212            ],
213            "relationships": [
214                {"source": "Neural Network", "target": "Machine Learning", "relation": "is_a"},
215                {"source": "Backpropagation", "target": "Neural Network", "relation": "used_by"}
216            ]
217        }"#;
218
219        let result = parse_semantic_response(json, "paper.pdf").unwrap();
220        assert_eq!(result.nodes.len(), 3);
221        assert_eq!(result.edges.len(), 2);
222        assert!(
223            result
224                .nodes
225                .iter()
226                .all(|n| n.node_type == NodeType::Concept)
227        );
228        assert_eq!(result.edges[0].relation, "is_a");
229    }
230
231    #[test]
232    fn parse_markdown_wrapped_json() {
233        let text = r#"Here is the extraction:
234```json
235{
236    "entities": [{"name": "Foo", "entity_type": "class"}],
237    "relationships": []
238}
239```
240"#;
241        let result = parse_semantic_response(text, "doc.md").unwrap();
242        assert_eq!(result.nodes.len(), 1);
243        assert_eq!(result.nodes[0].label, "Foo");
244        assert_eq!(result.nodes[0].node_type, NodeType::Class);
245    }
246
247    #[test]
248    fn parse_empty_response() {
249        let json = r#"{"entities": [], "relationships": []}"#;
250        let result = parse_semantic_response(json, "empty.txt").unwrap();
251        assert!(result.nodes.is_empty());
252        assert!(result.edges.is_empty());
253    }
254
255    #[test]
256    fn extract_json_block_plain() {
257        assert_eq!(extract_json_block(r#"{"a": 1}"#), r#"{"a": 1}"#);
258    }
259
260    #[test]
261    fn extract_json_block_fenced() {
262        let text = "blah\n```json\n{\"a\": 1}\n```\nmore";
263        assert_eq!(extract_json_block(text), r#"{"a": 1}"#);
264    }
265
266    #[test]
267    fn semantic_edges_are_inferred_confidence() {
268        let json = r#"{
269            "entities": [
270                {"name": "A", "entity_type": "concept"},
271                {"name": "B", "entity_type": "concept"}
272            ],
273            "relationships": [
274                {"source": "A", "target": "B", "relation": "depends_on"}
275            ]
276        }"#;
277        let result = parse_semantic_response(json, "test.md").unwrap();
278        assert_eq!(result.edges[0].confidence, Confidence::Inferred);
279    }
280
281    #[test]
282    fn build_prompts_contain_file_type() {
283        let sys = build_system_prompt("paper");
284        assert!(sys.contains("paper"));
285
286        let user = build_user_prompt("hello world", "document");
287        assert!(user.contains("document"));
288        assert!(user.contains("hello world"));
289    }
290}