use crate::embeddings::config::OptimizationConfig;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum MistralModel {
#[default]
MistralEmbed,
CodestralEmbed,
}
impl MistralModel {
#[must_use]
pub fn model_name(self) -> &'static str {
match self {
Self::MistralEmbed => "mistral-embed",
Self::CodestralEmbed => "codestral-embed",
}
}
#[must_use]
pub fn default_dimension(self) -> usize {
match self {
Self::MistralEmbed => 1024,
Self::CodestralEmbed => 1536,
}
}
#[must_use]
pub fn max_dimension(self) -> usize {
match self {
Self::MistralEmbed => 1024,
Self::CodestralEmbed => 3072,
}
}
#[must_use]
pub fn supports_output_dimension(self) -> bool {
matches!(self, Self::CodestralEmbed)
}
#[must_use]
pub fn supports_output_dtype(self) -> bool {
matches!(self, Self::CodestralEmbed)
}
#[must_use]
pub fn default_output_dtype(self) -> OutputDtype {
OutputDtype::Float
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum OutputDtype {
#[default]
Float,
Int8,
Uint8,
Binary,
Ubinary,
}
impl OutputDtype {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Float => "float",
Self::Int8 => "int8",
Self::Uint8 => "uint8",
Self::Binary => "binary",
Self::Ubinary => "ubinary",
}
}
#[must_use]
pub fn is_bit_packed(self) -> bool {
matches!(self, Self::Binary | Self::Ubinary)
}
#[must_use]
pub fn response_size(self, output_dimension: usize) -> usize {
if self.is_bit_packed() {
(output_dimension + 7) / 8
} else {
output_dimension
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MistralConfig {
#[serde(default)]
pub model: MistralModel,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_dimension: Option<usize>,
#[serde(default)]
pub output_dtype: OutputDtype,
#[serde(default = "default_mistral_base_url")]
pub base_url: String,
#[serde(default)]
pub optimization: OptimizationConfig,
}
fn default_mistral_base_url() -> String {
"https://api.mistral.ai/v1".to_string()
}
impl MistralConfig {
#[must_use]
pub fn new(model: MistralModel) -> Self {
Self {
model,
output_dimension: None,
output_dtype: OutputDtype::Float,
base_url: default_mistral_base_url(),
optimization: OptimizationConfig::mistral(),
}
}
#[must_use]
pub fn mistral_embed() -> Self {
Self::new(MistralModel::MistralEmbed)
}
#[must_use]
pub fn codestral_embed() -> Self {
Self::new(MistralModel::CodestralEmbed)
}
#[must_use]
pub fn with_output_dimension(mut self, dimension: usize) -> Self {
self.output_dimension = Some(dimension);
self
}
#[must_use]
pub fn with_output_dtype(mut self, dtype: OutputDtype) -> Self {
self.output_dtype = dtype;
self
}
#[must_use]
pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
#[must_use]
pub fn effective_dimension(&self) -> usize {
self.output_dimension
.unwrap_or_else(|| self.model.default_dimension())
}
#[must_use]
pub fn expected_response_size(&self) -> usize {
self.output_dtype.response_size(self.effective_dimension())
}
#[must_use]
pub fn embeddings_url(&self) -> String {
format!("{}/embeddings", self.base_url.trim_end_matches('/'))
}
pub fn validate(&self) -> anyhow::Result<()> {
if let Some(dims) = self.output_dimension {
if !self.model.supports_output_dimension() {
anyhow::bail!(
"Model {:?} does not support custom output_dimension",
self.model
);
}
if dims == 0 || dims > self.model.max_dimension() {
anyhow::bail!(
"output_dimension must be between 1 and {}, got {}",
self.model.max_dimension(),
dims
);
}
}
if self.output_dtype != OutputDtype::Float && !self.model.supports_output_dtype() {
anyhow::bail!(
"Model {:?} does not support custom output_dtype",
self.model
);
}
Ok(())
}
#[must_use]
pub fn codestral_binary() -> Self {
Self::codestral_embed().with_output_dtype(OutputDtype::Binary)
}
#[must_use]
pub fn codestral_compact(dimension: usize) -> Self {
Self::codestral_embed().with_output_dimension(dimension)
}
}
impl Default for MistralConfig {
fn default() -> Self {
Self::mistral_embed()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mistral_model_properties() {
assert_eq!(MistralModel::MistralEmbed.default_dimension(), 1024);
assert_eq!(MistralModel::CodestralEmbed.default_dimension(), 1536);
assert_eq!(MistralModel::MistralEmbed.max_dimension(), 1024);
assert_eq!(MistralModel::CodestralEmbed.max_dimension(), 3072);
assert!(!MistralModel::MistralEmbed.supports_output_dimension());
assert!(MistralModel::CodestralEmbed.supports_output_dimension());
assert!(!MistralModel::MistralEmbed.supports_output_dtype());
assert!(MistralModel::CodestralEmbed.supports_output_dtype());
}
#[test]
fn test_output_dtype_properties() {
assert!(!OutputDtype::Float.is_bit_packed());
assert!(!OutputDtype::Int8.is_bit_packed());
assert!(!OutputDtype::Uint8.is_bit_packed());
assert!(OutputDtype::Binary.is_bit_packed());
assert!(OutputDtype::Ubinary.is_bit_packed());
assert_eq!(OutputDtype::Float.response_size(1024), 1024);
assert_eq!(OutputDtype::Binary.response_size(1024), 128); assert_eq!(OutputDtype::Binary.response_size(1000), 125); }
#[test]
fn test_mistral_config_builder() {
let config = MistralConfig::codestral_embed()
.with_output_dimension(512)
.with_output_dtype(OutputDtype::Int8);
assert_eq!(config.model, MistralModel::CodestralEmbed);
assert_eq!(config.output_dimension, Some(512));
assert_eq!(config.output_dtype, OutputDtype::Int8);
assert_eq!(config.effective_dimension(), 512);
assert_eq!(config.expected_response_size(), 512);
}
#[test]
fn test_codestral_binary_config() {
let config = MistralConfig::codestral_binary();
assert_eq!(config.output_dtype, OutputDtype::Binary);
assert_eq!(config.effective_dimension(), 1536);
assert_eq!(config.expected_response_size(), 192); }
#[test]
fn test_mistral_config_validation() {
let valid = MistralConfig::codestral_embed().with_output_dimension(512);
assert!(valid.validate().is_ok());
let invalid = MistralConfig::mistral_embed().with_output_dimension(512);
assert!(invalid.validate().is_err());
let invalid_dim = MistralConfig::codestral_embed().with_output_dimension(4000);
assert!(invalid_dim.validate().is_err());
}
#[test]
fn test_mistral_config_serialization() {
let config = MistralConfig::codestral_embed()
.with_output_dimension(512)
.with_output_dtype(OutputDtype::Int8);
let json = serde_json::to_string(&config).unwrap();
let deserialized: MistralConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config.model, deserialized.model);
assert_eq!(config.output_dimension, deserialized.output_dimension);
assert_eq!(config.output_dtype, deserialized.output_dtype);
}
}