1use crate::real_time_embedding_pipeline::traits::{
4 ContentItem, EmbeddingGenerator, GeneratorStatistics, ProcessingResult, ProcessingStatus,
5};
6use crate::Vector;
7use anyhow::{anyhow, Result};
8use scirs2_core::random::Random;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::time::{Duration, Instant};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct TensorFlowConfig {
17 pub model_path: PathBuf,
18 pub input_name: String,
19 pub output_name: String,
20 pub device: TensorFlowDevice,
21 pub batch_size: usize,
22 pub max_sequence_length: usize,
23 pub optimization_level: OptimizationLevel,
24 pub use_mixed_precision: bool,
25 pub session_config: SessionConfig,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub enum TensorFlowDevice {
31 Cpu { num_threads: Option<usize> },
32 Gpu { device_id: i32, memory_growth: bool },
33 Tpu { worker: String },
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub enum OptimizationLevel {
39 None,
40 Basic,
41 Extended,
42 Aggressive,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct SessionConfig {
48 pub inter_op_parallelism_threads: Option<usize>,
49 pub intra_op_parallelism_threads: Option<usize>,
50 pub allow_soft_placement: bool,
51 pub log_device_placement: bool,
52}
53
54impl Default for TensorFlowConfig {
55 fn default() -> Self {
56 Self {
57 model_path: PathBuf::from("./models/universal-sentence-encoder"),
58 input_name: "inputs".to_string(),
59 output_name: "outputs".to_string(),
60 device: TensorFlowDevice::Cpu { num_threads: None },
61 batch_size: 32,
62 max_sequence_length: 512,
63 optimization_level: OptimizationLevel::Basic,
64 use_mixed_precision: false,
65 session_config: SessionConfig::default(),
66 }
67 }
68}
69
70impl Default for SessionConfig {
71 fn default() -> Self {
72 Self {
73 inter_op_parallelism_threads: None,
74 intra_op_parallelism_threads: None,
75 allow_soft_placement: true,
76 log_device_placement: false,
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct TensorFlowModelInfo {
84 pub model_path: PathBuf,
85 pub input_signature: Vec<TensorSpec>,
86 pub output_signature: Vec<TensorSpec>,
87 pub model_version: String,
88 pub dimensions: usize,
89 pub preprocessing_required: bool,
90}
91
92#[derive(Debug, Clone)]
94pub struct TensorSpec {
95 pub name: String,
96 pub dtype: TensorDataType,
97 pub shape: Vec<Option<i64>>,
98}
99
100#[derive(Debug, Clone)]
102pub enum TensorDataType {
103 Float32,
104 Float64,
105 Int32,
106 Int64,
107 String,
108 Bool,
109}
110
111#[derive(Debug)]
113pub struct TensorFlowEmbedder {
114 config: TensorFlowConfig,
115 model_info: Option<TensorFlowModelInfo>,
116 session_initialized: bool,
117 preprocessing_pipeline: PreprocessingPipeline,
118}
119
120#[derive(Debug)]
122pub struct PreprocessingPipeline {
123 pub lowercase: bool,
124 pub remove_punctuation: bool,
125 pub tokenizer: Option<String>,
126 pub vocabulary: Option<HashMap<String, i32>>,
127}
128
129impl Default for PreprocessingPipeline {
130 fn default() -> Self {
131 Self {
132 lowercase: true,
133 remove_punctuation: false,
134 tokenizer: None,
135 vocabulary: None,
136 }
137 }
138}
139
140impl TensorFlowEmbedder {
141 pub fn new(config: TensorFlowConfig) -> Result<Self> {
143 Ok(Self {
144 config,
145 model_info: None,
146 session_initialized: false,
147 preprocessing_pipeline: PreprocessingPipeline::default(),
148 })
149 }
150
151 pub fn load_model(&mut self) -> Result<()> {
153 if !self.config.model_path.exists() {
154 return Err(anyhow!(
155 "Model path does not exist: {:?}",
156 self.config.model_path
157 ));
158 }
159
160 let model_info = TensorFlowModelInfo {
162 model_path: self.config.model_path.clone(),
163 input_signature: vec![TensorSpec {
164 name: self.config.input_name.clone(),
165 dtype: TensorDataType::String,
166 shape: vec![None, None], }],
168 output_signature: vec![TensorSpec {
169 name: self.config.output_name.clone(),
170 dtype: TensorDataType::Float32,
171 shape: vec![None, Some(512)], }],
173 model_version: "1.0.0".to_string(),
174 dimensions: 512,
175 preprocessing_required: true,
176 };
177
178 self.model_info = Some(model_info);
179 self.session_initialized = true;
180 Ok(())
181 }
182
183 pub fn embed_text(&self, text: &str) -> Result<Vector> {
185 if !self.session_initialized {
186 return Err(anyhow!("Model not loaded. Call load_model() first."));
187 }
188
189 let preprocessed_text = self.preprocess_text(text)?;
190 let embedding = self.run_inference(&preprocessed_text)?;
191 Ok(Vector::new(embedding))
192 }
193
194 pub fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vector>> {
196 if !self.session_initialized {
197 return Err(anyhow!("Model not loaded. Call load_model() first."));
198 }
199
200 let mut results = Vec::new();
201 for text in texts {
202 let embedding = self.embed_text(text)?;
203 results.push(embedding);
204 }
205 Ok(results)
206 }
207
208 fn preprocess_text(&self, text: &str) -> Result<String> {
210 let mut processed = text.to_string();
211
212 if self.preprocessing_pipeline.lowercase {
213 processed = processed.to_lowercase();
214 }
215
216 if self.preprocessing_pipeline.remove_punctuation {
217 processed = processed
218 .chars()
219 .filter(|c| c.is_alphanumeric() || c.is_whitespace())
220 .collect();
221 }
222
223 if processed.len() > self.config.max_sequence_length {
225 processed.truncate(self.config.max_sequence_length);
226 }
227
228 Ok(processed)
229 }
230
231 fn run_inference(&self, text: &str) -> Result<Vec<f32>> {
233 let model_info = self
234 .model_info
235 .as_ref()
236 .ok_or_else(|| anyhow!("Model info not available"))?;
237
238 let mut rng = Random::seed(text.len() as u64);
240 use scirs2_core::random::Rng;
241
242 let mut embedding = vec![0.0f32; model_info.dimensions];
243 for value in &mut embedding {
244 *value = rng.gen_range(-1.0..1.0);
245 }
246
247 let norm = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
249 if norm > 0.0 {
250 for x in &mut embedding {
251 *x /= norm;
252 }
253 }
254
255 Ok(embedding)
256 }
257
258 pub fn get_model_info(&self) -> Option<&TensorFlowModelInfo> {
260 self.model_info.as_ref()
261 }
262
263 pub fn get_dimensions(&self) -> Option<usize> {
265 self.model_info.as_ref().map(|info| info.dimensions)
266 }
267
268 pub fn set_preprocessing_pipeline(&mut self, pipeline: PreprocessingPipeline) {
270 self.preprocessing_pipeline = pipeline;
271 }
272}
273
274#[derive(Debug)]
276pub struct TensorFlowModelServer {
277 models: HashMap<String, TensorFlowEmbedder>,
278 default_model: String,
279 server_config: ServerConfig,
280}
281
282#[derive(Debug, Clone)]
284pub struct ServerConfig {
285 pub model_warming: bool,
286 pub request_batching: bool,
287 pub max_batch_size: usize,
288 pub batch_timeout_ms: u64,
289 pub model_versions: HashMap<String, String>,
290}
291
292impl Default for ServerConfig {
293 fn default() -> Self {
294 Self {
295 model_warming: true,
296 request_batching: true,
297 max_batch_size: 64,
298 batch_timeout_ms: 10,
299 model_versions: HashMap::new(),
300 }
301 }
302}
303
304impl TensorFlowModelServer {
305 pub fn new(default_model: String, config: ServerConfig) -> Self {
307 Self {
308 models: HashMap::new(),
309 default_model,
310 server_config: config,
311 }
312 }
313
314 pub fn register_model(&mut self, name: String, embedder: TensorFlowEmbedder) -> Result<()> {
316 self.models.insert(name.clone(), embedder);
317
318 if self.server_config.model_warming {
319 if let Some(model) = self.models.get(&name) {
320 let _ = model.embed_text("warmup text");
322 }
323 }
324
325 Ok(())
326 }
327
328 pub fn list_models(&self) -> Vec<String> {
330 self.models.keys().cloned().collect()
331 }
332
333 pub fn embed_with_model(&self, model_name: &str, texts: &[String]) -> Result<Vec<Vector>> {
335 let model = self
336 .models
337 .get(model_name)
338 .ok_or_else(|| anyhow!("Model not found: {}", model_name))?;
339
340 if self.server_config.request_batching && texts.len() > 1 {
341 model.embed_batch(texts)
342 } else {
343 let mut results = Vec::new();
344 for text in texts {
345 results.push(model.embed_text(text)?);
346 }
347 Ok(results)
348 }
349 }
350
351 pub fn embed(&self, texts: &[String]) -> Result<Vec<Vector>> {
353 self.embed_with_model(&self.default_model, texts)
354 }
355
356 pub fn get_model_info(&self, model_name: &str) -> Option<&TensorFlowModelInfo> {
358 self.models.get(model_name)?.get_model_info()
359 }
360
361 pub fn update_config(&mut self, config: ServerConfig) {
363 self.server_config = config;
364 }
365}
366
367impl EmbeddingGenerator for TensorFlowEmbedder {
368 fn generate_embedding(&self, content: &ContentItem) -> Result<Vector> {
369 self.embed_text(&content.content)
370 }
371
372 fn generate_batch_embeddings(&self, content: &[ContentItem]) -> Result<Vec<ProcessingResult>> {
373 let mut results = Vec::new();
374
375 for item in content {
376 let start_time = Instant::now();
377 let vector_result = self.generate_embedding(item);
378 let duration = start_time.elapsed();
379
380 let result = match vector_result {
381 Ok(vector) => ProcessingResult {
382 item: item.clone(),
383 vector: Some(vector),
384 status: ProcessingStatus::Completed,
385 duration,
386 error: None,
387 metadata: HashMap::new(),
388 },
389 Err(e) => ProcessingResult {
390 item: item.clone(),
391 vector: None,
392 status: ProcessingStatus::Failed {
393 reason: e.to_string(),
394 },
395 duration,
396 error: Some(e.to_string()),
397 metadata: HashMap::new(),
398 },
399 };
400
401 results.push(result);
402 }
403
404 Ok(results)
405 }
406
407 fn embedding_dimensions(&self) -> usize {
408 self.get_dimensions().unwrap_or(512)
409 }
410
411 fn get_config(&self) -> serde_json::Value {
412 serde_json::to_value(&self.config).unwrap_or_default()
413 }
414
415 fn is_ready(&self) -> bool {
416 self.session_initialized
417 }
418
419 fn get_statistics(&self) -> GeneratorStatistics {
420 GeneratorStatistics {
421 total_embeddings: 0,
422 total_processing_time: Duration::from_millis(0),
423 average_processing_time: Duration::from_millis(0),
424 error_count: 0,
425 last_error: None,
426 }
427 }
428}
429
430#[cfg(test)]
431#[allow(unused_imports, clippy::useless_vec)]
432mod tests {
433 use super::*;
434 use std::path::PathBuf;
435
436 #[test]
437 fn test_tensorflow_config_creation() {
438 let config = TensorFlowConfig::default();
439 assert_eq!(config.batch_size, 32);
440 assert_eq!(config.max_sequence_length, 512);
441 assert!(matches!(config.device, TensorFlowDevice::Cpu { .. }));
442 }
443
444 #[test]
445 fn test_tensorflow_embedder_creation() {
446 let config = TensorFlowConfig::default();
447 let embedder = TensorFlowEmbedder::new(config);
448 assert!(embedder.is_ok());
449 }
450
451 #[test]
452 fn test_preprocessing_pipeline() {
453 let mut embedder = TensorFlowEmbedder::new(TensorFlowConfig::default()).unwrap();
454 let pipeline = PreprocessingPipeline {
455 lowercase: true,
456 remove_punctuation: true,
457 ..Default::default()
458 };
459 embedder.set_preprocessing_pipeline(pipeline);
460
461 let processed = embedder.preprocess_text("Hello, World!").unwrap();
462 assert_eq!(processed, "hello world");
463 }
464
465 #[test]
466 fn test_model_server_creation() {
467 let server = TensorFlowModelServer::new("default".to_string(), ServerConfig::default());
468 assert_eq!(server.default_model, "default");
469 assert!(server.list_models().is_empty());
470 }
471
472 #[test]
473 fn test_model_registration() {
474 let mut server =
475 TensorFlowModelServer::new("test_model".to_string(), ServerConfig::default());
476
477 let config = TensorFlowConfig::default();
478 let embedder = TensorFlowEmbedder::new(config).unwrap();
479
480 let result = server.register_model("test_model".to_string(), embedder);
481 assert!(result.is_ok());
482 assert_eq!(server.list_models().len(), 1);
483 }
484
485 #[test]
486 fn test_tensor_spec_creation() {
487 let spec = TensorSpec {
488 name: "input".to_string(),
489 dtype: TensorDataType::Float32,
490 shape: vec![None, Some(512)],
491 };
492 assert_eq!(spec.name, "input");
493 assert!(matches!(spec.dtype, TensorDataType::Float32));
494 }
495
496 #[test]
497 fn test_session_config_default() {
498 let config = SessionConfig::default();
499 assert!(config.allow_soft_placement);
500 assert!(!config.log_device_placement);
501 assert!(config.inter_op_parallelism_threads.is_none());
502 }
503
504 #[test]
505 fn test_device_configuration() {
506 let cpu_device = TensorFlowDevice::Cpu {
507 num_threads: Some(4),
508 };
509 let gpu_device = TensorFlowDevice::Gpu {
510 device_id: 0,
511 memory_growth: true,
512 };
513
514 assert!(matches!(cpu_device, TensorFlowDevice::Cpu { .. }));
515 assert!(matches!(gpu_device, TensorFlowDevice::Gpu { .. }));
516 }
517
518 #[test]
519 fn test_optimization_levels() {
520 let levels = vec![
521 OptimizationLevel::None,
522 OptimizationLevel::Basic,
523 OptimizationLevel::Extended,
524 OptimizationLevel::Aggressive,
525 ];
526 assert_eq!(levels.len(), 4);
527 }
528}