1#![forbid(unsafe_code)]
2#![warn(missing_docs)]
3#![warn(missing_debug_implementations)]
4#![allow(
5 clippy::pedantic,
6 clippy::missing_errors_doc,
7 reason = "error types documented at enum level"
8)]
9
10mod rules;
24
25#[cfg(feature = "onnx")]
26mod embedder;
27
28#[cfg(feature = "onnx")]
29use std::cell::RefCell;
30use std::fmt;
31
32use rules::{FieldAction, classify_field, detect_category};
33use serde_json::Value;
34
35#[derive(Debug, thiserror::Error)]
37pub enum EmbedderError {
38 #[error("I/O error: {0}")]
40 Io(#[from] std::io::Error),
41
42 #[error("Model not found at {0}")]
44 ModelNotFound(std::path::PathBuf),
45
46 #[error("Tokenizer not found at {0}")]
48 TokenizerNotFound(std::path::PathBuf),
49
50 #[error("Tokenizer load error: {0}")]
52 TokenizerLoad(String),
53
54 #[error("Tokenization error: {0}")]
56 Tokenize(String),
57
58 #[cfg(feature = "onnx")]
60 #[error("ONNX error: {0}")]
61 Ort(String),
62
63 #[error("Download error: {0}")]
65 Download(String),
66}
67
68pub struct SemanticCompressor {
74 threshold: f32,
77 #[cfg(feature = "onnx")]
80 embedder: Option<RefCell<embedder::Embedder>>,
81}
82
83impl fmt::Debug for SemanticCompressor {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 f.debug_struct("SemanticCompressor")
86 .field("threshold", &self.threshold)
87 .finish_non_exhaustive()
88 }
89}
90
91impl Default for SemanticCompressor {
92 fn default() -> Self {
93 Self {
94 threshold: 0.3,
95 #[cfg(feature = "onnx")]
96 embedder: None,
97 }
98 }
99}
100
101impl SemanticCompressor {
102 #[must_use]
106 pub fn new() -> Self {
107 Self::default()
108 }
109
110 pub fn load_onnx(&mut self) -> Result<bool, EmbedderError> {
119 #[cfg(feature = "onnx")]
120 {
121 let model_dir = model_dir();
122 embedder::ensure_models(&model_dir)?;
123
124 match embedder::Embedder::load(&model_dir) {
125 Ok(e) => {
126 self.embedder = Some(RefCell::new(e));
127 tracing::info!("ONNX embedder loaded (Level 2 enabled)");
128 Ok(true)
129 }
130 Err(e) => {
131 tracing::warn!("Failed to load ONNX model, falling back to Level 1: {e}");
132 Ok(false)
133 }
134 }
135 }
136 #[cfg(not(feature = "onnx"))]
137 {
138 let _ = self;
139 Ok(false)
140 }
141 }
142
143 #[must_use]
151 pub fn compress(&self, value: &Value, context: &str) -> Value {
152 #[cfg(feature = "onnx")]
153 if let Some(ref embedder) = self.embedder
154 && let Ok(ctx_embedding) = embedder.borrow_mut().embed(context)
155 {
156 return self.compress_with_embedding(value, &ctx_embedding, embedder);
157 }
158
159 let category = detect_category(context);
161 self.compress_with_rules(value, category, context)
162 }
163
164 #[must_use]
167 pub fn is_field_kept(&self, field_name: &str, context: &str) -> bool {
168 let category = detect_category(context);
169 matches!(classify_field(field_name, category), FieldAction::Keep)
170 }
171
172 #[must_use]
174 pub fn detect_category(&self, context: &str) -> &'static str {
175 detect_category(context)
176 }
177
178 #[allow(
181 clippy::only_used_in_recursion,
182 reason = "parameters needed for recursive calls"
183 )]
184 fn compress_with_rules(&self, value: &Value, category: &str, context: &str) -> Value {
185 match value {
186 Value::Object(obj) => {
187 let mut result = serde_json::Map::new();
188 for (key, val) in obj {
189 match classify_field(key, category) {
190 FieldAction::Drop => {}
191 FieldAction::Keep | FieldAction::Truncate => {
192 let compressed_val = self.compress_with_rules(val, category, context);
193 result.insert(key.clone(), compressed_val);
194 }
195 }
196 }
197 Value::Object(result)
198 }
199 Value::Array(arr) => {
200 let compressed: Vec<Value> = arr
201 .iter()
202 .map(|v| self.compress_with_rules(v, category, context))
203 .collect();
204 Value::Array(compressed)
205 }
206 other => other.clone(),
207 }
208 }
209
210 #[cfg(feature = "onnx")]
213 fn compress_with_embedding(
214 &self,
215 value: &Value,
216 ctx_embedding: &[f32],
217 embedder: &RefCell<embedder::Embedder>,
218 ) -> Value {
219 match value {
220 Value::Object(obj) => {
221 let mut result = serde_json::Map::new();
222 for (key, val) in obj {
223 if let Ok(field_emb) = embedder.borrow_mut().embed(key) {
224 let sim = embedder::Embedder::cosine_similarity(ctx_embedding, &field_emb);
225 if sim < self.threshold {
226 continue; }
228 }
229 let compressed_val = self.compress_with_embedding(val, ctx_embedding, embedder);
230 result.insert(key.clone(), compressed_val);
231 }
232 Value::Object(result)
233 }
234 Value::Array(arr) => {
235 let compressed: Vec<Value> = arr
236 .iter()
237 .map(|v| self.compress_with_embedding(v, ctx_embedding, embedder))
238 .collect();
239 Value::Array(compressed)
240 }
241 other => other.clone(),
242 }
243 }
244}
245
246#[cfg(feature = "onnx")]
248fn model_dir() -> std::path::PathBuf {
249 dirs::home_dir()
250 .unwrap_or_else(|| std::path::PathBuf::from("."))
251 .join(".tokenfleet-ai")
252 .join("tokenless")
253 .join("models")
254}
255
256#[cfg(test)]
259mod tests {
260 #![allow(clippy::unwrap_used, clippy::expect_used)]
261
262 use serde_json::json;
263
264 use super::*;
265
266 #[test]
267 fn test_compress_weather_drops_station_id() {
268 let compressor = SemanticCompressor::new();
269 let value = json!({
270 "temperature": 22.5,
271 "wind_speed": 12.0,
272 "station_id": "WX-001",
273 "sensor_version": "3.1.0",
274 });
275 let result = compressor.compress(&value, "今天天气怎么样");
276 assert!(result.get("temperature").is_some());
277 assert!(result.get("wind_speed").is_some());
278 assert!(result.get("station_id").is_none());
279 assert!(result.get("sensor_version").is_none());
280 }
281
282 #[test]
283 fn test_compress_devops_drops_uid() {
284 let compressor = SemanticCompressor::new();
285 let value = json!({
286 "pod_status": "Running",
287 "cpu_usage": 0.45,
288 "uid": "abc-123-def",
289 "self_link": "/api/v1/...",
290 });
291 let result = compressor.compress(&value, "deploy to kubernetes");
292 assert!(result.get("pod_status").is_some());
293 assert!(result.get("cpu_usage").is_some());
294 assert!(result.get("uid").is_none());
295 assert!(result.get("self_link").is_none());
296 }
297
298 #[test]
299 fn test_compress_default_drops_debug() {
300 let compressor = SemanticCompressor::new();
301 let value = json!({
302 "name": "Alice",
303 "age": 30,
304 "debug": "some debug info",
305 "trace": "trace data",
306 });
307 let result = compressor.compress(&value, "hello");
308 assert!(result.get("name").is_some());
309 assert!(result.get("age").is_some());
310 assert!(result.get("debug").is_none());
311 assert!(result.get("trace").is_none());
312 }
313
314 #[test]
315 fn test_compress_nested_object() {
316 let compressor = SemanticCompressor::new();
317 let value = json!({
318 "data": {
319 "temperature": 22.5,
320 "station_id": "WX-001",
321 "nested": {
322 "wind_speed": 12.0,
323 "calibration_date": "2025-01-01",
324 }
325 }
326 });
327 let result = compressor.compress(&value, "天气");
328 let data = &result["data"];
329 assert!(data["temperature"].is_f64());
330 assert!(data.get("station_id").is_none());
331 let nested = &data["nested"];
332 assert!(nested["wind_speed"].is_f64());
333 assert!(nested.get("calibration_date").is_none());
334 }
335
336 #[test]
337 fn test_compress_array_of_objects() {
338 let compressor = SemanticCompressor::new();
339 let value = json!([
340 {"temperature": 22.5, "station_id": "A"},
341 {"temperature": 18.0, "station_id": "B"},
342 ]);
343 let result = compressor.compress(&value, "天气");
344 let arr = result.as_array().unwrap();
345 assert_eq!(arr.len(), 2);
346 assert!(arr[0].get("station_id").is_none());
347 assert!(arr[1].get("station_id").is_none());
348 }
349
350 #[test]
351 fn test_is_field_kept() {
352 let compressor = SemanticCompressor::new();
353 assert!(compressor.is_field_kept("temperature", "天气怎么样"));
354 assert!(!compressor.is_field_kept("station_id", "天气怎么样"));
355 }
356
357 #[test]
358 fn test_detect_category_public() {
359 let compressor = SemanticCompressor::new();
360 assert_eq!(compressor.detect_category("天气"), "weather");
361 assert_eq!(compressor.detect_category("unknown"), "default");
362 }
363}