git_semantic/embedding/
model.rs1use directories::ProjectDirs;
2use indicatif::{ProgressBar, ProgressStyle};
3use ndarray::Array1;
4use ort::session::Session;
5use ort::session::builder::GraphOptimizationLevel;
6use std::fs;
7use std::path::PathBuf;
8use tokenizers::Tokenizer;
9use tracing::{debug, info};
10
11use super::{Embedding, EmbeddingConfig, EmbeddingError};
12
13pub struct ModelManager {
14 config: EmbeddingConfig,
15 model_dir: PathBuf,
16 session: Option<Session>,
17 tokenizer: Option<Tokenizer>,
18}
19
20impl ModelManager {
21 pub fn new() -> Result<Self, EmbeddingError> {
22 let config = EmbeddingConfig::default();
23
24 let project_dirs = ProjectDirs::from("com", "git-semantic", "git-semantic")
25 .ok_or(EmbeddingError::ProjectDirsNotFound)?;
26
27 let model_dir = project_dirs.data_dir().join("models");
28 fs::create_dir_all(&model_dir)?;
29
30 Ok(Self {
31 config,
32 model_dir,
33 session: None,
34 tokenizer: None,
35 })
36 }
37
38 pub fn init(&mut self) -> Result<(), EmbeddingError> {
40 if self.session.is_some() {
41 return Ok(());
42 }
43
44 info!("Loading ONNX model...");
45 let model_path = self.model_path();
46
47 if !model_path.exists() {
48 return Err(EmbeddingError::ModelNotDownloaded);
49 }
50
51 let session = Session::builder()?
52 .with_optimization_level(GraphOptimizationLevel::Level3)?
53 .with_intra_threads(4)?
54 .commit_from_file(&model_path)?;
55
56 info!("Loading tokenizer...");
57 let tokenizer_path = self.tokenizer_path();
58 let tokenizer = Tokenizer::from_file(tokenizer_path)
59 .map_err(|e| EmbeddingError::Tokenization(format!("failed to load tokenizer: {e}")))?;
60
61 self.session = Some(session);
62 self.tokenizer = Some(tokenizer);
63
64 info!("Model loaded successfully");
65 Ok(())
66 }
67
68 pub fn is_model_downloaded(&self) -> bool {
69 self.model_path().exists() && self.tokenizer_path().exists()
70 }
71
72 pub fn download_model(&self) -> Result<(), EmbeddingError> {
73 info!("Downloading model: {}", self.config.model_name);
74
75 let base_url = "https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main";
76
77 let files = vec![
78 ("model.onnx", "onnx/model.onnx"),
79 ("tokenizer.json", "tokenizer.json"),
80 ];
81
82 let client = reqwest::blocking::Client::builder()
83 .timeout(std::time::Duration::from_secs(300))
84 .build()?;
85
86 for (filename, remote_path) in files {
87 let url = format!("{}/{}", base_url, remote_path);
88 let target_path = self.model_dir.join(filename);
89
90 info!("Downloading {} from {}", filename, url);
91
92 let response = client.get(&url).send()?;
93
94 if !response.status().is_success() {
95 return Err(EmbeddingError::DownloadFailed {
96 filename: filename.to_string(),
97 reason: format!("HTTP {}", response.status()),
98 });
99 }
100
101 let total_size = response
102 .content_length()
103 .ok_or_else(|| EmbeddingError::MissingContentLength(filename.to_string()))?;
104
105 let pb = ProgressBar::new(total_size);
106 pb.set_style(
107 ProgressStyle::default_bar()
108 .template("{msg}\n[{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
109 .unwrap()
110 .progress_chars("=>-"),
111 );
112 pb.set_message(format!("Downloading {}", filename));
113
114 let mut file = fs::File::create(&target_path)?;
115 let mut downloaded = 0u64;
116 let mut content = response;
117
118 use std::io::Write;
119 let mut buffer = [0; 8192];
120
121 loop {
122 let bytes_read = std::io::Read::read(&mut content, &mut buffer)?;
123 if bytes_read == 0 {
124 break;
125 }
126 file.write_all(&buffer[..bytes_read])?;
127 downloaded += bytes_read as u64;
128 pb.set_position(downloaded);
129 }
130
131 pb.finish_with_message(format!("✅ Downloaded {}", filename));
132 }
133
134 info!("All model files downloaded successfully");
135 Ok(())
136 }
137
138 pub fn encode_text(&mut self, text: &str) -> Result<Embedding, EmbeddingError> {
139 debug!("Encoding text: {}", &text[..text.len().min(50)]);
140
141 let session = self
142 .session
143 .as_mut()
144 .ok_or(EmbeddingError::ModelNotInitialized)?;
145 let tokenizer = self
146 .tokenizer
147 .as_ref()
148 .ok_or(EmbeddingError::ModelNotInitialized)?;
149
150 let encoding = tokenizer
151 .encode(text, true)
152 .map_err(|e| EmbeddingError::Tokenization(e.to_string()))?;
153
154 let input_ids = encoding.get_ids();
155 let attention_mask = encoding.get_attention_mask();
156
157 let max_len = self.config.max_length.min(input_ids.len());
158 let input_ids = &input_ids[..max_len];
159 let attention_mask = &attention_mask[..max_len];
160
161 let input_ids_array: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
162 let attention_mask_array: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
163 let token_type_ids_array: Vec<i64> = vec![0; max_len];
164
165 use ort::value::Value;
166
167 let input_ids_array_2d = ndarray::Array2::from_shape_vec((1, max_len), input_ids_array)?;
168 let attention_mask_array_2d =
169 ndarray::Array2::from_shape_vec((1, max_len), attention_mask_array)?;
170 let token_type_ids_array_2d =
171 ndarray::Array2::from_shape_vec((1, max_len), token_type_ids_array)?;
172
173 let input_ids_tensor = Value::from_array((
174 input_ids_array_2d.shape(),
175 input_ids_array_2d.as_slice().unwrap().to_vec(),
176 ))?;
177 let attention_mask_tensor = Value::from_array((
178 attention_mask_array_2d.shape(),
179 attention_mask_array_2d.as_slice().unwrap().to_vec(),
180 ))?;
181 let token_type_ids_tensor = Value::from_array((
182 token_type_ids_array_2d.shape(),
183 token_type_ids_array_2d.as_slice().unwrap().to_vec(),
184 ))?;
185
186 let inputs = ort::inputs![
187 "input_ids" => input_ids_tensor,
188 "attention_mask" => attention_mask_tensor,
189 "token_type_ids" => token_type_ids_tensor,
190 ];
191 let outputs = session.run(inputs)?;
192
193 let output_tensor = outputs["last_hidden_state"].try_extract_tensor::<f32>()?;
194
195 let (shape, data) = output_tensor;
196 let _batch_size = shape[0] as usize;
197 let _seq_len = shape[1] as usize;
198 let hidden_size = shape[2] as usize;
199
200 let cls_start = 0;
201 let cls_end = hidden_size;
202 let embedding: Vec<f32> = data[cls_start..cls_end].to_vec();
203
204 let embedding_array = Array1::from_vec(embedding);
205 let norm = embedding_array.mapv(|x| x * x).sum().sqrt();
206 let normalized = if norm > 0.0 {
207 embedding_array / norm
208 } else {
209 embedding_array
210 };
211
212 Ok(normalized)
213 }
214
215 pub fn model_version(&self) -> String {
216 self.config.model_name.clone()
217 }
218
219 fn model_path(&self) -> PathBuf {
220 self.model_dir.join("model.onnx")
221 }
222
223 fn tokenizer_path(&self) -> PathBuf {
224 self.model_dir.join("tokenizer.json")
225 }
226}