Skip to main content

graphify_extract/semantic/
openai_compat.rs

1use std::path::Path;
2
3use anyhow::{Context, Result};
4use graphify_core::model::ExtractionResult;
5use serde::{Deserialize, Serialize};
6
7use super::provider::LLMProvider;
8
9#[derive(Serialize)]
10struct ChatRequest {
11    model: String,
12    max_tokens: u32,
13    messages: Vec<ChatMessage>,
14}
15
16#[derive(Serialize, Deserialize)]
17struct ChatMessage {
18    role: String,
19    content: String,
20}
21
22#[derive(Deserialize)]
23struct ChatResponse {
24    choices: Vec<ChatChoice>,
25}
26
27#[derive(Deserialize)]
28struct ChatChoice {
29    message: ChatMessageResponse,
30}
31
32#[derive(Deserialize)]
33struct ChatMessageResponse {
34    content: Option<String>,
35}
36
37pub async fn extract_openai_compatible(
38    path: &Path,
39    content: &str,
40    file_type: &str,
41    provider: LLMProvider,
42    model: &str,
43    api_key: Option<&str>,
44    base_url: &str,
45) -> Result<ExtractionResult> {
46    let file_str = path.to_string_lossy();
47    let system_prompt = super::build_system_prompt(file_type);
48    let user_prompt = super::build_user_prompt(content, file_type);
49
50    let request_body = ChatRequest {
51        model: model.to_string(),
52        max_tokens: 4096,
53        messages: vec![
54            ChatMessage {
55                role: "system".to_string(),
56                content: system_prompt,
57            },
58            ChatMessage {
59                role: "user".to_string(),
60                content: user_prompt,
61            },
62        ],
63    };
64
65    let client = reqwest::Client::new();
66    let mut request = client
67        .post(format!("{base_url}/chat/completions"))
68        .header("content-type", "application/json")
69        .json(&request_body);
70
71    if let Some(key) = api_key {
72        request = request.header("authorization", format!("Bearer {key}"));
73    }
74
75    let response = request.send().await.with_context(|| {
76        format!("Cannot connect to {base_url}. Make sure the server is running.")
77    })?;
78
79    if response.status().as_u16() == 401 {
80        match provider {
81            LLMProvider::OpenAI => {
82                anyhow::bail!(
83                    "OpenAI API key invalid. Set OPENAI_API_KEY or configure in graphify.toml."
84                );
85            }
86            _ => {
87                anyhow::bail!(
88                    "Authentication failed for {base_url}. Check your API key in graphify.toml."
89                );
90            }
91        }
92    }
93
94    if response.status().as_u16() == 404 {
95        match provider {
96            LLMProvider::Ollama => {
97                anyhow::bail!("Model '{model}' not found. Run: ollama pull {model}");
98            }
99            LLMProvider::OpenAI => {
100                anyhow::bail!(
101                    "Model '{model}' not found. Check available models at platform.openai.com"
102                );
103            }
104            _ => {
105                anyhow::bail!(
106                    "Model '{model}' not found at {base_url}. Check that the model is available."
107                );
108            }
109        }
110    }
111
112    if !response.status().is_success() {
113        let status = response.status();
114        let body = response.text().await.unwrap_or_default();
115        anyhow::bail!("LLM API at {base_url} returned {status}: {body}");
116    }
117
118    let chat_resp: ChatResponse = response
119        .json()
120        .await
121        .context("failed to parse LLM API response")?;
122
123    let text = chat_resp
124        .choices
125        .first()
126        .and_then(|c| c.message.content.as_deref())
127        .unwrap_or("{}");
128
129    super::parse_semantic_response(text, &file_str)
130}