dim_rs/
vectorization.rs

1use std::sync::Arc;
2
3use anyhow::{Error, Result};
4use async_openai::{config::Config, types::{ChatCompletionRequestMessageContentPartImageArgs, ChatCompletionRequestMessageContentPartTextArgs, ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, ImageDetail, ImageUrlArgs, ResponseFormat}, Client};
5use base64::prelude::*;
6use futures::future::join_all;
7use image::DynamicImage;
8use rand::Rng;
9use serde_json::Value;
10
11use crate::vector::{Vector, VectorOperations};
12
13pub struct ModelParameters {
14    model: String,
15    temperature: f32,
16    seed: Option<i64>,
17}
18
19impl ModelParameters {
20    /// Creates a new instance of `ModelParameters`.
21    ///
22    /// # Arguments
23    ///
24    /// * `model` - A string representing the model name.
25    /// * `temperature` - An optional floating-point value representing the temperature setting.
26    /// * `seed` - An optional integer value representing the seed for random number generation.
27    ///
28    /// # Returns
29    ///
30    /// A new instance of `ModelParameters` with the specified model, temperature, and seed.
31    ///
32    /// If `temperature` is not provided, it defaults to 0.0.
33    /// If `seed` is not provided, a random seed is generated.
34    pub fn new(model: String, temperature: Option<f32>, seed: Option<i64>) -> Self {
35        let temperature: f32 = temperature.unwrap_or(1.0);
36        
37        Self {
38            model,
39            temperature,
40            seed,
41        }
42    }
43
44    pub fn get_model(&self) -> String {
45        self.model.clone()
46    }
47
48    pub fn get_temperature(&self) -> f32 {
49        self.temperature
50    }
51
52    pub fn get_seed(&self) -> i64 {
53        let mut rng: rand::prelude::ThreadRng = rand::rng();
54        if let Some(seed) = self.seed {
55            seed
56        } else {
57            rng.random()
58        }
59    }
60}
61
62/// Converts a DynamicImage to a base64-encoded string
63fn dynamic_image_to_base64(image: &DynamicImage) -> Result<String, Error> {
64    let mut raw_image_bytes: Vec<u8> = Vec::new();
65    image.write_to(
66        &mut std::io::Cursor::new(&mut raw_image_bytes),
67        image::ImageFormat::Png,
68    )?;
69    let base64_image: String = BASE64_STANDARD.encode(raw_image_bytes);
70
71    Ok(base64_image)
72}
73
74/// Recursively extracts leaf values from a JSON response retrieved from the LLM.
75/// 
76/// Takes a JSON Value and returns a Vec of all leaf values found in the structure.
77fn extract_leaf_values_recursively(value: &Value) -> Vec<Value> {
78    match value {
79        Value::Object(map) => map
80            .values()
81            .flat_map(|v| extract_leaf_values_recursively(v))
82            .collect(),
83        Value::Array(arr) => arr
84            .iter()
85            .flat_map(|v| extract_leaf_values_recursively(v))
86            .collect(),
87        _ => vec![value.clone()],
88    }
89}
90
91/// Validates that all elements in the vectorization result are non-negative.
92/// 
93/// Takes a vector slice and validates that it:
94/// - Is not empty
95/// - Contains exactly one element 
96/// - All elements are non-negative (>= 0)
97///
98/// # Arguments
99/// * `vector` - Vector slice to validate
100///
101/// # Returns 
102/// * `bool` - True if vector meets all validation criteria, false otherwise
103fn validate_vectorization_result(vector: &Vec<f32>) -> Result<(), Error> {
104    // Return error if vector is empty
105    if vector.is_empty() {
106        return Err(Error::msg("Validation error: vector is empty"));
107    // Check if vector has more than one element
108    } else if vector.len() > 1 {
109        return Err(Error::msg("Validation error: vector has more than one element"));
110    }
111
112    // Check if any elements are negative
113    for element in vector {
114        if *element < 0.0 {
115            return Err(Error::msg("Validation error: vector contains negative elements"));
116        }
117    }
118
119    // All validation checks passed
120    Ok(())
121}
122
123/// Processes a single image with one prompt to generate a vector representation.
124/// 
125/// Continues retrying until valid results are obtained.
126async fn vectorize_image_single_prompt<C>(
127    client: &Client<C>,
128    image: &DynamicImage,
129    prompt: String,
130    model_parameters: &ModelParameters,
131) -> Result<Vec<f32>, Error>
132where
133    C: Config + Send + Sync + 'static,
134{
135    let base64_image = dynamic_image_to_base64(&image)?;
136    let image_url = format!("data:image/jpeg;base64,{}", base64_image);
137
138    loop {
139        let request = match CreateChatCompletionRequestArgs::default()
140            .temperature(model_parameters.get_temperature())
141            .seed(model_parameters.get_seed())
142            .model(model_parameters.get_model())
143            .response_format(ResponseFormat::JsonObject)
144            .messages(vec![ChatCompletionRequestUserMessageArgs::default()
145                .content(vec![
146                    ChatCompletionRequestMessageContentPartTextArgs::default()
147                        .text(&prompt)
148                        .build()
149                        .map_err(|e| Error::msg(e.to_string()))?
150                        .into(),
151                    ChatCompletionRequestMessageContentPartImageArgs::default()
152                        .image_url(
153                            ImageUrlArgs::default()
154                                .url(&image_url)
155                                .detail(ImageDetail::High)
156                                .build()
157                                .map_err(|e| Error::msg(e.to_string()))?,
158                        )
159                        .build()
160                        .map_err(|e| Error::msg(e.to_string()))?
161                        .into(),
162                ])
163                .build()
164                .map_err(|e| Error::msg(e.to_string()))?
165                .into()])
166            .build()
167        {
168            Ok(req) => req,
169            Err(e) => {
170                println!("Failed to build request: {}", e);
171                continue;
172            }
173        };
174
175        let response = match client.chat().create(request).await {
176            Ok(res) => res,
177            Err(e) => {
178                println!("API request error: {}", e);
179                continue;
180            }
181        };
182
183        let content = match response.choices.get(0).and_then(|c| c.message.content.as_ref()) {
184            Some(c) => c,
185            None => {
186                println!("Empty content in response");
187                continue;
188            }
189        };
190
191        let parsed_json = match serde_json::from_str::<Value>(content) {
192            Ok(v) => v,
193            Err(e) => {
194                println!("JSON parsing failed: {}", e);
195                continue;
196            }
197        };
198
199        let leaf_values = extract_leaf_values_recursively(&parsed_json);
200        let result: Vec<f32> = leaf_values
201            .into_iter()
202            .filter_map(|v| v.as_f64().map(|f| f as f32))
203            .collect();
204
205        if let Err(e) = validate_vectorization_result(&result) {
206            println!("Validation failed: {}, retrying...", e);
207            println!("Prompt: {}", prompt);
208            println!("Result: {}", &parsed_json);
209            println!("Output: {:?}", result);
210        } else {
211            return Ok(result);
212        }
213    }
214}
215
216/// Concurrently vectorizes an image with multiple prompts.
217/// 
218/// # Arguments
219/// * `model` - The name/identifier of the LLM model to use
220/// * `prompts` - A vector of prompts to process concurrently
221/// * `vector` - A mutable reference to the Vector struct containing the image
222/// * `client` - The OpenAI API client
223/// 
224/// # Returns
225/// * `Result<(), Error>` - Ok(()) on success, Error on failure
226/// 
227/// Each prompt's dimensionality is specified by how many digits that it 
228/// requires the LLM to return. The final dimensionality of the vector is 
229/// calculated by `number of prompts * digits specified by each prompt`.
230pub async fn vectorize_image_concurrently<C>(
231    prompts: Vec<String>,
232    vector: &mut Vector<DynamicImage>, 
233    client: Client<C>,
234    model_parameters: ModelParameters,
235) -> Result<(), Error>
236where
237    C: Config + Send + Sync + 'static,
238{
239    // get data from the struct
240    let image: DynamicImage = vector.get_data().clone();
241
242    let shared_client: Arc<Client<C>> = Arc::new(client);
243    let shared_image: Arc<DynamicImage> = Arc::new(image);
244    let shared_model: Arc<ModelParameters> = Arc::new(model_parameters);
245
246    // collect all tasks for concurrent execution
247    let mut tasks = Vec::new();
248    for (index, prompt) in prompts.into_iter().enumerate() {
249        let shared_client: Arc<Client<C>> = shared_client.clone();
250        let shared_image: Arc<DynamicImage> = shared_image.clone();
251        let shared_model: Arc<ModelParameters> = shared_model.clone();
252
253        let task = tokio::spawn(async move {
254            let subvector: Vec<f32> = vectorize_image_single_prompt(
255                shared_client.as_ref(), 
256                shared_image.as_ref(), 
257                prompt,
258                shared_model.as_ref(),
259            )
260                .await?;
261            println!("thread {index} finished vectorization.");
262
263            Ok::<_, Error>(subvector)
264        });
265
266        tasks.push(task);
267    }
268
269    let results = join_all(tasks).await;
270
271    // Collect and join the subvectors sequentially
272    let final_vector: Vec<f32> = results
273        .into_iter()
274        .filter_map(|result| result.ok())
275        .filter_map(|result| result.ok())
276        .flat_map(|subvec| subvec.iter().map(|&x| x as f32).collect::<Vec<f32>>())
277        .collect();
278
279    vector.overwrite_vector(final_vector);
280
281    Ok(())
282}
283
284/// Processes a single text string with one prompt to generate a vector representation.
285/// 
286/// Continues retrying until valid results are obtained.
287async fn vectorize_string_single_prompt<C>(
288    client: &Client<C>,
289    text: &str,
290    prompt: String,
291    model_parameters: &ModelParameters
292) -> Result<Vec<f32>, Error>
293where
294    C: Config + Send + Sync + 'static,
295{
296    loop {
297        let request: async_openai::types::CreateChatCompletionRequest = match CreateChatCompletionRequestArgs::default()
298            .temperature(model_parameters.get_temperature())
299            .seed(model_parameters.get_seed())
300            .model(model_parameters.get_model())
301            .response_format(ResponseFormat::JsonObject)
302            .messages(vec![ChatCompletionRequestUserMessageArgs::default()
303                .content(format!("{}\n\nText to analyze: {}", prompt, text))
304                .build()
305                .map_err(|e| Error::msg(e.to_string()))?
306                .into()])
307            .build()
308        {
309            Ok(req) => req,
310            Err(e) => {
311                println!("Failed to build request: {}", e);
312                continue;
313            }
314        };
315
316        let response = match client.chat().create(request).await {
317            Ok(res) => res,
318            Err(e) => {
319                println!("API request error: {}", e);
320                continue;
321            }
322        };
323
324        let content = match response.choices.get(0).and_then(|c| c.message.content.as_ref()) {
325            Some(c) => c,
326            None => {
327                println!("Empty content in response");
328                continue;
329            }
330        };
331
332        let parsed_json = match serde_json::from_str::<Value>(content) {
333            Ok(v) => v,
334            Err(e) => {
335                println!("JSON parsing failed: {}", e);
336                continue;
337            }
338        };
339
340        let leaf_values = extract_leaf_values_recursively(&parsed_json);
341        let result: Vec<f32> = leaf_values
342            .into_iter()
343            .filter_map(|v| v.as_f64().map(|f| f as f32))
344            .collect();
345
346        if let Err(e) = validate_vectorization_result(&result) {
347            println!("Validation failed: {}, retrying...", e);
348            println!("Prompt: {}", prompt);
349            println!("Text: {}", text);
350            println!("Result: {}", &parsed_json);
351            println!("Output: {:?}", result);
352        } else {
353            return Ok(result);
354        }
355    }
356}
357
358/// Concurrently vectorizes a text string with multiple prompts.
359/// 
360/// # Arguments
361/// * `model` - The name/identifier of the LLM model to use
362/// * `prompts` - A vector of prompts to process concurrently
363/// * `vector` - A mutable reference to the Vector struct containing the text
364/// * `client` - The OpenAI API client
365/// 
366/// # Returns
367/// * `Result<(), Error>` - Ok(()) on success, Error on failure
368pub async fn vectorize_string_concurrently<C>(
369    prompts: Vec<String>,
370    vector: &mut Vector<String>,
371    client: Client<C>,
372    model_parameters: ModelParameters,
373) -> Result<(), Error>
374where
375    C: Config + Send + Sync + 'static,
376{
377    // get data from the struct
378    let text: String = vector.get_data().clone();
379
380    let shared_client: Arc<Client<C>> = Arc::new(client);
381    let shared_text: Arc<String> = Arc::new(text);
382    let shared_model: Arc<ModelParameters> = Arc::new(model_parameters);
383
384    // collect all tasks for concurrent execution
385    let mut tasks = Vec::new();
386    for (index, prompt) in prompts.into_iter().enumerate() {
387        let shared_client: Arc<Client<C>> = shared_client.clone();
388        let shared_text: Arc<String> = shared_text.clone();
389        let shared_model: Arc<ModelParameters> = shared_model.clone();
390
391        let task = tokio::spawn(async move {
392            let subvector = vectorize_string_single_prompt(
393                shared_client.as_ref(),
394                shared_text.as_ref(),
395                prompt,
396                shared_model.as_ref(),
397            )
398                .await?;
399            println!("thread {index} finished vectorization.");
400
401            Ok::<_, Error>(subvector)
402        });
403
404        tasks.push(task);
405    }
406
407    let results = join_all(tasks).await;
408
409    // Collect and join the subvectors sequentially
410    let final_vector: Vec<f32> = results
411        .into_iter()
412        .filter_map(|result| result.ok())
413        .filter_map(|result| result.ok())
414        .flat_map(|subvec| subvec.iter().map(|&x| x as f32).collect::<Vec<f32>>())
415        .collect();
416
417    vector.overwrite_vector(final_vector);
418
419    Ok(())
420}