mrdocument/
chatgpt.rs

1use crate::api_key;
2use crate::document::DocumentData;
3use crate::error::{Error, Result};
4use crate::file_info::FileInfo;
5use crate::profile::ChatGptProfile;
6use openai_api_rs::v1::api::OpenAIClient;
7use openai_api_rs::v1::chat_completion::{
8    ChatCompletionMessage, ChatCompletionRequest, MessageRole, Tool, ToolChoiceType,
9};
10use openai_api_rs::v1::chat_completion::{Content, ContentType, ImageUrl, ImageUrlType};
11use tokio::time::{timeout, Duration};
12use serde_json::json;
13
14fn default_tools() -> Vec<Tool> {
15    vec![serde_json::from_value(json!({
16        "type": "function",
17        "function": {
18            "name": "return_document_data",
19            "description": "Please use this function to return the transcribed content \
20                of the document, your summary of the content, \
21                your classification of the document, the source of the document, \
22                the keywords you assigned, \
23                the title you assigned and the date you determined.",
24            "parameters": {
25                "type": "object",
26                "properties": {
27                    "content": {
28                        "type": "string",
29                        "description": "The contents of the document"
30                    },
31                    "summary": {
32                        "type": "string",
33                        "description": "Your summary of the content"
34                    },
35                    "class": {
36                        "type": "string",
37                        "description": "The class you assigned to the document"
38                    },
39                    "source": {
40                        "type": "string",
41                        "description": "The source you assigned to the document"
42                    },
43                    "keywords": {
44                        "type": "array",
45                        "items": { "type": "string" },
46                        "description": "The keywords assigned to the document"
47                    },
48                    "title": {
49                        "type": "string",
50                        "description": "The title assigned to the document"
51                    },
52                    "date": {
53                        "type": "string",
54                        "description": "The date assigned to the document in YYYY-MM-DD"
55                    }
56                },
57                "required": [
58                    "summary",
59                    "class",
60                    "keywords",
61                    "title",
62                    "date"
63                ]
64            }
65        }
66    }))
67    .unwrap()]
68}
69
70fn make_outputs() -> Vec<String> {
71    vec![
72        "* A transcription of the contents of the document. If the document is too large to provide a full transcription, you may omit this.".to_string(),
73         "* A summary of the content of the entire document.".to_string(),
74         "* A classification of the document. Please use rather broad and general concepts as classes. The class must be usable as part of a filename and must not contain whitespaces or non-ascii characters. Please favor using hyphens over underscores as separators. The grammatical number of the word used as class should be singular if possible.".to_string(),
75        "* The source of the document. This could be the author, creator, sender or issuer of the document. The source must be usable as part of a filename and must not contain whitespaces or non-ascii characters.".to_string(),
76        "* Between 2 and 4 keywords describing the content of the document.".to_string(),
77        "* A title describing the document. It should be sufficiently specific to differentiate this particular document from other documents of this class and source, but it should not duplicate words that are already found as class or source. The title must be usable as part of a filename and must not contain whitespaces or non-ascii characters.".to_string(),
78        "* A date to be associated with the document. Please favor the date when the document was issued over any other dates found.".to_string(),
79    ]
80}
81
82fn make_specs(classes: Vec<String>, sources: Vec<String>) -> Vec<String> {
83    let mut result = vec![
84        "* Please make sure that the language of all outputs matches the language of the input document.".to_string(),
85    ];
86    if classes.len() > 0 {
87        result.push(format!("* When choosing the class of the document, check if any of these classes match before creating a new one: {}", classes.join(", ")));
88    }
89    if sources.len() > 0 {
90        result.push(format!("* When choosing the source of the document, check if any of these sources match before creating a new one: {}", sources.join(", ")));
91    }
92
93    result
94}
95
96fn make_instructions(classes: Vec<String>, sources: Vec<String>) -> Vec<ChatCompletionMessage> {
97    let outputs = make_outputs().join("\n");
98    let specs = make_specs(classes, sources).join("\n");
99    vec![serde_json::from_value(json!({
100        "role": "system",
101        "content": format!("You will be given a scan of a document. It may consist of one or more pages. You shall provide as output in the language of the document:\n{outputs}\n\nWhen producing the output, you shall observe the following points:\n{specs}\n"),
102    })).unwrap()]
103}
104
105pub async fn query_ai(
106    profile: ChatGptProfile,
107    file_info: FileInfo,
108    classes: Vec<String>,
109    sources: Vec<String>,
110) -> Result<DocumentData> {
111    log::info!("Received {file_info:?}");
112    let api_key = api_key::get();
113    let client = OpenAIClient::builder()
114        .with_api_key(api_key)
115        .build()
116        .map_err(|_| Error::NoApiKeyError)?;
117    let files: Vec<String> = file_info
118        .base64()
119        .await?
120        .into_iter()
121        .map(|data| format!("data:{};base64,{}", file_info.mime_type(), data))
122        .collect();
123
124    let tools = default_tools();
125    let mut messages = make_instructions(classes, sources);
126    for instr in profile.additional_instructions {
127        messages.push(ChatCompletionMessage {
128            role: MessageRole::system,
129            content: Content::Text(instr),
130            name: None,
131            tool_calls: None,
132            tool_call_id: None,
133        });
134    }
135    log::debug!("Using instructions: {messages:?}");
136    for file in files {
137        messages.push(ChatCompletionMessage {
138            role: MessageRole::user,
139            content: Content::ImageUrl(vec![ImageUrl {
140                r#type: ContentType::image_url,
141                text: None,
142                image_url: Some(ImageUrlType { url: file }),
143            }]),
144            name: None,
145            tool_calls: None,
146            tool_call_id: None,
147        });
148    }
149    let req = ChatCompletionRequest::new(profile.model, messages)
150        .temperature(<u8 as Into<f64>>::into(profile.temperature) / 100.0)
151        .tools(tools)
152        .tool_choice(ToolChoiceType::Required);
153    log::info!("Sending {file_info:?}");
154    let response = timeout(Duration::from_secs(300), client.chat_completion(req)).await??;
155    log::trace!("received response");
156    let result_value: Result<serde_json::Value> = (|| {
157        Ok(serde_json::from_str(
158            &response
159                .choices
160                .first()
161                .ok_or_else(|| Error::DoesNotProcessError(None))?
162                .message
163                .tool_calls
164                .as_ref()
165                .ok_or_else(|| Error::DoesNotProcessError(None))?
166                .first()
167                .ok_or_else(|| Error::DoesNotProcessError(None))?
168                .function
169                .arguments
170                .as_ref()
171                .ok_or_else(|| Error::DoesNotProcessError(None))?,
172        )?)
173    })()
174    .map_err(|err| {
175        if let Error::DoesNotProcessError(_) = err {
176            Error::DoesNotProcessError(Some(response))
177        } else {
178            err
179        }
180    });
181    let result: Result<DocumentData> = Ok(serde_json::from_value(result_value?)?);
182
183    result
184}