1use std::collections::{HashMap, VecDeque};
2use std::fmt;
3use std::io::{BufRead, BufReader, Write};
4use std::os::unix::net::UnixStream;
5use std::path::Path;
6use std::process::Child;
7use std::sync::atomic::{AtomicU32, Ordering};
8use std::time::{Duration, Instant};
9
10use crate::error::EimError;
11use crate::inference::messages::{
12 ClassifyMessage, ErrorResponse, HelloMessage, InferenceResponse, InferenceResult, ModelInfo,
13 SetThresholdMessage, SetThresholdResponse, ThresholdConfig,
14};
15use crate::types::{ModelParameters, ModelThreshold, SensorType, VisualAnomalyResult};
16
17pub type DebugCallback = Box<dyn Fn(&str) + Send + Sync>;
19
20pub struct EimModel {
102 path: std::path::PathBuf,
104 socket_path: std::path::PathBuf,
106 socket: UnixStream,
108 debug: bool,
110 debug_callback: Option<DebugCallback>,
112 _process: Child,
114 model_info: Option<ModelInfo>,
116 message_id: AtomicU32,
118 #[allow(dead_code)]
120 child: Option<Child>,
121 continuous_state: Option<ContinuousState>,
122 model_parameters: ModelParameters,
123}
124
125#[derive(Debug)]
126struct ContinuousState {
127 feature_matrix: Vec<f32>,
128 feature_buffer_full: bool,
129 maf_buffers: HashMap<String, MovingAverageFilter>,
130 slice_size: usize,
131}
132
133impl ContinuousState {
134 fn new(labels: Vec<String>, slice_size: usize) -> Self {
135 Self {
136 feature_matrix: Vec::new(),
137 feature_buffer_full: false,
138 maf_buffers: labels
139 .into_iter()
140 .map(|label| (label, MovingAverageFilter::new(4)))
141 .collect(),
142 slice_size,
143 }
144 }
145
146 fn update_features(&mut self, features: &[f32]) {
147 self.feature_matrix.extend_from_slice(features);
149
150 if self.feature_matrix.len() >= self.slice_size {
152 self.feature_buffer_full = true;
153 if self.feature_matrix.len() > self.slice_size {
155 self.feature_matrix
156 .drain(0..self.feature_matrix.len() - self.slice_size);
157 }
158 }
159 }
160
161 fn apply_maf(&mut self, classification: &mut HashMap<String, f32>) {
162 for (label, value) in classification.iter_mut() {
163 if let Some(maf) = self.maf_buffers.get_mut(label) {
164 *value = maf.update(*value);
165 }
166 }
167 }
168}
169
170#[derive(Debug)]
171struct MovingAverageFilter {
172 buffer: VecDeque<f32>,
173 window_size: usize,
174 sum: f32,
175}
176
177impl MovingAverageFilter {
178 fn new(window_size: usize) -> Self {
179 Self {
180 buffer: VecDeque::with_capacity(window_size),
181 window_size,
182 sum: 0.0,
183 }
184 }
185
186 fn update(&mut self, value: f32) -> f32 {
187 if self.buffer.len() >= self.window_size {
188 self.sum -= self.buffer.pop_front().unwrap();
189 }
190 self.buffer.push_back(value);
191 self.sum += value;
192 self.sum / self.buffer.len() as f32
193 }
194}
195
196impl fmt::Debug for EimModel {
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 f.debug_struct("EimModel")
199 .field("path", &self.path)
200 .field("socket_path", &self.socket_path)
201 .field("socket", &self.socket)
202 .field("debug", &self.debug)
203 .field("_process", &self._process)
204 .field("model_info", &self.model_info)
205 .field("message_id", &self.message_id)
206 .field("child", &self.child)
207 .field("continuous_state", &self.continuous_state)
209 .field("model_parameters", &self.model_parameters)
210 .finish()
211 }
212}
213
214impl EimModel {
215 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, EimError> {
241 Self::new_with_debug(path, false)
242 }
243
244 pub fn new_with_socket<P: AsRef<Path>, S: AsRef<Path>>(
255 path: P,
256 socket_path: S,
257 ) -> Result<Self, EimError> {
258 Self::new_with_socket_and_debug(path, socket_path, false)
259 }
260
261 pub fn new_with_debug<P: AsRef<Path>>(path: P, debug: bool) -> Result<Self, EimError> {
263 let socket_path = std::env::temp_dir().join("eim_socket");
264 Self::new_with_socket_and_debug(path, &socket_path, debug)
265 }
266
267 fn ensure_executable<P: AsRef<Path>>(path: P) -> Result<(), EimError> {
269 use std::os::unix::fs::PermissionsExt;
270
271 let path = path.as_ref();
272 let metadata = std::fs::metadata(path)
273 .map_err(|e| EimError::ExecutionError(format!("Failed to get file metadata: {}", e)))?;
274
275 let perms = metadata.permissions();
276 let current_mode = perms.mode();
277 if current_mode & 0o100 == 0 {
278 let mut new_perms = perms;
280 new_perms.set_mode(current_mode | 0o100); std::fs::set_permissions(path, new_perms).map_err(|e| {
282 EimError::ExecutionError(format!("Failed to set executable permissions: {}", e))
283 })?;
284 }
285 Ok(())
286 }
287
288 pub fn new_with_socket_and_debug<P: AsRef<Path>, S: AsRef<Path>>(
290 path: P,
291 socket_path: S,
292 debug: bool,
293 ) -> Result<Self, EimError> {
294 let path = path.as_ref();
295 let socket_path = socket_path.as_ref();
296
297 if path.extension().and_then(|s| s.to_str()) != Some("eim") {
299 return Err(EimError::InvalidPath);
300 }
301
302 let absolute_path = if path.is_absolute() {
304 path.to_path_buf()
305 } else {
306 std::env::current_dir()
307 .map_err(|_e| EimError::InvalidPath)?
308 .join(path)
309 };
310
311 Self::ensure_executable(&absolute_path)?;
313
314 let process = std::process::Command::new(&absolute_path)
316 .arg(socket_path)
317 .spawn()
318 .map_err(|e| EimError::ExecutionError(e.to_string()))?;
319
320 let socket = Self::connect_with_retry(socket_path, Duration::from_secs(5))?;
321
322 let mut model = Self {
323 path: absolute_path, socket_path: socket_path.to_path_buf(),
325 socket,
326 debug,
327 _process: process,
328 model_info: None,
329 message_id: AtomicU32::new(1),
330 child: None,
331 debug_callback: None,
332 continuous_state: None,
333 model_parameters: ModelParameters::default(),
334 };
335
336 model.send_hello()?;
338
339 Ok(model)
340 }
341
342 fn connect_with_retry(socket_path: &Path, timeout: Duration) -> Result<UnixStream, EimError> {
354 let start = Instant::now();
355 let retry_interval = Duration::from_millis(50);
356
357 while start.elapsed() < timeout {
358 match UnixStream::connect(socket_path) {
359 Ok(stream) => return Ok(stream),
360 Err(e) => {
361 if e.kind() != std::io::ErrorKind::NotFound
364 && e.kind() != std::io::ErrorKind::ConnectionRefused
365 {
366 return Err(EimError::SocketError(format!(
367 "Failed to connect to socket: {}",
368 e
369 )));
370 }
371 }
372 }
373 std::thread::sleep(retry_interval);
374 }
375
376 Err(EimError::SocketError(format!(
377 "Timeout waiting for socket {} to become available",
378 socket_path.display()
379 )))
380 }
381
382 fn next_message_id(&self) -> u32 {
384 self.message_id.fetch_add(1, Ordering::Relaxed)
385 }
386
387 pub fn set_debug_callback<F>(&mut self, callback: F)
397 where
398 F: Fn(&str) + Send + Sync + 'static,
399 {
400 self.debug_callback = Some(Box::new(callback));
401 }
402
403 fn debug_message(&self, message: &str) {
405 if self.debug {
406 println!("{}", message);
407 if let Some(callback) = &self.debug_callback {
408 callback(message);
409 }
410 }
411 }
412
413 fn send_hello(&mut self) -> Result<(), EimError> {
414 let hello_msg = HelloMessage {
415 hello: 1,
416 id: self.next_message_id(),
417 };
418
419 let msg = serde_json::to_string(&hello_msg)?;
420 self.debug_message(&format!("Sending hello message: {}", msg));
421
422 writeln!(self.socket, "{}", msg).map_err(|e| {
423 self.debug_message(&format!("Failed to send hello: {}", e));
424 EimError::SocketError(format!("Failed to send hello message: {}", e))
425 })?;
426
427 self.socket.flush().map_err(|e| {
428 self.debug_message(&format!("Failed to flush hello: {}", e));
429 EimError::SocketError(format!("Failed to flush socket: {}", e))
430 })?;
431
432 self.debug_message("Waiting for hello response...");
433
434 let mut reader = BufReader::new(&self.socket);
435 let mut line = String::new();
436
437 match reader.read_line(&mut line) {
438 Ok(n) => {
439 self.debug_message(&format!("Read {} bytes: {}", n, line));
440
441 match serde_json::from_str::<ModelInfo>(&line) {
442 Ok(info) => {
443 self.debug_message("Successfully parsed model info");
444 if !info.success {
445 self.debug_message("Model initialization failed");
446 return Err(EimError::ExecutionError(
447 "Model initialization failed".to_string(),
448 ));
449 }
450 self.debug_message("Got model info response, storing it");
451 self.model_info = Some(info);
452 return Ok(());
453 }
454 Err(e) => {
455 self.debug_message(&format!("Failed to parse model info: {}", e));
456 if let Ok(error) = serde_json::from_str::<ErrorResponse>(&line) {
457 if !error.success {
458 self.debug_message(&format!("Got error response: {:?}", error));
459 return Err(EimError::ExecutionError(
460 error.error.unwrap_or_else(|| "Unknown error".to_string()),
461 ));
462 }
463 }
464 }
465 }
466 }
467 Err(e) => {
468 self.debug_message(&format!("Failed to read hello response: {}", e));
469 return Err(EimError::SocketError(format!(
470 "Failed to read response: {}",
471 e
472 )));
473 }
474 }
475
476 self.debug_message("No valid hello response received");
477 Err(EimError::SocketError(
478 "No valid response received".to_string(),
479 ))
480 }
481
482 pub fn path(&self) -> &Path {
484 &self.path
485 }
486
487 pub fn socket_path(&self) -> &Path {
489 &self.socket_path
490 }
491
492 pub fn sensor_type(&self) -> Result<SensorType, EimError> {
494 self.model_info
495 .as_ref()
496 .map(|info| SensorType::from(info.model_parameters.sensor))
497 .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
498 }
499
500 pub fn parameters(&self) -> Result<&ModelParameters, EimError> {
502 self.model_info
503 .as_ref()
504 .map(|info| &info.model_parameters)
505 .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
506 }
507
508 pub fn infer(
532 &mut self,
533 features: Vec<f32>,
534 debug: Option<bool>,
535 ) -> Result<InferenceResponse, EimError> {
536 if self.model_info.is_none() {
538 self.send_hello()?;
539 }
540
541 let uses_continuous_mode = self.requires_continuous_mode();
542
543 if uses_continuous_mode {
544 self.infer_continuous_internal(features, debug)
545 } else {
546 self.infer_single(features, debug)
547 }
548 }
549
550 fn infer_continuous_internal(
551 &mut self,
552 features: Vec<f32>,
553 debug: Option<bool>,
554 ) -> Result<InferenceResponse, EimError> {
555 if self.continuous_state.is_none() {
557 let labels = self
558 .model_info
559 .as_ref()
560 .map(|info| info.model_parameters.labels.clone())
561 .unwrap_or_default();
562 let slice_size = self.input_size()?;
563
564 self.continuous_state = Some(ContinuousState::new(labels, slice_size));
565 }
566
567 let mut state = self.continuous_state.take().unwrap();
569 state.update_features(&features);
570
571 let response = if !state.feature_buffer_full {
572 Ok(InferenceResponse {
574 success: true,
575 id: self.next_message_id(),
576 result: InferenceResult::Classification {
577 classification: HashMap::new(),
578 },
579 })
580 } else {
581 let mut response = self.infer_single(state.feature_matrix.clone(), debug)?;
583
584 if let InferenceResult::Classification {
586 ref mut classification,
587 } = response.result
588 {
589 state.apply_maf(classification);
590 }
591
592 Ok(response)
593 };
594
595 self.continuous_state = Some(state);
597
598 response
599 }
600
601 fn infer_single(
602 &mut self,
603 features: Vec<f32>,
604 debug: Option<bool>,
605 ) -> Result<InferenceResponse, EimError> {
606 if self.model_info.is_none() {
608 self.debug_message("No model info, sending hello message...");
609 self.send_hello()?;
610 self.debug_message("Hello handshake completed");
611 }
612
613 let msg = ClassifyMessage {
614 classify: features.clone(),
615 id: self.next_message_id(),
616 debug,
617 };
618
619 let msg_str = serde_json::to_string(&msg)?;
620 self.debug_message(&format!(
621 "Sending inference message with {} features",
622 features.len()
623 ));
624
625 writeln!(self.socket, "{}", msg_str).map_err(|e| {
626 self.debug_message(&format!("Failed to send inference message: {}", e));
627 EimError::SocketError(format!("Failed to send inference message: {}", e))
628 })?;
629
630 self.socket.flush().map_err(|e| {
631 self.debug_message(&format!("Failed to flush inference message: {}", e));
632 EimError::SocketError(format!("Failed to flush socket: {}", e))
633 })?;
634
635 self.debug_message("Inference message sent, waiting for response...");
636
637 self.socket.set_nonblocking(true).map_err(|e| {
639 self.debug_message(&format!("Failed to set non-blocking mode: {}", e));
640 EimError::SocketError(format!("Failed to set non-blocking mode: {}", e))
641 })?;
642
643 let mut reader = BufReader::new(&self.socket);
644 let mut buffer = String::new();
645 let start = Instant::now();
646 let timeout = Duration::from_secs(5);
647
648 while start.elapsed() < timeout {
649 match reader.read_line(&mut buffer) {
650 Ok(0) => {
651 self.debug_message("EOF reached");
652 break;
653 }
654 Ok(n) => {
655 if !buffer.contains("features:") && !buffer.contains("Features (") {
657 self.debug_message(&format!("Read {} bytes: {}", n, buffer));
658 }
659
660 if let Ok(response) = serde_json::from_str::<InferenceResponse>(&buffer) {
661 if response.success {
662 self.debug_message("Got successful inference response");
663 let _ = self.socket.set_nonblocking(false);
665 return Ok(response);
666 }
667 }
668 buffer.clear();
669 }
670 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
671 std::thread::sleep(Duration::from_millis(10));
673 continue;
674 }
675 Err(e) => {
676 self.debug_message(&format!("Read error: {}", e));
677 let _ = self.socket.set_nonblocking(false);
679 return Err(EimError::SocketError(format!("Read error: {}", e)));
680 }
681 }
682 }
683
684 let _ = self.socket.set_nonblocking(false);
686 self.debug_message("Timeout reached");
687
688 Err(EimError::ExecutionError(format!(
689 "No valid response received within {} seconds",
690 timeout.as_secs()
691 )))
692 }
693
694 fn requires_continuous_mode(&self) -> bool {
696 self.model_info
697 .as_ref()
698 .map(|info| info.model_parameters.use_continuous_mode)
699 .unwrap_or(false)
700 }
701
702 pub fn input_size(&self) -> Result<usize, EimError> {
714 self.model_info
715 .as_ref()
716 .map(|info| info.model_parameters.input_features_count as usize)
717 .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
718 }
719
720 pub async fn set_learn_block_threshold(
735 &mut self,
736 threshold: ThresholdConfig,
737 ) -> Result<(), EimError> {
738 if self.model_info.is_none() {
740 self.debug_message("No model info available, sending hello message...");
741 self.send_hello()?;
742 }
743
744 if let Some(info) = &self.model_info {
746 self.debug_message(&format!(
747 "Current model type: {}",
748 info.model_parameters.model_type
749 ));
750 self.debug_message(&format!(
751 "Current model parameters: {:?}",
752 info.model_parameters
753 ));
754 }
755
756 let msg = SetThresholdMessage {
757 set_threshold: threshold,
758 id: self.next_message_id(),
759 };
760
761 let msg_str = serde_json::to_string(&msg)?;
762 self.debug_message(&format!("Sending threshold message: {}", msg_str));
763
764 writeln!(self.socket, "{}", msg_str).map_err(|e| {
765 self.debug_message(&format!("Failed to send threshold message: {}", e));
766 EimError::SocketError(format!("Failed to send threshold message: {}", e))
767 })?;
768
769 self.socket.flush().map_err(|e| {
770 self.debug_message(&format!("Failed to flush threshold message: {}", e));
771 EimError::SocketError(format!("Failed to flush socket: {}", e))
772 })?;
773
774 let mut reader = BufReader::new(&self.socket);
775 let mut line = String::new();
776
777 match reader.read_line(&mut line) {
778 Ok(_) => {
779 self.debug_message(&format!("Received response: {}", line));
780 match serde_json::from_str::<SetThresholdResponse>(&line) {
781 Ok(response) => {
782 if response.success {
783 self.debug_message("Successfully set threshold");
784 Ok(())
785 } else {
786 self.debug_message("Server reported failure setting threshold");
787 Err(EimError::ExecutionError(
788 "Server reported failure setting threshold".to_string(),
789 ))
790 }
791 }
792 Err(e) => {
793 self.debug_message(&format!("Failed to parse threshold response: {}", e));
794 if let Ok(error) = serde_json::from_str::<ErrorResponse>(&line) {
796 Err(EimError::ExecutionError(
797 error.error.unwrap_or_else(|| "Unknown error".to_string()),
798 ))
799 } else {
800 Err(EimError::ExecutionError(format!(
801 "Invalid threshold response format: {}",
802 e
803 )))
804 }
805 }
806 }
807 }
808 Err(e) => {
809 self.debug_message(&format!("Failed to read threshold response: {}", e));
810 Err(EimError::SocketError(format!(
811 "Failed to read response: {}",
812 e
813 )))
814 }
815 }
816 }
817
818 fn get_min_anomaly_score(&self) -> f32 {
820 self.model_info
821 .as_ref()
822 .and_then(|info| {
823 info.model_parameters
824 .thresholds
825 .iter()
826 .find_map(|t| match t {
827 ModelThreshold::AnomalyGMM {
828 min_anomaly_score, ..
829 } => Some(*min_anomaly_score),
830 _ => None,
831 })
832 })
833 .unwrap_or(6.0)
834 }
835
836 fn normalize_anomaly_score(&self, score: f32) -> f32 {
838 (score / self.get_min_anomaly_score()).min(1.0)
839 }
840
841 pub fn normalize_visual_anomaly(
843 &self,
844 anomaly: f32,
845 max: f32,
846 mean: f32,
847 regions: &[(f32, u32, u32, u32, u32)],
848 ) -> VisualAnomalyResult {
849 let normalized_anomaly = self.normalize_anomaly_score(anomaly);
850 let normalized_max = self.normalize_anomaly_score(max);
851 let normalized_mean = self.normalize_anomaly_score(mean);
852 let normalized_regions: Vec<_> = regions
853 .iter()
854 .map(|(value, x, y, w, h)| (self.normalize_anomaly_score(*value), *x, *y, *w, *h))
855 .collect();
856
857 (
858 normalized_anomaly,
859 normalized_max,
860 normalized_mean,
861 normalized_regions,
862 )
863 }
864}