reco_forge/
lib.rs

1//! This crate provides an interface for users to turn any dataset with titles and descriptions into a recommendation system. It uses the BERT model to create embeddings for each item in the dataset and then finds recommendations based on the user's input.
2//! 
3//! To run the examples, you can clone the git repository at <https://github.com/jameslk3/reco-forge> and then run the following commands:
4//! ```bash
5//! cargo run --example description
6//! cargo run --example item
7//! ```
8//! 
9//! Example usage:
10//! ```no_run
11//! use reco_forge::{create_model, pass_description, Data, Tensor, HashMap}; // Can also use pass_item
12//! 
13//! fn main() -> Result<(), Box<dyn std::error::Error>> {
14//!    let mut path = String::new();
15//!    println!("Please enter the file path:");
16//!
17//!    let model_wrapped: Result<HashMap<Data, Option<Tensor>>, ()>;
18//!
19//!    loop {
20//!        std::io::stdin().read_line(&mut path).expect("Failed to read line");
21//!        path = path.trim().to_string();
22//!
23//!        if let Ok(model) = create_model(&path) {
24//!            model_wrapped = Ok(model);
25//!            break;
26//!        } else {
27//!            println!("File path is not valid or file cannot be deserialized, please input the correct file path and try again:");
28//!            path.clear();
29//!        }
30//!    }
31//! 
32//!   let model: HashMap<Data, Option<Tensor>> = model_wrapped.unwrap();
33//!
34//!    println!("Input tags that you would like to use to filter, else enter NONE");
35//!    let mut tags_input: String = String::new();
36//!    std::io::stdin().read_line(&mut tags_input).expect("Failed to read line");
37//!    tags_input = tags_input.trim().to_string();
38//!
39//!    println!("Describe what you want to be recommended:");
40//!    let mut query: String = String::new();
41//!    std::io::stdin().read_line(&mut query).expect("Failed to read line");
42//!    query = query.trim().to_string();
43//!    println!();
44//!
45//!    let recommendations = pass_description(&model, query, tags_input, 10);
46//!    match recommendations {
47//!        Ok(recommendations) => {
48//!            println!("Recommendations:");
49//!            for recommendation in recommendations {
50//!                println!("{}% {}", (recommendation.1 * 100.0).round(), recommendation.0);
51//!            }
52//!        },
53//!        Err(_) => println!("No recommendations found"),
54//!    }
55//!    Ok(())
56//!}
57//! 
58//! // Examples are also provided in the examples folder of the git repository at https://github.com/jameslk3/reco-forge.
59//! 
60//! ```
61//! Required JSON file format:
62//! ```json
63//! [
64//! 	{
65//! 		"id": int,
66//! 		"name": "string",
67//! 		"summary": "string",
68//! 		"tags": ["string1", "string2"]
69//! 	},
70//! 	{
71//! 		...
72//! 	}
73//! ]
74//! ```
75
76
77pub(crate) mod helpers;
78
79extern crate candle;
80
81pub use candle::Tensor;
82pub use helpers::types::Data;
83pub use std::collections::HashMap;
84
85use helpers::pre_recommendation::{extract_data, insert_embeddings, find_embedding};
86use helpers::recommendation::{get_recommendations, create_input_embedding};
87
88/// # create_model
89/// This function creates the model from the file path given by the user
90/// 
91/// # Arguments
92/// ```no_run
93///     * file_path: &String - The file path to the model
94/// ```
95/// 
96/// # Returns
97/// ```no_run
98///    * Result<HashMap<Data, Option<Tensor>>, String> - The model if it was created successfully, otherwise a wrapped error message
99/// ```
100/// 
101/// # Example
102/// ```no_run
103/// 	let file_path = "path/to/model".to_string();
104/// 	let model = create_model(&file_path);
105/// 	match model {
106/// 		Ok(model) => println!("Model created successfully"),
107/// 		Err(e) => println!("Error: {}", e),
108/// 	}
109/// ```
110pub fn create_model(file_path: &String) -> Result<HashMap<Data, Option<Tensor>>, String> {
111    let nodes_wrapped: Result<HashMap<Data, Option<Tensor>>, ()> = extract_data(&file_path);
112    if nodes_wrapped.is_err() {
113        return Err(("File path is not valid or file cannot be deserialized, please input the correct file path and try again:").to_string());
114    }
115    let mut nodes: HashMap<Data, Option<Tensor>> = nodes_wrapped.unwrap();
116    println!("Creating model, please be patient...");
117    if let Ok(_) = insert_embeddings(&mut nodes) {
118        return Ok(nodes);
119    }
120    return Err("Error inserting embeddings".to_string());
121}
122
123/// # pass_description
124/// This function is used when the user wants to find recommendations based on a description
125/// 
126/// # Arguments
127/// ```no_run
128///    	* node_embeddings: &HashMap<Data, Option<Tensor> - The model
129///    	* description_input: String - The description input by the user
130/// 	* tags_input: String - The tags input by the user, each tag separated by a comma. If the user doesn't want to filter by tags, they can enter NONE
131/// 	* num_recommendations: usize - The number of recommendations the user wants
132/// ```
133/// 
134/// # Returns
135/// ```no_run
136///   	* Result<Vec<String, f32>, ()> - A vector of (Item name, similarity) tuples if recommendations were found, otherwise Err
137/// ```
138/// 
139/// # Example
140/// ```no_run
141/// 	let recommendations = pass_description(&model, "description".to_string(), "tag1,tag2".to_string(), 10);
142/// 	match recommendations {
143/// 		Ok(recommendations) => {
144/// 			println!("Recommendations:");
145/// 			for recommendation in recommendations {
146/// 				println!("{}% {}", (recommendation.1 * 100.0).round(), recommendation.0);
147/// 			}
148/// 		},
149/// 		Err(_) => println!("No recommendations found"),
150/// 	}
151/// ```
152pub fn pass_description(node_embeddings: &HashMap<Data, Option<Tensor>>, description_input: String, tags_input: String, num_recommendations: usize) -> Result<Vec<(String, f32)>, ()> {
153
154    // When we are given a description, we need to create an embedding for it and then find recommendations based on that
155    let input_embedding = create_input_embedding(&description_input).unwrap().unwrap();
156    let recommendations = get_recommendations(node_embeddings, None, &input_embedding, &tags_input, num_recommendations);
157    recommendations
158}
159
160/// # pass_item
161/// This function is used when the user wants to find recommendations based on a specific item that is already in the model
162/// 
163/// # Arguments
164/// ```no_run
165///   	* node_embeddings: &HashMap<Data, Option<Tensor> - The model
166///  	* item: String - The item the user wants recommendations for
167/// 	* tags_input: String - The tags input by the user, each tag separated by a comma. If the user doesn't want to filter by tags, they can enter NONE
168/// 	* num_recommendations: usize - The number of recommendations the user wants
169/// ```
170/// 
171/// # Returns
172/// ```no_run
173///   	* Result<Vec<String, f32>, ()> - A vector of (Item name, similarity) tuples if recommendations were found, otherwise Err
174/// ```
175/// 
176/// # Example
177/// ```no_run
178/// 	let recommendations = pass_item(&model, "item".to_string(), "tag1,tag2".to_string(), 10);
179/// 	match recommendations {
180/// 		Ok(recommendations) => {
181/// 			println!("Recommendations:");
182/// 			for recommendation in recommendations {
183///                 println!("{}% {}", (recommendation.1 * 100.0).round(), recommendation.0);
184/// 			}
185/// 		},
186/// 		Err(_) => println!("No recommendations found"),
187/// 	}
188/// ```
189pub fn pass_item(node_embeddings: &HashMap<Data, Option<Tensor>>, item: String, tags_input: String, num_recommendations: usize) -> Result<Vec<(String, f32)>, ()> {
190
191    // When we want to find items similar to a specific item, we need to make sure that the item is in the embeddings and then retrieve the embedding
192    let input_embedding = {
193        if let Ok(embedding) = find_embedding(node_embeddings, &item) {
194            embedding
195        } else {
196            return Err(());
197        }
198    };
199    let recommendations = get_recommendations(node_embeddings, Some(&item), &input_embedding, &tags_input, num_recommendations);
200    recommendations
201}