use std::error::Error;
use anyhow::Result;
use axum::response::Sse;
use mistralrs_core::{DrySamplingParams, MistralRs, StopTokens as InternalStopTokens};
use crate::{openai::StopTokens, types::SharedMistralRsState, util::sanitize_error_message};
#[derive(Debug)]
pub enum BaseCompletionResponder<R, S> {
Sse(Sse<S>),
Json(R),
ModelError(String, R),
InternalError(Box<dyn Error>),
ValidationError(Box<dyn Error>),
}
pub(crate) fn handle_completion_error<R, S>(
state: SharedMistralRsState,
e: Box<dyn std::error::Error + Send + Sync + 'static>,
) -> BaseCompletionResponder<R, S> {
let full_error = anyhow::Error::msg(e.to_string());
MistralRs::maybe_log_error(state, &*full_error);
let sanitized_msg = sanitize_error_message(&*e);
let sanitized_error = anyhow::Error::msg(sanitized_msg);
BaseCompletionResponder::InternalError(sanitized_error.into())
}
pub(crate) fn convert_stop_tokens(stop_seqs: Option<StopTokens>) -> Option<InternalStopTokens> {
match stop_seqs {
Some(StopTokens::Multi(sequences)) => Some(InternalStopTokens::Seqs(sequences)),
Some(StopTokens::Single(sequence)) => Some(InternalStopTokens::Seqs(vec![sequence])),
None => None,
}
}
pub(crate) fn get_dry_sampling_params(
dry_multiplier: Option<f32>,
dry_sequence_breakers: Option<Vec<String>>,
dry_base: Option<f32>,
dry_allowed_length: Option<usize>,
) -> Result<Option<DrySamplingParams>> {
match dry_multiplier {
Some(multiplier) => {
let params = DrySamplingParams::new_with_defaults(
multiplier,
dry_sequence_breakers,
dry_base,
dry_allowed_length,
)?;
Ok(Some(params))
}
None => Ok(None),
}
}