sipp-rs 0.1.0

Unified Rust library for extensible Sipp inference
use std::time::Duration;

use crate::engine::protocol::EngineEvent;
use crate::error::{Error, Result};
use crate::runtime::request::{GenerateResponse, GenerateResponseStatus, ResponseOutput};
use crate::runtime::RequestStepResult;

use super::super::events::emit_event;
use super::super::runtime_command;
use super::{ActiveRequestOutput, EngineThreadState};

const REQUEST_CANCELLED_FALLBACK: &str = "request cancelled";
const REQUEST_FAILED_FALLBACK: &str = "request failed";
const REQUEST_PENDING_RESPONSE: &str = "request returned pending response";
const ENGINE_CLOSED: &str = "engine closed";

fn cancel_and_consume_request(runtime: &mut crate::runtime::InferenceRuntime, request_id: u32) {
    let _ = runtime.cancel_request(request_id);

    loop {
        if runtime
            .request_queue
            .completed_responses
            .contains_key(&request_id)
        {
            let _ = runtime.take_completed_response(request_id);
            return;
        }
        if !runtime.request_queue.contains_request(request_id) {
            return;
        }

        let burst = runtime.run_scheduler_loop(256, 1, 0, Duration::ZERO);

        if matches!(
            burst.status,
            RequestStepResult::Invalid | RequestStepResult::FatalNoProgress
        ) {
            return;
        }
        if burst.status == RequestStepResult::Waiting
            && !runtime
                .request_queue
                .completed_responses
                .contains_key(&request_id)
        {
            return;
        }
    }
}

trait EmptyStringFallback {
    fn if_empty(self, fallback: &'static str) -> String;
}

impl EmptyStringFallback for String {
    fn if_empty(self, fallback: &'static str) -> String {
        if self.is_empty() {
            fallback.to_string()
        } else {
            self
        }
    }
}

impl EngineThreadState {
    pub(super) fn complete_finished_requests(&mut self) {
        let Some(runtime) = self.runtime.as_mut() else {
            return;
        };
        let completed: Vec<_> = self
            .active_requests
            .keys()
            .copied()
            .filter_map(|request_id| {
                runtime
                    .take_completed_response(request_id)
                    .map(|response| (request_id, response))
            })
            .collect();

        for (request_id, response) in completed {
            self.close_token_emission(request_id);
            if let Some(runtime) = self.runtime.as_mut() {
                runtime
                    .request_queue
                    .token_emission_sinks
                    .remove(&request_id);
            }

            let result = match response.status {
                GenerateResponseStatus::Completed => {
                    match self.validate_completed_response(request_id, &response) {
                        Ok(()) => {
                            emit_event(
                                &self.event_subscribers,
                                EngineEvent::RequestCompleted {
                                    request_id: request_id.to_string(),
                                },
                            );
                            Ok(response)
                        }
                        Err(error) => {
                            self.emit_request_failed(request_id, error.to_string());
                            Err(error)
                        }
                    }
                }
                GenerateResponseStatus::Cancelled => self.failed_completion_result(
                    request_id,
                    response.error_message,
                    REQUEST_CANCELLED_FALLBACK,
                ),
                GenerateResponseStatus::Failed => self.failed_completion_result(
                    request_id,
                    response.error_message,
                    REQUEST_FAILED_FALLBACK,
                ),
                GenerateResponseStatus::Pending => Err(runtime_command(REQUEST_PENDING_RESPONSE)),
            };

            self.send_active_response(request_id, result);
        }
    }

    pub(super) fn fail_all_active_requests(&mut self, error_msg: String) {
        for request_id in self.active_request_ids() {
            self.cancel_and_cleanup_request(request_id);
            self.send_active_response(request_id, Err(runtime_command(error_msg.clone())));
            self.emit_request_failed(request_id, error_msg.clone());
        }
    }

    pub(super) fn close_active_requests(&mut self) {
        for request_id in self.active_request_ids() {
            self.cancel_and_cleanup_request(request_id);
            self.send_active_response(request_id, Err(runtime_command(ENGINE_CLOSED)));
        }
    }

    fn active_request_ids(&self) -> Vec<u32> {
        self.active_requests.keys().copied().collect()
    }

    pub(super) fn cancel_and_cleanup_request(&mut self, request_id: u32) {
        if let Some(runtime) = self.runtime.as_mut() {
            cancel_and_consume_request(runtime, request_id);
            runtime
                .request_queue
                .token_emission_sinks
                .remove(&request_id);
        }
        self.close_token_emission(request_id);
    }

    fn close_token_emission(&mut self, request_id: u32) {
        if let Some(active) = self.active_requests.get_mut(&request_id) {
            if let Some(token) = active.token.take() {
                token.producer.close();
            }
        }
    }

    fn send_active_response(&mut self, request_id: u32, response: Result<GenerateResponse>) {
        if let Some(active) = self.active_requests.remove(&request_id) {
            let _ = active.response_tx.send(response);
        }
    }

    fn validate_completed_response(
        &self,
        request_id: u32,
        response: &GenerateResponse,
    ) -> Result<()> {
        let expected = self
            .active_requests
            .get(&request_id)
            .map(|request| request.output)
            .ok_or_else(|| runtime_command("completed request is not active"))?;

        match (expected, &response.output) {
            (ActiveRequestOutput::Text, ResponseOutput::Text(_))
            | (ActiveRequestOutput::Embedding, ResponseOutput::Embedding { .. }) => Ok(()),
            (ActiveRequestOutput::Text, ResponseOutput::Embedding { .. }) => {
                Err(Error::RuntimeCommand(
                    "generation request completed with embedding output".to_string(),
                ))
            }
            (ActiveRequestOutput::Embedding, ResponseOutput::Text(_)) => Err(
                Error::RuntimeCommand("embedding request completed with text output".to_string()),
            ),
        }
    }

    fn failed_completion_result(
        &self,
        request_id: u32,
        error_message: String,
        fallback: &'static str,
    ) -> Result<GenerateResponse> {
        let message = error_message.if_empty(fallback);
        self.emit_request_failed(request_id, message.clone());
        Err(runtime_command(message))
    }

    pub(super) fn emit_request_failed(&self, request_id: u32, error: String) {
        emit_event(
            &self.event_subscribers,
            EngineEvent::RequestFailed {
                request_id: request_id.to_string(),
                error,
            },
        );
    }
}

#[cfg(test)]
#[path = "../../../tests/engine/driver/thread_loop/completion_tests.rs"]
mod completion_tests;