1use crate::real_time_embedding_pipeline::traits::{
4 ContentItem, EmbeddingGenerator, GeneratorStatistics, ProcessingResult, ProcessingStatus,
5};
6use crate::Vector;
7use anyhow::{anyhow, Result};
8use scirs2_core::random::Random;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::time::{Duration, Instant};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct TensorFlowConfig {
17 pub model_path: PathBuf,
18 pub input_name: String,
19 pub output_name: String,
20 pub device: TensorFlowDevice,
21 pub batch_size: usize,
22 pub max_sequence_length: usize,
23 pub optimization_level: OptimizationLevel,
24 pub use_mixed_precision: bool,
25 pub session_config: SessionConfig,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub enum TensorFlowDevice {
31 Cpu { num_threads: Option<usize> },
32 Gpu { device_id: i32, memory_growth: bool },
33 Tpu { worker: String },
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub enum OptimizationLevel {
39 None,
40 Basic,
41 Extended,
42 Aggressive,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct SessionConfig {
48 pub inter_op_parallelism_threads: Option<usize>,
49 pub intra_op_parallelism_threads: Option<usize>,
50 pub allow_soft_placement: bool,
51 pub log_device_placement: bool,
52}
53
54impl Default for TensorFlowConfig {
55 fn default() -> Self {
56 Self {
57 model_path: PathBuf::from("./models/universal-sentence-encoder"),
58 input_name: "inputs".to_string(),
59 output_name: "outputs".to_string(),
60 device: TensorFlowDevice::Cpu { num_threads: None },
61 batch_size: 32,
62 max_sequence_length: 512,
63 optimization_level: OptimizationLevel::Basic,
64 use_mixed_precision: false,
65 session_config: SessionConfig::default(),
66 }
67 }
68}
69
70impl Default for SessionConfig {
71 fn default() -> Self {
72 Self {
73 inter_op_parallelism_threads: None,
74 intra_op_parallelism_threads: None,
75 allow_soft_placement: true,
76 log_device_placement: false,
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct TensorFlowModelInfo {
84 pub model_path: PathBuf,
85 pub input_signature: Vec<TensorSpec>,
86 pub output_signature: Vec<TensorSpec>,
87 pub model_version: String,
88 pub dimensions: usize,
89 pub preprocessing_required: bool,
90}
91
92#[derive(Debug, Clone)]
94pub struct TensorSpec {
95 pub name: String,
96 pub dtype: TensorDataType,
97 pub shape: Vec<Option<i64>>,
98}
99
100#[derive(Debug, Clone)]
102pub enum TensorDataType {
103 Float32,
104 Float64,
105 Int32,
106 Int64,
107 String,
108 Bool,
109}
110
111#[derive(Debug)]
113pub struct TensorFlowEmbedder {
114 config: TensorFlowConfig,
115 model_info: Option<TensorFlowModelInfo>,
116 session_initialized: bool,
117 preprocessing_pipeline: PreprocessingPipeline,
118}
119
120#[derive(Debug)]
122pub struct PreprocessingPipeline {
123 pub lowercase: bool,
124 pub remove_punctuation: bool,
125 pub tokenizer: Option<String>,
126 pub vocabulary: Option<HashMap<String, i32>>,
127}
128
129impl Default for PreprocessingPipeline {
130 fn default() -> Self {
131 Self {
132 lowercase: true,
133 remove_punctuation: false,
134 tokenizer: None,
135 vocabulary: None,
136 }
137 }
138}
139
140impl TensorFlowEmbedder {
141 pub fn new(config: TensorFlowConfig) -> Result<Self> {
143 Ok(Self {
144 config,
145 model_info: None,
146 session_initialized: false,
147 preprocessing_pipeline: PreprocessingPipeline::default(),
148 })
149 }
150
151 pub fn load_model(&mut self) -> Result<()> {
153 if !self.config.model_path.exists() {
154 return Err(anyhow!(
155 "Model path does not exist: {:?}",
156 self.config.model_path
157 ));
158 }
159
160 let model_info = TensorFlowModelInfo {
162 model_path: self.config.model_path.clone(),
163 input_signature: vec![TensorSpec {
164 name: self.config.input_name.clone(),
165 dtype: TensorDataType::String,
166 shape: vec![None, None], }],
168 output_signature: vec![TensorSpec {
169 name: self.config.output_name.clone(),
170 dtype: TensorDataType::Float32,
171 shape: vec![None, Some(512)], }],
173 model_version: "1.0.0".to_string(),
174 dimensions: 512,
175 preprocessing_required: true,
176 };
177
178 self.model_info = Some(model_info);
179 self.session_initialized = true;
180 Ok(())
181 }
182
183 pub fn embed_text(&self, text: &str) -> Result<Vector> {
185 if !self.session_initialized {
186 return Err(anyhow!("Model not loaded. Call load_model() first."));
187 }
188
189 let preprocessed_text = self.preprocess_text(text)?;
190 let embedding = self.run_inference(&preprocessed_text)?;
191 Ok(Vector::new(embedding))
192 }
193
194 pub fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vector>> {
196 if !self.session_initialized {
197 return Err(anyhow!("Model not loaded. Call load_model() first."));
198 }
199
200 let mut results = Vec::new();
201 for text in texts {
202 let embedding = self.embed_text(text)?;
203 results.push(embedding);
204 }
205 Ok(results)
206 }
207
208 fn preprocess_text(&self, text: &str) -> Result<String> {
210 let mut processed = text.to_string();
211
212 if self.preprocessing_pipeline.lowercase {
213 processed = processed.to_lowercase();
214 }
215
216 if self.preprocessing_pipeline.remove_punctuation {
217 processed = processed
218 .chars()
219 .filter(|c| c.is_alphanumeric() || c.is_whitespace())
220 .collect();
221 }
222
223 if processed.len() > self.config.max_sequence_length {
225 processed.truncate(self.config.max_sequence_length);
226 }
227
228 Ok(processed)
229 }
230
231 fn run_inference(&self, text: &str) -> Result<Vec<f32>> {
233 let model_info = self
234 .model_info
235 .as_ref()
236 .ok_or_else(|| anyhow!("Model info not available"))?;
237
238 let mut rng = Random::seed(text.len() as u64);
240
241 let mut embedding = vec![0.0f32; model_info.dimensions];
242 for value in &mut embedding {
243 *value = rng.gen_range(-1.0..1.0);
244 }
245
246 let norm = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
248 if norm > 0.0 {
249 for x in &mut embedding {
250 *x /= norm;
251 }
252 }
253
254 Ok(embedding)
255 }
256
257 pub fn get_model_info(&self) -> Option<&TensorFlowModelInfo> {
259 self.model_info.as_ref()
260 }
261
262 pub fn get_dimensions(&self) -> Option<usize> {
264 self.model_info.as_ref().map(|info| info.dimensions)
265 }
266
267 pub fn set_preprocessing_pipeline(&mut self, pipeline: PreprocessingPipeline) {
269 self.preprocessing_pipeline = pipeline;
270 }
271}
272
273#[derive(Debug)]
275pub struct TensorFlowModelServer {
276 models: HashMap<String, TensorFlowEmbedder>,
277 default_model: String,
278 server_config: ServerConfig,
279}
280
281#[derive(Debug, Clone)]
283pub struct ServerConfig {
284 pub model_warming: bool,
285 pub request_batching: bool,
286 pub max_batch_size: usize,
287 pub batch_timeout_ms: u64,
288 pub model_versions: HashMap<String, String>,
289}
290
291impl Default for ServerConfig {
292 fn default() -> Self {
293 Self {
294 model_warming: true,
295 request_batching: true,
296 max_batch_size: 64,
297 batch_timeout_ms: 10,
298 model_versions: HashMap::new(),
299 }
300 }
301}
302
303impl TensorFlowModelServer {
304 pub fn new(default_model: String, config: ServerConfig) -> Self {
306 Self {
307 models: HashMap::new(),
308 default_model,
309 server_config: config,
310 }
311 }
312
313 pub fn register_model(&mut self, name: String, embedder: TensorFlowEmbedder) -> Result<()> {
315 self.models.insert(name.clone(), embedder);
316
317 if self.server_config.model_warming {
318 if let Some(model) = self.models.get(&name) {
319 let _ = model.embed_text("warmup text");
321 }
322 }
323
324 Ok(())
325 }
326
327 pub fn list_models(&self) -> Vec<String> {
329 self.models.keys().cloned().collect()
330 }
331
332 pub fn embed_with_model(&self, model_name: &str, texts: &[String]) -> Result<Vec<Vector>> {
334 let model = self
335 .models
336 .get(model_name)
337 .ok_or_else(|| anyhow!("Model not found: {}", model_name))?;
338
339 if self.server_config.request_batching && texts.len() > 1 {
340 model.embed_batch(texts)
341 } else {
342 let mut results = Vec::new();
343 for text in texts {
344 results.push(model.embed_text(text)?);
345 }
346 Ok(results)
347 }
348 }
349
350 pub fn embed(&self, texts: &[String]) -> Result<Vec<Vector>> {
352 self.embed_with_model(&self.default_model, texts)
353 }
354
355 pub fn get_model_info(&self, model_name: &str) -> Option<&TensorFlowModelInfo> {
357 self.models.get(model_name)?.get_model_info()
358 }
359
360 pub fn update_config(&mut self, config: ServerConfig) {
362 self.server_config = config;
363 }
364}
365
366impl EmbeddingGenerator for TensorFlowEmbedder {
367 fn generate_embedding(&self, content: &ContentItem) -> Result<Vector> {
368 self.embed_text(&content.content)
369 }
370
371 fn generate_batch_embeddings(&self, content: &[ContentItem]) -> Result<Vec<ProcessingResult>> {
372 let mut results = Vec::new();
373
374 for item in content {
375 let start_time = Instant::now();
376 let vector_result = self.generate_embedding(item);
377 let duration = start_time.elapsed();
378
379 let result = match vector_result {
380 Ok(vector) => ProcessingResult {
381 item: item.clone(),
382 vector: Some(vector),
383 status: ProcessingStatus::Completed,
384 duration,
385 error: None,
386 metadata: HashMap::new(),
387 },
388 Err(e) => ProcessingResult {
389 item: item.clone(),
390 vector: None,
391 status: ProcessingStatus::Failed {
392 reason: e.to_string(),
393 },
394 duration,
395 error: Some(e.to_string()),
396 metadata: HashMap::new(),
397 },
398 };
399
400 results.push(result);
401 }
402
403 Ok(results)
404 }
405
406 fn embedding_dimensions(&self) -> usize {
407 self.get_dimensions().unwrap_or(512)
408 }
409
410 fn get_config(&self) -> serde_json::Value {
411 serde_json::to_value(&self.config).unwrap_or_default()
412 }
413
414 fn is_ready(&self) -> bool {
415 self.session_initialized
416 }
417
418 fn get_statistics(&self) -> GeneratorStatistics {
419 GeneratorStatistics {
420 total_embeddings: 0,
421 total_processing_time: Duration::from_millis(0),
422 average_processing_time: Duration::from_millis(0),
423 error_count: 0,
424 last_error: None,
425 }
426 }
427}
428
429#[cfg(test)]
430#[allow(unused_imports, clippy::useless_vec)]
431mod tests {
432 use super::*;
433 use std::path::PathBuf;
434
435 #[test]
436 fn test_tensorflow_config_creation() {
437 let config = TensorFlowConfig::default();
438 assert_eq!(config.batch_size, 32);
439 assert_eq!(config.max_sequence_length, 512);
440 assert!(matches!(config.device, TensorFlowDevice::Cpu { .. }));
441 }
442
443 #[test]
444 fn test_tensorflow_embedder_creation() {
445 let config = TensorFlowConfig::default();
446 let embedder = TensorFlowEmbedder::new(config);
447 assert!(embedder.is_ok());
448 }
449
450 #[test]
451 fn test_preprocessing_pipeline() {
452 let mut embedder = TensorFlowEmbedder::new(TensorFlowConfig::default()).unwrap();
453 let pipeline = PreprocessingPipeline {
454 lowercase: true,
455 remove_punctuation: true,
456 ..Default::default()
457 };
458 embedder.set_preprocessing_pipeline(pipeline);
459
460 let processed = embedder.preprocess_text("Hello, World!").unwrap();
461 assert_eq!(processed, "hello world");
462 }
463
464 #[test]
465 fn test_model_server_creation() {
466 let server = TensorFlowModelServer::new("default".to_string(), ServerConfig::default());
467 assert_eq!(server.default_model, "default");
468 assert!(server.list_models().is_empty());
469 }
470
471 #[test]
472 fn test_model_registration() {
473 let mut server =
474 TensorFlowModelServer::new("test_model".to_string(), ServerConfig::default());
475
476 let config = TensorFlowConfig::default();
477 let embedder = TensorFlowEmbedder::new(config).unwrap();
478
479 let result = server.register_model("test_model".to_string(), embedder);
480 assert!(result.is_ok());
481 assert_eq!(server.list_models().len(), 1);
482 }
483
484 #[test]
485 fn test_tensor_spec_creation() {
486 let spec = TensorSpec {
487 name: "input".to_string(),
488 dtype: TensorDataType::Float32,
489 shape: vec![None, Some(512)],
490 };
491 assert_eq!(spec.name, "input");
492 assert!(matches!(spec.dtype, TensorDataType::Float32));
493 }
494
495 #[test]
496 fn test_session_config_default() {
497 let config = SessionConfig::default();
498 assert!(config.allow_soft_placement);
499 assert!(!config.log_device_placement);
500 assert!(config.inter_op_parallelism_threads.is_none());
501 }
502
503 #[test]
504 fn test_device_configuration() {
505 let cpu_device = TensorFlowDevice::Cpu {
506 num_threads: Some(4),
507 };
508 let gpu_device = TensorFlowDevice::Gpu {
509 device_id: 0,
510 memory_growth: true,
511 };
512
513 assert!(matches!(cpu_device, TensorFlowDevice::Cpu { .. }));
514 assert!(matches!(gpu_device, TensorFlowDevice::Gpu { .. }));
515 }
516
517 #[test]
518 fn test_optimization_levels() {
519 let levels = vec![
520 OptimizationLevel::None,
521 OptimizationLevel::Basic,
522 OptimizationLevel::Extended,
523 OptimizationLevel::Aggressive,
524 ];
525 assert_eq!(levels.len(), 4);
526 }
527}