edge_impulse_runner/inference/
model.rs

1use std::collections::{HashMap, VecDeque};
2use std::fmt;
3use std::io::{BufRead, BufReader, Write};
4use std::os::unix::net::UnixStream;
5use std::path::Path;
6use std::process::Child;
7use std::sync::atomic::{AtomicU32, Ordering};
8use std::time::{Duration, Instant};
9
10use crate::error::EimError;
11use crate::inference::messages::{
12    ClassifyMessage, ErrorResponse, HelloMessage, InferenceResponse, InferenceResult, ModelInfo,
13    SetThresholdMessage, SetThresholdResponse, ThresholdConfig,
14};
15use crate::types::{ModelParameters, ModelThreshold, SensorType, VisualAnomalyResult};
16
17/// Debug callback type for receiving debug messages
18pub type DebugCallback = Box<dyn Fn(&str) + Send + Sync>;
19
20/// Edge Impulse Model Runner for Rust
21///
22/// This module provides functionality for running Edge Impulse machine learning models on Linux systems.
23/// It handles model lifecycle management, communication, and inference operations.
24///
25/// # Key Components
26///
27/// - `EimModel`: Main struct for managing Edge Impulse models
28/// - `SensorType`: Enum representing supported sensor input types
29/// - `ContinuousState`: Internal state management for continuous inference mode
30/// - `MovingAverageFilter`: Smoothing filter for continuous inference results
31///
32/// # Features
33///
34/// - Model process management and Unix socket communication
35/// - Support for both single-shot and continuous inference modes
36/// - Debug logging and callback system
37/// - Moving average filtering for continuous mode results
38/// - Automatic retry mechanisms for socket connections
39/// - Visual anomaly detection (FOMO AD) support with normalized scores
40///
41/// # Example Usage
42///
43/// ```no_run
44/// use edge_impulse_runner::{EimModel, InferenceResult};
45///
46/// // Create a new model instance
47/// let mut model = EimModel::new("path/to/model.eim").unwrap();
48///
49/// // Run inference with some features
50/// let features = vec![0.1, 0.2, 0.3];
51/// let result = model.infer(features, None).unwrap();
52///
53/// // For visual anomaly detection models, normalize the results
54/// if let InferenceResult::VisualAnomaly { anomaly, visual_anomaly_max, visual_anomaly_mean, visual_anomaly_grid } = result.result {
55///     let (normalized_anomaly, normalized_max, normalized_mean, normalized_regions) =
56///         model.normalize_visual_anomaly(
57///             anomaly,
58///             visual_anomaly_max,
59///             visual_anomaly_mean,
60///             &visual_anomaly_grid.iter()
61///                 .map(|bbox| (bbox.value, bbox.x as u32, bbox.y as u32, bbox.width as u32, bbox.height as u32))
62///                 .collect::<Vec<_>>()
63///         );
64///     println!("Anomaly score: {:.2}%", normalized_anomaly * 100.0);
65/// }
66/// ```
67///
68/// # Communication Protocol
69///
70/// The model communicates with the Edge Impulse process using JSON messages over Unix sockets:
71/// 1. Hello message for initialization
72/// 2. Model info response
73/// 3. Classification requests
74/// 4. Inference responses
75///
76/// # Error Handling
77///
78/// The module uses a custom `EimError` type for error handling, covering:
79/// - Invalid file paths
80/// - Socket communication errors
81/// - Model execution errors
82/// - JSON serialization/deserialization errors
83///
84/// # Visual Anomaly Detection
85///
86/// For visual anomaly detection models (FOMO AD):
87/// - Scores are normalized relative to the model's minimum anomaly threshold
88/// - Results include overall anomaly score, maximum score, mean score, and anomalous regions
89/// - Region coordinates are provided in the original image dimensions
90/// - All scores are clamped to [0,1] range and displayed as percentages
91/// - Debug mode provides detailed information about thresholds and regions
92///
93/// # Threshold Configuration
94///
95/// Models can be configured with different thresholds:
96/// - Anomaly detection: `min_anomaly_score` threshold for visual anomaly detection
97/// - Object detection: `min_score` threshold for object confidence
98/// - Object tracking: `keep_grace`, `max_observations`, and `threshold` parameters
99///
100/// Thresholds can be updated at runtime using `set_learn_block_threshold`.
101pub struct EimModel {
102    /// Path to the Edge Impulse model file (.eim)
103    path: std::path::PathBuf,
104    /// Path to the Unix socket used for IPC
105    socket_path: std::path::PathBuf,
106    /// Active Unix socket connection to the model process
107    socket: UnixStream,
108    /// Enable debug logging of socket communications
109    debug: bool,
110    /// Optional debug callback for receiving debug messages
111    debug_callback: Option<DebugCallback>,
112    /// Handle to the model process (kept alive while model exists)
113    _process: Child,
114    /// Cached model information received during initialization
115    model_info: Option<ModelInfo>,
116    /// Atomic counter for generating unique message IDs
117    message_id: AtomicU32,
118    /// Optional child process handle for restart functionality
119    #[allow(dead_code)]
120    child: Option<Child>,
121    continuous_state: Option<ContinuousState>,
122    model_parameters: ModelParameters,
123}
124
125#[derive(Debug)]
126struct ContinuousState {
127    feature_matrix: Vec<f32>,
128    feature_buffer_full: bool,
129    maf_buffers: HashMap<String, MovingAverageFilter>,
130    slice_size: usize,
131}
132
133impl ContinuousState {
134    fn new(labels: Vec<String>, slice_size: usize) -> Self {
135        Self {
136            feature_matrix: Vec::new(),
137            feature_buffer_full: false,
138            maf_buffers: labels
139                .into_iter()
140                .map(|label| (label, MovingAverageFilter::new(4)))
141                .collect(),
142            slice_size,
143        }
144    }
145
146    fn update_features(&mut self, features: &[f32]) {
147        // Add new features to the matrix
148        self.feature_matrix.extend_from_slice(features);
149
150        // Check if buffer is full
151        if self.feature_matrix.len() >= self.slice_size {
152            self.feature_buffer_full = true;
153            // Keep only the most recent features if we've exceeded the buffer size
154            if self.feature_matrix.len() > self.slice_size {
155                self.feature_matrix
156                    .drain(0..self.feature_matrix.len() - self.slice_size);
157            }
158        }
159    }
160
161    fn apply_maf(&mut self, classification: &mut HashMap<String, f32>) {
162        for (label, value) in classification.iter_mut() {
163            if let Some(maf) = self.maf_buffers.get_mut(label) {
164                *value = maf.update(*value);
165            }
166        }
167    }
168}
169
170#[derive(Debug)]
171struct MovingAverageFilter {
172    buffer: VecDeque<f32>,
173    window_size: usize,
174    sum: f32,
175}
176
177impl MovingAverageFilter {
178    fn new(window_size: usize) -> Self {
179        Self {
180            buffer: VecDeque::with_capacity(window_size),
181            window_size,
182            sum: 0.0,
183        }
184    }
185
186    fn update(&mut self, value: f32) -> f32 {
187        if self.buffer.len() >= self.window_size {
188            self.sum -= self.buffer.pop_front().unwrap();
189        }
190        self.buffer.push_back(value);
191        self.sum += value;
192        self.sum / self.buffer.len() as f32
193    }
194}
195
196impl fmt::Debug for EimModel {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        f.debug_struct("EimModel")
199            .field("path", &self.path)
200            .field("socket_path", &self.socket_path)
201            .field("socket", &self.socket)
202            .field("debug", &self.debug)
203            .field("_process", &self._process)
204            .field("model_info", &self.model_info)
205            .field("message_id", &self.message_id)
206            .field("child", &self.child)
207            // Skip debug_callback field as it doesn't implement Debug
208            .field("continuous_state", &self.continuous_state)
209            .field("model_parameters", &self.model_parameters)
210            .finish()
211    }
212}
213
214impl EimModel {
215    /// Creates a new EimModel instance from a path to the .eim file.
216    ///
217    /// This is the standard way to create a new model instance. The function will:
218    /// 1. Validate the file extension
219    /// 2. Spawn the model process
220    /// 3. Establish socket communication
221    /// 4. Initialize the model
222    ///
223    /// # Arguments
224    ///
225    /// * `path` - Path to the .eim file. Must be a valid Edge Impulse model file.
226    ///
227    /// # Returns
228    ///
229    /// Returns `Result<EimModel, EimError>` where:
230    /// - `Ok(EimModel)` - Successfully created and initialized model
231    /// - `Err(EimError)` - Failed to create model (invalid path, process spawn failure, etc.)
232    ///
233    /// # Examples
234    ///
235    /// ```no_run
236    /// use edge_impulse_runner::EimModel;
237    ///
238    /// let model = EimModel::new("path/to/model.eim").unwrap();
239    /// ```
240    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, EimError> {
241        Self::new_with_debug(path, false)
242    }
243
244    /// Creates a new EimModel instance with a specific Unix socket path.
245    ///
246    /// Similar to `new()`, but allows specifying the socket path for communication.
247    /// This is useful when you need control over the socket location or when running
248    /// multiple models simultaneously.
249    ///
250    /// # Arguments
251    ///
252    /// * `path` - Path to the .eim file
253    /// * `socket_path` - Custom path where the Unix socket should be created
254    pub fn new_with_socket<P: AsRef<Path>, S: AsRef<Path>>(
255        path: P,
256        socket_path: S,
257    ) -> Result<Self, EimError> {
258        Self::new_with_socket_and_debug(path, socket_path, false)
259    }
260
261    /// Create a new EimModel instance with debug output enabled
262    pub fn new_with_debug<P: AsRef<Path>>(path: P, debug: bool) -> Result<Self, EimError> {
263        let socket_path = std::env::temp_dir().join("eim_socket");
264        Self::new_with_socket_and_debug(path, &socket_path, debug)
265    }
266
267    /// Ensure the model file has execution permissions for the current user
268    fn ensure_executable<P: AsRef<Path>>(path: P) -> Result<(), EimError> {
269        use std::os::unix::fs::PermissionsExt;
270
271        let path = path.as_ref();
272        let metadata = std::fs::metadata(path)
273            .map_err(|e| EimError::ExecutionError(format!("Failed to get file metadata: {}", e)))?;
274
275        let perms = metadata.permissions();
276        let current_mode = perms.mode();
277        if current_mode & 0o100 == 0 {
278            // File is not executable for user, try to make it executable
279            let mut new_perms = perms;
280            new_perms.set_mode(current_mode | 0o100); // Add executable bit for user only
281            std::fs::set_permissions(path, new_perms).map_err(|e| {
282                EimError::ExecutionError(format!("Failed to set executable permissions: {}", e))
283            })?;
284        }
285        Ok(())
286    }
287
288    /// Create a new EimModel instance with debug output enabled and a specific socket path
289    pub fn new_with_socket_and_debug<P: AsRef<Path>, S: AsRef<Path>>(
290        path: P,
291        socket_path: S,
292        debug: bool,
293    ) -> Result<Self, EimError> {
294        let path = path.as_ref();
295        let socket_path = socket_path.as_ref();
296
297        // Validate file extension
298        if path.extension().and_then(|s| s.to_str()) != Some("eim") {
299            return Err(EimError::InvalidPath);
300        }
301
302        // Convert relative path to absolute path
303        let absolute_path = if path.is_absolute() {
304            path.to_path_buf()
305        } else {
306            std::env::current_dir()
307                .map_err(|_e| EimError::InvalidPath)?
308                .join(path)
309        };
310
311        // Ensure the model file is executable
312        Self::ensure_executable(&absolute_path)?;
313
314        // Start the process
315        let process = std::process::Command::new(&absolute_path)
316            .arg(socket_path)
317            .spawn()
318            .map_err(|e| EimError::ExecutionError(e.to_string()))?;
319
320        let socket = Self::connect_with_retry(socket_path, Duration::from_secs(5))?;
321
322        let mut model = Self {
323            path: absolute_path, // Store the absolute path
324            socket_path: socket_path.to_path_buf(),
325            socket,
326            debug,
327            _process: process,
328            model_info: None,
329            message_id: AtomicU32::new(1),
330            child: None,
331            debug_callback: None,
332            continuous_state: None,
333            model_parameters: ModelParameters::default(),
334        };
335
336        // Initialize the model by sending hello message
337        model.send_hello()?;
338
339        Ok(model)
340    }
341
342    /// Attempts to connect to the Unix socket with a retry mechanism
343    ///
344    /// This function will repeatedly try to connect to the socket until either:
345    /// - A successful connection is established
346    /// - An unexpected error occurs
347    /// - The timeout duration is exceeded
348    ///
349    /// # Arguments
350    ///
351    /// * `socket_path` - Path to the Unix socket
352    /// * `timeout` - Maximum time to wait for connection
353    fn connect_with_retry(socket_path: &Path, timeout: Duration) -> Result<UnixStream, EimError> {
354        let start = Instant::now();
355        let retry_interval = Duration::from_millis(50);
356
357        while start.elapsed() < timeout {
358            match UnixStream::connect(socket_path) {
359                Ok(stream) => return Ok(stream),
360                Err(e) => {
361                    // NotFound and ConnectionRefused are expected errors while the socket
362                    // is being created, so we retry in these cases
363                    if e.kind() != std::io::ErrorKind::NotFound
364                        && e.kind() != std::io::ErrorKind::ConnectionRefused
365                    {
366                        return Err(EimError::SocketError(format!(
367                            "Failed to connect to socket: {}",
368                            e
369                        )));
370                    }
371                }
372            }
373            std::thread::sleep(retry_interval);
374        }
375
376        Err(EimError::SocketError(format!(
377            "Timeout waiting for socket {} to become available",
378            socket_path.display()
379        )))
380    }
381
382    /// Get the next message ID
383    fn next_message_id(&self) -> u32 {
384        self.message_id.fetch_add(1, Ordering::Relaxed)
385    }
386
387    /// Set a debug callback function to receive debug messages
388    ///
389    /// When debug mode is enabled, this callback will be invoked with debug messages
390    /// from the model runner. This is useful for logging or displaying debug information
391    /// in your application.
392    ///
393    /// # Arguments
394    ///
395    /// * `callback` - Function that takes a string slice and handles the debug message
396    pub fn set_debug_callback<F>(&mut self, callback: F)
397    where
398        F: Fn(&str) + Send + Sync + 'static,
399    {
400        self.debug_callback = Some(Box::new(callback));
401    }
402
403    /// Send debug messages when debug mode is enabled
404    fn debug_message(&self, message: &str) {
405        if self.debug {
406            println!("{}", message);
407            if let Some(callback) = &self.debug_callback {
408                callback(message);
409            }
410        }
411    }
412
413    fn send_hello(&mut self) -> Result<(), EimError> {
414        let hello_msg = HelloMessage {
415            hello: 1,
416            id: self.next_message_id(),
417        };
418
419        let msg = serde_json::to_string(&hello_msg)?;
420        self.debug_message(&format!("Sending hello message: {}", msg));
421
422        writeln!(self.socket, "{}", msg).map_err(|e| {
423            self.debug_message(&format!("Failed to send hello: {}", e));
424            EimError::SocketError(format!("Failed to send hello message: {}", e))
425        })?;
426
427        self.socket.flush().map_err(|e| {
428            self.debug_message(&format!("Failed to flush hello: {}", e));
429            EimError::SocketError(format!("Failed to flush socket: {}", e))
430        })?;
431
432        self.debug_message("Waiting for hello response...");
433
434        let mut reader = BufReader::new(&self.socket);
435        let mut line = String::new();
436
437        match reader.read_line(&mut line) {
438            Ok(n) => {
439                self.debug_message(&format!("Read {} bytes: {}", n, line));
440
441                match serde_json::from_str::<ModelInfo>(&line) {
442                    Ok(info) => {
443                        self.debug_message("Successfully parsed model info");
444                        if !info.success {
445                            self.debug_message("Model initialization failed");
446                            return Err(EimError::ExecutionError(
447                                "Model initialization failed".to_string(),
448                            ));
449                        }
450                        self.debug_message("Got model info response, storing it");
451                        self.model_info = Some(info);
452                        return Ok(());
453                    }
454                    Err(e) => {
455                        self.debug_message(&format!("Failed to parse model info: {}", e));
456                        if let Ok(error) = serde_json::from_str::<ErrorResponse>(&line) {
457                            if !error.success {
458                                self.debug_message(&format!("Got error response: {:?}", error));
459                                return Err(EimError::ExecutionError(
460                                    error.error.unwrap_or_else(|| "Unknown error".to_string()),
461                                ));
462                            }
463                        }
464                    }
465                }
466            }
467            Err(e) => {
468                self.debug_message(&format!("Failed to read hello response: {}", e));
469                return Err(EimError::SocketError(format!(
470                    "Failed to read response: {}",
471                    e
472                )));
473            }
474        }
475
476        self.debug_message("No valid hello response received");
477        Err(EimError::SocketError(
478            "No valid response received".to_string(),
479        ))
480    }
481
482    /// Get the path to the EIM file
483    pub fn path(&self) -> &Path {
484        &self.path
485    }
486
487    /// Get the socket path used for communication
488    pub fn socket_path(&self) -> &Path {
489        &self.socket_path
490    }
491
492    /// Get the sensor type for this model
493    pub fn sensor_type(&self) -> Result<SensorType, EimError> {
494        self.model_info
495            .as_ref()
496            .map(|info| SensorType::from(info.model_parameters.sensor))
497            .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
498    }
499
500    /// Get the model parameters
501    pub fn parameters(&self) -> Result<&ModelParameters, EimError> {
502        self.model_info
503            .as_ref()
504            .map(|info| &info.model_parameters)
505            .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
506    }
507
508    /// Run inference on the input features
509    ///
510    /// This method automatically handles both continuous and non-continuous modes:
511    ///
512    /// ## Non-Continuous Mode
513    /// - Each call is independent
514    /// - All features must be provided in a single call
515    /// - Results are returned immediately
516    ///
517    /// ## Continuous Mode (automatically enabled for supported models)
518    /// - Features are accumulated across calls
519    /// - Internal buffer maintains sliding window of features
520    /// - Moving average filter smooths results
521    /// - Initial calls may return empty results while buffer fills
522    ///
523    /// # Arguments
524    ///
525    /// * `features` - Vector of input features
526    /// * `debug` - Optional debug flag to enable detailed output for this inference
527    ///
528    /// # Returns
529    ///
530    /// Returns `Result<InferenceResponse, EimError>` containing inference results
531    pub fn infer(
532        &mut self,
533        features: Vec<f32>,
534        debug: Option<bool>,
535    ) -> Result<InferenceResponse, EimError> {
536        // Initialize model info if needed
537        if self.model_info.is_none() {
538            self.send_hello()?;
539        }
540
541        let uses_continuous_mode = self.requires_continuous_mode();
542
543        if uses_continuous_mode {
544            self.infer_continuous_internal(features, debug)
545        } else {
546            self.infer_single(features, debug)
547        }
548    }
549
550    fn infer_continuous_internal(
551        &mut self,
552        features: Vec<f32>,
553        debug: Option<bool>,
554    ) -> Result<InferenceResponse, EimError> {
555        // Initialize continuous state if needed
556        if self.continuous_state.is_none() {
557            let labels = self
558                .model_info
559                .as_ref()
560                .map(|info| info.model_parameters.labels.clone())
561                .unwrap_or_default();
562            let slice_size = self.input_size()?;
563
564            self.continuous_state = Some(ContinuousState::new(labels, slice_size));
565        }
566
567        // Take ownership of state temporarily to avoid multiple mutable borrows
568        let mut state = self.continuous_state.take().unwrap();
569        state.update_features(&features);
570
571        let response = if !state.feature_buffer_full {
572            // Return empty response while building up the buffer
573            Ok(InferenceResponse {
574                success: true,
575                id: self.next_message_id(),
576                result: InferenceResult::Classification {
577                    classification: HashMap::new(),
578                },
579            })
580        } else {
581            // Run inference on the full buffer
582            let mut response = self.infer_single(state.feature_matrix.clone(), debug)?;
583
584            // Apply moving average filter to the results
585            if let InferenceResult::Classification {
586                ref mut classification,
587            } = response.result
588            {
589                state.apply_maf(classification);
590            }
591
592            Ok(response)
593        };
594
595        // Restore the state
596        self.continuous_state = Some(state);
597
598        response
599    }
600
601    fn infer_single(
602        &mut self,
603        features: Vec<f32>,
604        debug: Option<bool>,
605    ) -> Result<InferenceResponse, EimError> {
606        // First ensure we've sent the hello message and received model info
607        if self.model_info.is_none() {
608            self.debug_message("No model info, sending hello message...");
609            self.send_hello()?;
610            self.debug_message("Hello handshake completed");
611        }
612
613        let msg = ClassifyMessage {
614            classify: features.clone(),
615            id: self.next_message_id(),
616            debug,
617        };
618
619        let msg_str = serde_json::to_string(&msg)?;
620        self.debug_message(&format!(
621            "Sending inference message with {} features",
622            features.len()
623        ));
624
625        writeln!(self.socket, "{}", msg_str).map_err(|e| {
626            self.debug_message(&format!("Failed to send inference message: {}", e));
627            EimError::SocketError(format!("Failed to send inference message: {}", e))
628        })?;
629
630        self.socket.flush().map_err(|e| {
631            self.debug_message(&format!("Failed to flush inference message: {}", e));
632            EimError::SocketError(format!("Failed to flush socket: {}", e))
633        })?;
634
635        self.debug_message("Inference message sent, waiting for response...");
636
637        // Set socket to non-blocking mode
638        self.socket.set_nonblocking(true).map_err(|e| {
639            self.debug_message(&format!("Failed to set non-blocking mode: {}", e));
640            EimError::SocketError(format!("Failed to set non-blocking mode: {}", e))
641        })?;
642
643        let mut reader = BufReader::new(&self.socket);
644        let mut buffer = String::new();
645        let start = Instant::now();
646        let timeout = Duration::from_secs(5);
647
648        while start.elapsed() < timeout {
649            match reader.read_line(&mut buffer) {
650                Ok(0) => {
651                    self.debug_message("EOF reached");
652                    break;
653                }
654                Ok(n) => {
655                    // Skip printing feature values in the response
656                    if !buffer.contains("features:") && !buffer.contains("Features (") {
657                        self.debug_message(&format!("Read {} bytes: {}", n, buffer));
658                    }
659
660                    if let Ok(response) = serde_json::from_str::<InferenceResponse>(&buffer) {
661                        if response.success {
662                            self.debug_message("Got successful inference response");
663                            // Reset to blocking mode before returning
664                            let _ = self.socket.set_nonblocking(false);
665                            return Ok(response);
666                        }
667                    }
668                    buffer.clear();
669                }
670                Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
671                    // No data available yet, sleep briefly and retry
672                    std::thread::sleep(Duration::from_millis(10));
673                    continue;
674                }
675                Err(e) => {
676                    self.debug_message(&format!("Read error: {}", e));
677                    // Always try to reset blocking mode, even on error
678                    let _ = self.socket.set_nonblocking(false);
679                    return Err(EimError::SocketError(format!("Read error: {}", e)));
680                }
681            }
682        }
683
684        // Reset to blocking mode before returning
685        let _ = self.socket.set_nonblocking(false);
686        self.debug_message("Timeout reached");
687
688        Err(EimError::ExecutionError(format!(
689            "No valid response received within {} seconds",
690            timeout.as_secs()
691        )))
692    }
693
694    /// Check if model requires continuous mode
695    fn requires_continuous_mode(&self) -> bool {
696        self.model_info
697            .as_ref()
698            .map(|info| info.model_parameters.use_continuous_mode)
699            .unwrap_or(false)
700    }
701
702    /// Get the required number of input features for this model
703    ///
704    /// Returns the number of features expected by the model for each classification.
705    /// This is useful for:
706    /// - Validating input size before classification
707    /// - Preparing the correct amount of data
708    /// - Padding or truncating inputs to match model requirements
709    ///
710    /// # Returns
711    ///
712    /// The number of input features required by the model
713    pub fn input_size(&self) -> Result<usize, EimError> {
714        self.model_info
715            .as_ref()
716            .map(|info| info.model_parameters.input_features_count as usize)
717            .ok_or_else(|| EimError::ExecutionError("Model info not available".to_string()))
718    }
719
720    /// Set a threshold for a specific learning block
721    ///
722    /// This method allows updating thresholds for different types of blocks:
723    /// - Anomaly detection (GMM)
724    /// - Object detection
725    /// - Object tracking
726    ///
727    /// # Arguments
728    ///
729    /// * `threshold` - The threshold configuration to set
730    ///
731    /// # Returns
732    ///
733    /// Returns `Result<(), EimError>` indicating success or failure
734    pub async fn set_learn_block_threshold(
735        &mut self,
736        threshold: ThresholdConfig,
737    ) -> Result<(), EimError> {
738        // First check if model info is available and supports thresholds
739        if self.model_info.is_none() {
740            self.debug_message("No model info available, sending hello message...");
741            self.send_hello()?;
742        }
743
744        // Log the current model state
745        if let Some(info) = &self.model_info {
746            self.debug_message(&format!(
747                "Current model type: {}",
748                info.model_parameters.model_type
749            ));
750            self.debug_message(&format!(
751                "Current model parameters: {:?}",
752                info.model_parameters
753            ));
754        }
755
756        let msg = SetThresholdMessage {
757            set_threshold: threshold,
758            id: self.next_message_id(),
759        };
760
761        let msg_str = serde_json::to_string(&msg)?;
762        self.debug_message(&format!("Sending threshold message: {}", msg_str));
763
764        writeln!(self.socket, "{}", msg_str).map_err(|e| {
765            self.debug_message(&format!("Failed to send threshold message: {}", e));
766            EimError::SocketError(format!("Failed to send threshold message: {}", e))
767        })?;
768
769        self.socket.flush().map_err(|e| {
770            self.debug_message(&format!("Failed to flush threshold message: {}", e));
771            EimError::SocketError(format!("Failed to flush socket: {}", e))
772        })?;
773
774        let mut reader = BufReader::new(&self.socket);
775        let mut line = String::new();
776
777        match reader.read_line(&mut line) {
778            Ok(_) => {
779                self.debug_message(&format!("Received response: {}", line));
780                match serde_json::from_str::<SetThresholdResponse>(&line) {
781                    Ok(response) => {
782                        if response.success {
783                            self.debug_message("Successfully set threshold");
784                            Ok(())
785                        } else {
786                            self.debug_message("Server reported failure setting threshold");
787                            Err(EimError::ExecutionError(
788                                "Server reported failure setting threshold".to_string(),
789                            ))
790                        }
791                    }
792                    Err(e) => {
793                        self.debug_message(&format!("Failed to parse threshold response: {}", e));
794                        // Try to parse as error response
795                        if let Ok(error) = serde_json::from_str::<ErrorResponse>(&line) {
796                            Err(EimError::ExecutionError(
797                                error.error.unwrap_or_else(|| "Unknown error".to_string()),
798                            ))
799                        } else {
800                            Err(EimError::ExecutionError(format!(
801                                "Invalid threshold response format: {}",
802                                e
803                            )))
804                        }
805                    }
806                }
807            }
808            Err(e) => {
809                self.debug_message(&format!("Failed to read threshold response: {}", e));
810                Err(EimError::SocketError(format!(
811                    "Failed to read response: {}",
812                    e
813                )))
814            }
815        }
816    }
817
818    /// Get the minimum anomaly score threshold from model parameters
819    fn get_min_anomaly_score(&self) -> f32 {
820        self.model_info
821            .as_ref()
822            .and_then(|info| {
823                info.model_parameters
824                    .thresholds
825                    .iter()
826                    .find_map(|t| match t {
827                        ModelThreshold::AnomalyGMM {
828                            min_anomaly_score, ..
829                        } => Some(*min_anomaly_score),
830                        _ => None,
831                    })
832            })
833            .unwrap_or(6.0)
834    }
835
836    /// Normalize an anomaly score relative to the model's minimum threshold
837    fn normalize_anomaly_score(&self, score: f32) -> f32 {
838        (score / self.get_min_anomaly_score()).min(1.0)
839    }
840
841    /// Normalize a visual anomaly result
842    pub fn normalize_visual_anomaly(
843        &self,
844        anomaly: f32,
845        max: f32,
846        mean: f32,
847        regions: &[(f32, u32, u32, u32, u32)],
848    ) -> VisualAnomalyResult {
849        let normalized_anomaly = self.normalize_anomaly_score(anomaly);
850        let normalized_max = self.normalize_anomaly_score(max);
851        let normalized_mean = self.normalize_anomaly_score(mean);
852        let normalized_regions: Vec<_> = regions
853            .iter()
854            .map(|(value, x, y, w, h)| (self.normalize_anomaly_score(*value), *x, *y, *w, *h))
855            .collect();
856
857        (
858            normalized_anomaly,
859            normalized_max,
860            normalized_mean,
861            normalized_regions,
862        )
863    }
864}