mdmodels_core/llm/extraction.rs
1/*
2 * Copyright (c) 2025 Jan Range
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice shall be included in
12 * all copies or substantial portions of the Software.
13 *
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 * THE SOFTWARE.
21 *
22 */
23
24//! This module provides functionality for extracting structured data from text using LLM APIs.
25//! It handles the communication with OpenAI's API, formatting requests and parsing responses
26//! according to a specified data model schema.
27
28use std::env;
29
30use openai_api_rs::v1::{api::OpenAIClient, chat_completion};
31use serde_json::{json, Value};
32
33use crate::{datamodel::DataModel, json::export::to_json_schema};
34
35/// Queries the OpenAI API with a given prompt and pre-prompt, using a specified data model and root.
36///
37/// # Arguments
38///
39/// * `prompt` - The main prompt to send to the OpenAI API.
40/// * `pre_prompt` - An additional pre-prompt to provide context or setup for the main prompt.
41/// * `data_model` - The data model used to generate the JSON schema for the response format.
42/// * `root` - The root name for the JSON schema.
43/// * `model` - The OpenAI model to use for the chat completion.
44/// * `multiple` - Whether to extract multiple objects.
45/// * `api_key` - Optional API key for OpenAI. If None, will try to read from environment variable.
46///
47/// # Returns
48///
49/// A `Result` containing a `serde_json::Value` with the parsed JSON response from the OpenAI API, or an error if the operation fails.
50///
51/// # Errors
52///
53/// This function will return an error if:
54/// - The JSON schema cannot be generated from the data model
55/// - The OpenAI API key is not provided and not found in environment variables
56/// - The API request fails
57/// - The response cannot be parsed as valid JSON
58pub async fn query_openai(
59 prompt: &str,
60 pre_prompt: &str,
61 data_model: &DataModel,
62 root: &str,
63 model: &str,
64 multiple: bool,
65 api_key: Option<String>,
66) -> Result<Value, Box<dyn std::error::Error>> {
67 let response_format = prepare_response_format(data_model, root, multiple)?;
68 let mut client = prepare_client(api_key)?;
69 let messages = vec![create_chat_message(pre_prompt), create_chat_message(prompt)];
70 let model_type = ModelType::from_str(model)?;
71
72 let req = match model_type {
73 ModelType::Reasoning => {
74 chat_completion::ChatCompletionRequest::new(model.to_string(), messages)
75 .response_format(response_format)
76 }
77 ModelType::Generation => {
78 chat_completion::ChatCompletionRequest::new(model.to_string(), messages)
79 .response_format(response_format)
80 .temperature(0.0)
81 }
82 };
83
84 let result = client.chat_completion(req).await?;
85 let content = result
86 .choices
87 .first()
88 .and_then(|choice| choice.message.content.as_ref())
89 .ok_or_else(|| format!("No content in response from {model}"))?;
90
91 Ok(serde_json::from_str(content)?)
92}
93
94/// Prepares the response format for the OpenAI API request based on the data model.
95///
96/// # Arguments
97///
98/// * `model` - The data model used to generate the JSON schema.
99/// * `root` - The root name for the JSON schema.
100/// * `multiple` - Whether to prepare a format for multiple objects (array) or a single object.
101///
102/// # Returns
103///
104/// A `Result` containing a `serde_json::Value` with the prepared response format, or an error if the operation fails.
105fn prepare_response_format(
106 model: &DataModel,
107 root: &str,
108 multiple: bool,
109) -> Result<Value, Box<dyn std::error::Error>> {
110 let mut schema = to_json_schema(model, root, true)?;
111
112 if multiple {
113 let definitions = schema.definitions.clone();
114 schema.definitions.clear();
115
116 Ok(json!(
117 {
118 "type": "json_schema",
119 "json_schema": {
120 "strict": true,
121 "name": root,
122 "schema": {
123 "type": "object",
124 "additionalProperties": false,
125 "required": ["items"],
126 "properties": {
127 "items": {
128 "type": "array",
129 "items": schema
130 }
131 },
132 "$defs": definitions
133 }
134 }
135 }
136 ))
137 } else {
138 Ok(json!(
139 {
140 "type": "json_schema",
141 "json_schema": {
142 "name": root,
143 "strict": true,
144 "schema": schema
145 }
146 }
147 ))
148 }
149}
150
151/// Prepares the OpenAI client with the provided API key or from environment variables.
152///
153/// # Arguments
154///
155/// * `api_key` - An optional API key for OpenAI. If None, will try to read from the OPENAI_API_KEY environment variable.
156///
157/// # Returns
158///
159/// A `Result` containing the configured `OpenAIClient`, or an error if the API key is not available.
160fn prepare_client(api_key: Option<String>) -> Result<OpenAIClient, Box<dyn std::error::Error>> {
161 let api_key = match api_key {
162 Some(api_key) => api_key,
163 None => env::var("OPENAI_API_KEY")?,
164 };
165
166 OpenAIClient::builder().with_api_key(api_key).build()
167}
168
169/// Creates a chat message for the OpenAI API with the specified content.
170///
171/// # Arguments
172///
173/// * `content` - The text content of the message.
174///
175/// # Returns
176///
177/// A `ChatCompletionMessage` configured with the user role and provided content.
178fn create_chat_message(content: &str) -> chat_completion::ChatCompletionMessage {
179 chat_completion::ChatCompletionMessage {
180 role: chat_completion::MessageRole::user,
181 content: chat_completion::Content::Text(content.to_string()),
182 name: None,
183 tool_calls: None,
184 tool_call_id: None,
185 }
186}
187
188/// Represents the type of OpenAI model being used, which affects request configuration.
189enum ModelType {
190 /// Models optimized for reasoning tasks (e.g., GPT-4o)
191 Reasoning,
192 /// Models optimized for text generation (e.g., GPT-3.5-turbo)
193 Generation,
194}
195
196impl ModelType {
197 /// Determines the model type from a model name string.
198 ///
199 /// # Arguments
200 ///
201 /// * `s` - The model name string to parse.
202 ///
203 /// # Returns
204 ///
205 /// A `Result` containing the determined `ModelType`, or an error if the model name is invalid.
206 fn from_str(s: &str) -> Result<Self, String> {
207 // Use a regex to check if the model is a reasoning model
208 if regex::Regex::new(r"o\d*").unwrap().is_match(s) {
209 Ok(ModelType::Reasoning)
210 } else {
211 Ok(ModelType::Generation)
212 }
213 }
214}