Skip to main content

git_semantic/embedding/
model.rs

1use directories::ProjectDirs;
2use indicatif::{ProgressBar, ProgressStyle};
3use ndarray::Array1;
4use ort::session::Session;
5use ort::session::builder::GraphOptimizationLevel;
6use std::fs;
7use std::path::PathBuf;
8use tokenizers::Tokenizer;
9use tracing::{debug, info};
10
11use super::{Embedding, EmbeddingConfig, EmbeddingError};
12
13pub struct ModelManager {
14    config: EmbeddingConfig,
15    model_dir: PathBuf,
16    session: Option<Session>,
17    tokenizer: Option<Tokenizer>,
18}
19
20impl ModelManager {
21    pub fn new() -> Result<Self, EmbeddingError> {
22        let config = EmbeddingConfig::default();
23
24        let project_dirs = ProjectDirs::from("com", "git-semantic", "git-semantic")
25            .ok_or(EmbeddingError::ProjectDirsNotFound)?;
26
27        let model_dir = project_dirs.data_dir().join("models");
28        fs::create_dir_all(&model_dir)?;
29
30        Ok(Self {
31            config,
32            model_dir,
33            session: None,
34            tokenizer: None,
35        })
36    }
37
38    /// Initialize the model (load ONNX session and tokenizer)
39    pub fn init(&mut self) -> Result<(), EmbeddingError> {
40        if self.session.is_some() {
41            return Ok(());
42        }
43
44        info!("Loading ONNX model...");
45        let model_path = self.model_path();
46
47        if !model_path.exists() {
48            return Err(EmbeddingError::ModelNotDownloaded);
49        }
50
51        let session = Session::builder()?
52            .with_optimization_level(GraphOptimizationLevel::Level3)?
53            .with_intra_threads(4)?
54            .commit_from_file(&model_path)?;
55
56        info!("Loading tokenizer...");
57        let tokenizer_path = self.tokenizer_path();
58        let tokenizer = Tokenizer::from_file(tokenizer_path)
59            .map_err(|e| EmbeddingError::Tokenization(format!("failed to load tokenizer: {e}")))?;
60
61        self.session = Some(session);
62        self.tokenizer = Some(tokenizer);
63
64        info!("Model loaded successfully");
65        Ok(())
66    }
67
68    pub fn is_model_downloaded(&self) -> bool {
69        self.model_path().exists() && self.tokenizer_path().exists()
70    }
71
72    pub fn download_model(&self) -> Result<(), EmbeddingError> {
73        info!("Downloading model: {}", self.config.model_name);
74
75        let base_url = "https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main";
76
77        let files = vec![
78            ("model.onnx", "onnx/model.onnx"),
79            ("tokenizer.json", "tokenizer.json"),
80        ];
81
82        let client = reqwest::blocking::Client::builder()
83            .timeout(std::time::Duration::from_secs(300))
84            .build()?;
85
86        for (filename, remote_path) in files {
87            let url = format!("{}/{}", base_url, remote_path);
88            let target_path = self.model_dir.join(filename);
89
90            info!("Downloading {} from {}", filename, url);
91
92            let response = client.get(&url).send()?;
93
94            if !response.status().is_success() {
95                return Err(EmbeddingError::DownloadFailed {
96                    filename: filename.to_string(),
97                    reason: format!("HTTP {}", response.status()),
98                });
99            }
100
101            let total_size = response
102                .content_length()
103                .ok_or_else(|| EmbeddingError::MissingContentLength(filename.to_string()))?;
104
105            let pb = ProgressBar::new(total_size);
106            pb.set_style(
107                ProgressStyle::default_bar()
108                    .template("{msg}\n[{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
109                    .unwrap()
110                    .progress_chars("=>-"),
111            );
112            pb.set_message(format!("Downloading {}", filename));
113
114            let mut file = fs::File::create(&target_path)?;
115            let mut downloaded = 0u64;
116            let mut content = response;
117
118            use std::io::Write;
119            let mut buffer = [0; 8192];
120
121            loop {
122                let bytes_read = std::io::Read::read(&mut content, &mut buffer)?;
123                if bytes_read == 0 {
124                    break;
125                }
126                file.write_all(&buffer[..bytes_read])?;
127                downloaded += bytes_read as u64;
128                pb.set_position(downloaded);
129            }
130
131            pb.finish_with_message(format!("✅ Downloaded {}", filename));
132        }
133
134        info!("All model files downloaded successfully");
135        Ok(())
136    }
137
138    pub fn encode_text(&mut self, text: &str) -> Result<Embedding, EmbeddingError> {
139        debug!("Encoding text: {}", &text[..text.len().min(50)]);
140
141        let session = self
142            .session
143            .as_mut()
144            .ok_or(EmbeddingError::ModelNotInitialized)?;
145        let tokenizer = self
146            .tokenizer
147            .as_ref()
148            .ok_or(EmbeddingError::ModelNotInitialized)?;
149
150        let encoding = tokenizer
151            .encode(text, true)
152            .map_err(|e| EmbeddingError::Tokenization(e.to_string()))?;
153
154        let input_ids = encoding.get_ids();
155        let attention_mask = encoding.get_attention_mask();
156
157        let max_len = self.config.max_length.min(input_ids.len());
158        let input_ids = &input_ids[..max_len];
159        let attention_mask = &attention_mask[..max_len];
160
161        let input_ids_array: Vec<i64> = input_ids.iter().map(|&x| x as i64).collect();
162        let attention_mask_array: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
163        let token_type_ids_array: Vec<i64> = vec![0; max_len];
164
165        use ort::value::Value;
166
167        let input_ids_array_2d = ndarray::Array2::from_shape_vec((1, max_len), input_ids_array)?;
168        let attention_mask_array_2d =
169            ndarray::Array2::from_shape_vec((1, max_len), attention_mask_array)?;
170        let token_type_ids_array_2d =
171            ndarray::Array2::from_shape_vec((1, max_len), token_type_ids_array)?;
172
173        let input_ids_tensor = Value::from_array((
174            input_ids_array_2d.shape(),
175            input_ids_array_2d.as_slice().unwrap().to_vec(),
176        ))?;
177        let attention_mask_tensor = Value::from_array((
178            attention_mask_array_2d.shape(),
179            attention_mask_array_2d.as_slice().unwrap().to_vec(),
180        ))?;
181        let token_type_ids_tensor = Value::from_array((
182            token_type_ids_array_2d.shape(),
183            token_type_ids_array_2d.as_slice().unwrap().to_vec(),
184        ))?;
185
186        let inputs = ort::inputs![
187            "input_ids" => input_ids_tensor,
188            "attention_mask" => attention_mask_tensor,
189            "token_type_ids" => token_type_ids_tensor,
190        ];
191        let outputs = session.run(inputs)?;
192
193        let output_tensor = outputs["last_hidden_state"].try_extract_tensor::<f32>()?;
194
195        let (shape, data) = output_tensor;
196        let _batch_size = shape[0] as usize;
197        let _seq_len = shape[1] as usize;
198        let hidden_size = shape[2] as usize;
199
200        let cls_start = 0;
201        let cls_end = hidden_size;
202        let embedding: Vec<f32> = data[cls_start..cls_end].to_vec();
203
204        let embedding_array = Array1::from_vec(embedding);
205        let norm = embedding_array.mapv(|x| x * x).sum().sqrt();
206        let normalized = if norm > 0.0 {
207            embedding_array / norm
208        } else {
209            embedding_array
210        };
211
212        Ok(normalized)
213    }
214
215    pub fn model_version(&self) -> String {
216        self.config.model_name.clone()
217    }
218
219    fn model_path(&self) -> PathBuf {
220        self.model_dir.join("model.onnx")
221    }
222
223    fn tokenizer_path(&self) -> PathBuf {
224        self.model_dir.join("tokenizer.json")
225    }
226}