1use std::sync::Arc;
2
3use anyhow::{Error, Result};
4use async_openai::{config::Config, types::{ChatCompletionRequestMessageContentPartImageArgs, ChatCompletionRequestMessageContentPartTextArgs, ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, ImageDetail, ImageUrlArgs, ResponseFormat}, Client};
5use base64::prelude::*;
6use futures::future::join_all;
7use image::DynamicImage;
8use rand::Rng;
9use serde_json::Value;
10
11use crate::vector::{Vector, VectorOperations};
12
13pub struct ModelParameters {
14 model: String,
15 temperature: f32,
16 seed: Option<i64>,
17}
18
19impl ModelParameters {
20 pub fn new(model: String, temperature: Option<f32>, seed: Option<i64>) -> Self {
35 let temperature: f32 = temperature.unwrap_or(1.0);
36
37 Self {
38 model,
39 temperature,
40 seed,
41 }
42 }
43
44 pub fn get_model(&self) -> String {
45 self.model.clone()
46 }
47
48 pub fn get_temperature(&self) -> f32 {
49 self.temperature
50 }
51
52 pub fn get_seed(&self) -> i64 {
53 let mut rng: rand::prelude::ThreadRng = rand::rng();
54 if let Some(seed) = self.seed {
55 seed
56 } else {
57 rng.random()
58 }
59 }
60}
61
62fn dynamic_image_to_base64(image: &DynamicImage) -> Result<String, Error> {
64 let mut raw_image_bytes: Vec<u8> = Vec::new();
65 image.write_to(
66 &mut std::io::Cursor::new(&mut raw_image_bytes),
67 image::ImageFormat::Png,
68 )?;
69 let base64_image: String = BASE64_STANDARD.encode(raw_image_bytes);
70
71 Ok(base64_image)
72}
73
74fn extract_leaf_values_recursively(value: &Value) -> Vec<Value> {
78 match value {
79 Value::Object(map) => map
80 .values()
81 .flat_map(|v| extract_leaf_values_recursively(v))
82 .collect(),
83 Value::Array(arr) => arr
84 .iter()
85 .flat_map(|v| extract_leaf_values_recursively(v))
86 .collect(),
87 _ => vec![value.clone()],
88 }
89}
90
91fn validate_vectorization_result(vector: &Vec<f32>) -> Result<(), Error> {
104 if vector.is_empty() {
106 return Err(Error::msg("Validation error: vector is empty"));
107 } else if vector.len() > 1 {
109 return Err(Error::msg("Validation error: vector has more than one element"));
110 }
111
112 for element in vector {
114 if *element < 0.0 {
115 return Err(Error::msg("Validation error: vector contains negative elements"));
116 }
117 }
118
119 Ok(())
121}
122
123async fn vectorize_image_single_prompt<C>(
127 client: &Client<C>,
128 image: &DynamicImage,
129 prompt: String,
130 model_parameters: &ModelParameters,
131) -> Result<Vec<f32>, Error>
132where
133 C: Config + Send + Sync + 'static,
134{
135 let base64_image = dynamic_image_to_base64(&image)?;
136 let image_url = format!("data:image/jpeg;base64,{}", base64_image);
137
138 loop {
139 let request = match CreateChatCompletionRequestArgs::default()
140 .temperature(model_parameters.get_temperature())
141 .seed(model_parameters.get_seed())
142 .model(model_parameters.get_model())
143 .response_format(ResponseFormat::JsonObject)
144 .messages(vec![ChatCompletionRequestUserMessageArgs::default()
145 .content(vec![
146 ChatCompletionRequestMessageContentPartTextArgs::default()
147 .text(&prompt)
148 .build()
149 .map_err(|e| Error::msg(e.to_string()))?
150 .into(),
151 ChatCompletionRequestMessageContentPartImageArgs::default()
152 .image_url(
153 ImageUrlArgs::default()
154 .url(&image_url)
155 .detail(ImageDetail::High)
156 .build()
157 .map_err(|e| Error::msg(e.to_string()))?,
158 )
159 .build()
160 .map_err(|e| Error::msg(e.to_string()))?
161 .into(),
162 ])
163 .build()
164 .map_err(|e| Error::msg(e.to_string()))?
165 .into()])
166 .build()
167 {
168 Ok(req) => req,
169 Err(e) => {
170 println!("Failed to build request: {}", e);
171 continue;
172 }
173 };
174
175 let response = match client.chat().create(request).await {
176 Ok(res) => res,
177 Err(e) => {
178 println!("API request error: {}", e);
179 continue;
180 }
181 };
182
183 let content = match response.choices.get(0).and_then(|c| c.message.content.as_ref()) {
184 Some(c) => c,
185 None => {
186 println!("Empty content in response");
187 continue;
188 }
189 };
190
191 let parsed_json = match serde_json::from_str::<Value>(content) {
192 Ok(v) => v,
193 Err(e) => {
194 println!("JSON parsing failed: {}", e);
195 continue;
196 }
197 };
198
199 let leaf_values = extract_leaf_values_recursively(&parsed_json);
200 let result: Vec<f32> = leaf_values
201 .into_iter()
202 .filter_map(|v| v.as_f64().map(|f| f as f32))
203 .collect();
204
205 if let Err(e) = validate_vectorization_result(&result) {
206 println!("Validation failed: {}, retrying...", e);
207 println!("Prompt: {}", prompt);
208 println!("Result: {}", &parsed_json);
209 println!("Output: {:?}", result);
210 } else {
211 return Ok(result);
212 }
213 }
214}
215
216pub async fn vectorize_image_concurrently<C>(
231 prompts: Vec<String>,
232 vector: &mut Vector<DynamicImage>,
233 client: Client<C>,
234 model_parameters: ModelParameters,
235) -> Result<(), Error>
236where
237 C: Config + Send + Sync + 'static,
238{
239 let image: DynamicImage = vector.get_data().clone();
241
242 let shared_client: Arc<Client<C>> = Arc::new(client);
243 let shared_image: Arc<DynamicImage> = Arc::new(image);
244 let shared_model: Arc<ModelParameters> = Arc::new(model_parameters);
245
246 let mut tasks = Vec::new();
248 for (index, prompt) in prompts.into_iter().enumerate() {
249 let shared_client: Arc<Client<C>> = shared_client.clone();
250 let shared_image: Arc<DynamicImage> = shared_image.clone();
251 let shared_model: Arc<ModelParameters> = shared_model.clone();
252
253 let task = tokio::spawn(async move {
254 let subvector: Vec<f32> = vectorize_image_single_prompt(
255 shared_client.as_ref(),
256 shared_image.as_ref(),
257 prompt,
258 shared_model.as_ref(),
259 )
260 .await?;
261 println!("thread {index} finished vectorization.");
262
263 Ok::<_, Error>(subvector)
264 });
265
266 tasks.push(task);
267 }
268
269 let results = join_all(tasks).await;
270
271 let final_vector: Vec<f32> = results
273 .into_iter()
274 .filter_map(|result| result.ok())
275 .filter_map(|result| result.ok())
276 .flat_map(|subvec| subvec.iter().map(|&x| x as f32).collect::<Vec<f32>>())
277 .collect();
278
279 vector.overwrite_vector(final_vector);
280
281 Ok(())
282}
283
284async fn vectorize_string_single_prompt<C>(
288 client: &Client<C>,
289 text: &str,
290 prompt: String,
291 model_parameters: &ModelParameters
292) -> Result<Vec<f32>, Error>
293where
294 C: Config + Send + Sync + 'static,
295{
296 loop {
297 let request: async_openai::types::CreateChatCompletionRequest = match CreateChatCompletionRequestArgs::default()
298 .temperature(model_parameters.get_temperature())
299 .seed(model_parameters.get_seed())
300 .model(model_parameters.get_model())
301 .response_format(ResponseFormat::JsonObject)
302 .messages(vec![ChatCompletionRequestUserMessageArgs::default()
303 .content(format!("{}\n\nText to analyze: {}", prompt, text))
304 .build()
305 .map_err(|e| Error::msg(e.to_string()))?
306 .into()])
307 .build()
308 {
309 Ok(req) => req,
310 Err(e) => {
311 println!("Failed to build request: {}", e);
312 continue;
313 }
314 };
315
316 let response = match client.chat().create(request).await {
317 Ok(res) => res,
318 Err(e) => {
319 println!("API request error: {}", e);
320 continue;
321 }
322 };
323
324 let content = match response.choices.get(0).and_then(|c| c.message.content.as_ref()) {
325 Some(c) => c,
326 None => {
327 println!("Empty content in response");
328 continue;
329 }
330 };
331
332 let parsed_json = match serde_json::from_str::<Value>(content) {
333 Ok(v) => v,
334 Err(e) => {
335 println!("JSON parsing failed: {}", e);
336 continue;
337 }
338 };
339
340 let leaf_values = extract_leaf_values_recursively(&parsed_json);
341 let result: Vec<f32> = leaf_values
342 .into_iter()
343 .filter_map(|v| v.as_f64().map(|f| f as f32))
344 .collect();
345
346 if let Err(e) = validate_vectorization_result(&result) {
347 println!("Validation failed: {}, retrying...", e);
348 println!("Prompt: {}", prompt);
349 println!("Text: {}", text);
350 println!("Result: {}", &parsed_json);
351 println!("Output: {:?}", result);
352 } else {
353 return Ok(result);
354 }
355 }
356}
357
358pub async fn vectorize_string_concurrently<C>(
369 prompts: Vec<String>,
370 vector: &mut Vector<String>,
371 client: Client<C>,
372 model_parameters: ModelParameters,
373) -> Result<(), Error>
374where
375 C: Config + Send + Sync + 'static,
376{
377 let text: String = vector.get_data().clone();
379
380 let shared_client: Arc<Client<C>> = Arc::new(client);
381 let shared_text: Arc<String> = Arc::new(text);
382 let shared_model: Arc<ModelParameters> = Arc::new(model_parameters);
383
384 let mut tasks = Vec::new();
386 for (index, prompt) in prompts.into_iter().enumerate() {
387 let shared_client: Arc<Client<C>> = shared_client.clone();
388 let shared_text: Arc<String> = shared_text.clone();
389 let shared_model: Arc<ModelParameters> = shared_model.clone();
390
391 let task = tokio::spawn(async move {
392 let subvector = vectorize_string_single_prompt(
393 shared_client.as_ref(),
394 shared_text.as_ref(),
395 prompt,
396 shared_model.as_ref(),
397 )
398 .await?;
399 println!("thread {index} finished vectorization.");
400
401 Ok::<_, Error>(subvector)
402 });
403
404 tasks.push(task);
405 }
406
407 let results = join_all(tasks).await;
408
409 let final_vector: Vec<f32> = results
411 .into_iter()
412 .filter_map(|result| result.ok())
413 .filter_map(|result| result.ok())
414 .flat_map(|subvec| subvec.iter().map(|&x| x as f32).collect::<Vec<f32>>())
415 .collect();
416
417 vector.overwrite_vector(final_vector);
418
419 Ok(())
420}