1use crate::api_key;
2use crate::document::DocumentData;
3use crate::error::{Error, Result};
4use crate::file_info::FileInfo;
5use crate::profile::ChatGptProfile;
6use openai_api_rs::v1::api::OpenAIClient;
7use openai_api_rs::v1::chat_completion::{
8 ChatCompletionMessage, ChatCompletionRequest, MessageRole, Tool, ToolChoiceType,
9};
10use openai_api_rs::v1::chat_completion::{Content, ContentType, ImageUrl, ImageUrlType};
11use tokio::time::{timeout, Duration};
12use serde_json::json;
13
14fn default_tools() -> Vec<Tool> {
15 vec![serde_json::from_value(json!({
16 "type": "function",
17 "function": {
18 "name": "return_document_data",
19 "description": "Please use this function to return the transcribed content \
20 of the document, your summary of the content, \
21 your classification of the document, the source of the document, \
22 the keywords you assigned, \
23 the title you assigned and the date you determined.",
24 "parameters": {
25 "type": "object",
26 "properties": {
27 "content": {
28 "type": "string",
29 "description": "The contents of the document"
30 },
31 "summary": {
32 "type": "string",
33 "description": "Your summary of the content"
34 },
35 "class": {
36 "type": "string",
37 "description": "The class you assigned to the document"
38 },
39 "source": {
40 "type": "string",
41 "description": "The source you assigned to the document"
42 },
43 "keywords": {
44 "type": "array",
45 "items": { "type": "string" },
46 "description": "The keywords assigned to the document"
47 },
48 "title": {
49 "type": "string",
50 "description": "The title assigned to the document"
51 },
52 "date": {
53 "type": "string",
54 "description": "The date assigned to the document in YYYY-MM-DD"
55 }
56 },
57 "required": [
58 "summary",
59 "class",
60 "keywords",
61 "title",
62 "date"
63 ]
64 }
65 }
66 }))
67 .unwrap()]
68}
69
70fn make_outputs() -> Vec<String> {
71 vec![
72 "* A transcription of the contents of the document. If the document is too large to provide a full transcription, you may omit this.".to_string(),
73 "* A summary of the content of the entire document.".to_string(),
74 "* A classification of the document. Please use rather broad and general concepts as classes. The class must be usable as part of a filename and must not contain whitespaces or non-ascii characters. Please favor using hyphens over underscores as separators. The grammatical number of the word used as class should be singular if possible.".to_string(),
75 "* The source of the document. This could be the author, creator, sender or issuer of the document. The source must be usable as part of a filename and must not contain whitespaces or non-ascii characters.".to_string(),
76 "* Between 2 and 4 keywords describing the content of the document.".to_string(),
77 "* A title describing the document. It should be sufficiently specific to differentiate this particular document from other documents of this class and source, but it should not duplicate words that are already found as class or source. The title must be usable as part of a filename and must not contain whitespaces or non-ascii characters.".to_string(),
78 "* A date to be associated with the document. Please favor the date when the document was issued over any other dates found.".to_string(),
79 ]
80}
81
82fn make_specs(classes: Vec<String>, sources: Vec<String>) -> Vec<String> {
83 let mut result = vec![
84 "* Please make sure that the language of all outputs matches the language of the input document.".to_string(),
85 ];
86 if classes.len() > 0 {
87 result.push(format!("* When choosing the class of the document, check if any of these classes match before creating a new one: {}", classes.join(", ")));
88 }
89 if sources.len() > 0 {
90 result.push(format!("* When choosing the source of the document, check if any of these sources match before creating a new one: {}", sources.join(", ")));
91 }
92
93 result
94}
95
96fn make_instructions(classes: Vec<String>, sources: Vec<String>) -> Vec<ChatCompletionMessage> {
97 let outputs = make_outputs().join("\n");
98 let specs = make_specs(classes, sources).join("\n");
99 vec![serde_json::from_value(json!({
100 "role": "system",
101 "content": format!("You will be given a scan of a document. It may consist of one or more pages. You shall provide as output in the language of the document:\n{outputs}\n\nWhen producing the output, you shall observe the following points:\n{specs}\n"),
102 })).unwrap()]
103}
104
105pub async fn query_ai(
106 profile: ChatGptProfile,
107 file_info: FileInfo,
108 classes: Vec<String>,
109 sources: Vec<String>,
110) -> Result<DocumentData> {
111 log::info!("Received {file_info:?}");
112 let api_key = api_key::get();
113 let client = OpenAIClient::builder()
114 .with_api_key(api_key)
115 .build()
116 .map_err(|_| Error::NoApiKeyError)?;
117 let files: Vec<String> = file_info
118 .base64()
119 .await?
120 .into_iter()
121 .map(|data| format!("data:{};base64,{}", file_info.mime_type(), data))
122 .collect();
123
124 let tools = default_tools();
125 let mut messages = make_instructions(classes, sources);
126 for instr in profile.additional_instructions {
127 messages.push(ChatCompletionMessage {
128 role: MessageRole::system,
129 content: Content::Text(instr),
130 name: None,
131 tool_calls: None,
132 tool_call_id: None,
133 });
134 }
135 log::debug!("Using instructions: {messages:?}");
136 for file in files {
137 messages.push(ChatCompletionMessage {
138 role: MessageRole::user,
139 content: Content::ImageUrl(vec![ImageUrl {
140 r#type: ContentType::image_url,
141 text: None,
142 image_url: Some(ImageUrlType { url: file }),
143 }]),
144 name: None,
145 tool_calls: None,
146 tool_call_id: None,
147 });
148 }
149 let req = ChatCompletionRequest::new(profile.model, messages)
150 .temperature(<u8 as Into<f64>>::into(profile.temperature) / 100.0)
151 .tools(tools)
152 .tool_choice(ToolChoiceType::Required);
153 log::info!("Sending {file_info:?}");
154 let response = timeout(Duration::from_secs(300), client.chat_completion(req)).await??;
155 log::trace!("received response");
156 let result_value: Result<serde_json::Value> = (|| {
157 Ok(serde_json::from_str(
158 &response
159 .choices
160 .first()
161 .ok_or_else(|| Error::DoesNotProcessError(None))?
162 .message
163 .tool_calls
164 .as_ref()
165 .ok_or_else(|| Error::DoesNotProcessError(None))?
166 .first()
167 .ok_or_else(|| Error::DoesNotProcessError(None))?
168 .function
169 .arguments
170 .as_ref()
171 .ok_or_else(|| Error::DoesNotProcessError(None))?,
172 )?)
173 })()
174 .map_err(|err| {
175 if let Error::DoesNotProcessError(_) = err {
176 Error::DoesNotProcessError(Some(response))
177 } else {
178 err
179 }
180 });
181 let result: Result<DocumentData> = Ok(serde_json::from_value(result_value?)?);
182
183 result
184}