1#![forbid(unsafe_code)]
2#![warn(missing_docs)]
3#![warn(missing_debug_implementations)]
4#![allow(
5 clippy::pedantic,
6 clippy::missing_errors_doc,
7 reason = "error types documented at enum level"
8)]
9
10mod rules;
24
25#[cfg(feature = "onnx")]
26mod embedder;
27
28#[cfg(feature = "onnx")]
29use std::cell::RefCell;
30use std::fmt;
31
32use rules::{FieldAction, classify_field, detect_category};
33use serde_json::Value;
34
35#[derive(Debug, thiserror::Error)]
37pub enum EmbedderError {
38 #[error("I/O error: {0}")]
40 Io(#[from] std::io::Error),
41
42 #[error("Model not found at {0}")]
44 ModelNotFound(std::path::PathBuf),
45
46 #[error("Tokenizer not found at {0}")]
48 TokenizerNotFound(std::path::PathBuf),
49
50 #[error("Tokenizer load error: {0}")]
52 TokenizerLoad(String),
53
54 #[error("Tokenization error: {0}")]
56 Tokenize(String),
57
58 #[cfg(feature = "onnx")]
60 #[error("ONNX error: {0}")]
61 Ort(String),
62
63 #[error("Download error: {0}")]
65 Download(String),
66}
67
68pub struct SemanticCompressor {
74 threshold: f32,
77 #[cfg(feature = "onnx")]
80 embedder: Option<RefCell<embedder::Embedder>>,
81}
82
83impl fmt::Debug for SemanticCompressor {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 f.debug_struct("SemanticCompressor")
86 .field("threshold", &self.threshold)
87 .finish_non_exhaustive()
88 }
89}
90
91impl Default for SemanticCompressor {
92 fn default() -> Self {
93 Self {
94 threshold: 0.3,
95 #[cfg(feature = "onnx")]
96 embedder: None,
97 }
98 }
99}
100
101impl SemanticCompressor {
102 #[must_use]
106 pub fn new() -> Self {
107 Self::default()
108 }
109
110 pub fn load_onnx(&mut self) -> Result<bool, EmbedderError> {
119 #[cfg(feature = "onnx")]
120 {
121 let model_dir = model_dir();
122 embedder::ensure_models(&model_dir)?;
123
124 match embedder::Embedder::load(&model_dir) {
125 Ok(e) => {
126 self.embedder = Some(RefCell::new(e));
127 tracing::info!("ONNX embedder loaded (Level 2 enabled)");
128 Ok(true)
129 }
130 Err(e) => {
131 tracing::warn!("Failed to load ONNX model, falling back to Level 1: {e}");
132 Ok(false)
133 }
134 }
135 }
136 #[cfg(not(feature = "onnx"))]
137 {
138 let _ = self;
139 Ok(false)
140 }
141 }
142
143 #[must_use]
151 pub fn compress(&self, value: &Value, context: &str) -> Value {
152 #[cfg(feature = "onnx")]
153 if let Some(ref embedder) = self.embedder
154 && let Ok(ctx_embedding) = embedder.borrow_mut().embed(context)
155 {
156 return self.compress_with_embedding(value, &ctx_embedding, embedder);
157 }
158
159 let category = detect_category(context);
161 self.compress_with_rules(value, category, context)
162 }
163
164 #[must_use]
167 pub fn is_field_kept(&self, field_name: &str, context: &str) -> bool {
168 let category = detect_category(context);
169 matches!(classify_field(field_name, category), FieldAction::Keep)
170 }
171
172 #[must_use]
174 pub fn detect_category(&self, context: &str) -> &'static str {
175 detect_category(context)
176 }
177
178 #[allow(
181 clippy::only_used_in_recursion,
182 reason = "parameters needed for recursive calls"
183 )]
184 fn compress_with_rules(&self, value: &Value, category: &str, context: &str) -> Value {
185 match value {
186 Value::Object(obj) => {
187 let mut result = serde_json::Map::new();
188 for (key, val) in obj {
189 match classify_field(key, category) {
190 FieldAction::Drop => {}
191 FieldAction::Keep | FieldAction::Truncate => {
192 let compressed_val = self.compress_with_rules(val, category, context);
193 result.insert(key.clone(), compressed_val);
194 }
195 }
196 }
197 Value::Object(result)
198 }
199 Value::Array(arr) => {
200 let compressed: Vec<Value> = arr
201 .iter()
202 .map(|v| self.compress_with_rules(v, category, context))
203 .collect();
204 Value::Array(compressed)
205 }
206 other => other.clone(),
207 }
208 }
209
210 #[cfg(feature = "onnx")]
213 fn compress_with_embedding(
214 &self,
215 value: &Value,
216 ctx_embedding: &[f32],
217 embedder: &RefCell<embedder::Embedder>,
218 ) -> Value {
219 match value {
220 Value::Object(obj) => {
221 let mut result = serde_json::Map::new();
222 for (key, val) in obj {
223 if let Ok(field_emb) = embedder.borrow_mut().embed(key) {
224 let sim = embedder::Embedder::cosine_similarity(ctx_embedding, &field_emb);
225 if sim < self.threshold {
226 continue; }
228 }
229 let compressed_val = self.compress_with_embedding(val, ctx_embedding, embedder);
230 result.insert(key.clone(), compressed_val);
231 }
232 Value::Object(result)
233 }
234 Value::Array(arr) => {
235 let compressed: Vec<Value> = arr
236 .iter()
237 .map(|v| self.compress_with_embedding(v, ctx_embedding, embedder))
238 .collect();
239 Value::Array(compressed)
240 }
241 other => other.clone(),
242 }
243 }
244}
245
246#[cfg(feature = "onnx")]
248fn model_dir() -> std::path::PathBuf {
249 dirs::home_dir()
250 .unwrap_or_else(|| std::path::PathBuf::from("."))
251 .join(".tokenless")
252 .join("models")
253}
254
255#[cfg(test)]
258mod tests {
259 #![allow(clippy::unwrap_used, clippy::expect_used)]
260
261 use serde_json::json;
262
263 use super::*;
264
265 #[test]
266 fn test_compress_weather_drops_station_id() {
267 let compressor = SemanticCompressor::new();
268 let value = json!({
269 "temperature": 22.5,
270 "wind_speed": 12.0,
271 "station_id": "WX-001",
272 "sensor_version": "3.1.0",
273 });
274 let result = compressor.compress(&value, "今天天气怎么样");
275 assert!(result.get("temperature").is_some());
276 assert!(result.get("wind_speed").is_some());
277 assert!(result.get("station_id").is_none());
278 assert!(result.get("sensor_version").is_none());
279 }
280
281 #[test]
282 fn test_compress_devops_drops_uid() {
283 let compressor = SemanticCompressor::new();
284 let value = json!({
285 "pod_status": "Running",
286 "cpu_usage": 0.45,
287 "uid": "abc-123-def",
288 "self_link": "/api/v1/...",
289 });
290 let result = compressor.compress(&value, "deploy to kubernetes");
291 assert!(result.get("pod_status").is_some());
292 assert!(result.get("cpu_usage").is_some());
293 assert!(result.get("uid").is_none());
294 assert!(result.get("self_link").is_none());
295 }
296
297 #[test]
298 fn test_compress_default_drops_debug() {
299 let compressor = SemanticCompressor::new();
300 let value = json!({
301 "name": "Alice",
302 "age": 30,
303 "debug": "some debug info",
304 "trace": "trace data",
305 });
306 let result = compressor.compress(&value, "hello");
307 assert!(result.get("name").is_some());
308 assert!(result.get("age").is_some());
309 assert!(result.get("debug").is_none());
310 assert!(result.get("trace").is_none());
311 }
312
313 #[test]
314 fn test_compress_nested_object() {
315 let compressor = SemanticCompressor::new();
316 let value = json!({
317 "data": {
318 "temperature": 22.5,
319 "station_id": "WX-001",
320 "nested": {
321 "wind_speed": 12.0,
322 "calibration_date": "2025-01-01",
323 }
324 }
325 });
326 let result = compressor.compress(&value, "天气");
327 let data = &result["data"];
328 assert!(data["temperature"].is_f64());
329 assert!(data.get("station_id").is_none());
330 let nested = &data["nested"];
331 assert!(nested["wind_speed"].is_f64());
332 assert!(nested.get("calibration_date").is_none());
333 }
334
335 #[test]
336 fn test_compress_array_of_objects() {
337 let compressor = SemanticCompressor::new();
338 let value = json!([
339 {"temperature": 22.5, "station_id": "A"},
340 {"temperature": 18.0, "station_id": "B"},
341 ]);
342 let result = compressor.compress(&value, "天气");
343 let arr = result.as_array().unwrap();
344 assert_eq!(arr.len(), 2);
345 assert!(arr[0].get("station_id").is_none());
346 assert!(arr[1].get("station_id").is_none());
347 }
348
349 #[test]
350 fn test_is_field_kept() {
351 let compressor = SemanticCompressor::new();
352 assert!(compressor.is_field_kept("temperature", "天气怎么样"));
353 assert!(!compressor.is_field_kept("station_id", "天气怎么样"));
354 }
355
356 #[test]
357 fn test_detect_category_public() {
358 let compressor = SemanticCompressor::new();
359 assert_eq!(compressor.detect_category("天气"), "weather");
360 assert_eq!(compressor.detect_category("unknown"), "default");
361 }
362}