Skip to main content

mdmodels_core/llm/
extraction.rs

1/*
2 * Copyright (c) 2025 Jan Range
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice shall be included in
12 * all copies or substantial portions of the Software.
13 *
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 * THE SOFTWARE.
21 *
22 */
23
24//! This module provides functionality for extracting structured data from text using LLM APIs.
25//! It handles the communication with OpenAI's API, formatting requests and parsing responses
26//! according to a specified data model schema.
27
28use std::env;
29
30use openai_api_rs::v1::{api::OpenAIClient, chat_completion};
31use serde_json::{json, Value};
32
33use crate::{datamodel::DataModel, json::export::to_json_schema};
34
35/// Queries the OpenAI API with a given prompt and pre-prompt, using a specified data model and root.
36///
37/// # Arguments
38///
39/// * `prompt` - The main prompt to send to the OpenAI API.
40/// * `pre_prompt` - An additional pre-prompt to provide context or setup for the main prompt.
41/// * `data_model` - The data model used to generate the JSON schema for the response format.
42/// * `root` - The root name for the JSON schema.
43/// * `model` - The OpenAI model to use for the chat completion.
44/// * `multiple` - Whether to extract multiple objects.
45/// * `api_key` - Optional API key for OpenAI. If None, will try to read from environment variable.
46///
47/// # Returns
48///
49/// A `Result` containing a `serde_json::Value` with the parsed JSON response from the OpenAI API, or an error if the operation fails.
50///
51/// # Errors
52///
53/// This function will return an error if:
54/// - The JSON schema cannot be generated from the data model
55/// - The OpenAI API key is not provided and not found in environment variables
56/// - The API request fails
57/// - The response cannot be parsed as valid JSON
58pub async fn query_openai(
59    prompt: &str,
60    pre_prompt: &str,
61    data_model: &DataModel,
62    root: &str,
63    model: &str,
64    multiple: bool,
65    api_key: Option<String>,
66) -> Result<Value, Box<dyn std::error::Error>> {
67    let response_format = prepare_response_format(data_model, root, multiple)?;
68    let mut client = prepare_client(api_key)?;
69    let messages = vec![create_chat_message(pre_prompt), create_chat_message(prompt)];
70    let model_type = ModelType::from_str(model)?;
71
72    let req = match model_type {
73        ModelType::Reasoning => {
74            chat_completion::ChatCompletionRequest::new(model.to_string(), messages)
75                .response_format(response_format)
76        }
77        ModelType::Generation => {
78            chat_completion::ChatCompletionRequest::new(model.to_string(), messages)
79                .response_format(response_format)
80                .temperature(0.0)
81        }
82    };
83
84    let result = client.chat_completion(req).await?;
85    let content = result
86        .choices
87        .first()
88        .and_then(|choice| choice.message.content.as_ref())
89        .ok_or_else(|| format!("No content in response from {model}"))?;
90
91    Ok(serde_json::from_str(content)?)
92}
93
94/// Prepares the response format for the OpenAI API request based on the data model.
95///
96/// # Arguments
97///
98/// * `model` - The data model used to generate the JSON schema.
99/// * `root` - The root name for the JSON schema.
100/// * `multiple` - Whether to prepare a format for multiple objects (array) or a single object.
101///
102/// # Returns
103///
104/// A `Result` containing a `serde_json::Value` with the prepared response format, or an error if the operation fails.
105fn prepare_response_format(
106    model: &DataModel,
107    root: &str,
108    multiple: bool,
109) -> Result<Value, Box<dyn std::error::Error>> {
110    let mut schema = to_json_schema(model, root, true)?;
111
112    if multiple {
113        let definitions = schema.definitions.clone();
114        schema.definitions.clear();
115
116        Ok(json!(
117            {
118                "type": "json_schema",
119                "json_schema": {
120                    "strict": true,
121                    "name": root,
122                    "schema": {
123                        "type": "object",
124                        "additionalProperties": false,
125                        "required": ["items"],
126                        "properties": {
127                            "items": {
128                                "type": "array",
129                                "items": schema
130                            }
131                        },
132                        "$defs": definitions
133                    }
134                }
135            }
136        ))
137    } else {
138        Ok(json!(
139                {
140                    "type": "json_schema",
141                    "json_schema": {
142                        "name": root,
143                        "strict": true,
144                        "schema": schema
145                    }
146                }
147        ))
148    }
149}
150
151/// Prepares the OpenAI client with the provided API key or from environment variables.
152///
153/// # Arguments
154///
155/// * `api_key` - An optional API key for OpenAI. If None, will try to read from the OPENAI_API_KEY environment variable.
156///
157/// # Returns
158///
159/// A `Result` containing the configured `OpenAIClient`, or an error if the API key is not available.
160fn prepare_client(api_key: Option<String>) -> Result<OpenAIClient, Box<dyn std::error::Error>> {
161    let api_key = match api_key {
162        Some(api_key) => api_key,
163        None => env::var("OPENAI_API_KEY")?,
164    };
165
166    OpenAIClient::builder().with_api_key(api_key).build()
167}
168
169/// Creates a chat message for the OpenAI API with the specified content.
170///
171/// # Arguments
172///
173/// * `content` - The text content of the message.
174///
175/// # Returns
176///
177/// A `ChatCompletionMessage` configured with the user role and provided content.
178fn create_chat_message(content: &str) -> chat_completion::ChatCompletionMessage {
179    chat_completion::ChatCompletionMessage {
180        role: chat_completion::MessageRole::user,
181        content: chat_completion::Content::Text(content.to_string()),
182        name: None,
183        tool_calls: None,
184        tool_call_id: None,
185    }
186}
187
188/// Represents the type of OpenAI model being used, which affects request configuration.
189enum ModelType {
190    /// Models optimized for reasoning tasks (e.g., GPT-4o)
191    Reasoning,
192    /// Models optimized for text generation (e.g., GPT-3.5-turbo)
193    Generation,
194}
195
196impl ModelType {
197    /// Determines the model type from a model name string.
198    ///
199    /// # Arguments
200    ///
201    /// * `s` - The model name string to parse.
202    ///
203    /// # Returns
204    ///
205    /// A `Result` containing the determined `ModelType`, or an error if the model name is invalid.
206    fn from_str(s: &str) -> Result<Self, String> {
207        // Use a regex to check if the model is a reasoning model
208        if regex::Regex::new(r"o\d*").unwrap().is_match(s) {
209            Ok(ModelType::Reasoning)
210        } else {
211            Ok(ModelType::Generation)
212        }
213    }
214}